diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py index fd7d1e8..560c1f5 100644 --- a/nanobot/config/loader.py +++ b/nanobot/config/loader.py @@ -2,7 +2,6 @@ import json from pathlib import Path -from typing import Any from nanobot.config.schema import Config @@ -21,43 +20,41 @@ def get_data_dir() -> Path: def load_config(config_path: Path | None = None) -> Config: """ Load configuration from file or create default. - + Args: config_path: Optional path to config file. Uses default if not provided. - + Returns: Loaded configuration object. """ path = config_path or get_config_path() - + if path.exists(): try: with open(path) as f: data = json.load(f) data = _migrate_config(data) - return Config.model_validate(convert_keys(data)) + return Config.model_validate(data) except (json.JSONDecodeError, ValueError) as e: print(f"Warning: Failed to load config from {path}: {e}") print("Using default configuration.") - + return Config() def save_config(config: Config, config_path: Path | None = None) -> None: """ Save configuration to file. - + Args: config: Configuration to save. config_path: Optional path to save to. Uses default if not provided. """ path = config_path or get_config_path() path.parent.mkdir(parents=True, exist_ok=True) - - # Convert to camelCase format - data = config.model_dump() - data = convert_to_camel(data) - + + data = config.model_dump(by_alias=True) + with open(path, "w") as f: json.dump(data, f, indent=2) @@ -70,37 +67,3 @@ def _migrate_config(data: dict) -> dict: if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools: tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace") return data - - -def convert_keys(data: Any) -> Any: - """Convert camelCase keys to snake_case for Pydantic.""" - if isinstance(data, dict): - return {camel_to_snake(k): convert_keys(v) for k, v in data.items()} - if isinstance(data, list): - return [convert_keys(item) for item in data] - return data - - -def convert_to_camel(data: Any) -> Any: - """Convert snake_case keys to camelCase.""" - if isinstance(data, dict): - return {snake_to_camel(k): convert_to_camel(v) for k, v in data.items()} - if isinstance(data, list): - return [convert_to_camel(item) for item in data] - return data - - -def camel_to_snake(name: str) -> str: - """Convert camelCase to snake_case.""" - result = [] - for i, char in enumerate(name): - if char.isupper() and i > 0: - result.append("_") - result.append(char.lower()) - return "".join(result) - - -def snake_to_camel(name: str) -> str: - """Convert snake_case to camelCase.""" - components = name.split("_") - return components[0] + "".join(x.title() for x in components[1:]) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 64609ec..0786080 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -2,27 +2,39 @@ from pathlib import Path from pydantic import BaseModel, Field, ConfigDict +from pydantic.alias_generators import to_camel from pydantic_settings import BaseSettings -class WhatsAppConfig(BaseModel): +class Base(BaseModel): + """Base model that accepts both camelCase and snake_case keys.""" + + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) + + +class WhatsAppConfig(Base): """WhatsApp channel configuration.""" + enabled: bool = False bridge_url: str = "ws://localhost:3001" bridge_token: str = "" # Shared token for bridge auth (optional, recommended) allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers -class TelegramConfig(BaseModel): +class TelegramConfig(Base): """Telegram channel configuration.""" + enabled: bool = False token: str = "" # Bot token from @BotFather allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames - proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" + proxy: str | None = ( + None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" + ) -class FeishuConfig(BaseModel): +class FeishuConfig(Base): """Feishu/Lark channel configuration using WebSocket long connection.""" + enabled: bool = False app_id: str = "" # App ID from Feishu Open Platform app_secret: str = "" # App Secret from Feishu Open Platform @@ -31,24 +43,28 @@ class FeishuConfig(BaseModel): allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids -class DingTalkConfig(BaseModel): +class DingTalkConfig(Base): """DingTalk channel configuration using Stream mode.""" + enabled: bool = False client_id: str = "" # AppKey client_secret: str = "" # AppSecret allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids -class DiscordConfig(BaseModel): +class DiscordConfig(Base): """Discord channel configuration.""" + enabled: bool = False token: str = "" # Bot token from Discord Developer Portal allow_from: list[str] = Field(default_factory=list) # Allowed user IDs gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT -class EmailConfig(BaseModel): + +class EmailConfig(Base): """Email channel configuration (IMAP inbound + SMTP outbound).""" + enabled: bool = False consent_granted: bool = False # Explicit owner permission to access mailbox data @@ -70,7 +86,9 @@ class EmailConfig(BaseModel): from_address: str = "" # Behavior - auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent + auto_reply_enabled: bool = ( + True # If false, inbound email is read but no automatic reply is sent + ) poll_interval_seconds: int = 30 mark_seen: bool = True max_body_chars: int = 12000 @@ -78,18 +96,21 @@ class EmailConfig(BaseModel): allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses -class MochatMentionConfig(BaseModel): +class MochatMentionConfig(Base): """Mochat mention behavior configuration.""" + require_in_groups: bool = False -class MochatGroupRule(BaseModel): +class MochatGroupRule(Base): """Mochat per-group mention requirement.""" + require_mention: bool = False -class MochatConfig(BaseModel): +class MochatConfig(Base): """Mochat channel configuration.""" + enabled: bool = False base_url: str = "https://mochat.io" socket_url: str = "" @@ -114,15 +135,17 @@ class MochatConfig(BaseModel): reply_delay_ms: int = 120000 -class SlackDMConfig(BaseModel): +class SlackDMConfig(Base): """Slack DM policy configuration.""" + enabled: bool = True policy: str = "open" # "open" or "allowlist" allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs -class SlackConfig(BaseModel): +class SlackConfig(Base): """Slack channel configuration.""" + enabled: bool = False mode: str = "socket" # "socket" supported webhook_path: str = "/slack/events" @@ -134,16 +157,20 @@ class SlackConfig(BaseModel): dm: SlackDMConfig = Field(default_factory=SlackDMConfig) -class QQConfig(BaseModel): +class QQConfig(Base): """QQ channel configuration using botpy SDK.""" + enabled: bool = False app_id: str = "" # 机器人 ID (AppID) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com - allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access) + allow_from: list[str] = Field( + default_factory=list + ) # Allowed user openids (empty = public access) -class ChannelsConfig(BaseModel): +class ChannelsConfig(Base): """Configuration for chat channels.""" + whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig) telegram: TelegramConfig = Field(default_factory=TelegramConfig) discord: DiscordConfig = Field(default_factory=DiscordConfig) @@ -155,8 +182,9 @@ class ChannelsConfig(BaseModel): qq: QQConfig = Field(default_factory=QQConfig) -class AgentDefaults(BaseModel): +class AgentDefaults(Base): """Default agent configuration.""" + workspace: str = "~/.nanobot/workspace" model: str = "anthropic/claude-opus-4-5" max_tokens: int = 8192 @@ -165,20 +193,23 @@ class AgentDefaults(BaseModel): memory_window: int = 50 -class AgentsConfig(BaseModel): +class AgentsConfig(Base): """Agent configuration.""" + defaults: AgentDefaults = Field(default_factory=AgentDefaults) -class ProviderConfig(BaseModel): +class ProviderConfig(Base): """LLM provider configuration.""" + api_key: str = "" api_base: str | None = None extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix) -class ProvidersConfig(BaseModel): +class ProvidersConfig(Base): """Configuration for LLM providers.""" + custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint anthropic: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig) @@ -196,38 +227,44 @@ class ProvidersConfig(BaseModel): github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) -class GatewayConfig(BaseModel): +class GatewayConfig(Base): """Gateway/server configuration.""" + host: str = "0.0.0.0" port: int = 18790 -class WebSearchConfig(BaseModel): +class WebSearchConfig(Base): """Web search tool configuration.""" + api_key: str = "" # Brave Search API key max_results: int = 5 -class WebToolsConfig(BaseModel): +class WebToolsConfig(Base): """Web tools configuration.""" + search: WebSearchConfig = Field(default_factory=WebSearchConfig) -class ExecToolConfig(BaseModel): +class ExecToolConfig(Base): """Shell exec tool configuration.""" + timeout: int = 60 -class MCPServerConfig(BaseModel): +class MCPServerConfig(Base): """MCP server connection configuration (stdio or HTTP).""" + command: str = "" # Stdio: command to run (e.g. "npx") args: list[str] = Field(default_factory=list) # Stdio: command arguments env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars url: str = "" # HTTP: streamable HTTP endpoint URL -class ToolsConfig(BaseModel): +class ToolsConfig(Base): """Tools configuration.""" + web: WebToolsConfig = Field(default_factory=WebToolsConfig) exec: ExecToolConfig = Field(default_factory=ExecToolConfig) restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory @@ -236,20 +273,24 @@ class ToolsConfig(BaseModel): class Config(BaseSettings): """Root configuration for nanobot.""" + agents: AgentsConfig = Field(default_factory=AgentsConfig) channels: ChannelsConfig = Field(default_factory=ChannelsConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig) gateway: GatewayConfig = Field(default_factory=GatewayConfig) tools: ToolsConfig = Field(default_factory=ToolsConfig) - + @property def workspace_path(self) -> Path: """Get expanded workspace path.""" return Path(self.agents.defaults.workspace).expanduser() - - def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]: + + def _match_provider( + self, model: str | None = None + ) -> tuple["ProviderConfig | None", str | None]: """Match provider config and its registry name. Returns (config, spec_name).""" from nanobot.providers.registry import PROVIDERS + model_lower = (model or self.agents.defaults.model).lower() # Match by keyword (order follows PROVIDERS registry) @@ -283,10 +324,11 @@ class Config(BaseSettings): """Get API key for the given model. Falls back to first available key.""" p = self.get_provider(model) return p.api_key if p else None - + def get_api_base(self, model: str | None = None) -> str | None: """Get API base URL for the given model. Applies default URLs for known gateways.""" from nanobot.providers.registry import find_by_name + p, name = self._match_provider(model) if p and p.api_base: return p.api_base @@ -298,8 +340,5 @@ class Config(BaseSettings): if spec and spec.is_gateway and spec.default_api_base: return spec.default_api_base return None - - model_config = ConfigDict( - env_prefix="NANOBOT_", - env_nested_delimiter="__" - ) + + model_config = ConfigDict(env_prefix="NANOBOT_", env_nested_delimiter="__")