feat: fix API key matching by model name
This commit is contained in:
parent
f5a50d08eb
commit
760a369004
@ -122,30 +122,57 @@ class Config(BaseSettings):
|
||||
"""Get expanded workspace path."""
|
||||
return Path(self.agents.defaults.workspace).expanduser()
|
||||
|
||||
def get_api_key(self) -> str | None:
|
||||
"""Get API key in priority order: OpenRouter > DeepSeek > Anthropic > OpenAI > Gemini > Zhipu > Groq > Moonshot > vLLM."""
|
||||
return (
|
||||
self.providers.openrouter.api_key or
|
||||
self.providers.deepseek.api_key or
|
||||
self.providers.anthropic.api_key or
|
||||
self.providers.openai.api_key or
|
||||
self.providers.gemini.api_key or
|
||||
self.providers.zhipu.api_key or
|
||||
self.providers.groq.api_key or
|
||||
self.providers.moonshot.api_key or
|
||||
self.providers.vllm.api_key or
|
||||
None
|
||||
)
|
||||
def _match_provider(self, model: str | None = None) -> ProviderConfig | None:
|
||||
"""Match a provider based on model name."""
|
||||
model = (model or self.agents.defaults.model).lower()
|
||||
# Map of keywords to provider configs
|
||||
providers = {
|
||||
"openrouter": self.providers.openrouter,
|
||||
"deepseek": self.providers.deepseek,
|
||||
"anthropic": self.providers.anthropic,
|
||||
"claude": self.providers.anthropic,
|
||||
"openai": self.providers.openai,
|
||||
"gpt": self.providers.openai,
|
||||
"gemini": self.providers.gemini,
|
||||
"zhipu": self.providers.zhipu,
|
||||
"glm": self.providers.zhipu,
|
||||
"zai": self.providers.zhipu,
|
||||
"groq": self.providers.groq,
|
||||
"moonshot": self.providers.moonshot,
|
||||
"kimi": self.providers.moonshot,
|
||||
"vllm": self.providers.vllm,
|
||||
}
|
||||
for keyword, provider in providers.items():
|
||||
if keyword in model and provider.api_key:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_api_key(self, model: str | None = None) -> str | None:
|
||||
"""Get API key for the given model (or default model). Falls back to first available key."""
|
||||
# Try matching by model name first
|
||||
matched = self._match_provider(model)
|
||||
if matched:
|
||||
return matched.api_key
|
||||
# Fallback: return first available key
|
||||
for provider in [
|
||||
self.providers.openrouter, self.providers.deepseek,
|
||||
self.providers.anthropic, self.providers.openai,
|
||||
self.providers.gemini, self.providers.zhipu,
|
||||
self.providers.moonshot, self.providers.vllm,
|
||||
self.providers.groq,
|
||||
]:
|
||||
if provider.api_key:
|
||||
return provider.api_key
|
||||
return None
|
||||
|
||||
def get_api_base(self) -> str | None:
|
||||
"""Get API base URL if using OpenRouter, Zhipu, Moonshot or vLLM."""
|
||||
if self.providers.openrouter.api_key:
|
||||
def get_api_base(self, model: str | None = None) -> str | None:
|
||||
"""Get API base URL based on model name."""
|
||||
model = (model or self.agents.defaults.model).lower()
|
||||
if "openrouter" in model:
|
||||
return self.providers.openrouter.api_base or "https://openrouter.ai/api/v1"
|
||||
if self.providers.zhipu.api_key:
|
||||
if any(k in model for k in ("zhipu", "glm", "zai")):
|
||||
return self.providers.zhipu.api_base
|
||||
if self.providers.moonshot.api_key:
|
||||
return self.providers.moonshot.api_base
|
||||
if self.providers.vllm.api_base:
|
||||
if "vllm" in model:
|
||||
return self.providers.vllm.api_base
|
||||
return None
|
||||
|
||||
|
||||
@ -31,15 +31,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
(api_key and api_key.startswith("sk-or-")) or
|
||||
(api_base and "openrouter" in api_base)
|
||||
)
|
||||
|
||||
# Detect Moonshot by api_base or model name
|
||||
self.is_moonshot = (
|
||||
(api_base and "moonshot" in api_base) or
|
||||
("moonshot" in default_model or "kimi" in default_model)
|
||||
)
|
||||
|
||||
|
||||
# Track if using custom endpoint (vLLM, etc.)
|
||||
self.is_vllm = bool(api_base) and not self.is_openrouter and not self.is_moonshot
|
||||
self.is_vllm = bool(api_base) and not self.is_openrouter
|
||||
|
||||
# Configure LiteLLM based on provider
|
||||
if api_key:
|
||||
@ -63,10 +57,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
os.environ.setdefault("GROQ_API_KEY", api_key)
|
||||
elif "moonshot" in default_model or "kimi" in default_model:
|
||||
os.environ.setdefault("MOONSHOT_API_KEY", api_key)
|
||||
if api_base:
|
||||
os.environ["MOONSHOT_API_BASE"] = api_base
|
||||
|
||||
if api_base and not self.is_moonshot:
|
||||
os.environ.setdefault("MOONSHOT_API_BASE", api_base or "https://api.moonshot.cn/v1")
|
||||
|
||||
if api_base:
|
||||
litellm.api_base = api_base
|
||||
|
||||
# Disable LiteLLM logging noise
|
||||
@ -123,17 +116,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
if self.is_vllm:
|
||||
model = f"hosted_vllm/{model}"
|
||||
|
||||
# kimi-k2.5 only supports temperature=1.0
|
||||
if "kimi-k2.5" in model.lower():
|
||||
temperature = 1.0
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# kimi-k2.5 only supports temperature=1.0
|
||||
if "kimi-k2.5" in model.lower():
|
||||
kwargs["temperature"] = 1.0
|
||||
|
||||
|
||||
# Pass api_base directly for custom endpoints (vLLM, etc.)
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user