diff --git a/.gitignore b/.gitignore index 36dbfc2..d7b930d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,8 +14,9 @@ docs/ *.pywz *.pyzz .venv/ +venv/ __pycache__/ poetry.lock .pytest_cache/ -tests/ botpy.log +tests/ diff --git a/README.md b/README.md index fed25c8..47702c1 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,12 @@ ⚑️ Delivers core agent functionality in just **~4,000** lines of code β€” **99% smaller** than Clawdbot's 430k+ lines. -πŸ“ Real-time line count: **3,510 lines** (run `bash core_agent_lines.sh` to verify anytime) +πŸ“ Real-time line count: **3,536 lines** (run `bash core_agent_lines.sh` to verify anytime) ## πŸ“’ News +- **2026-02-13** πŸŽ‰ Released v0.1.3.post7 β€” includes security hardening and multiple improvements. All users are recommended to upgrade to the latest version. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details. +- **2026-02-12** 🧠 Redesigned memory system β€” Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it! - **2026-02-10** πŸŽ‰ Released v0.1.3.post6 with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431). - **2026-02-09** πŸ’¬ Added Slack, Email, and QQ support β€” nanobot now supports multiple chat platforms! - **2026-02-08** πŸ”§ Refactored Providersβ€”adding a new LLM provider now takes just 2 simple steps! Check [here](#providers). @@ -597,6 +599,7 @@ Config file: `~/.nanobot/config.json` | Provider | Purpose | Get API Key | |----------|---------|-------------| +| `custom` | Any OpenAI-compatible endpoint | β€” | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | @@ -610,6 +613,31 @@ Config file: `~/.nanobot/config.json` | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `vllm` | LLM (local, any OpenAI-compatible server) | β€” | +
+Custom Provider (Any OpenAI-compatible API) + +If your provider is not listed above but exposes an **OpenAI-compatible API** (e.g. Together AI, Fireworks, Azure OpenAI, self-hosted endpoints), use the `custom` provider: + +```json +{ + "providers": { + "custom": { + "apiKey": "your-api-key", + "apiBase": "https://api.your-provider.com/v1" + } + }, + "agents": { + "defaults": { + "model": "your-model-name" + } + } +} +``` + +> The `custom` provider routes through LiteLLM's OpenAI-compatible path. It works with any endpoint that follows the OpenAI chat completions API format. The model name is passed directly to the endpoint without any prefix. + +
+
Adding a New Provider (Developer Guide) diff --git a/SECURITY.md b/SECURITY.md index ac15ba4..af3448c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -95,8 +95,8 @@ File operations have path traversal protection, but: - Consider using a firewall to restrict outbound connections if needed **WhatsApp Bridge:** -- The bridge runs on `localhost:3001` by default -- If exposing to network, use proper authentication and TLS +- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network) +- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js - Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700) ### 6. Dependency Security @@ -224,7 +224,7 @@ If you suspect a security breach: βœ… **Secure Communication** - HTTPS for all external API calls - TLS for Telegram API -- WebSocket security for WhatsApp bridge +- WhatsApp bridge: localhost-only binding + optional token auth ## Known Limitations diff --git a/bridge/src/index.ts b/bridge/src/index.ts index 8db63ef..e8f3db9 100644 --- a/bridge/src/index.ts +++ b/bridge/src/index.ts @@ -25,11 +25,12 @@ import { join } from 'path'; const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10); const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth'); +const TOKEN = process.env.BRIDGE_TOKEN || undefined; console.log('🐈 nanobot WhatsApp Bridge'); console.log('========================\n'); -const server = new BridgeServer(PORT, AUTH_DIR); +const server = new BridgeServer(PORT, AUTH_DIR, TOKEN); // Handle graceful shutdown process.on('SIGINT', async () => { diff --git a/bridge/src/server.ts b/bridge/src/server.ts index c6fd599..7d48f5e 100644 --- a/bridge/src/server.ts +++ b/bridge/src/server.ts @@ -1,5 +1,6 @@ /** * WebSocket server for Python-Node.js bridge communication. + * Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth. */ import { WebSocketServer, WebSocket } from 'ws'; @@ -21,12 +22,13 @@ export class BridgeServer { private wa: WhatsAppClient | null = null; private clients: Set = new Set(); - constructor(private port: number, private authDir: string) {} + constructor(private port: number, private authDir: string, private token?: string) {} async start(): Promise { - // Create WebSocket server - this.wss = new WebSocketServer({ port: this.port }); - console.log(`πŸŒ‰ Bridge server listening on ws://localhost:${this.port}`); + // Bind to localhost only β€” never expose to external network + this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port }); + console.log(`πŸŒ‰ Bridge server listening on ws://127.0.0.1:${this.port}`); + if (this.token) console.log('πŸ”’ Token authentication enabled'); // Initialize WhatsApp client this.wa = new WhatsAppClient({ @@ -38,35 +40,58 @@ export class BridgeServer { // Handle WebSocket connections this.wss.on('connection', (ws) => { - console.log('πŸ”— Python client connected'); - this.clients.add(ws); - - ws.on('message', async (data) => { - try { - const cmd = JSON.parse(data.toString()) as SendCommand; - await this.handleCommand(cmd); - ws.send(JSON.stringify({ type: 'sent', to: cmd.to })); - } catch (error) { - console.error('Error handling command:', error); - ws.send(JSON.stringify({ type: 'error', error: String(error) })); - } - }); - - ws.on('close', () => { - console.log('πŸ”Œ Python client disconnected'); - this.clients.delete(ws); - }); - - ws.on('error', (error) => { - console.error('WebSocket error:', error); - this.clients.delete(ws); - }); + if (this.token) { + // Require auth handshake as first message + const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000); + ws.once('message', (data) => { + clearTimeout(timeout); + try { + const msg = JSON.parse(data.toString()); + if (msg.type === 'auth' && msg.token === this.token) { + console.log('πŸ”— Python client authenticated'); + this.setupClient(ws); + } else { + ws.close(4003, 'Invalid token'); + } + } catch { + ws.close(4003, 'Invalid auth message'); + } + }); + } else { + console.log('πŸ”— Python client connected'); + this.setupClient(ws); + } }); // Connect to WhatsApp await this.wa.connect(); } + private setupClient(ws: WebSocket): void { + this.clients.add(ws); + + ws.on('message', async (data) => { + try { + const cmd = JSON.parse(data.toString()) as SendCommand; + await this.handleCommand(cmd); + ws.send(JSON.stringify({ type: 'sent', to: cmd.to })); + } catch (error) { + console.error('Error handling command:', error); + ws.send(JSON.stringify({ type: 'error', error: String(error) })); + } + }); + + ws.on('close', () => { + console.log('πŸ”Œ Python client disconnected'); + this.clients.delete(ws); + }); + + ws.on('error', (error) => { + console.error('WebSocket error:', error); + this.clients.delete(ws); + }); + } + private async handleCommand(cmd: SendCommand): Promise { if (cmd.type === 'send' && this.wa) { await this.wa.sendMessage(cmd.to, cmd.text); diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index d807854..f460f2b 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -73,7 +73,9 @@ Skills with available="false" need dependencies installed first - you can try in def _get_identity(self) -> str: """Get the core identity section.""" from datetime import datetime + import time as _time now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") + tz = _time.strftime("%Z") or "UTC" workspace_path = str(self.workspace.expanduser().resolve()) system = platform.system() runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" @@ -88,23 +90,24 @@ You are nanobot, a helpful AI assistant. You have access to tools that allow you - Spawn subagents for complex background tasks ## Current Time -{now} +{now} ({tz}) ## Runtime {runtime} ## Workspace Your workspace is at: {workspace_path} -- Memory files: {workspace_path}/memory/MEMORY.md -- Daily notes: {workspace_path}/memory/YYYY-MM-DD.md +- Long-term memory: {workspace_path}/memory/MEMORY.md +- History log: {workspace_path}/memory/HISTORY.md (grep-searchable) - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md IMPORTANT: When responding to direct questions or conversations, reply directly with your text response. Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp). For normal conversation, just respond with text - do not call the message tool. -Always be helpful, accurate, and concise. When using tools, explain what you're doing. -When remembering something, write to {workspace_path}/memory/MEMORY.md""" +Always be helpful, accurate, and concise. When using tools, think step by step: what you know, what you need, and why you chose this tool. +When remembering something important, write to {workspace_path}/memory/MEMORY.md +To recall past events, grep {workspace_path}/memory/HISTORY.md""" def _load_bootstrap_files(self) -> str: """Load all bootstrap files from workspace.""" diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b15803a..cc7a0d0 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -19,14 +19,15 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.spawn import SpawnTool from nanobot.agent.tools.cron import CronTool +from nanobot.agent.memory import MemoryStore from nanobot.agent.subagent import SubagentManager -from nanobot.session.manager import SessionManager +from nanobot.session.manager import Session, SessionManager class AgentLoop: """ The agent loop is the core processing engine. - + It: 1. Receives messages from the bus 2. Builds context with history, memory, skills @@ -34,7 +35,7 @@ class AgentLoop: 4. Executes tool calls 5. Sends responses back """ - + def __init__( self, bus: MessageBus, @@ -42,6 +43,9 @@ class AgentLoop: workspace: Path, model: str | None = None, max_iterations: int = 20, + temperature: float = 0.7, + max_tokens: int = 4096, + memory_window: int = 50, brave_api_key: str | None = None, exec_config: "ExecToolConfig | None" = None, cron_service: "CronService | None" = None, @@ -56,11 +60,14 @@ class AgentLoop: self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = max_iterations + self.temperature = temperature + self.max_tokens = max_tokens + self.memory_window = memory_window self.brave_api_key = brave_api_key self.exec_config = exec_config or ExecToolConfig() self.cron_service = cron_service self.restrict_to_workspace = restrict_to_workspace - + self.context = ContextBuilder(workspace) self.sessions = session_manager or SessionManager(workspace) self.tools = ToolRegistry() @@ -69,6 +76,8 @@ class AgentLoop: workspace=workspace, bus=bus, model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, brave_api_key=brave_api_key, exec_config=self.exec_config, restrict_to_workspace=restrict_to_workspace, @@ -122,28 +131,96 @@ class AgentLoop: await self._mcp_stack.__aenter__() await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) + def _set_tool_context(self, channel: str, chat_id: str) -> None: + """Update context for all tools that need routing info.""" + if message_tool := self.tools.get("message"): + if isinstance(message_tool, MessageTool): + message_tool.set_context(channel, chat_id) + + if spawn_tool := self.tools.get("spawn"): + if isinstance(spawn_tool, SpawnTool): + spawn_tool.set_context(channel, chat_id) + + if cron_tool := self.tools.get("cron"): + if isinstance(cron_tool, CronTool): + cron_tool.set_context(channel, chat_id) + + async def _run_agent_loop(self, initial_messages: list[dict]) -> tuple[str | None, list[str]]: + """ + Run the agent iteration loop. + + Args: + initial_messages: Starting messages for the LLM conversation. + + Returns: + Tuple of (final_content, list_of_tools_used). + """ + messages = initial_messages + iteration = 0 + final_content = None + tools_used: list[str] = [] + + while iteration < self.max_iterations: + iteration += 1 + + response = await self.provider.chat( + messages=messages, + tools=self.tools.get_definitions(), + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + if response.has_tool_calls: + tool_call_dicts = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments) + } + } + for tc in response.tool_calls + ] + messages = self.context.add_assistant_message( + messages, response.content, tool_call_dicts, + reasoning_content=response.reasoning_content, + ) + + for tool_call in response.tool_calls: + tools_used.append(tool_call.name) + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.info(f"Tool call: {tool_call.name}({args_str[:200]})") + result = await self.tools.execute(tool_call.name, tool_call.arguments) + messages = self.context.add_tool_result( + messages, tool_call.id, tool_call.name, result + ) + messages.append({"role": "user", "content": "Reflect on the results and decide next steps."}) + else: + final_content = response.content + break + + return final_content, tools_used + async def run(self) -> None: """Run the agent loop, processing messages from the bus.""" self._running = True await self._connect_mcp() logger.info("Agent loop started") - + while self._running: try: - # Wait for next message msg = await asyncio.wait_for( self.bus.consume_inbound(), timeout=1.0 ) - - # Process it try: response = await self._process_message(msg) if response: await self.bus.publish_outbound(response) except Exception as e: logger.error(f"Error processing message: {e}") - # Send error response await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, @@ -166,105 +243,70 @@ class AgentLoop: self._running = False logger.info("Agent loop stopping") - async def _process_message(self, msg: InboundMessage) -> OutboundMessage | None: + async def _process_message(self, msg: InboundMessage, session_key: str | None = None) -> OutboundMessage | None: """ Process a single inbound message. Args: msg: The inbound message to process. + session_key: Override session key (used by process_direct). Returns: The response message, or None if no response needed. """ - # Handle system messages (subagent announces) - # The chat_id contains the original "channel:chat_id" to route back to + # System messages route back via chat_id ("channel:chat_id") if msg.channel == "system": return await self._process_system_message(msg) preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}") - # Get or create session - session = self.sessions.get_or_create(msg.session_key) + key = session_key or msg.session_key + session = self.sessions.get_or_create(key) - # Update tool contexts - message_tool = self.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_context(msg.channel, msg.chat_id) + # Handle slash commands + cmd = msg.content.strip().lower() + if cmd == "/new": + # Capture messages before clearing (avoid race condition with background task) + messages_to_archive = session.messages.copy() + session.clear() + self.sessions.save(session) + self.sessions.invalidate(session.key) + + async def _consolidate_and_cleanup(): + temp_session = Session(key=session.key) + temp_session.messages = messages_to_archive + await self._consolidate_memory(temp_session, archive_all=True) + + asyncio.create_task(_consolidate_and_cleanup()) + return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, + content="New session started. Memory consolidation in progress.") + if cmd == "/help": + return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, + content="🐈 nanobot commands:\n/new β€” Start a new conversation\n/help β€” Show available commands") - spawn_tool = self.tools.get("spawn") - if isinstance(spawn_tool, SpawnTool): - spawn_tool.set_context(msg.channel, msg.chat_id) - - cron_tool = self.tools.get("cron") - if isinstance(cron_tool, CronTool): - cron_tool.set_context(msg.channel, msg.chat_id) - - # Build initial messages (use get_history for LLM-formatted messages) - messages = self.context.build_messages( - history=session.get_history(), + if len(session.messages) > self.memory_window: + asyncio.create_task(self._consolidate_memory(session)) + + self._set_tool_context(msg.channel, msg.chat_id) + initial_messages = self.context.build_messages( + history=session.get_history(max_messages=self.memory_window), current_message=msg.content, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, ) - - # Agent loop - iteration = 0 - final_content = None - - while iteration < self.max_iterations: - iteration += 1 - - # Call LLM - response = await self.provider.chat( - messages=messages, - tools=self.tools.get_definitions(), - model=self.model - ) - - # Handle tool calls - if response.has_tool_calls: - # Add assistant message with tool calls - tool_call_dicts = [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": json.dumps(tc.arguments) # Must be JSON string - } - } - for tc in response.tool_calls - ] - messages = self.context.add_assistant_message( - messages, response.content, tool_call_dicts, - reasoning_content=response.reasoning_content, - ) - - # Execute tools - for tool_call in response.tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.info(f"Tool call: {tool_call.name}({args_str[:200]})") - result = await self.tools.execute(tool_call.name, tool_call.arguments) - messages = self.context.add_tool_result( - messages, tool_call.id, tool_call.name, result - ) - else: - # No tool calls, we're done - final_content = response.content - break - + final_content, tools_used = await self._run_agent_loop(initial_messages) + if final_content is None: final_content = "I've completed processing but have no response to give." - # Log response preview preview = final_content[:120] + "..." if len(final_content) > 120 else final_content logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}") - # Save to session session.add_message("user", msg.content) - session.add_message("assistant", final_content) + session.add_message("assistant", final_content, + tools_used=tools_used if tools_used else None) self.sessions.save(session) return OutboundMessage( @@ -293,76 +335,20 @@ class AgentLoop: origin_channel = "cli" origin_chat_id = msg.chat_id - # Use the origin session for context session_key = f"{origin_channel}:{origin_chat_id}" session = self.sessions.get_or_create(session_key) - - # Update tool contexts - message_tool = self.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_context(origin_channel, origin_chat_id) - - spawn_tool = self.tools.get("spawn") - if isinstance(spawn_tool, SpawnTool): - spawn_tool.set_context(origin_channel, origin_chat_id) - - cron_tool = self.tools.get("cron") - if isinstance(cron_tool, CronTool): - cron_tool.set_context(origin_channel, origin_chat_id) - - # Build messages with the announce content - messages = self.context.build_messages( - history=session.get_history(), + self._set_tool_context(origin_channel, origin_chat_id) + initial_messages = self.context.build_messages( + history=session.get_history(max_messages=self.memory_window), current_message=msg.content, channel=origin_channel, chat_id=origin_chat_id, ) - - # Agent loop (limited for announce handling) - iteration = 0 - final_content = None - - while iteration < self.max_iterations: - iteration += 1 - - response = await self.provider.chat( - messages=messages, - tools=self.tools.get_definitions(), - model=self.model - ) - - if response.has_tool_calls: - tool_call_dicts = [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": json.dumps(tc.arguments) - } - } - for tc in response.tool_calls - ] - messages = self.context.add_assistant_message( - messages, response.content, tool_call_dicts, - reasoning_content=response.reasoning_content, - ) - - for tool_call in response.tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.info(f"Tool call: {tool_call.name}({args_str[:200]})") - result = await self.tools.execute(tool_call.name, tool_call.arguments) - messages = self.context.add_tool_result( - messages, tool_call.id, tool_call.name, result - ) - else: - final_content = response.content - break - + final_content, _ = await self._run_agent_loop(initial_messages) + if final_content is None: final_content = "Background task completed." - # Save to session (mark as system message in history) session.add_message("user", f"[System: {msg.sender_id}] {msg.content}") session.add_message("assistant", final_content) self.sessions.save(session) @@ -373,6 +359,85 @@ class AgentLoop: content=final_content ) + async def _consolidate_memory(self, session, archive_all: bool = False) -> None: + """Consolidate old messages into MEMORY.md + HISTORY.md. + + Args: + archive_all: If True, clear all messages and reset session (for /new command). + If False, only write to files without modifying session. + """ + memory = MemoryStore(self.workspace) + + if archive_all: + old_messages = session.messages + keep_count = 0 + logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived") + else: + keep_count = self.memory_window // 2 + if len(session.messages) <= keep_count: + logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})") + return + + messages_to_process = len(session.messages) - session.last_consolidated + if messages_to_process <= 0: + logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})") + return + + old_messages = session.messages[session.last_consolidated:-keep_count] + if not old_messages: + return + logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep") + + lines = [] + for m in old_messages: + if not m.get("content"): + continue + tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" + lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") + conversation = "\n".join(lines) + current_memory = memory.read_long_term() + + prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys: + +1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later. + +2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged. + +## Current Long-term Memory +{current_memory or "(empty)"} + +## Conversation to Process +{conversation} + +Respond with ONLY valid JSON, no markdown fences.""" + + try: + response = await self.provider.chat( + messages=[ + {"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."}, + {"role": "user", "content": prompt}, + ], + model=self.model, + ) + text = (response.content or "").strip() + if text.startswith("```"): + text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip() + result = json.loads(text) + + if entry := result.get("history_entry"): + memory.append_history(entry) + if update := result.get("memory_update"): + if update != current_memory: + memory.write_long_term(update) + + if archive_all: + session.last_consolidated = 0 + else: + session.last_consolidated = len(session.messages) - keep_count + logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}") + except Exception as e: + logger.error(f"Memory consolidation failed: {e}") + async def process_direct( self, content: str, @@ -385,9 +450,9 @@ class AgentLoop: Args: content: The message content. - session_key: Session identifier. - channel: Source channel (for context). - chat_id: Source chat ID (for context). + session_key: Session identifier (overrides channel:chat_id for session lookup). + channel: Source channel (for tool context routing). + chat_id: Source chat ID (for tool context routing). Returns: The agent's response. @@ -400,5 +465,5 @@ class AgentLoop: content=content ) - response = await self._process_message(msg) + response = await self._process_message(msg, session_key=session_key) return response.content if response else "" diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 453407e..29477c4 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -1,109 +1,30 @@ """Memory system for persistent agent memory.""" from pathlib import Path -from datetime import datetime -from nanobot.utils.helpers import ensure_dir, today_date +from nanobot.utils.helpers import ensure_dir class MemoryStore: - """ - Memory system for the agent. - - Supports daily notes (memory/YYYY-MM-DD.md) and long-term memory (MEMORY.md). - """ - + """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" + def __init__(self, workspace: Path): - self.workspace = workspace self.memory_dir = ensure_dir(workspace / "memory") self.memory_file = self.memory_dir / "MEMORY.md" - - def get_today_file(self) -> Path: - """Get path to today's memory file.""" - return self.memory_dir / f"{today_date()}.md" - - def read_today(self) -> str: - """Read today's memory notes.""" - today_file = self.get_today_file() - if today_file.exists(): - return today_file.read_text(encoding="utf-8") - return "" - - def append_today(self, content: str) -> None: - """Append content to today's memory notes.""" - today_file = self.get_today_file() - - if today_file.exists(): - existing = today_file.read_text(encoding="utf-8") - content = existing + "\n" + content - else: - # Add header for new day - header = f"# {today_date()}\n\n" - content = header + content - - today_file.write_text(content, encoding="utf-8") - + self.history_file = self.memory_dir / "HISTORY.md" + def read_long_term(self) -> str: - """Read long-term memory (MEMORY.md).""" if self.memory_file.exists(): return self.memory_file.read_text(encoding="utf-8") return "" - + def write_long_term(self, content: str) -> None: - """Write to long-term memory (MEMORY.md).""" self.memory_file.write_text(content, encoding="utf-8") - - def get_recent_memories(self, days: int = 7) -> str: - """ - Get memories from the last N days. - - Args: - days: Number of days to look back. - - Returns: - Combined memory content. - """ - from datetime import timedelta - - memories = [] - today = datetime.now().date() - - for i in range(days): - date = today - timedelta(days=i) - date_str = date.strftime("%Y-%m-%d") - file_path = self.memory_dir / f"{date_str}.md" - - if file_path.exists(): - content = file_path.read_text(encoding="utf-8") - memories.append(content) - - return "\n\n---\n\n".join(memories) - - def list_memory_files(self) -> list[Path]: - """List all memory files sorted by date (newest first).""" - if not self.memory_dir.exists(): - return [] - - files = list(self.memory_dir.glob("????-??-??.md")) - return sorted(files, reverse=True) - + + def append_history(self, entry: str) -> None: + with open(self.history_file, "a", encoding="utf-8") as f: + f.write(entry.rstrip() + "\n\n") + def get_memory_context(self) -> str: - """ - Get memory context for the agent. - - Returns: - Formatted memory context including long-term and recent memories. - """ - parts = [] - - # Long-term memory long_term = self.read_long_term() - if long_term: - parts.append("## Long-term Memory\n" + long_term) - - # Today's notes - today = self.read_today() - if today: - parts.append("## Today's Notes\n" + today) - - return "\n\n".join(parts) if parts else "" + return f"## Long-term Memory\n{long_term}" if long_term else "" diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 6113efb..203836a 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -12,7 +12,7 @@ from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMProvider from nanobot.agent.tools.registry import ToolRegistry -from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, ListDirTool +from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.web import WebSearchTool, WebFetchTool @@ -32,6 +32,8 @@ class SubagentManager: workspace: Path, bus: MessageBus, model: str | None = None, + temperature: float = 0.7, + max_tokens: int = 4096, brave_api_key: str | None = None, exec_config: "ExecToolConfig | None" = None, restrict_to_workspace: bool = False, @@ -41,6 +43,8 @@ class SubagentManager: self.workspace = workspace self.bus = bus self.model = model or provider.get_default_model() + self.temperature = temperature + self.max_tokens = max_tokens self.brave_api_key = brave_api_key self.exec_config = exec_config or ExecToolConfig() self.restrict_to_workspace = restrict_to_workspace @@ -101,6 +105,7 @@ class SubagentManager: allowed_dir = self.workspace if self.restrict_to_workspace else None tools.register(ReadFileTool(allowed_dir=allowed_dir)) tools.register(WriteFileTool(allowed_dir=allowed_dir)) + tools.register(EditFileTool(allowed_dir=allowed_dir)) tools.register(ListDirTool(allowed_dir=allowed_dir)) tools.register(ExecTool( working_dir=str(self.workspace), @@ -129,6 +134,8 @@ class SubagentManager: messages=messages, tools=tools.get_definitions(), model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, ) if response.has_tool_calls: @@ -210,12 +217,17 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men def _build_subagent_prompt(self, task: str) -> str: """Build a focused system prompt for the subagent.""" + from datetime import datetime + import time as _time + now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") + tz = _time.strftime("%Z") or "UTC" + return f"""# Subagent -You are a subagent spawned by the main agent to complete a specific task. +## Current Time +{now} ({tz}) -## Your Task -{task} +You are a subagent spawned by the main agent to complete a specific task. ## Rules 1. Stay focused - complete only the assigned task, nothing else @@ -236,6 +248,7 @@ You are a subagent spawned by the main agent to complete a specific task. ## Workspace Your workspace is at: {self.workspace} +Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed) When you have completed the task, provide a clear summary of your findings or actions.""" diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index ec0d2cd..9f1ecdb 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -50,6 +50,10 @@ class CronTool(Tool): "type": "string", "description": "Cron expression like '0 9 * * *' (for scheduled tasks)" }, + "at": { + "type": "string", + "description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')" + }, "job_id": { "type": "string", "description": "Job ID (for remove)" @@ -64,30 +68,38 @@ class CronTool(Tool): message: str = "", every_seconds: int | None = None, cron_expr: str | None = None, + at: str | None = None, job_id: str | None = None, **kwargs: Any ) -> str: if action == "add": - return self._add_job(message, every_seconds, cron_expr) + return self._add_job(message, every_seconds, cron_expr, at) elif action == "list": return self._list_jobs() elif action == "remove": return self._remove_job(job_id) return f"Unknown action: {action}" - def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None) -> str: + def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None, at: str | None) -> str: if not message: return "Error: message is required for add" if not self._channel or not self._chat_id: return "Error: no session context (channel/chat_id)" # Build schedule + delete_after = False if every_seconds: schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000) elif cron_expr: schedule = CronSchedule(kind="cron", expr=cron_expr) + elif at: + from datetime import datetime + dt = datetime.fromisoformat(at) + at_ms = int(dt.timestamp() * 1000) + schedule = CronSchedule(kind="at", at_ms=at_ms) + delete_after = True else: - return "Error: either every_seconds or cron_expr is required" + return "Error: either every_seconds, cron_expr, or at is required" job = self._cron.add_job( name=message[:30], @@ -96,6 +108,7 @@ class CronTool(Tool): deliver=True, channel=self._channel, to=self._chat_id, + delete_after_run=delete_after, ) return f"Created job '{job.name}' (id: {job.id})" diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 72d3afd..4a8cdd9 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -137,8 +137,15 @@ class DingTalkChannel(BaseChannel): logger.info("DingTalk bot started with Stream Mode") - # client.start() is an async infinite loop handling the websocket connection - await self._client.start() + # Reconnect loop: restart stream if SDK exits or crashes + while self._running: + try: + await self._client.start() + except Exception as e: + logger.warning(f"DingTalk stream error: {e}") + if self._running: + logger.info("Reconnecting DingTalk stream in 5 seconds...") + await asyncio.sleep(5) except Exception as e: logger.exception(f"Failed to start DingTalk channel: {e}") diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 1c176a2..bc4a2b8 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -39,6 +39,53 @@ MSG_TYPE_MAP = { } +def _extract_post_text(content_json: dict) -> str: + """Extract plain text from Feishu post (rich text) message content. + + Supports two formats: + 1. Direct format: {"title": "...", "content": [...]} + 2. Localized format: {"zh_cn": {"title": "...", "content": [...]}} + """ + def extract_from_lang(lang_content: dict) -> str | None: + if not isinstance(lang_content, dict): + return None + title = lang_content.get("title", "") + content_blocks = lang_content.get("content", []) + if not isinstance(content_blocks, list): + return None + text_parts = [] + if title: + text_parts.append(title) + for block in content_blocks: + if not isinstance(block, list): + continue + for element in block: + if isinstance(element, dict): + tag = element.get("tag") + if tag == "text": + text_parts.append(element.get("text", "")) + elif tag == "a": + text_parts.append(element.get("text", "")) + elif tag == "at": + text_parts.append(f"@{element.get('user_name', 'user')}") + return " ".join(text_parts).strip() if text_parts else None + + # Try direct format first + if "content" in content_json: + result = extract_from_lang(content_json) + if result: + return result + + # Try localized format + for lang_key in ("zh_cn", "en_us", "ja_jp"): + lang_content = content_json.get(lang_key) + result = extract_from_lang(lang_content) + if result: + return result + + return "" + + class FeishuChannel(BaseChannel): """ Feishu/Lark channel using WebSocket long connection. @@ -98,12 +145,15 @@ class FeishuChannel(BaseChannel): log_level=lark.LogLevel.INFO ) - # Start WebSocket client in a separate thread + # Start WebSocket client in a separate thread with reconnect loop def run_ws(): - try: - self._ws_client.start() - except Exception as e: - logger.error(f"Feishu WebSocket error: {e}") + while self._running: + try: + self._ws_client.start() + except Exception as e: + logger.warning(f"Feishu WebSocket error: {e}") + if self._running: + import time; time.sleep(5) self._ws_thread = threading.Thread(target=run_ws, daemon=True) self._ws_thread.start() @@ -163,6 +213,10 @@ class FeishuChannel(BaseChannel): re.MULTILINE, ) + _HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) + + _CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE) + @staticmethod def _parse_md_table(table_text: str) -> dict | None: """Parse a markdown table into a Feishu table element.""" @@ -182,17 +236,52 @@ class FeishuChannel(BaseChannel): } def _build_card_elements(self, content: str) -> list[dict]: - """Split content into markdown + table elements for Feishu card.""" + """Split content into div/markdown + table elements for Feishu card.""" elements, last_end = [], 0 for m in self._TABLE_RE.finditer(content): - before = content[last_end:m.start()].strip() - if before: - elements.append({"tag": "markdown", "content": before}) + before = content[last_end:m.start()] + if before.strip(): + elements.extend(self._split_headings(before)) elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)}) last_end = m.end() - remaining = content[last_end:].strip() + remaining = content[last_end:] + if remaining.strip(): + elements.extend(self._split_headings(remaining)) + return elements or [{"tag": "markdown", "content": content}] + + def _split_headings(self, content: str) -> list[dict]: + """Split content by headings, converting headings to div elements.""" + protected = content + code_blocks = [] + for m in self._CODE_BLOCK_RE.finditer(content): + code_blocks.append(m.group(1)) + protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1) + + elements = [] + last_end = 0 + for m in self._HEADING_RE.finditer(protected): + before = protected[last_end:m.start()].strip() + if before: + elements.append({"tag": "markdown", "content": before}) + level = len(m.group(1)) + text = m.group(2).strip() + elements.append({ + "tag": "div", + "text": { + "tag": "lark_md", + "content": f"**{text}**", + }, + }) + last_end = m.end() + remaining = protected[last_end:].strip() if remaining: elements.append({"tag": "markdown", "content": remaining}) + + for i, cb in enumerate(code_blocks): + for el in elements: + if el.get("tag") == "markdown": + el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb) + return elements or [{"tag": "markdown", "content": content}] async def send(self, msg: OutboundMessage) -> None: @@ -284,6 +373,12 @@ class FeishuChannel(BaseChannel): content = json.loads(message.content).get("text", "") except json.JSONDecodeError: content = message.content or "" + elif msg_type == "post": + try: + content_json = json.loads(message.content) + content = _extract_post_text(content_json) + except (json.JSONDecodeError, TypeError): + content = message.content or "" else: content = MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]") diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 464fa97..e860d26 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Any, TYPE_CHECKING +from typing import Any from loguru import logger @@ -12,9 +12,6 @@ from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Config -if TYPE_CHECKING: - from nanobot.session.manager import SessionManager - class ChannelManager: """ @@ -26,10 +23,9 @@ class ChannelManager: - Route outbound messages """ - def __init__(self, config: Config, bus: MessageBus, session_manager: "SessionManager | None" = None): + def __init__(self, config: Config, bus: MessageBus): self.config = config self.bus = bus - self.session_manager = session_manager self.channels: dict[str, BaseChannel] = {} self._dispatch_task: asyncio.Task | None = None @@ -46,7 +42,6 @@ class ChannelManager: self.config.channels.telegram, self.bus, groq_api_key=self.config.providers.groq.api_key, - session_manager=self.session_manager, ) logger.info("Telegram channel enabled") except ImportError as e: diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 5964d30..0e8fe66 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -75,12 +75,15 @@ class QQChannel(BaseChannel): logger.info("QQ bot started (C2C private message)") async def _run_bot(self) -> None: - """Run the bot connection.""" - try: - await self._client.start(appid=self.config.app_id, secret=self.config.secret) - except Exception as e: - logger.error(f"QQ auth failed, check AppID/Secret at q.qq.com: {e}") - self._running = False + """Run the bot connection with auto-reconnect.""" + while self._running: + try: + await self._client.start(appid=self.config.app_id, secret=self.config.secret) + except Exception as e: + logger.warning(f"QQ bot error: {e}") + if self._running: + logger.info("Reconnecting QQ bot in 5 seconds...") + await asyncio.sleep(5) async def stop(self) -> None: """Stop the QQ bot.""" diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index ff46c86..32f8c67 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -4,20 +4,16 @@ from __future__ import annotations import asyncio import re -from typing import TYPE_CHECKING - from loguru import logger from telegram import BotCommand, Update from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes +from telegram.request import HTTPXRequest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import TelegramConfig -if TYPE_CHECKING: - from nanobot.session.manager import SessionManager - def _markdown_to_telegram_html(text: str) -> str: """ @@ -94,7 +90,7 @@ class TelegramChannel(BaseChannel): # Commands registered with Telegram's command menu BOT_COMMANDS = [ BotCommand("start", "Start the bot"), - BotCommand("reset", "Reset conversation history"), + BotCommand("new", "Start a new conversation"), BotCommand("help", "Show available commands"), ] @@ -103,12 +99,10 @@ class TelegramChannel(BaseChannel): config: TelegramConfig, bus: MessageBus, groq_api_key: str = "", - session_manager: SessionManager | None = None, ): super().__init__(config, bus) self.config: TelegramConfig = config self.groq_api_key = groq_api_key - self.session_manager = session_manager self._app: Application | None = None self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task @@ -121,16 +115,18 @@ class TelegramChannel(BaseChannel): self._running = True - # Build the application - builder = Application.builder().token(self.config.token) + # Build the application with larger connection pool to avoid pool-timeout on long runs + req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) + builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) if self.config.proxy: builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy) self._app = builder.build() + self._app.add_error_handler(self._on_error) # Add command handlers self._app.add_handler(CommandHandler("start", self._on_start)) - self._app.add_handler(CommandHandler("reset", self._on_reset)) - self._app.add_handler(CommandHandler("help", self._on_help)) + self._app.add_handler(CommandHandler("new", self._forward_command)) + self._app.add_handler(CommandHandler("help", self._forward_command)) # Add message handler for text, photos, voice, documents self._app.add_handler( @@ -226,40 +222,15 @@ class TelegramChannel(BaseChannel): "Type /help to see available commands." ) - async def _on_reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handle /reset command β€” clear conversation history.""" + async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Forward slash commands to the bus for unified handling in AgentLoop.""" if not update.message or not update.effective_user: return - - chat_id = str(update.message.chat_id) - session_key = f"{self.name}:{chat_id}" - - if self.session_manager is None: - logger.warning("/reset called but session_manager is not available") - await update.message.reply_text("⚠️ Session management is not available.") - return - - session = self.session_manager.get_or_create(session_key) - msg_count = len(session.messages) - session.clear() - self.session_manager.save(session) - - logger.info(f"Session reset for {session_key} (cleared {msg_count} messages)") - await update.message.reply_text("πŸ”„ Conversation history cleared. Let's start fresh!") - - async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handle /help command β€” show available commands.""" - if not update.message: - return - - help_text = ( - "🐈 nanobot commands\n\n" - "/start β€” Start the bot\n" - "/reset β€” Reset conversation history\n" - "/help β€” Show this help message\n\n" - "Just send me a text message to chat!" + await self._handle_message( + sender_id=str(update.effective_user.id), + chat_id=str(update.message.chat_id), + content=update.message.text, ) - await update.message.reply_text(help_text, parse_mode="HTML") async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming messages (text, photos, voice, documents).""" @@ -386,6 +357,10 @@ class TelegramChannel(BaseChannel): except Exception as e: logger.debug(f"Typing indicator stopped for {chat_id}: {e}") + async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: + """Log polling / handler errors instead of silently swallowing them.""" + logger.error(f"Telegram error: {context.error}") + def _get_extension(self, media_type: str, mime_type: str | None) -> str: """Get file extension based on media type.""" if mime_type: diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 6e00e9d..0cf2dd7 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -42,6 +42,9 @@ class WhatsAppChannel(BaseChannel): try: async with websockets.connect(bridge_url) as ws: self._ws = ws + # Send auth token if configured + if self.config.bridge_token: + await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token})) self._connected = True logger.info("Connected to WhatsApp bridge") diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cab4d41..34bfde8 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -155,7 +155,7 @@ def main( @app.command() def onboard(): """Initialize nanobot configuration and workspace.""" - from nanobot.config.loader import get_config_path, save_config + from nanobot.config.loader import get_config_path, load_config, save_config from nanobot.config.schema import Config from nanobot.utils.helpers import get_workspace_path @@ -163,17 +163,26 @@ def onboard(): if config_path.exists(): console.print(f"[yellow]Config already exists at {config_path}[/yellow]") - if not typer.confirm("Overwrite?"): - raise typer.Exit() - - # Create default config - config = Config() - save_config(config) - console.print(f"[green]βœ“[/green] Created config at {config_path}") + console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") + console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields") + if typer.confirm("Overwrite?"): + config = Config() + save_config(config) + console.print(f"[green]βœ“[/green] Config reset to defaults at {config_path}") + else: + config = load_config() + save_config(config) + console.print(f"[green]βœ“[/green] Config refreshed at {config_path} (existing values preserved)") + else: + save_config(Config()) + console.print(f"[green]βœ“[/green] Created config at {config_path}") # Create workspace workspace = get_workspace_path() - console.print(f"[green]βœ“[/green] Created workspace at {workspace}") + + if not workspace.exists(): + workspace.mkdir(parents=True, exist_ok=True) + console.print(f"[green]βœ“[/green] Created workspace at {workspace}") # Create default bootstrap files _create_workspace_templates(workspace) @@ -200,7 +209,7 @@ You are a helpful AI assistant. Be concise, accurate, and friendly. - Always explain what you're doing before taking actions - Ask for clarification when the request is ambiguous - Use tools to help accomplish tasks -- Remember important information in your memory files +- Remember important information in memory/MEMORY.md; past events are logged in memory/HISTORY.md """, "SOUL.md": """# Soul @@ -258,6 +267,11 @@ This file stores important information that should persist across sessions. (Things to remember) """) console.print(" [dim]Created memory/MEMORY.md[/dim]") + + history_file = memory_dir / "HISTORY.md" + if not history_file.exists(): + history_file.write_text("") + console.print(" [dim]Created memory/HISTORY.md[/dim]") # Create skills directory for custom user skills skills_dir = workspace / "skills" @@ -323,7 +337,10 @@ def gateway( provider=provider, workspace=config.workspace_path, model=config.agents.defaults.model, + temperature=config.agents.defaults.temperature, + max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, + memory_window=config.agents.defaults.memory_window, brave_api_key=config.tools.web.search.api_key or None, exec_config=config.tools.exec, cron_service=cron, @@ -364,7 +381,7 @@ def gateway( ) # Create channel manager - channels = ChannelManager(config, bus, session_manager=session_manager) + channels = ChannelManager(config, bus) if channels.enabled_channels: console.print(f"[green]βœ“[/green] Channels enabled: {', '.join(channels.enabled_channels)}") @@ -407,7 +424,7 @@ def gateway( @app.command() def agent( message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"), - session_id: str = typer.Option("cli:default", "--session", "-s", help="Session ID"), + session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"), markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"), logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), ): @@ -431,6 +448,11 @@ def agent( bus=bus, provider=provider, workspace=config.workspace_path, + model=config.agents.defaults.model, + temperature=config.agents.defaults.temperature, + max_tokens=config.agents.defaults.max_tokens, + max_iterations=config.agents.defaults.max_tool_iterations, + memory_window=config.agents.defaults.memory_window, brave_api_key=config.tools.web.search.api_key or None, exec_config=config.tools.exec, restrict_to_workspace=config.tools.restrict_to_workspace, @@ -635,14 +657,20 @@ def _get_bridge_dir() -> Path: def channels_login(): """Link device via QR code.""" import subprocess + from nanobot.config.loader import load_config + config = load_config() bridge_dir = _get_bridge_dir() console.print(f"{__logo__} Starting bridge...") console.print("Scan the QR code to connect.\n") + env = {**os.environ} + if config.channels.whatsapp.bridge_token: + env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token + try: - subprocess.run(["npm", "start"], cwd=bridge_dir, check=True) + subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env) except subprocess.CalledProcessError as e: console.print(f"[red]Bridge failed: {e}[/red]") except FileNotFoundError: diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 2a206e1..0934aac 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -9,6 +9,7 @@ class WhatsAppConfig(BaseModel): """WhatsApp channel configuration.""" enabled: bool = False bridge_url: str = "ws://localhost:3001" + bridge_token: str = "" # Shared token for bridge auth (optional, recommended) allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers @@ -161,6 +162,7 @@ class AgentDefaults(BaseModel): max_tokens: int = 8192 temperature: float = 0.7 max_tool_iterations: int = 20 + memory_window: int = 50 class AgentsConfig(BaseModel): @@ -177,6 +179,7 @@ class ProviderConfig(BaseModel): class ProvidersConfig(BaseModel): """Configuration for LLM providers.""" + custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint anthropic: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig) openrouter: ProviderConfig = Field(default_factory=ProviderConfig) diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index d1965a9..4da845a 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -4,6 +4,7 @@ import asyncio import json import time import uuid +from datetime import datetime from pathlib import Path from typing import Any, Callable, Coroutine @@ -30,9 +31,13 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None: if schedule.kind == "cron" and schedule.expr: try: from croniter import croniter - cron = croniter(schedule.expr, time.time()) - next_time = cron.get_next() - return int(next_time * 1000) + from zoneinfo import ZoneInfo + base_time = time.time() + tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo + base_dt = datetime.fromtimestamp(base_time, tz=tz) + cron = croniter(schedule.expr, base_dt) + next_dt = cron.get_next(datetime) + return int(next_dt.timestamp() * 1000) except Exception: return None diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 7865139..a39893b 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -122,6 +122,10 @@ class LiteLLMProvider(LLMProvider): """ model = self._resolve_model(model or self.default_model) + # Clamp max_tokens to at least 1 β€” negative or zero values cause + # LiteLLM to reject the request with "max_tokens must be at least 1". + max_tokens = max(1, max_tokens) + kwargs: dict[str, Any] = { "model": model, "messages": messages, diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index fdd036e..b9071a0 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -62,6 +62,20 @@ class ProviderSpec: PROVIDERS: tuple[ProviderSpec, ...] = ( + # === Custom (user-provided OpenAI-compatible endpoint) ================= + # No auto-detection β€” only activates when user explicitly configures "custom". + + ProviderSpec( + name="custom", + keywords=(), + env_key="OPENAI_API_KEY", + display_name="Custom", + litellm_prefix="openai", + skip_prefixes=("openai/",), + is_gateway=True, + strip_model_prefix=True, + ), + # === Gateways (detected by api_key / api_base, not model name) ========= # Gateways can route any model, so they win in fallback. diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index cd25019..bce12a1 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -15,15 +15,20 @@ from nanobot.utils.helpers import ensure_dir, safe_filename class Session: """ A conversation session. - + Stores messages in JSONL format for easy reading and persistence. + + Important: Messages are append-only for LLM cache efficiency. + The consolidation process writes summaries to MEMORY.md/HISTORY.md + but does NOT modify the messages list or get_history() output. """ - + key: str # channel:chat_id messages: list[dict[str, Any]] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) + last_consolidated: int = 0 # Number of messages already consolidated to files def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" @@ -36,35 +41,24 @@ class Session: self.messages.append(msg) self.updated_at = datetime.now() - def get_history(self, max_messages: int = 50) -> list[dict[str, Any]]: - """ - Get message history for LLM context. - - Args: - max_messages: Maximum messages to return. - - Returns: - List of messages in LLM format. - """ - # Get recent messages - recent = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages - - # Convert to LLM format (just role and content) - return [{"role": m["role"], "content": m["content"]} for m in recent] + def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: + """Get recent messages in LLM format (role + content only).""" + return [{"role": m["role"], "content": m["content"]} for m in self.messages[-max_messages:]] def clear(self) -> None: - """Clear all messages in the session.""" + """Clear all messages and reset session to initial state.""" self.messages = [] + self.last_consolidated = 0 self.updated_at = datetime.now() class SessionManager: """ Manages conversation sessions. - + Sessions are stored as JSONL files in the sessions directory. """ - + def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions") @@ -85,11 +79,9 @@ class SessionManager: Returns: The session. """ - # Check cache if key in self._cache: return self._cache[key] - # Try to load from disk session = self._load(key) if session is None: session = Session(key=key) @@ -100,34 +92,37 @@ class SessionManager: def _load(self, key: str) -> Session | None: """Load a session from disk.""" path = self._get_session_path(key) - + if not path.exists(): return None - + try: messages = [] metadata = {} created_at = None - + last_consolidated = 0 + with open(path) as f: for line in f: line = line.strip() if not line: continue - + data = json.loads(line) - + if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None + last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) - + return Session( key=key, messages=messages, created_at=created_at or datetime.now(), - metadata=metadata + metadata=metadata, + last_consolidated=last_consolidated ) except Exception as e: logger.warning(f"Failed to load session {key}: {e}") @@ -136,42 +131,24 @@ class SessionManager: def save(self, session: Session) -> None: """Save a session to disk.""" path = self._get_session_path(session.key) - + with open(path, "w") as f: - # Write metadata first metadata_line = { "_type": "metadata", "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), - "metadata": session.metadata + "metadata": session.metadata, + "last_consolidated": session.last_consolidated } f.write(json.dumps(metadata_line) + "\n") - - # Write messages for msg in session.messages: f.write(json.dumps(msg) + "\n") - + self._cache[session.key] = session - def delete(self, key: str) -> bool: - """ - Delete a session. - - Args: - key: Session key. - - Returns: - True if deleted, False if not found. - """ - # Remove from cache + def invalidate(self, key: str) -> None: + """Remove a session from the in-memory cache.""" self._cache.pop(key, None) - - # Remove file - path = self._get_session_path(key) - if path.exists(): - path.unlink() - return True - return False def list_sessions(self) -> list[dict[str, Any]]: """ diff --git a/nanobot/skills/cron/SKILL.md b/nanobot/skills/cron/SKILL.md index c8beecb..7db25d8 100644 --- a/nanobot/skills/cron/SKILL.md +++ b/nanobot/skills/cron/SKILL.md @@ -7,10 +7,11 @@ description: Schedule reminders and recurring tasks. Use the `cron` tool to schedule reminders or recurring tasks. -## Two Modes +## Three Modes 1. **Reminder** - message is sent directly to user 2. **Task** - message is a task description, agent executes and sends result +3. **One-time** - runs once at a specific time, then auto-deletes ## Examples @@ -24,6 +25,11 @@ Dynamic task (agent executes each time): cron(action="add", message="Check HKUDS/nanobot GitHub stars and report", every_seconds=600) ``` +One-time scheduled task (compute ISO datetime from current time): +``` +cron(action="add", message="Remind me about the meeting", at="") +``` + List/remove: ``` cron(action="list") @@ -38,3 +44,4 @@ cron(action="remove", job_id="abc123") | every hour | every_seconds: 3600 | | every day at 8am | cron_expr: "0 8 * * *" | | weekdays at 5pm | cron_expr: "0 17 * * 1-5" | +| at a specific time | at: ISO datetime string (compute from current time) | diff --git a/nanobot/skills/memory/SKILL.md b/nanobot/skills/memory/SKILL.md new file mode 100644 index 0000000..39adbde --- /dev/null +++ b/nanobot/skills/memory/SKILL.md @@ -0,0 +1,31 @@ +--- +name: memory +description: Two-layer memory system with grep-based recall. +always: true +--- + +# Memory + +## Structure + +- `memory/MEMORY.md` β€” Long-term facts (preferences, project context, relationships). Always loaded into your context. +- `memory/HISTORY.md` β€” Append-only event log. NOT loaded into context. Search it with grep. + +## Search Past Events + +```bash +grep -i "keyword" memory/HISTORY.md +``` + +Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md` + +## When to Update MEMORY.md + +Write important facts immediately using `edit_file` or `write_file`: +- User preferences ("I prefer dark mode") +- Project context ("The API uses OAuth2") +- Relationships ("Alice is the project lead") + +## Auto-consolidation + +Old conversations are automatically summarized and appended to HISTORY.md when the session grows large. Long-term facts are extracted to MEMORY.md. You don't need to manage this. diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 667b4c4..62f80ac 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -37,23 +37,12 @@ def get_sessions_path() -> Path: return ensure_dir(get_data_path() / "sessions") -def get_memory_path(workspace: Path | None = None) -> Path: - """Get the memory directory within the workspace.""" - ws = workspace or get_workspace_path() - return ensure_dir(ws / "memory") - - def get_skills_path(workspace: Path | None = None) -> Path: """Get the skills directory within the workspace.""" ws = workspace or get_workspace_path() return ensure_dir(ws / "skills") -def today_date() -> str: - """Get today's date in YYYY-MM-DD format.""" - return datetime.now().strftime("%Y-%m-%d") - - def timestamp() -> str: """Get current timestamp in ISO format.""" return datetime.now().isoformat() diff --git a/pyproject.toml b/pyproject.toml index bdccbf0..17c739f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nanobot-ai" -version = "0.1.3.post6" +version = "0.1.3.post7" description = "A lightweight personal AI assistant framework" requires-python = ">=3.11" license = {text = "MIT"} diff --git a/tests/test_cli_input.py b/tests/test_cli_input.py index 6f9c257..9626120 100644 --- a/tests/test_cli_input.py +++ b/tests/test_cli_input.py @@ -12,7 +12,8 @@ def mock_prompt_session(): """Mock the global prompt session.""" mock_session = MagicMock() mock_session.prompt_async = AsyncMock() - with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session): + with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \ + patch("nanobot.cli.commands.patch_stdout"): yield mock_session diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 0000000..f5495fd --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,92 @@ +import shutil +from pathlib import Path +from unittest.mock import patch + +import pytest +from typer.testing import CliRunner + +from nanobot.cli.commands import app + +runner = CliRunner() + + +@pytest.fixture +def mock_paths(): + """Mock config/workspace paths for test isolation.""" + with patch("nanobot.config.loader.get_config_path") as mock_cp, \ + patch("nanobot.config.loader.save_config") as mock_sc, \ + patch("nanobot.config.loader.load_config") as mock_lc, \ + patch("nanobot.utils.helpers.get_workspace_path") as mock_ws: + + base_dir = Path("./test_onboard_data") + if base_dir.exists(): + shutil.rmtree(base_dir) + base_dir.mkdir() + + config_file = base_dir / "config.json" + workspace_dir = base_dir / "workspace" + + mock_cp.return_value = config_file + mock_ws.return_value = workspace_dir + mock_sc.side_effect = lambda config: config_file.write_text("{}") + + yield config_file, workspace_dir + + if base_dir.exists(): + shutil.rmtree(base_dir) + + +def test_onboard_fresh_install(mock_paths): + """No existing config β€” should create from scratch.""" + config_file, workspace_dir = mock_paths + + result = runner.invoke(app, ["onboard"]) + + assert result.exit_code == 0 + assert "Created config" in result.stdout + assert "Created workspace" in result.stdout + assert "nanobot is ready" in result.stdout + assert config_file.exists() + assert (workspace_dir / "AGENTS.md").exists() + assert (workspace_dir / "memory" / "MEMORY.md").exists() + + +def test_onboard_existing_config_refresh(mock_paths): + """Config exists, user declines overwrite β€” should refresh (load-merge-save).""" + config_file, workspace_dir = mock_paths + config_file.write_text('{"existing": true}') + + result = runner.invoke(app, ["onboard"], input="n\n") + + assert result.exit_code == 0 + assert "Config already exists" in result.stdout + assert "existing values preserved" in result.stdout + assert workspace_dir.exists() + assert (workspace_dir / "AGENTS.md").exists() + + +def test_onboard_existing_config_overwrite(mock_paths): + """Config exists, user confirms overwrite β€” should reset to defaults.""" + config_file, workspace_dir = mock_paths + config_file.write_text('{"existing": true}') + + result = runner.invoke(app, ["onboard"], input="y\n") + + assert result.exit_code == 0 + assert "Config already exists" in result.stdout + assert "Config reset to defaults" in result.stdout + assert workspace_dir.exists() + + +def test_onboard_existing_workspace_safe_create(mock_paths): + """Workspace exists β€” should not recreate, but still add missing templates.""" + config_file, workspace_dir = mock_paths + workspace_dir.mkdir(parents=True) + config_file.write_text("{}") + + result = runner.invoke(app, ["onboard"], input="n\n") + + assert result.exit_code == 0 + assert "Created workspace" not in result.stdout + assert "Created AGENTS.md" in result.stdout + assert (workspace_dir / "AGENTS.md").exists() diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py new file mode 100644 index 0000000..e204733 --- /dev/null +++ b/tests/test_consolidate_offset.py @@ -0,0 +1,477 @@ +"""Test session management with cache-friendly message handling.""" + +import pytest +from pathlib import Path +from nanobot.session.manager import Session, SessionManager + +# Test constants +MEMORY_WINDOW = 50 +KEEP_COUNT = MEMORY_WINDOW // 2 # 25 + + +def create_session_with_messages(key: str, count: int, role: str = "user") -> Session: + """Create a session and add the specified number of messages. + + Args: + key: Session identifier + count: Number of messages to add + role: Message role (default: "user") + + Returns: + Session with the specified messages + """ + session = Session(key=key) + for i in range(count): + session.add_message(role, f"msg{i}") + return session + + +def assert_messages_content(messages: list, start_index: int, end_index: int) -> None: + """Assert that messages contain expected content from start to end index. + + Args: + messages: List of message dictionaries + start_index: Expected first message index + end_index: Expected last message index + """ + assert len(messages) > 0 + assert messages[0]["content"] == f"msg{start_index}" + assert messages[-1]["content"] == f"msg{end_index}" + + +def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list: + """Extract messages that would be consolidated using the standard slice logic. + + Args: + session: The session containing messages + last_consolidated: Index of last consolidated message + keep_count: Number of recent messages to keep + + Returns: + List of messages that would be consolidated + """ + return session.messages[last_consolidated:-keep_count] + + +class TestSessionLastConsolidated: + """Test last_consolidated tracking to avoid duplicate processing.""" + + def test_initial_last_consolidated_zero(self) -> None: + """Test that new session starts with last_consolidated=0.""" + session = Session(key="test:initial") + assert session.last_consolidated == 0 + + def test_last_consolidated_persistence(self, tmp_path) -> None: + """Test that last_consolidated persists across save/load.""" + manager = SessionManager(Path(tmp_path)) + session1 = create_session_with_messages("test:persist", 20) + session1.last_consolidated = 15 + manager.save(session1) + + session2 = manager.get_or_create("test:persist") + assert session2.last_consolidated == 15 + assert len(session2.messages) == 20 + + def test_clear_resets_last_consolidated(self) -> None: + """Test that clear() resets last_consolidated to 0.""" + session = create_session_with_messages("test:clear", 10) + session.last_consolidated = 5 + + session.clear() + assert len(session.messages) == 0 + assert session.last_consolidated == 0 + + +class TestSessionImmutableHistory: + """Test Session message immutability for cache efficiency.""" + + def test_initial_state(self) -> None: + """Test that new session has empty messages list.""" + session = Session(key="test:initial") + assert len(session.messages) == 0 + + def test_add_messages_appends_only(self) -> None: + """Test that adding messages only appends, never modifies.""" + session = Session(key="test:preserve") + session.add_message("user", "msg1") + session.add_message("assistant", "resp1") + session.add_message("user", "msg2") + assert len(session.messages) == 3 + assert session.messages[0]["content"] == "msg1" + + def test_get_history_returns_most_recent(self) -> None: + """Test get_history returns the most recent messages.""" + session = Session(key="test:history") + for i in range(10): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + + history = session.get_history(max_messages=6) + assert len(history) == 6 + assert history[0]["content"] == "msg7" + assert history[-1]["content"] == "resp9" + + def test_get_history_with_all_messages(self) -> None: + """Test get_history with max_messages larger than actual.""" + session = create_session_with_messages("test:all", 5) + history = session.get_history(max_messages=100) + assert len(history) == 5 + assert history[0]["content"] == "msg0" + + def test_get_history_stable_for_same_session(self) -> None: + """Test that get_history returns same content for same max_messages.""" + session = create_session_with_messages("test:stable", 20) + history1 = session.get_history(max_messages=10) + history2 = session.get_history(max_messages=10) + assert history1 == history2 + + def test_messages_list_never_modified(self) -> None: + """Test that messages list is never modified after creation.""" + session = create_session_with_messages("test:immutable", 5) + original_len = len(session.messages) + + session.get_history(max_messages=2) + assert len(session.messages) == original_len + + for _ in range(10): + session.get_history(max_messages=3) + assert len(session.messages) == original_len + + +class TestSessionPersistence: + """Test Session persistence and reload.""" + + @pytest.fixture + def temp_manager(self, tmp_path): + return SessionManager(Path(tmp_path)) + + def test_persistence_roundtrip(self, temp_manager): + """Test that messages persist across save/load.""" + session1 = create_session_with_messages("test:persistence", 20) + temp_manager.save(session1) + + session2 = temp_manager.get_or_create("test:persistence") + assert len(session2.messages) == 20 + assert session2.messages[0]["content"] == "msg0" + assert session2.messages[-1]["content"] == "msg19" + + def test_get_history_after_reload(self, temp_manager): + """Test that get_history works correctly after reload.""" + session1 = create_session_with_messages("test:reload", 30) + temp_manager.save(session1) + + session2 = temp_manager.get_or_create("test:reload") + history = session2.get_history(max_messages=10) + assert len(history) == 10 + assert history[0]["content"] == "msg20" + assert history[-1]["content"] == "msg29" + + def test_clear_resets_session(self, temp_manager): + """Test that clear() properly resets session.""" + session = create_session_with_messages("test:clear", 10) + assert len(session.messages) == 10 + + session.clear() + assert len(session.messages) == 0 + + +class TestConsolidationTriggerConditions: + """Test consolidation trigger conditions and logic.""" + + def test_consolidation_needed_when_messages_exceed_window(self): + """Test consolidation logic: should trigger when messages > memory_window.""" + session = create_session_with_messages("test:trigger", 60) + + total_messages = len(session.messages) + messages_to_process = total_messages - session.last_consolidated + + assert total_messages > MEMORY_WINDOW + assert messages_to_process > 0 + + expected_consolidate_count = total_messages - KEEP_COUNT + assert expected_consolidate_count == 35 + + def test_consolidation_skipped_when_within_keep_count(self): + """Test consolidation skipped when total messages <= keep_count.""" + session = create_session_with_messages("test:skip", 20) + + total_messages = len(session.messages) + assert total_messages <= KEEP_COUNT + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_consolidation_skipped_when_no_new_messages(self): + """Test consolidation skipped when messages_to_process <= 0.""" + session = create_session_with_messages("test:already_consolidated", 40) + session.last_consolidated = len(session.messages) - KEEP_COUNT # 15 + + # Add a few more messages + for i in range(40, 42): + session.add_message("user", f"msg{i}") + + total_messages = len(session.messages) + messages_to_process = total_messages - session.last_consolidated + assert messages_to_process > 0 + + # Simulate last_consolidated catching up + session.last_consolidated = total_messages - KEEP_COUNT + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + +class TestLastConsolidatedEdgeCases: + """Test last_consolidated edge cases and data corruption scenarios.""" + + def test_last_consolidated_exceeds_message_count(self): + """Test behavior when last_consolidated > len(messages) (data corruption).""" + session = create_session_with_messages("test:corruption", 10) + session.last_consolidated = 20 + + total_messages = len(session.messages) + messages_to_process = total_messages - session.last_consolidated + assert messages_to_process <= 0 + + old_messages = get_old_messages(session, session.last_consolidated, 5) + assert len(old_messages) == 0 + + def test_last_consolidated_negative_value(self): + """Test behavior with negative last_consolidated (invalid state).""" + session = create_session_with_messages("test:negative", 10) + session.last_consolidated = -5 + + keep_count = 3 + old_messages = get_old_messages(session, session.last_consolidated, keep_count) + + # messages[-5:-3] with 10 messages gives indices 5,6 + assert len(old_messages) == 2 + assert old_messages[0]["content"] == "msg5" + assert old_messages[-1]["content"] == "msg6" + + def test_messages_added_after_consolidation(self): + """Test correct behavior when new messages arrive after consolidation.""" + session = create_session_with_messages("test:new_messages", 40) + session.last_consolidated = len(session.messages) - KEEP_COUNT # 15 + + # Add new messages after consolidation + for i in range(40, 50): + session.add_message("user", f"msg{i}") + + total_messages = len(session.messages) + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated + + assert len(old_messages) == expected_consolidate_count + assert_messages_content(old_messages, 15, 24) + + def test_slice_behavior_when_indices_overlap(self): + """Test slice behavior when last_consolidated >= total - keep_count.""" + session = create_session_with_messages("test:overlap", 30) + session.last_consolidated = 12 + + old_messages = get_old_messages(session, session.last_consolidated, 20) + assert len(old_messages) == 0 + + +class TestArchiveAllMode: + """Test archive_all mode (used by /new command).""" + + def test_archive_all_consolidates_everything(self): + """Test archive_all=True consolidates all messages.""" + session = create_session_with_messages("test:archive_all", 50) + + archive_all = True + if archive_all: + old_messages = session.messages + assert len(old_messages) == 50 + + assert session.last_consolidated == 0 + + def test_archive_all_resets_last_consolidated(self): + """Test that archive_all mode resets last_consolidated to 0.""" + session = create_session_with_messages("test:reset", 40) + session.last_consolidated = 15 + + archive_all = True + if archive_all: + session.last_consolidated = 0 + + assert session.last_consolidated == 0 + assert len(session.messages) == 40 + + def test_archive_all_vs_normal_consolidation(self): + """Test difference between archive_all and normal consolidation.""" + # Normal consolidation + session1 = create_session_with_messages("test:normal", 60) + session1.last_consolidated = len(session1.messages) - KEEP_COUNT + + # archive_all mode + session2 = create_session_with_messages("test:all", 60) + session2.last_consolidated = 0 + + assert session1.last_consolidated == 35 + assert len(session1.messages) == 60 + assert session2.last_consolidated == 0 + assert len(session2.messages) == 60 + + +class TestCacheImmutability: + """Test that consolidation doesn't modify session.messages (cache safety).""" + + def test_consolidation_does_not_modify_messages_list(self): + """Test that consolidation leaves messages list unchanged.""" + session = create_session_with_messages("test:immutable", 50) + + original_messages = session.messages.copy() + original_len = len(session.messages) + session.last_consolidated = original_len - KEEP_COUNT + + assert len(session.messages) == original_len + assert session.messages == original_messages + + def test_get_history_does_not_modify_messages(self): + """Test that get_history doesn't modify messages list.""" + session = create_session_with_messages("test:history_immutable", 40) + original_messages = [m.copy() for m in session.messages] + + for _ in range(5): + history = session.get_history(max_messages=10) + assert len(history) == 10 + + assert len(session.messages) == 40 + for i, msg in enumerate(session.messages): + assert msg["content"] == original_messages[i]["content"] + + def test_consolidation_only_updates_last_consolidated(self): + """Test that consolidation only updates last_consolidated field.""" + session = create_session_with_messages("test:field_only", 60) + + original_messages = session.messages.copy() + original_key = session.key + original_metadata = session.metadata.copy() + + session.last_consolidated = len(session.messages) - KEEP_COUNT + + assert session.messages == original_messages + assert session.key == original_key + assert session.metadata == original_metadata + assert session.last_consolidated == 35 + + +class TestSliceLogic: + """Test the slice logic: messages[last_consolidated:-keep_count].""" + + def test_slice_extracts_correct_range(self): + """Test that slice extracts the correct message range.""" + session = create_session_with_messages("test:slice", 60) + + old_messages = get_old_messages(session, 0, KEEP_COUNT) + + assert len(old_messages) == 35 + assert_messages_content(old_messages, 0, 34) + + remaining = session.messages[-KEEP_COUNT:] + assert len(remaining) == 25 + assert_messages_content(remaining, 35, 59) + + def test_slice_with_partial_consolidation(self): + """Test slice when some messages already consolidated.""" + session = create_session_with_messages("test:partial", 70) + + last_consolidated = 30 + old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT) + + assert len(old_messages) == 15 + assert_messages_content(old_messages, 30, 44) + + def test_slice_with_various_keep_counts(self): + """Test slice behavior with different keep_count values.""" + session = create_session_with_messages("test:keep_counts", 50) + + test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)] + + for keep_count, expected_count in test_cases: + old_messages = session.messages[0:-keep_count] + assert len(old_messages) == expected_count + + def test_slice_when_keep_count_exceeds_messages(self): + """Test slice when keep_count > len(messages).""" + session = create_session_with_messages("test:exceed", 10) + + old_messages = session.messages[0:-20] + assert len(old_messages) == 0 + + +class TestEmptyAndBoundarySessions: + """Test empty sessions and boundary conditions.""" + + def test_empty_session_consolidation(self): + """Test consolidation behavior with empty session.""" + session = Session(key="test:empty") + + assert len(session.messages) == 0 + assert session.last_consolidated == 0 + + messages_to_process = len(session.messages) - session.last_consolidated + assert messages_to_process == 0 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_single_message_session(self): + """Test consolidation with single message.""" + session = Session(key="test:single") + session.add_message("user", "only message") + + assert len(session.messages) == 1 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_exactly_keep_count_messages(self): + """Test session with exactly keep_count messages.""" + session = create_session_with_messages("test:exact", KEEP_COUNT) + + assert len(session.messages) == KEEP_COUNT + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_just_over_keep_count(self): + """Test session with one message over keep_count.""" + session = create_session_with_messages("test:over", KEEP_COUNT + 1) + + assert len(session.messages) == 26 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 1 + assert old_messages[0]["content"] == "msg0" + + def test_very_large_session(self): + """Test consolidation with very large message count.""" + session = create_session_with_messages("test:large", 1000) + + assert len(session.messages) == 1000 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 975 + assert_messages_content(old_messages, 0, 974) + + remaining = session.messages[-KEEP_COUNT:] + assert len(remaining) == 25 + assert_messages_content(remaining, 975, 999) + + def test_session_with_gaps_in_consolidation(self): + """Test session with potential gaps in consolidation history.""" + session = create_session_with_messages("test:gaps", 50) + session.last_consolidated = 10 + + # Add more messages + for i in range(50, 60): + session.add_message("user", f"msg{i}") + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + + expected_count = 60 - KEEP_COUNT - 10 + assert len(old_messages) == expected_count + assert_messages_content(old_messages, 10, 34) diff --git a/workspace/AGENTS.md b/workspace/AGENTS.md index b4e5b5f..69bd823 100644 --- a/workspace/AGENTS.md +++ b/workspace/AGENTS.md @@ -20,8 +20,8 @@ You have access to: ## Memory -- Use `memory/` directory for daily notes -- Use `MEMORY.md` for long-term information +- `memory/MEMORY.md` β€” long-term facts (preferences, context, relationships) +- `memory/HISTORY.md` β€” append-only event log, search with grep to recall past events ## Scheduled Reminders