diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 5e8e46c..353ca4b 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -122,30 +122,57 @@ class Config(BaseSettings): """Get expanded workspace path.""" return Path(self.agents.defaults.workspace).expanduser() - def get_api_key(self) -> str | None: - """Get API key in priority order: OpenRouter > DeepSeek > Anthropic > OpenAI > Gemini > Zhipu > Groq > Moonshot > vLLM.""" - return ( - self.providers.openrouter.api_key or - self.providers.deepseek.api_key or - self.providers.anthropic.api_key or - self.providers.openai.api_key or - self.providers.gemini.api_key or - self.providers.zhipu.api_key or - self.providers.groq.api_key or - self.providers.moonshot.api_key or - self.providers.vllm.api_key or - None - ) + def _match_provider(self, model: str | None = None) -> ProviderConfig | None: + """Match a provider based on model name.""" + model = (model or self.agents.defaults.model).lower() + # Map of keywords to provider configs + providers = { + "openrouter": self.providers.openrouter, + "deepseek": self.providers.deepseek, + "anthropic": self.providers.anthropic, + "claude": self.providers.anthropic, + "openai": self.providers.openai, + "gpt": self.providers.openai, + "gemini": self.providers.gemini, + "zhipu": self.providers.zhipu, + "glm": self.providers.zhipu, + "zai": self.providers.zhipu, + "groq": self.providers.groq, + "moonshot": self.providers.moonshot, + "kimi": self.providers.moonshot, + "vllm": self.providers.vllm, + } + for keyword, provider in providers.items(): + if keyword in model and provider.api_key: + return provider + return None + + def get_api_key(self, model: str | None = None) -> str | None: + """Get API key for the given model (or default model). Falls back to first available key.""" + # Try matching by model name first + matched = self._match_provider(model) + if matched: + return matched.api_key + # Fallback: return first available key + for provider in [ + self.providers.openrouter, self.providers.deepseek, + self.providers.anthropic, self.providers.openai, + self.providers.gemini, self.providers.zhipu, + self.providers.moonshot, self.providers.vllm, + self.providers.groq, + ]: + if provider.api_key: + return provider.api_key + return None - def get_api_base(self) -> str | None: - """Get API base URL if using OpenRouter, Zhipu, Moonshot or vLLM.""" - if self.providers.openrouter.api_key: + def get_api_base(self, model: str | None = None) -> str | None: + """Get API base URL based on model name.""" + model = (model or self.agents.defaults.model).lower() + if "openrouter" in model: return self.providers.openrouter.api_base or "https://openrouter.ai/api/v1" - if self.providers.zhipu.api_key: + if any(k in model for k in ("zhipu", "glm", "zai")): return self.providers.zhipu.api_base - if self.providers.moonshot.api_key: - return self.providers.moonshot.api_base - if self.providers.vllm.api_base: + if "vllm" in model: return self.providers.vllm.api_base return None diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index c2cdda7..2125b15 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -31,15 +31,9 @@ class LiteLLMProvider(LLMProvider): (api_key and api_key.startswith("sk-or-")) or (api_base and "openrouter" in api_base) ) - - # Detect Moonshot by api_base or model name - self.is_moonshot = ( - (api_base and "moonshot" in api_base) or - ("moonshot" in default_model or "kimi" in default_model) - ) - + # Track if using custom endpoint (vLLM, etc.) - self.is_vllm = bool(api_base) and not self.is_openrouter and not self.is_moonshot + self.is_vllm = bool(api_base) and not self.is_openrouter # Configure LiteLLM based on provider if api_key: @@ -63,10 +57,9 @@ class LiteLLMProvider(LLMProvider): os.environ.setdefault("GROQ_API_KEY", api_key) elif "moonshot" in default_model or "kimi" in default_model: os.environ.setdefault("MOONSHOT_API_KEY", api_key) - if api_base: - os.environ["MOONSHOT_API_BASE"] = api_base - - if api_base and not self.is_moonshot: + os.environ.setdefault("MOONSHOT_API_BASE", api_base or "https://api.moonshot.cn/v1") + + if api_base: litellm.api_base = api_base # Disable LiteLLM logging noise @@ -123,17 +116,17 @@ class LiteLLMProvider(LLMProvider): if self.is_vllm: model = f"hosted_vllm/{model}" + # kimi-k2.5 only supports temperature=1.0 + if "kimi-k2.5" in model.lower(): + temperature = 1.0 + kwargs: dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, } - - # kimi-k2.5 only supports temperature=1.0 - if "kimi-k2.5" in model.lower(): - kwargs["temperature"] = 1.0 - + # Pass api_base directly for custom endpoints (vLLM, etc.) if self.api_base: kwargs["api_base"] = self.api_base