fix: refactor code structure for improved readability and maintainability

This commit is contained in:
qiupinhua 2026-02-13 18:51:30 +08:00
parent 09c7e7aded
commit 1ae47058d9
4 changed files with 2406 additions and 17 deletions

View File

@ -16,6 +16,7 @@ from rich.table import Table
from rich.text import Text from rich.text import Text
from nanobot import __version__, __logo__ from nanobot import __version__, __logo__
from nanobot.config.schema import Config
app = typer.Typer( app = typer.Typer(
name="nanobot", name="nanobot",
@ -295,21 +296,33 @@ This file stores important information that should persist across sessions.
console.print(" [dim]Created memory/MEMORY.md[/dim]") console.print(" [dim]Created memory/MEMORY.md[/dim]")
def _make_provider(config): def _make_provider(config: Config):
"""Create LiteLLMProvider from config. Exits if no API key found.""" """Create LiteLLMProvider from config. Exits if no API key found."""
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
p = config.get_provider() from nanobot.providers.openai_codex_provider import OpenAICodexProvider
model = config.agents.defaults.model model = config.agents.defaults.model
if not (p and p.api_key) and not model.startswith("bedrock/"): provider_name = config.get_provider_name(model)
p = config.get_provider(model)
# OpenAI Codex (OAuth): don't route via LiteLLM; use the dedicated implementation.
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(
default_model=model,
api_base=p.api_base if p else None,
)
if not model.startswith("bedrock/") and not (p and p.api_key):
console.print("[red]Error: No API key configured.[/red]") console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section") console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1) raise typer.Exit(1)
return LiteLLMProvider( return LiteLLMProvider(
api_key=p.api_key if p else None, api_key=p.api_key if p else None,
api_base=config.get_api_base(), api_base=config.get_api_base(model),
default_model=model, default_model=model,
extra_headers=p.extra_headers if p else None, extra_headers=p.extra_headers if p else None,
provider_name=config.get_provider_name(), provider_name=provider_name,
) )

View File

@ -54,6 +54,9 @@ class LiteLLMProvider(LLMProvider):
spec = self._gateway or find_by_model(model) spec = self._gateway or find_by_model(model)
if not spec: if not spec:
return return
if not spec.env_key:
# OAuth/provider-only specs (for example: openai_codex)
return
# Gateway/local overrides existing env; standard provider doesn't # Gateway/local overrides existing env; standard provider doesn't
if self._gateway: if self._gateway:

View File

@ -77,7 +77,6 @@ class OpenAICodexProvider(LLMProvider):
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model
def _strip_model_prefix(model: str) -> str: def _strip_model_prefix(model: str) -> str:
if model.startswith("openai-codex/"): if model.startswith("openai-codex/"):
return model.split("/", 1)[1] return model.split("/", 1)[1]
@ -95,7 +94,6 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]:
"content-type": "application/json", "content-type": "application/json",
} }
async def _request_codex( async def _request_codex(
url: str, url: str,
headers: dict[str, str], headers: dict[str, str],
@ -109,7 +107,6 @@ async def _request_codex(
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response) return await _consume_sse(response)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
# Nanobot tool definitions already use the OpenAI function schema. # Nanobot tool definitions already use the OpenAI function schema.
converted: list[dict[str, Any]] = [] converted: list[dict[str, Any]] = []
@ -140,7 +137,6 @@ def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
) )
return converted return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = "" system_prompt = ""
input_items: list[dict[str, Any]] = [] input_items: list[dict[str, Any]] = []
@ -200,7 +196,6 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
return system_prompt, input_items return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]: def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str): if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]} return {"role": "user", "content": [{"type": "input_text", "text": content}]}
@ -234,12 +229,10 @@ def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
return tool_call_id, None return tool_call_id, None
return "call_0", None return "call_0", None
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True) raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest() return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = [] buffer: list[str] = []
async for line in response.aiter_lines(): async for line in response.aiter_lines():
@ -259,9 +252,6 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
continue continue
buffer.append(line) buffer.append(line)
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]: async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
content = "" content = ""
tool_calls: list[ToolCallRequest] = [] tool_calls: list[ToolCallRequest] = []
@ -318,7 +308,6 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ
return content, tool_calls, finish_reason return content, tool_calls, finish_reason
def _map_finish_reason(status: str | None) -> str: def _map_finish_reason(status: str | None) -> str:
if not status: if not status:
return "stop" return "stop"
@ -330,7 +319,6 @@ def _map_finish_reason(status: str | None) -> str:
return "error" return "error"
return "stop" return "stop"
def _friendly_error(status_code: int, raw: str) -> str: def _friendly_error(status_code: int, raw: str) -> str:
if status_code == 429: if status_code == 429:
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later." return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."

2385
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff