fix: use config key to detect provider, prevent api_base misidentifying as vLLM

This commit is contained in:
Re-bin 2026-02-08 19:31:25 +00:00
parent 2931694eb8
commit eb2fbf80da
5 changed files with 72 additions and 44 deletions

View File

@ -16,7 +16,7 @@
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. ⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
📏 Real-time line count: **3,437 lines** (run `bash core_agent_lines.sh` to verify anytime) 📏 Real-time line count: **3,448 lines** (run `bash core_agent_lines.sh` to verify anytime)
## 📢 News ## 📢 News

View File

@ -263,6 +263,7 @@ def _make_provider(config):
api_base=config.get_api_base(), api_base=config.get_api_base(),
default_model=model, default_model=model,
extra_headers=p.extra_headers if p else None, extra_headers=p.extra_headers if p else None,
provider_name=config.get_provider_name(),
) )

View File

@ -134,8 +134,8 @@ class Config(BaseSettings):
"""Get expanded workspace path.""" """Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser() return Path(self.agents.defaults.workspace).expanduser()
def get_provider(self, model: str | None = None) -> ProviderConfig | None: def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available.""" """Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import PROVIDERS
model_lower = (model or self.agents.defaults.model).lower() model_lower = (model or self.agents.defaults.model).lower()
@ -143,14 +143,24 @@ class Config(BaseSettings):
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key: if p and any(kw in model_lower for kw in spec.keywords) and p.api_key:
return p return p, spec.name
# Fallback: gateways first, then others (follows registry order) # Fallback: gateways first, then others (follows registry order)
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and p.api_key: if p and p.api_key:
return p return p, spec.name
return None return None, None
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
p, _ = self._match_provider(model)
return p
def get_provider_name(self, model: str | None = None) -> str | None:
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
_, name = self._match_provider(model)
return name
def get_api_key(self, model: str | None = None) -> str | None: def get_api_key(self, model: str | None = None) -> str | None:
"""Get API key for the given model. Falls back to first available key.""" """Get API key for the given model. Falls back to first available key."""
@ -159,15 +169,16 @@ class Config(BaseSettings):
def get_api_base(self, model: str | None = None) -> str | None: def get_api_base(self, model: str | None = None) -> str | None:
"""Get API base URL for the given model. Applies default URLs for known gateways.""" """Get API base URL for the given model. Applies default URLs for known gateways."""
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import find_by_name
p = self.get_provider(model) p, name = self._match_provider(model)
if p and p.api_base: if p and p.api_base:
return p.api_base return p.api_base
# Only gateways get a default URL here. Standard providers (like Moonshot) # Only gateways get a default api_base here. Standard providers
# handle their base URL via env vars in _setup_env, NOT via api_base — # (like Moonshot) set their base URL via env vars in _setup_env
# otherwise find_gateway() would misdetect them as local/vLLM. # to avoid polluting the global litellm.api_base.
for spec in PROVIDERS: if name:
if spec.is_gateway and spec.default_api_base and p == getattr(self.providers, spec.name, None): spec = find_by_name(name)
if spec and spec.is_gateway and spec.default_api_base:
return spec.default_api_base return spec.default_api_base
return None return None

View File

@ -26,18 +26,16 @@ class LiteLLMProvider(LLMProvider):
api_base: str | None = None, api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5", default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None,
provider_name: str | None = None,
): ):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
self.extra_headers = extra_headers or {} self.extra_headers = extra_headers or {}
# Detect gateway / local deployment from api_key and api_base # Detect gateway / local deployment.
self._gateway = find_gateway(api_key, api_base) # provider_name (from config key) is the primary signal;
# api_key / api_base are fallback for auto-detection.
# Backwards-compatible flags (used by tests and possibly external code) self._gateway = find_gateway(provider_name, api_key, api_base)
self.is_openrouter = bool(self._gateway and self._gateway.name == "openrouter")
self.is_aihubmix = bool(self._gateway and self._gateway.name == "aihubmix")
self.is_vllm = bool(self._gateway and self._gateway.is_local)
# Configure environment variables # Configure environment variables
if api_key: if api_key:
@ -51,23 +49,24 @@ class LiteLLMProvider(LLMProvider):
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
"""Set environment variables based on detected provider.""" """Set environment variables based on detected provider."""
if self._gateway: spec = self._gateway or find_by_model(model)
# Gateway / local: direct set (not setdefault) if not spec:
os.environ[self._gateway.env_key] = api_key
return return
# Standard provider: match by model name # Gateway/local overrides existing env; standard provider doesn't
spec = find_by_model(model) if self._gateway:
if spec: os.environ[spec.env_key] = api_key
else:
os.environ.setdefault(spec.env_key, api_key) os.environ.setdefault(spec.env_key, api_key)
# Resolve env_extras placeholders:
# {api_key} → user's API key # Resolve env_extras placeholders:
# {api_base} → user's api_base, falling back to spec.default_api_base # {api_key} → user's API key
effective_base = api_base or spec.default_api_base # {api_base} → user's api_base, falling back to spec.default_api_base
for env_name, env_val in spec.env_extras: effective_base = api_base or spec.default_api_base
resolved = env_val.replace("{api_key}", api_key) for env_name, env_val in spec.env_extras:
resolved = resolved.replace("{api_base}", effective_base) resolved = env_val.replace("{api_key}", api_key)
os.environ.setdefault(env_name, resolved) resolved = resolved.replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
def _resolve_model(self, model: str) -> str: def _resolve_model(self, model: str) -> str:
"""Resolve model name by applying provider/gateway prefixes.""" """Resolve model name by applying provider/gateway prefixes."""
@ -131,7 +130,7 @@ class LiteLLMProvider(LLMProvider):
# Apply model-specific overrides (e.g. kimi-k2.5 temperature) # Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs) self._apply_model_overrides(model, kwargs)
# Pass api_base directly for custom endpoints (vLLM, etc.) # Pass api_base for custom endpoints
if self.api_base: if self.api_base:
kwargs["api_base"] = self.api_base kwargs["api_base"] = self.api_base

View File

@ -241,11 +241,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
), ),
), ),
# === Local deployment (fallback: unknown api_base → assume local) ====== # === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server. # vLLM / any OpenAI-compatible local server.
# If api_base is set but doesn't match a known gateway, we land here. # Detected when config key is "vllm" (provider_name="vllm").
# Placed before Groq so vLLM wins the fallback when both are configured.
ProviderSpec( ProviderSpec(
name="vllm", name="vllm",
keywords=("vllm",), keywords=("vllm",),
@ -302,16 +301,34 @@ def find_by_model(model: str) -> ProviderSpec | None:
return None return None
def find_gateway(api_key: str | None, api_base: str | None) -> ProviderSpec | None: def find_gateway(
"""Detect gateway/local by api_key prefix or api_base substring. provider_name: str | None = None,
Fallback: unknown api_base treat as local (vLLM).""" api_key: str | None = None,
api_base: str | None = None,
) -> ProviderSpec | None:
"""Detect gateway/local provider.
Priority:
1. provider_name if it maps to a gateway/local spec, use it directly.
2. api_key prefix e.g. "sk-or-" OpenRouter.
3. api_base keyword e.g. "aihubmix" in URL AiHubMix.
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
will NOT be mistaken for vLLM the old fallback is gone.
"""
# 1. Direct match by config key
if provider_name:
spec = find_by_name(provider_name)
if spec and (spec.is_gateway or spec.is_local):
return spec
# 2. Auto-detect by api_key prefix / api_base keyword
for spec in PROVIDERS: for spec in PROVIDERS:
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix): if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
return spec return spec
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base: if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
return spec return spec
if api_base:
return next((s for s in PROVIDERS if s.is_local), None)
return None return None