fix: use config key to detect provider, prevent api_base misidentifying as vLLM
This commit is contained in:
parent
2931694eb8
commit
eb2fbf80da
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,437 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,448 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
|||||||
@ -263,6 +263,7 @@ def _make_provider(config):
|
|||||||
api_base=config.get_api_base(),
|
api_base=config.get_api_base(),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
|
provider_name=config.get_provider_name(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -134,8 +134,8 @@ class Config(BaseSettings):
|
|||||||
"""Get expanded workspace path."""
|
"""Get expanded workspace path."""
|
||||||
return Path(self.agents.defaults.workspace).expanduser()
|
return Path(self.agents.defaults.workspace).expanduser()
|
||||||
|
|
||||||
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
|
||||||
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS
|
||||||
model_lower = (model or self.agents.defaults.model).lower()
|
model_lower = (model or self.agents.defaults.model).lower()
|
||||||
|
|
||||||
@ -143,14 +143,24 @@ class Config(BaseSettings):
|
|||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key:
|
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key:
|
||||||
return p
|
return p, spec.name
|
||||||
|
|
||||||
# Fallback: gateways first, then others (follows registry order)
|
# Fallback: gateways first, then others (follows registry order)
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and p.api_key:
|
if p and p.api_key:
|
||||||
return p
|
return p, spec.name
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
|
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
||||||
|
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
|
||||||
|
p, _ = self._match_provider(model)
|
||||||
|
return p
|
||||||
|
|
||||||
|
def get_provider_name(self, model: str | None = None) -> str | None:
|
||||||
|
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
|
||||||
|
_, name = self._match_provider(model)
|
||||||
|
return name
|
||||||
|
|
||||||
def get_api_key(self, model: str | None = None) -> str | None:
|
def get_api_key(self, model: str | None = None) -> str | None:
|
||||||
"""Get API key for the given model. Falls back to first available key."""
|
"""Get API key for the given model. Falls back to first available key."""
|
||||||
@ -159,15 +169,16 @@ class Config(BaseSettings):
|
|||||||
|
|
||||||
def get_api_base(self, model: str | None = None) -> str | None:
|
def get_api_base(self, model: str | None = None) -> str | None:
|
||||||
"""Get API base URL for the given model. Applies default URLs for known gateways."""
|
"""Get API base URL for the given model. Applies default URLs for known gateways."""
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import find_by_name
|
||||||
p = self.get_provider(model)
|
p, name = self._match_provider(model)
|
||||||
if p and p.api_base:
|
if p and p.api_base:
|
||||||
return p.api_base
|
return p.api_base
|
||||||
# Only gateways get a default URL here. Standard providers (like Moonshot)
|
# Only gateways get a default api_base here. Standard providers
|
||||||
# handle their base URL via env vars in _setup_env, NOT via api_base —
|
# (like Moonshot) set their base URL via env vars in _setup_env
|
||||||
# otherwise find_gateway() would misdetect them as local/vLLM.
|
# to avoid polluting the global litellm.api_base.
|
||||||
for spec in PROVIDERS:
|
if name:
|
||||||
if spec.is_gateway and spec.default_api_base and p == getattr(self.providers, spec.name, None):
|
spec = find_by_name(name)
|
||||||
|
if spec and spec.is_gateway and spec.default_api_base:
|
||||||
return spec.default_api_base
|
return spec.default_api_base
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -26,18 +26,16 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
default_model: str = "anthropic/claude-opus-4-5",
|
default_model: str = "anthropic/claude-opus-4-5",
|
||||||
extra_headers: dict[str, str] | None = None,
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
provider_name: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
self.extra_headers = extra_headers or {}
|
self.extra_headers = extra_headers or {}
|
||||||
|
|
||||||
# Detect gateway / local deployment from api_key and api_base
|
# Detect gateway / local deployment.
|
||||||
self._gateway = find_gateway(api_key, api_base)
|
# provider_name (from config key) is the primary signal;
|
||||||
|
# api_key / api_base are fallback for auto-detection.
|
||||||
# Backwards-compatible flags (used by tests and possibly external code)
|
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||||
self.is_openrouter = bool(self._gateway and self._gateway.name == "openrouter")
|
|
||||||
self.is_aihubmix = bool(self._gateway and self._gateway.name == "aihubmix")
|
|
||||||
self.is_vllm = bool(self._gateway and self._gateway.is_local)
|
|
||||||
|
|
||||||
# Configure environment variables
|
# Configure environment variables
|
||||||
if api_key:
|
if api_key:
|
||||||
@ -51,23 +49,24 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||||
"""Set environment variables based on detected provider."""
|
"""Set environment variables based on detected provider."""
|
||||||
if self._gateway:
|
spec = self._gateway or find_by_model(model)
|
||||||
# Gateway / local: direct set (not setdefault)
|
if not spec:
|
||||||
os.environ[self._gateway.env_key] = api_key
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Standard provider: match by model name
|
# Gateway/local overrides existing env; standard provider doesn't
|
||||||
spec = find_by_model(model)
|
if self._gateway:
|
||||||
if spec:
|
os.environ[spec.env_key] = api_key
|
||||||
|
else:
|
||||||
os.environ.setdefault(spec.env_key, api_key)
|
os.environ.setdefault(spec.env_key, api_key)
|
||||||
# Resolve env_extras placeholders:
|
|
||||||
# {api_key} → user's API key
|
# Resolve env_extras placeholders:
|
||||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
# {api_key} → user's API key
|
||||||
effective_base = api_base or spec.default_api_base
|
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||||
for env_name, env_val in spec.env_extras:
|
effective_base = api_base or spec.default_api_base
|
||||||
resolved = env_val.replace("{api_key}", api_key)
|
for env_name, env_val in spec.env_extras:
|
||||||
resolved = resolved.replace("{api_base}", effective_base)
|
resolved = env_val.replace("{api_key}", api_key)
|
||||||
os.environ.setdefault(env_name, resolved)
|
resolved = resolved.replace("{api_base}", effective_base)
|
||||||
|
os.environ.setdefault(env_name, resolved)
|
||||||
|
|
||||||
def _resolve_model(self, model: str) -> str:
|
def _resolve_model(self, model: str) -> str:
|
||||||
"""Resolve model name by applying provider/gateway prefixes."""
|
"""Resolve model name by applying provider/gateway prefixes."""
|
||||||
@ -131,7 +130,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||||
self._apply_model_overrides(model, kwargs)
|
self._apply_model_overrides(model, kwargs)
|
||||||
|
|
||||||
# Pass api_base directly for custom endpoints (vLLM, etc.)
|
# Pass api_base for custom endpoints
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
kwargs["api_base"] = self.api_base
|
kwargs["api_base"] = self.api_base
|
||||||
|
|
||||||
|
|||||||
@ -241,11 +241,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Local deployment (fallback: unknown api_base → assume local) ======
|
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||||
|
|
||||||
# vLLM / any OpenAI-compatible local server.
|
# vLLM / any OpenAI-compatible local server.
|
||||||
# If api_base is set but doesn't match a known gateway, we land here.
|
# Detected when config key is "vllm" (provider_name="vllm").
|
||||||
# Placed before Groq so vLLM wins the fallback when both are configured.
|
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="vllm",
|
name="vllm",
|
||||||
keywords=("vllm",),
|
keywords=("vllm",),
|
||||||
@ -302,16 +301,34 @@ def find_by_model(model: str) -> ProviderSpec | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def find_gateway(api_key: str | None, api_base: str | None) -> ProviderSpec | None:
|
def find_gateway(
|
||||||
"""Detect gateway/local by api_key prefix or api_base substring.
|
provider_name: str | None = None,
|
||||||
Fallback: unknown api_base → treat as local (vLLM)."""
|
api_key: str | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
) -> ProviderSpec | None:
|
||||||
|
"""Detect gateway/local provider.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||||
|
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||||
|
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||||
|
|
||||||
|
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||||
|
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||||
|
"""
|
||||||
|
# 1. Direct match by config key
|
||||||
|
if provider_name:
|
||||||
|
spec = find_by_name(provider_name)
|
||||||
|
if spec and (spec.is_gateway or spec.is_local):
|
||||||
|
return spec
|
||||||
|
|
||||||
|
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||||
return spec
|
return spec
|
||||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||||
return spec
|
return spec
|
||||||
if api_base:
|
|
||||||
return next((s for s in PROVIDERS if s.is_local), None)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user