fix: refactor code structure for improved readability and maintainability
This commit is contained in:
parent
09c7e7aded
commit
1ae47058d9
@ -16,6 +16,7 @@ from rich.table import Table
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from nanobot import __version__, __logo__
|
from nanobot import __version__, __logo__
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="nanobot",
|
name="nanobot",
|
||||||
@ -295,21 +296,33 @@ This file stores important information that should persist across sessions.
|
|||||||
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config):
|
def _make_provider(config: Config):
|
||||||
"""Create LiteLLMProvider from config. Exits if no API key found."""
|
"""Create LiteLLMProvider from config. Exits if no API key found."""
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
p = config.get_provider()
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
if not (p and p.api_key) and not model.startswith("bedrock/"):
|
provider_name = config.get_provider_name(model)
|
||||||
|
p = config.get_provider(model)
|
||||||
|
|
||||||
|
# OpenAI Codex (OAuth): don't route via LiteLLM; use the dedicated implementation.
|
||||||
|
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||||
|
return OpenAICodexProvider(
|
||||||
|
default_model=model,
|
||||||
|
api_base=p.api_base if p else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model.startswith("bedrock/") and not (p and p.api_key):
|
||||||
console.print("[red]Error: No API key configured.[/red]")
|
console.print("[red]Error: No API key configured.[/red]")
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
|
||||||
return LiteLLMProvider(
|
return LiteLLMProvider(
|
||||||
api_key=p.api_key if p else None,
|
api_key=p.api_key if p else None,
|
||||||
api_base=config.get_api_base(),
|
api_base=config.get_api_base(model),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
provider_name=config.get_provider_name(),
|
provider_name=provider_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -54,6 +54,9 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
spec = self._gateway or find_by_model(model)
|
spec = self._gateway or find_by_model(model)
|
||||||
if not spec:
|
if not spec:
|
||||||
return
|
return
|
||||||
|
if not spec.env_key:
|
||||||
|
# OAuth/provider-only specs (for example: openai_codex)
|
||||||
|
return
|
||||||
|
|
||||||
# Gateway/local overrides existing env; standard provider doesn't
|
# Gateway/local overrides existing env; standard provider doesn't
|
||||||
if self._gateway:
|
if self._gateway:
|
||||||
|
|||||||
@ -77,7 +77,6 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
|
|
||||||
|
|
||||||
def _strip_model_prefix(model: str) -> str:
|
def _strip_model_prefix(model: str) -> str:
|
||||||
if model.startswith("openai-codex/"):
|
if model.startswith("openai-codex/"):
|
||||||
return model.split("/", 1)[1]
|
return model.split("/", 1)[1]
|
||||||
@ -95,7 +94,6 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]:
|
|||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _request_codex(
|
async def _request_codex(
|
||||||
url: str,
|
url: str,
|
||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
@ -109,7 +107,6 @@ async def _request_codex(
|
|||||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||||
return await _consume_sse(response)
|
return await _consume_sse(response)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
# Nanobot tool definitions already use the OpenAI function schema.
|
# Nanobot tool definitions already use the OpenAI function schema.
|
||||||
converted: list[dict[str, Any]] = []
|
converted: list[dict[str, Any]] = []
|
||||||
@ -140,7 +137,6 @@ def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|||||||
)
|
)
|
||||||
return converted
|
return converted
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||||
system_prompt = ""
|
system_prompt = ""
|
||||||
input_items: list[dict[str, Any]] = []
|
input_items: list[dict[str, Any]] = []
|
||||||
@ -200,7 +196,6 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
|||||||
|
|
||||||
return system_prompt, input_items
|
return system_prompt, input_items
|
||||||
|
|
||||||
|
|
||||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||||
@ -234,12 +229,10 @@ def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
|||||||
return tool_call_id, None
|
return tool_call_id, None
|
||||||
return "call_0", None
|
return "call_0", None
|
||||||
|
|
||||||
|
|
||||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||||
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
|
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
|
||||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
buffer: list[str] = []
|
buffer: list[str] = []
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
@ -259,9 +252,6 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
|
|||||||
continue
|
continue
|
||||||
buffer.append(line)
|
buffer.append(line)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
content = ""
|
content = ""
|
||||||
tool_calls: list[ToolCallRequest] = []
|
tool_calls: list[ToolCallRequest] = []
|
||||||
@ -318,7 +308,6 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ
|
|||||||
|
|
||||||
return content, tool_calls, finish_reason
|
return content, tool_calls, finish_reason
|
||||||
|
|
||||||
|
|
||||||
def _map_finish_reason(status: str | None) -> str:
|
def _map_finish_reason(status: str | None) -> str:
|
||||||
if not status:
|
if not status:
|
||||||
return "stop"
|
return "stop"
|
||||||
@ -330,7 +319,6 @@ def _map_finish_reason(status: str | None) -> str:
|
|||||||
return "error"
|
return "error"
|
||||||
return "stop"
|
return "stop"
|
||||||
|
|
||||||
|
|
||||||
def _friendly_error(status_code: int, raw: str) -> str:
|
def _friendly_error(status_code: int, raw: str) -> str:
|
||||||
if status_code == 429:
|
if status_code == 429:
|
||||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user