Merge branch 'main' into pr-151
This commit is contained in:
commit
9e5f7348fe
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,8 +14,9 @@ docs/
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
tests/
|
||||
botpy.log
|
||||
tests/
|
||||
|
||||
267
README.md
267
README.md
@ -12,14 +12,19 @@
|
||||
</p>
|
||||
</div>
|
||||
|
||||
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [Clawdbot](https://github.com/openclaw/openclaw)
|
||||
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw)
|
||||
|
||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||
|
||||
📏 Real-time line count: **3,510 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||
📏 Real-time line count: **3,663 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||
|
||||
## 📢 News
|
||||
|
||||
- **2026-02-14** 🔌 nanobot now supports MCP! See [MCP section](#mcp-model-context-protocol) for details.
|
||||
- **2026-02-13** 🎉 Released v0.1.3.post7 — includes security hardening and multiple improvements. All users are recommended to upgrade to the latest version. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
|
||||
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
|
||||
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
|
||||
- **2026-02-10** 🎉 Released v0.1.3.post6 with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
||||
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
||||
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
||||
- **2026-02-07** 🚀 Released v0.1.3.post5 with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details.
|
||||
@ -94,7 +99,7 @@ pip install nanobot-ai
|
||||
|
||||
> [!TIP]
|
||||
> Set your API key in `~/.nanobot/config.json`.
|
||||
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (Global) · [DashScope](https://dashscope.console.aliyun.com) (Qwen) · [Brave Search](https://brave.com/search/api/) (optional, for web search)
|
||||
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (Global) · [Brave Search](https://brave.com/search/api/) (optional, for web search)
|
||||
|
||||
**1. Initialize**
|
||||
|
||||
@ -104,14 +109,22 @@ nanobot onboard
|
||||
|
||||
**2. Configure** (`~/.nanobot/config.json`)
|
||||
|
||||
For OpenRouter - recommended for global users:
|
||||
Add or merge these **two parts** into your config (other options have defaults).
|
||||
|
||||
*Set your API key* (e.g. OpenRouter, recommended for global users):
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"apiKey": "sk-or-v1-xxx"
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
*Set your model*:
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "anthropic/claude-opus-4-5"
|
||||
@ -123,51 +136,14 @@ For OpenRouter - recommended for global users:
|
||||
**3. Chat**
|
||||
|
||||
```bash
|
||||
nanobot agent -m "What is 2+2?"
|
||||
nanobot agent
|
||||
```
|
||||
|
||||
That's it! You have a working AI assistant in 2 minutes.
|
||||
|
||||
## 🖥️ Local Models (vLLM)
|
||||
|
||||
Run nanobot with your own local models using vLLM or any OpenAI-compatible server.
|
||||
|
||||
**1. Start your vLLM server**
|
||||
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
|
||||
```
|
||||
|
||||
**2. Configure** (`~/.nanobot/config.json`)
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"vllm": {
|
||||
"apiKey": "dummy",
|
||||
"apiBase": "http://localhost:8000/v1"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Chat**
|
||||
|
||||
```bash
|
||||
nanobot agent -m "Hello from my local LLM!"
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The `apiKey` can be any non-empty string for local servers that don't require authentication.
|
||||
|
||||
## 💬 Chat Apps
|
||||
|
||||
Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, DingTalk, Slack, Email, or QQ — anytime, anywhere.
|
||||
Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, Mochat, DingTalk, Slack, Email, or QQ — anytime, anywhere.
|
||||
|
||||
| Channel | Setup |
|
||||
|---------|-------|
|
||||
@ -175,6 +151,7 @@ Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, DingTalk, Slac
|
||||
| **Discord** | Easy (bot token + intents) |
|
||||
| **WhatsApp** | Medium (scan QR) |
|
||||
| **Feishu** | Medium (app credentials) |
|
||||
| **Mochat** | Medium (claw token + websocket) |
|
||||
| **DingTalk** | Medium (app credentials) |
|
||||
| **Slack** | Medium (bot + app tokens) |
|
||||
| **Email** | Medium (IMAP/SMTP credentials) |
|
||||
@ -214,6 +191,63 @@ nanobot gateway
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Mochat (Claw IM)</b></summary>
|
||||
|
||||
Uses **Socket.IO WebSocket** by default, with HTTP polling fallback.
|
||||
|
||||
**1. Ask nanobot to set up Mochat for you**
|
||||
|
||||
Simply send this message to nanobot (replace `xxx@xxx` with your real email):
|
||||
|
||||
```
|
||||
Read https://raw.githubusercontent.com/HKUDS/MoChat/refs/heads/main/skills/nanobot/skill.md and register on MoChat. My Email account is xxx@xxx Bind me as your owner and DM me on MoChat.
|
||||
```
|
||||
|
||||
nanobot will automatically register, configure `~/.nanobot/config.json`, and connect to Mochat.
|
||||
|
||||
**2. Restart gateway**
|
||||
|
||||
```bash
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
That's it — nanobot handles the rest!
|
||||
|
||||
<br>
|
||||
|
||||
<details>
|
||||
<summary>Manual configuration (advanced)</summary>
|
||||
|
||||
If you prefer to configure manually, add the following to `~/.nanobot/config.json`:
|
||||
|
||||
> Keep `claw_token` private. It should only be sent in `X-Claw-Token` header to your Mochat API endpoint.
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"mochat": {
|
||||
"enabled": true,
|
||||
"base_url": "https://mochat.io",
|
||||
"socket_url": "https://mochat.io",
|
||||
"socket_path": "/socket.io",
|
||||
"claw_token": "claw_xxx",
|
||||
"agent_user_id": "6982abcdef",
|
||||
"sessions": ["*"],
|
||||
"panels": ["*"],
|
||||
"reply_delay_mode": "non-mention",
|
||||
"reply_delay_ms": 120000
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Discord</b></summary>
|
||||
|
||||
@ -428,13 +462,17 @@ nanobot gateway
|
||||
Uses **Socket Mode** — no public URL required.
|
||||
|
||||
**1. Create a Slack app**
|
||||
- Go to [Slack API](https://api.slack.com/apps) → Create New App
|
||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
|
||||
- Install to your workspace and copy the **Bot Token** (`xoxb-...`)
|
||||
- **Socket Mode**: Enable it and generate an **App-Level Token** (`xapp-...`) with `connections:write` scope
|
||||
- **Event Subscriptions**: Subscribe to `message.im`, `message.channels`, `app_mention`
|
||||
- Go to [Slack API](https://api.slack.com/apps) → **Create New App** → "From scratch"
|
||||
- Pick a name and select your workspace
|
||||
|
||||
**2. Configure**
|
||||
**2. Configure the app**
|
||||
- **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`)
|
||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
|
||||
- **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes
|
||||
- **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"**
|
||||
- **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`)
|
||||
|
||||
**3. Configure nanobot**
|
||||
|
||||
```json
|
||||
{
|
||||
@ -449,15 +487,18 @@ Uses **Socket Mode** — no public URL required.
|
||||
}
|
||||
```
|
||||
|
||||
> `groupPolicy`: `"mention"` (respond only when @mentioned), `"open"` (respond to all messages), or `"allowlist"` (restrict to specific channels).
|
||||
> DM policy defaults to open. Set `"dm": {"enabled": false}` to disable DMs.
|
||||
|
||||
**3. Run**
|
||||
**4. Run**
|
||||
|
||||
```bash
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
DM the bot directly or @mention it in a channel — it should respond!
|
||||
|
||||
> [!TIP]
|
||||
> - `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all channel messages), or `"allowlist"` (restrict to specific channels).
|
||||
> - DM policy defaults to open. Set `"dm": {"enabled": false}` to disable DMs.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@ -507,6 +548,17 @@ nanobot gateway
|
||||
|
||||
</details>
|
||||
|
||||
## 🌐 Agent Social Network
|
||||
|
||||
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
|
||||
|
||||
| Platform | How to Join (send this message to your bot) |
|
||||
|----------|-------------|
|
||||
| [**Moltbook**](https://www.moltbook.com/) | `Read https://moltbook.com/skill.md and follow the instructions to join Moltbook` |
|
||||
| [**ClawdChat**](https://clawdchat.ai/) | `Read https://clawdchat.ai/skill.md and follow the instructions to join ClawdChat` |
|
||||
|
||||
Simply send the command above to your nanobot (via CLI or any chat channel), and it will handle the rest.
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
Config file: `~/.nanobot/config.json`
|
||||
@ -516,21 +568,86 @@ Config file: `~/.nanobot/config.json`
|
||||
> [!TIP]
|
||||
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||
|
||||
| Provider | Purpose | Get API Key |
|
||||
|----------|---------|-------------|
|
||||
| `custom` | Any OpenAI-compatible endpoint | — |
|
||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `minimax` | LLM (MiniMax direct) | [platform.minimax.io](https://platform.minimax.io) |
|
||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||
|
||||
<details>
|
||||
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
|
||||
|
||||
If your provider is not listed above but exposes an **OpenAI-compatible API** (e.g. Together AI, Fireworks, Azure OpenAI, self-hosted endpoints), use the `custom` provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"custom": {
|
||||
"apiKey": "your-api-key",
|
||||
"apiBase": "https://api.your-provider.com/v1"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "your-model-name"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> The `custom` provider routes through LiteLLM's OpenAI-compatible path. It works with any endpoint that follows the OpenAI chat completions API format. The model name is passed directly to the endpoint without any prefix.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
||||
|
||||
Run your own model with vLLM or any OpenAI-compatible server, then add to config:
|
||||
|
||||
**1. Start the server** (example):
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
|
||||
```
|
||||
|
||||
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||
|
||||
*Provider (key can be any non-empty string for local):*
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"vllm": {
|
||||
"apiKey": "dummy",
|
||||
"apiBase": "http://localhost:8000/v1"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
*Model:*
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Adding a New Provider (Developer Guide)</b></summary>
|
||||
|
||||
@ -576,8 +693,43 @@ That's it! Environment variables, model prefixing, config matching, and `nanobot
|
||||
</details>
|
||||
|
||||
|
||||
### MCP (Model Context Protocol)
|
||||
|
||||
> [!TIP]
|
||||
> The config format is compatible with Claude Desktop / Cursor. You can copy MCP server configs directly from any MCP server's README.
|
||||
|
||||
nanobot supports [MCP](https://modelcontextprotocol.io/) — connect external tool servers and use them as native agent tools.
|
||||
|
||||
Add MCP servers to your `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcpServers": {
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Two transport modes are supported:
|
||||
|
||||
| Mode | Config | Example |
|
||||
|------|--------|---------|
|
||||
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
|
||||
| **HTTP** | `url` | Remote endpoint (`https://mcp.example.com/sse`) |
|
||||
|
||||
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
||||
|
||||
|
||||
|
||||
|
||||
### Security
|
||||
|
||||
> [!TIP]
|
||||
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
||||
|
||||
| Option | Default | Description |
|
||||
@ -637,7 +789,7 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot onboard
|
||||
# Edit config on host to add API keys
|
||||
vim ~/.nanobot/config.json
|
||||
|
||||
# Run gateway (connects to Telegram/WhatsApp)
|
||||
# Run gateway (connects to enabled channels, e.g. Telegram/Discord/Mochat)
|
||||
docker run -v ~/.nanobot:/root/.nanobot -p 18790:18790 nanobot gateway
|
||||
|
||||
# Or run a single command
|
||||
@ -657,7 +809,7 @@ nanobot/
|
||||
│ ├── subagent.py # Background task execution
|
||||
│ └── tools/ # Built-in tools (incl. spawn)
|
||||
├── skills/ # 🎯 Bundled skills (github, weather, tmux...)
|
||||
├── channels/ # 📱 WhatsApp integration
|
||||
├── channels/ # 📱 Chat channel integrations
|
||||
├── bus/ # 🚌 Message routing
|
||||
├── cron/ # ⏰ Scheduled tasks
|
||||
├── heartbeat/ # 💓 Proactive wake-up
|
||||
@ -673,7 +825,6 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||
|
||||
**Roadmap** — Pick an item and [open a PR](https://github.com/HKUDS/nanobot/pulls)!
|
||||
|
||||
- [x] **Voice Transcription** — Support for Groq Whisper (Issue #13)
|
||||
- [ ] **Multi-modal** — See and hear (images, voice, video)
|
||||
- [ ] **Long-term memory** — Never forget important context
|
||||
- [ ] **Better reasoning** — Multi-step planning and reflection
|
||||
@ -683,7 +834,7 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||
### Contributors
|
||||
|
||||
<a href="https://github.com/HKUDS/nanobot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=HKUDS/nanobot&max=100&columns=12" />
|
||||
<img src="https://contrib.rocks/image?repo=HKUDS/nanobot&max=100&columns=12&updated=20260210" alt="Contributors" />
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
@ -95,8 +95,8 @@ File operations have path traversal protection, but:
|
||||
- Consider using a firewall to restrict outbound connections if needed
|
||||
|
||||
**WhatsApp Bridge:**
|
||||
- The bridge runs on `localhost:3001` by default
|
||||
- If exposing to network, use proper authentication and TLS
|
||||
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||
|
||||
### 6. Dependency Security
|
||||
@ -224,7 +224,7 @@ If you suspect a security breach:
|
||||
✅ **Secure Communication**
|
||||
- HTTPS for all external API calls
|
||||
- TLS for Telegram API
|
||||
- WebSocket security for WhatsApp bridge
|
||||
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||
|
||||
## Known Limitations
|
||||
|
||||
|
||||
@ -25,11 +25,12 @@ import { join } from 'path';
|
||||
|
||||
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
||||
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
||||
const TOKEN = process.env.BRIDGE_TOKEN || undefined;
|
||||
|
||||
console.log('🐈 nanobot WhatsApp Bridge');
|
||||
console.log('========================\n');
|
||||
|
||||
const server = new BridgeServer(PORT, AUTH_DIR);
|
||||
const server = new BridgeServer(PORT, AUTH_DIR, TOKEN);
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
/**
|
||||
* WebSocket server for Python-Node.js bridge communication.
|
||||
* Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth.
|
||||
*/
|
||||
|
||||
import { WebSocketServer, WebSocket } from 'ws';
|
||||
@ -21,12 +22,13 @@ export class BridgeServer {
|
||||
private wa: WhatsAppClient | null = null;
|
||||
private clients: Set<WebSocket> = new Set();
|
||||
|
||||
constructor(private port: number, private authDir: string) {}
|
||||
constructor(private port: number, private authDir: string, private token?: string) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
// Create WebSocket server
|
||||
this.wss = new WebSocketServer({ port: this.port });
|
||||
console.log(`🌉 Bridge server listening on ws://localhost:${this.port}`);
|
||||
// Bind to localhost only — never expose to external network
|
||||
this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port });
|
||||
console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`);
|
||||
if (this.token) console.log('🔒 Token authentication enabled');
|
||||
|
||||
// Initialize WhatsApp client
|
||||
this.wa = new WhatsAppClient({
|
||||
@ -38,35 +40,58 @@ export class BridgeServer {
|
||||
|
||||
// Handle WebSocket connections
|
||||
this.wss.on('connection', (ws) => {
|
||||
console.log('🔗 Python client connected');
|
||||
this.clients.add(ws);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
try {
|
||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||
await this.handleCommand(cmd);
|
||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||
} catch (error) {
|
||||
console.error('Error handling command:', error);
|
||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
console.log('🔌 Python client disconnected');
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
if (this.token) {
|
||||
// Require auth handshake as first message
|
||||
const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000);
|
||||
ws.once('message', (data) => {
|
||||
clearTimeout(timeout);
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === 'auth' && msg.token === this.token) {
|
||||
console.log('🔗 Python client authenticated');
|
||||
this.setupClient(ws);
|
||||
} else {
|
||||
ws.close(4003, 'Invalid token');
|
||||
}
|
||||
} catch {
|
||||
ws.close(4003, 'Invalid auth message');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.log('🔗 Python client connected');
|
||||
this.setupClient(ws);
|
||||
}
|
||||
});
|
||||
|
||||
// Connect to WhatsApp
|
||||
await this.wa.connect();
|
||||
}
|
||||
|
||||
private setupClient(ws: WebSocket): void {
|
||||
this.clients.add(ws);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
try {
|
||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||
await this.handleCommand(cmd);
|
||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||
} catch (error) {
|
||||
console.error('Error handling command:', error);
|
||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
console.log('🔌 Python client disconnected');
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
}
|
||||
|
||||
private async handleCommand(cmd: SendCommand): Promise<void> {
|
||||
if (cmd.type === 'send' && this.wa) {
|
||||
await this.wa.sendMessage(cmd.to, cmd.text);
|
||||
|
||||
@ -73,7 +73,9 @@ Skills with available="false" need dependencies installed first - you can try in
|
||||
def _get_identity(self) -> str:
|
||||
"""Get the core identity section."""
|
||||
from datetime import datetime
|
||||
import time as _time
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||
tz = _time.strftime("%Z") or "UTC"
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||
@ -88,23 +90,24 @@ You are nanobot, a helpful AI assistant. You have access to tools that allow you
|
||||
- Spawn subagents for complex background tasks
|
||||
|
||||
## Current Time
|
||||
{now}
|
||||
{now} ({tz})
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
- Memory files: {workspace_path}/memory/MEMORY.md
|
||||
- Daily notes: {workspace_path}/memory/YYYY-MM-DD.md
|
||||
- Long-term memory: {workspace_path}/memory/MEMORY.md
|
||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
|
||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||
|
||||
IMPORTANT: When responding to direct questions or conversations, reply directly with your text response.
|
||||
Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp).
|
||||
For normal conversation, just respond with text - do not call the message tool.
|
||||
|
||||
Always be helpful, accurate, and concise. When using tools, explain what you're doing.
|
||||
When remembering something, write to {workspace_path}/memory/MEMORY.md"""
|
||||
Always be helpful, accurate, and concise. When using tools, think step by step: what you know, what you need, and why you chose this tool.
|
||||
When remembering something important, write to {workspace_path}/memory/MEMORY.md
|
||||
To recall past events, grep {workspace_path}/memory/HISTORY.md"""
|
||||
|
||||
def _load_bootstrap_files(self) -> str:
|
||||
"""Load all bootstrap files from workspace."""
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
"""Agent loop: the core processing engine."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
import json
|
||||
import json_repair
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -18,14 +20,15 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.session.manager import SessionManager
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
|
||||
|
||||
It:
|
||||
1. Receives messages from the bus
|
||||
2. Builds context with history, memory, skills
|
||||
@ -33,7 +36,7 @@ class AgentLoop:
|
||||
4. Executes tool calls
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
@ -41,11 +44,15 @@ class AgentLoop:
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 20,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int = 50,
|
||||
brave_api_key: str | None = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
cron_service: "CronService | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
session_manager: SessionManager | None = None,
|
||||
mcp_servers: dict | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.cron.service import CronService
|
||||
@ -54,11 +61,14 @@ class AgentLoop:
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.memory_window = memory_window
|
||||
self.brave_api_key = brave_api_key
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.cron_service = cron_service
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
|
||||
|
||||
self.context = ContextBuilder(workspace)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
@ -67,12 +77,17 @@ class AgentLoop:
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
brave_api_key=brave_api_key,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
)
|
||||
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
@ -107,107 +122,64 @@ class AgentLoop:
|
||||
if self.cron_service:
|
||||
self.tools.register(CronTool(self.cron_service))
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, processing messages from the bus."""
|
||||
self._running = True
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for next message
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_inbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# Process it
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
if response:
|
||||
await self.bus.publish_outbound(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
# Send error response
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=f"Sorry, I encountered an error: {str(e)}"
|
||||
))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
||||
async def _connect_mcp(self) -> None:
|
||||
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||
if self._mcp_connected or not self._mcp_servers:
|
||||
return
|
||||
self._mcp_connected = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(channel, chat_id)
|
||||
|
||||
if spawn_tool := self.tools.get("spawn"):
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(channel, chat_id)
|
||||
|
||||
if cron_tool := self.tools.get("cron"):
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(channel, chat_id)
|
||||
|
||||
async def _run_agent_loop(self, initial_messages: list[dict]) -> tuple[str | None, list[str]]:
|
||||
"""
|
||||
Process a single inbound message.
|
||||
|
||||
Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
msg: The inbound message to process.
|
||||
|
||||
initial_messages: Starting messages for the LLM conversation.
|
||||
|
||||
Returns:
|
||||
The response message, or None if no response needed.
|
||||
Tuple of (final_content, list_of_tools_used).
|
||||
"""
|
||||
# Handle system messages (subagent announces)
|
||||
# The chat_id contains the original "channel:chat_id" to route back to
|
||||
if msg.channel == "system":
|
||||
return await self._process_system_message(msg)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
||||
|
||||
# Get or create session
|
||||
session = self.sessions.get_or_create(msg.session_key)
|
||||
|
||||
# Update tool contexts
|
||||
message_tool = self.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
spawn_tool = self.tools.get("spawn")
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
cron_tool = self.tools.get("cron")
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
# Build initial messages (use get_history for LLM-formatted messages)
|
||||
messages = self.context.build_messages(
|
||||
history=session.get_history(),
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
# Agent loop
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
|
||||
tools_used: list[str] = []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
# Handle tool calls
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments) # Must be JSON string
|
||||
"arguments": json.dumps(tc.arguments)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
@ -216,30 +188,126 @@ class AgentLoop:
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||
else:
|
||||
# No tool calls, we're done
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
return final_content, tools_used
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, processing messages from the bus."""
|
||||
self._running = True
|
||||
await self._connect_mcp()
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_inbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
if response:
|
||||
await self.bus.publish_outbound(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=f"Sorry, I encountered an error: {str(e)}"
|
||||
))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Close MCP connections."""
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(self, msg: InboundMessage, session_key: str | None = None) -> OutboundMessage | None:
|
||||
"""
|
||||
Process a single inbound message.
|
||||
|
||||
Args:
|
||||
msg: The inbound message to process.
|
||||
session_key: Override session key (used by process_direct).
|
||||
|
||||
Returns:
|
||||
The response message, or None if no response needed.
|
||||
"""
|
||||
# System messages route back via chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
return await self._process_system_message(msg)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
|
||||
# Handle slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
# Capture messages before clearing (avoid race condition with background task)
|
||||
messages_to_archive = session.messages.copy()
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
self.sessions.invalidate(session.key)
|
||||
|
||||
async def _consolidate_and_cleanup():
|
||||
temp_session = Session(key=session.key)
|
||||
temp_session.messages = messages_to_archive
|
||||
await self._consolidate_memory(temp_session, archive_all=True)
|
||||
|
||||
asyncio.create_task(_consolidate_and_cleanup())
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started. Memory consolidation in progress.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||
|
||||
if len(session.messages) > self.memory_window:
|
||||
asyncio.create_task(self._consolidate_memory(session))
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=session.get_history(max_messages=self.memory_window),
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
final_content, tools_used = await self._run_agent_loop(initial_messages)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
# Log response preview
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
||||
|
||||
# Save to session
|
||||
session.add_message("user", msg.content)
|
||||
session.add_message("assistant", final_content)
|
||||
session.add_message("assistant", final_content,
|
||||
tools_used=tools_used if tools_used else None)
|
||||
self.sessions.save(session)
|
||||
|
||||
return OutboundMessage(
|
||||
@ -268,76 +336,20 @@ class AgentLoop:
|
||||
origin_channel = "cli"
|
||||
origin_chat_id = msg.chat_id
|
||||
|
||||
# Use the origin session for context
|
||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
||||
session = self.sessions.get_or_create(session_key)
|
||||
|
||||
# Update tool contexts
|
||||
message_tool = self.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
spawn_tool = self.tools.get("spawn")
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
cron_tool = self.tools.get("cron")
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
# Build messages with the announce content
|
||||
messages = self.context.build_messages(
|
||||
history=session.get_history(),
|
||||
self._set_tool_context(origin_channel, origin_chat_id)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=session.get_history(max_messages=self.memory_window),
|
||||
current_message=msg.content,
|
||||
channel=origin_channel,
|
||||
chat_id=origin_chat_id,
|
||||
)
|
||||
|
||||
# Agent loop (limited for announce handling)
|
||||
iteration = 0
|
||||
final_content = None
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
final_content, _ = await self._run_agent_loop(initial_messages)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "Background task completed."
|
||||
|
||||
# Save to session (mark as system message in history)
|
||||
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
||||
session.add_message("assistant", final_content)
|
||||
self.sessions.save(session)
|
||||
@ -348,6 +360,91 @@ class AgentLoop:
|
||||
content=final_content
|
||||
)
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
||||
|
||||
Args:
|
||||
archive_all: If True, clear all messages and reset session (for /new command).
|
||||
If False, only write to files without modifying session.
|
||||
"""
|
||||
memory = MemoryStore(self.workspace)
|
||||
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
keep_count = 0
|
||||
logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived")
|
||||
else:
|
||||
keep_count = self.memory_window // 2
|
||||
if len(session.messages) <= keep_count:
|
||||
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
|
||||
return
|
||||
|
||||
messages_to_process = len(session.messages) - session.last_consolidated
|
||||
if messages_to_process <= 0:
|
||||
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
|
||||
return
|
||||
|
||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||
if not old_messages:
|
||||
return
|
||||
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")
|
||||
|
||||
lines = []
|
||||
for m in old_messages:
|
||||
if not m.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
||||
conversation = "\n".join(lines)
|
||||
current_memory = memory.read_long_term()
|
||||
|
||||
prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys:
|
||||
|
||||
1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.
|
||||
|
||||
2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.
|
||||
|
||||
## Current Long-term Memory
|
||||
{current_memory or "(empty)"}
|
||||
|
||||
## Conversation to Process
|
||||
{conversation}
|
||||
|
||||
Respond with ONLY valid JSON, no markdown fences."""
|
||||
|
||||
try:
|
||||
response = await self.provider.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
model=self.model,
|
||||
)
|
||||
text = (response.content or "").strip()
|
||||
if not text:
|
||||
logger.warning("Memory consolidation: LLM returned empty response, skipping")
|
||||
return
|
||||
if text.startswith("```"):
|
||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
result = json_repair.loads(text)
|
||||
if not isinstance(result, dict):
|
||||
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
|
||||
return
|
||||
|
||||
if entry := result.get("history_entry"):
|
||||
memory.append_history(entry)
|
||||
if update := result.get("memory_update"):
|
||||
if update != current_memory:
|
||||
memory.write_long_term(update)
|
||||
|
||||
if archive_all:
|
||||
session.last_consolidated = 0
|
||||
else:
|
||||
session.last_consolidated = len(session.messages) - keep_count
|
||||
logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}")
|
||||
except Exception as e:
|
||||
logger.error(f"Memory consolidation failed: {e}")
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
@ -360,13 +457,14 @@ class AgentLoop:
|
||||
|
||||
Args:
|
||||
content: The message content.
|
||||
session_key: Session identifier.
|
||||
channel: Source channel (for context).
|
||||
chat_id: Source chat ID (for context).
|
||||
session_key: Session identifier (overrides channel:chat_id for session lookup).
|
||||
channel: Source channel (for tool context routing).
|
||||
chat_id: Source chat ID (for tool context routing).
|
||||
|
||||
Returns:
|
||||
The agent's response.
|
||||
"""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(
|
||||
channel=channel,
|
||||
sender_id="user",
|
||||
@ -374,5 +472,5 @@ class AgentLoop:
|
||||
content=content
|
||||
)
|
||||
|
||||
response = await self._process_message(msg)
|
||||
response = await self._process_message(msg, session_key=session_key)
|
||||
return response.content if response else ""
|
||||
|
||||
@ -1,109 +1,30 @@
|
||||
"""Memory system for persistent agent memory."""
|
||||
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir, today_date
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""
|
||||
Memory system for the agent.
|
||||
|
||||
Supports daily notes (memory/YYYY-MM-DD.md) and long-term memory (MEMORY.md).
|
||||
"""
|
||||
|
||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.memory_dir = ensure_dir(workspace / "memory")
|
||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||
|
||||
def get_today_file(self) -> Path:
|
||||
"""Get path to today's memory file."""
|
||||
return self.memory_dir / f"{today_date()}.md"
|
||||
|
||||
def read_today(self) -> str:
|
||||
"""Read today's memory notes."""
|
||||
today_file = self.get_today_file()
|
||||
if today_file.exists():
|
||||
return today_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def append_today(self, content: str) -> None:
|
||||
"""Append content to today's memory notes."""
|
||||
today_file = self.get_today_file()
|
||||
|
||||
if today_file.exists():
|
||||
existing = today_file.read_text(encoding="utf-8")
|
||||
content = existing + "\n" + content
|
||||
else:
|
||||
# Add header for new day
|
||||
header = f"# {today_date()}\n\n"
|
||||
content = header + content
|
||||
|
||||
today_file.write_text(content, encoding="utf-8")
|
||||
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def read_long_term(self) -> str:
|
||||
"""Read long-term memory (MEMORY.md)."""
|
||||
if self.memory_file.exists():
|
||||
return self.memory_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
|
||||
def write_long_term(self, content: str) -> None:
|
||||
"""Write to long-term memory (MEMORY.md)."""
|
||||
self.memory_file.write_text(content, encoding="utf-8")
|
||||
|
||||
def get_recent_memories(self, days: int = 7) -> str:
|
||||
"""
|
||||
Get memories from the last N days.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back.
|
||||
|
||||
Returns:
|
||||
Combined memory content.
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
memories = []
|
||||
today = datetime.now().date()
|
||||
|
||||
for i in range(days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
file_path = self.memory_dir / f"{date_str}.md"
|
||||
|
||||
if file_path.exists():
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
memories.append(content)
|
||||
|
||||
return "\n\n---\n\n".join(memories)
|
||||
|
||||
def list_memory_files(self) -> list[Path]:
|
||||
"""List all memory files sorted by date (newest first)."""
|
||||
if not self.memory_dir.exists():
|
||||
return []
|
||||
|
||||
files = list(self.memory_dir.glob("????-??-??.md"))
|
||||
return sorted(files, reverse=True)
|
||||
|
||||
|
||||
def append_history(self, entry: str) -> None:
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry.rstrip() + "\n\n")
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
"""
|
||||
Get memory context for the agent.
|
||||
|
||||
Returns:
|
||||
Formatted memory context including long-term and recent memories.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Long-term memory
|
||||
long_term = self.read_long_term()
|
||||
if long_term:
|
||||
parts.append("## Long-term Memory\n" + long_term)
|
||||
|
||||
# Today's notes
|
||||
today = self.read_today()
|
||||
if today:
|
||||
parts.append("## Today's Notes\n" + today)
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
@ -12,7 +12,7 @@ from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, ListDirTool
|
||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
||||
|
||||
@ -32,6 +32,8 @@ class SubagentManager:
|
||||
workspace: Path,
|
||||
bus: MessageBus,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
brave_api_key: str | None = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
@ -41,6 +43,8 @@ class SubagentManager:
|
||||
self.workspace = workspace
|
||||
self.bus = bus
|
||||
self.model = model or provider.get_default_model()
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.brave_api_key = brave_api_key
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
@ -101,6 +105,7 @@ class SubagentManager:
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
tools.register(ReadFileTool(allowed_dir=allowed_dir))
|
||||
tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
||||
tools.register(EditFileTool(allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(allowed_dir=allowed_dir))
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
@ -129,6 +134,8 @@ class SubagentManager:
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
@ -210,12 +217,17 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
||||
|
||||
def _build_subagent_prompt(self, task: str) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from datetime import datetime
|
||||
import time as _time
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||
tz = _time.strftime("%Z") or "UTC"
|
||||
|
||||
return f"""# Subagent
|
||||
|
||||
You are a subagent spawned by the main agent to complete a specific task.
|
||||
## Current Time
|
||||
{now} ({tz})
|
||||
|
||||
## Your Task
|
||||
{task}
|
||||
You are a subagent spawned by the main agent to complete a specific task.
|
||||
|
||||
## Rules
|
||||
1. Stay focused - complete only the assigned task, nothing else
|
||||
@ -236,6 +248,7 @@ You are a subagent spawned by the main agent to complete a specific task.
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {self.workspace}
|
||||
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
||||
|
||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
||||
|
||||
|
||||
@ -50,6 +50,10 @@ class CronTool(Tool):
|
||||
"type": "string",
|
||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
|
||||
},
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove)"
|
||||
@ -64,30 +68,38 @@ class CronTool(Tool):
|
||||
message: str = "",
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
if action == "add":
|
||||
return self._add_job(message, every_seconds, cron_expr)
|
||||
return self._add_job(message, every_seconds, cron_expr, at)
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
elif action == "remove":
|
||||
return self._remove_job(job_id)
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None) -> str:
|
||||
def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None, at: str | None) -> str:
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
if not self._channel or not self._chat_id:
|
||||
return "Error: no session context (channel/chat_id)"
|
||||
|
||||
# Build schedule
|
||||
delete_after = False
|
||||
if every_seconds:
|
||||
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
||||
elif cron_expr:
|
||||
schedule = CronSchedule(kind="cron", expr=cron_expr)
|
||||
elif at:
|
||||
from datetime import datetime
|
||||
dt = datetime.fromisoformat(at)
|
||||
at_ms = int(dt.timestamp() * 1000)
|
||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||
delete_after = True
|
||||
else:
|
||||
return "Error: either every_seconds or cron_expr is required"
|
||||
return "Error: either every_seconds, cron_expr, or at is required"
|
||||
|
||||
job = self._cron.add_job(
|
||||
name=message[:30],
|
||||
@ -96,6 +108,7 @@ class CronTool(Tool):
|
||||
deliver=True,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
|
||||
80
nanobot/agent/tools/mcp.py
Normal file
80
nanobot/agent/tools/mcp.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||
|
||||
def __init__(self, session, server_name: str, tool_def):
|
||||
self._session = session
|
||||
self._original_name = tool_def.name
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._description = tool_def.description or tool_def.name
|
||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._parameters
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
result = await self._session.call_tool(self._original_name, arguments=kwargs)
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
parts.append(block.text)
|
||||
else:
|
||||
parts.append(str(block))
|
||||
return "\n".join(parts) or "(no output)"
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools."""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
try:
|
||||
if cfg.command:
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
elif cfg.url:
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
streamable_http_client(cfg.url)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"MCP server '{name}': no command or url configured, skipping")
|
||||
continue
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(session, name, tool_def)
|
||||
registry.register(wrapper)
|
||||
logger.debug(f"MCP: registered tool '{wrapper.name}' from server '{name}'")
|
||||
|
||||
logger.info(f"MCP server '{name}': connected, {len(tools.tools)} tools registered")
|
||||
except Exception as e:
|
||||
logger.error(f"MCP server '{name}': failed to connect: {e}")
|
||||
@ -128,14 +128,17 @@ class ExecTool(Tool):
|
||||
cwd_path = Path(cwd).resolve()
|
||||
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
||||
posix_paths = re.findall(r"/[^\s\"']+", cmd)
|
||||
# Only match absolute paths — avoid false positives on relative
|
||||
# paths like ".venv/bin/python" where "/bin/python" would be
|
||||
# incorrectly extracted by the old pattern.
|
||||
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
|
||||
|
||||
for raw in win_paths + posix_paths:
|
||||
try:
|
||||
p = Path(raw).resolve()
|
||||
p = Path(raw.strip()).resolve()
|
||||
except Exception:
|
||||
continue
|
||||
if cwd_path not in p.parents and p != cwd_path:
|
||||
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
return None
|
||||
|
||||
@ -137,8 +137,15 @@ class DingTalkChannel(BaseChannel):
|
||||
|
||||
logger.info("DingTalk bot started with Stream Mode")
|
||||
|
||||
# client.start() is an async infinite loop handling the websocket connection
|
||||
await self._client.start()
|
||||
# Reconnect loop: restart stream if SDK exits or crashes
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start()
|
||||
except Exception as e:
|
||||
logger.warning(f"DingTalk stream error: {e}")
|
||||
if self._running:
|
||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to start DingTalk channel: {e}")
|
||||
|
||||
@ -39,6 +39,53 @@ MSG_TYPE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _extract_post_text(content_json: dict) -> str:
|
||||
"""Extract plain text from Feishu post (rich text) message content.
|
||||
|
||||
Supports two formats:
|
||||
1. Direct format: {"title": "...", "content": [...]}
|
||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||
"""
|
||||
def extract_from_lang(lang_content: dict) -> str | None:
|
||||
if not isinstance(lang_content, dict):
|
||||
return None
|
||||
title = lang_content.get("title", "")
|
||||
content_blocks = lang_content.get("content", [])
|
||||
if not isinstance(content_blocks, list):
|
||||
return None
|
||||
text_parts = []
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
for block in content_blocks:
|
||||
if not isinstance(block, list):
|
||||
continue
|
||||
for element in block:
|
||||
if isinstance(element, dict):
|
||||
tag = element.get("tag")
|
||||
if tag == "text":
|
||||
text_parts.append(element.get("text", ""))
|
||||
elif tag == "a":
|
||||
text_parts.append(element.get("text", ""))
|
||||
elif tag == "at":
|
||||
text_parts.append(f"@{element.get('user_name', 'user')}")
|
||||
return " ".join(text_parts).strip() if text_parts else None
|
||||
|
||||
# Try direct format first
|
||||
if "content" in content_json:
|
||||
result = extract_from_lang(content_json)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Try localized format
|
||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||
lang_content = content_json.get(lang_key)
|
||||
result = extract_from_lang(lang_content)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
Feishu/Lark channel using WebSocket long connection.
|
||||
@ -98,12 +145,15 @@ class FeishuChannel(BaseChannel):
|
||||
log_level=lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
# Start WebSocket client in a separate thread
|
||||
# Start WebSocket client in a separate thread with reconnect loop
|
||||
def run_ws():
|
||||
try:
|
||||
self._ws_client.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Feishu WebSocket error: {e}")
|
||||
while self._running:
|
||||
try:
|
||||
self._ws_client.start()
|
||||
except Exception as e:
|
||||
logger.warning(f"Feishu WebSocket error: {e}")
|
||||
if self._running:
|
||||
import time; time.sleep(5)
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
@ -163,6 +213,10 @@ class FeishuChannel(BaseChannel):
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||
|
||||
_CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE)
|
||||
|
||||
@staticmethod
|
||||
def _parse_md_table(table_text: str) -> dict | None:
|
||||
"""Parse a markdown table into a Feishu table element."""
|
||||
@ -182,17 +236,52 @@ class FeishuChannel(BaseChannel):
|
||||
}
|
||||
|
||||
def _build_card_elements(self, content: str) -> list[dict]:
|
||||
"""Split content into markdown + table elements for Feishu card."""
|
||||
"""Split content into div/markdown + table elements for Feishu card."""
|
||||
elements, last_end = [], 0
|
||||
for m in self._TABLE_RE.finditer(content):
|
||||
before = content[last_end:m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
before = content[last_end:m.start()]
|
||||
if before.strip():
|
||||
elements.extend(self._split_headings(before))
|
||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||
last_end = m.end()
|
||||
remaining = content[last_end:].strip()
|
||||
remaining = content[last_end:]
|
||||
if remaining.strip():
|
||||
elements.extend(self._split_headings(remaining))
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
def _split_headings(self, content: str) -> list[dict]:
|
||||
"""Split content by headings, converting headings to div elements."""
|
||||
protected = content
|
||||
code_blocks = []
|
||||
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||
code_blocks.append(m.group(1))
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||
|
||||
elements = []
|
||||
last_end = 0
|
||||
for m in self._HEADING_RE.finditer(protected):
|
||||
before = protected[last_end:m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
level = len(m.group(1))
|
||||
text = m.group(2).strip()
|
||||
elements.append({
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": f"**{text}**",
|
||||
},
|
||||
})
|
||||
last_end = m.end()
|
||||
remaining = protected[last_end:].strip()
|
||||
if remaining:
|
||||
elements.append({"tag": "markdown", "content": remaining})
|
||||
|
||||
for i, cb in enumerate(code_blocks):
|
||||
for el in elements:
|
||||
if el.get("tag") == "markdown":
|
||||
el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb)
|
||||
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
@ -284,6 +373,12 @@ class FeishuChannel(BaseChannel):
|
||||
content = json.loads(message.content).get("text", "")
|
||||
except json.JSONDecodeError:
|
||||
content = message.content or ""
|
||||
elif msg_type == "post":
|
||||
try:
|
||||
content_json = json.loads(message.content)
|
||||
content = _extract_post_text(content_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
content = message.content or ""
|
||||
else:
|
||||
content = MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]")
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@ -12,9 +12,6 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
@ -26,10 +23,9 @@ class ChannelManager:
|
||||
- Route outbound messages
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus, session_manager: "SessionManager | None" = None):
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self.session_manager = session_manager
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
@ -46,7 +42,6 @@ class ChannelManager:
|
||||
self.config.channels.telegram,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
session_manager=self.session_manager,
|
||||
)
|
||||
logger.info("Telegram channel enabled")
|
||||
except ImportError as e:
|
||||
@ -85,6 +80,18 @@ class ChannelManager:
|
||||
except ImportError as e:
|
||||
logger.warning(f"Feishu channel not available: {e}")
|
||||
|
||||
# Mochat channel
|
||||
if self.config.channels.mochat.enabled:
|
||||
try:
|
||||
from nanobot.channels.mochat import MochatChannel
|
||||
|
||||
self.channels["mochat"] = MochatChannel(
|
||||
self.config.channels.mochat, self.bus
|
||||
)
|
||||
logger.info("Mochat channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Mochat channel not available: {e}")
|
||||
|
||||
# DingTalk channel
|
||||
if self.config.channels.dingtalk.enabled:
|
||||
try:
|
||||
|
||||
895
nanobot/channels/mochat.py
Normal file
895
nanobot/channels/mochat.py
Normal file
@ -0,0 +1,895 @@
|
||||
"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import MochatConfig
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
|
||||
try:
|
||||
import socketio
|
||||
SOCKETIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
socketio = None
|
||||
SOCKETIO_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import msgpack # noqa: F401
|
||||
MSGPACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSGPACK_AVAILABLE = False
|
||||
|
||||
MAX_SEEN_MESSAGE_IDS = 2000
|
||||
CURSOR_SAVE_DEBOUNCE_S = 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MochatBufferedEntry:
|
||||
"""Buffered inbound entry for delayed dispatch."""
|
||||
raw_body: str
|
||||
author: str
|
||||
sender_name: str = ""
|
||||
sender_username: str = ""
|
||||
timestamp: int | None = None
|
||||
message_id: str = ""
|
||||
group_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelayState:
|
||||
"""Per-target delayed message state."""
|
||||
entries: list[MochatBufferedEntry] = field(default_factory=list)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
timer: asyncio.Task | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MochatTarget:
|
||||
"""Outbound target resolution result."""
|
||||
id: str
|
||||
is_panel: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_dict(value: Any) -> dict:
|
||||
"""Return *value* if it's a dict, else empty dict."""
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _str_field(src: dict, *keys: str) -> str:
|
||||
"""Return the first non-empty str value found for *keys*, stripped."""
|
||||
for k in keys:
|
||||
v = src.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
return v.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _make_synthetic_event(
|
||||
message_id: str, author: str, content: Any,
|
||||
meta: Any, group_id: str, converse_id: str,
|
||||
timestamp: Any = None, *, author_info: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a synthetic ``message.add`` event dict."""
|
||||
payload: dict[str, Any] = {
|
||||
"messageId": message_id, "author": author,
|
||||
"content": content, "meta": _safe_dict(meta),
|
||||
"groupId": group_id, "converseId": converse_id,
|
||||
}
|
||||
if author_info is not None:
|
||||
payload["authorInfo"] = _safe_dict(author_info)
|
||||
return {
|
||||
"type": "message.add",
|
||||
"timestamp": timestamp or datetime.utcnow().isoformat(),
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def normalize_mochat_content(content: Any) -> str:
|
||||
"""Normalize content payload to text."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def resolve_mochat_target(raw: str) -> MochatTarget:
|
||||
"""Resolve id and target kind from user-provided target string."""
|
||||
trimmed = (raw or "").strip()
|
||||
if not trimmed:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
|
||||
lowered = trimmed.lower()
|
||||
cleaned, forced_panel = trimmed, False
|
||||
for prefix in ("mochat:", "group:", "channel:", "panel:"):
|
||||
if lowered.startswith(prefix):
|
||||
cleaned = trimmed[len(prefix):].strip()
|
||||
forced_panel = prefix in {"group:", "channel:", "panel:"}
|
||||
break
|
||||
|
||||
if not cleaned:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
|
||||
|
||||
|
||||
def extract_mention_ids(value: Any) -> list[str]:
|
||||
"""Extract mention ids from heterogeneous mention payload."""
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
ids: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
ids.append(item.strip())
|
||||
elif isinstance(item, dict):
|
||||
for key in ("id", "userId", "_id"):
|
||||
candidate = item.get(key)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
ids.append(candidate.strip())
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
|
||||
"""Resolve mention state from payload metadata and text fallback."""
|
||||
meta = payload.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
|
||||
return True
|
||||
for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
|
||||
if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
|
||||
return True
|
||||
if not agent_user_id:
|
||||
return False
|
||||
content = payload.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
return False
|
||||
return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
|
||||
|
||||
|
||||
def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
|
||||
"""Resolve mention requirement for group/panel conversations."""
|
||||
groups = config.groups or {}
|
||||
for key in (group_id, session_id, "*"):
|
||||
if key and key in groups:
|
||||
return bool(groups[key].require_mention)
|
||||
return bool(config.mention.require_in_groups)
|
||||
|
||||
|
||||
def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
|
||||
"""Build text body from one or more buffered entries."""
|
||||
if not entries:
|
||||
return ""
|
||||
if len(entries) == 1:
|
||||
return entries[0].raw_body
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
if not entry.raw_body:
|
||||
continue
|
||||
if is_group:
|
||||
label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
|
||||
if label:
|
||||
lines.append(f"{label}: {entry.raw_body}")
|
||||
continue
|
||||
lines.append(entry.raw_body)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def parse_timestamp(value: Any) -> int | None:
|
||||
"""Parse event timestamp to epoch milliseconds."""
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
try:
|
||||
return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MochatChannel(BaseChannel):
|
||||
"""Mochat channel using socket.io with fallback polling workers."""
|
||||
|
||||
name = "mochat"
|
||||
|
||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: MochatConfig = config
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._socket: Any = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
self._state_dir = get_data_path() / "mochat"
|
||||
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||
self._session_cursor: dict[str, int] = {}
|
||||
self._cursor_save_task: asyncio.Task | None = None
|
||||
|
||||
self._session_set: set[str] = set()
|
||||
self._panel_set: set[str] = set()
|
||||
self._auto_discover_sessions = self._auto_discover_panels = False
|
||||
|
||||
self._cold_sessions: set[str] = set()
|
||||
self._session_by_converse: dict[str, str] = {}
|
||||
|
||||
self._seen_set: dict[str, set[str]] = {}
|
||||
self._seen_queue: dict[str, deque[str]] = {}
|
||||
self._delay_states: dict[str, DelayState] = {}
|
||||
|
||||
self._fallback_mode = False
|
||||
self._session_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
self._target_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# ---- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Mochat channel workers and websocket connection."""
|
||||
if not self.config.claw_token:
|
||||
logger.error("Mochat claw_token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
await self._load_session_cursors()
|
||||
self._seed_targets_from_config()
|
||||
await self._refresh_targets(subscribe_new=False)
|
||||
|
||||
if not await self._start_socket_client():
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
self._refresh_task = asyncio.create_task(self._refresh_loop())
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all workers and clean up resources."""
|
||||
self._running = False
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
|
||||
await self._stop_fallback_workers()
|
||||
await self._cancel_delay_timers()
|
||||
|
||||
if self._socket:
|
||||
try:
|
||||
await self._socket.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
|
||||
if self._cursor_save_task:
|
||||
self._cursor_save_task.cancel()
|
||||
self._cursor_save_task = None
|
||||
await self._save_session_cursors()
|
||||
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound message to session or panel."""
|
||||
if not self.config.claw_token:
|
||||
logger.warning("Mochat claw_token missing, skip send")
|
||||
return
|
||||
|
||||
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||
if msg.media:
|
||||
parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
|
||||
content = "\n".join(parts).strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
target = resolve_mochat_target(msg.chat_id)
|
||||
if not target.id:
|
||||
logger.warning("Mochat outbound target is empty")
|
||||
return
|
||||
|
||||
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||
try:
|
||||
if is_panel:
|
||||
await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
|
||||
content, msg.reply_to, self._read_group_id(msg.metadata))
|
||||
else:
|
||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||
content, msg.reply_to)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send Mochat message: {e}")
|
||||
|
||||
# ---- config / init helpers ---------------------------------------------
|
||||
|
||||
def _seed_targets_from_config(self) -> None:
|
||||
sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
|
||||
panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
|
||||
self._session_set.update(sessions)
|
||||
self._panel_set.update(panels)
|
||||
for sid in sessions:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
|
||||
cleaned = [str(v).strip() for v in values if str(v).strip()]
|
||||
return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
|
||||
|
||||
# ---- websocket ---------------------------------------------------------
|
||||
|
||||
async def _start_socket_client(self) -> bool:
|
||||
if not SOCKETIO_AVAILABLE:
|
||||
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
||||
return False
|
||||
|
||||
serializer = "default"
|
||||
if not self.config.socket_disable_msgpack:
|
||||
if MSGPACK_AVAILABLE:
|
||||
serializer = "msgpack"
|
||||
else:
|
||||
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||
|
||||
client = socketio.AsyncClient(
|
||||
reconnection=True,
|
||||
reconnection_attempts=self.config.max_retry_attempts or None,
|
||||
reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
|
||||
reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
|
||||
logger=False, engineio_logger=False, serializer=serializer,
|
||||
)
|
||||
|
||||
@client.event
|
||||
async def connect() -> None:
|
||||
self._ws_connected, self._ws_ready = True, False
|
||||
logger.info("Mochat websocket connected")
|
||||
subscribed = await self._subscribe_all()
|
||||
self._ws_ready = subscribed
|
||||
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||
|
||||
@client.event
|
||||
async def disconnect() -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._ws_connected = self._ws_ready = False
|
||||
logger.warning("Mochat websocket disconnected")
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
@client.event
|
||||
async def connect_error(data: Any) -> None:
|
||||
logger.error(f"Mochat websocket connect error: {data}")
|
||||
|
||||
@client.on("claw.session.events")
|
||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
|
||||
@client.on("claw.panel.events")
|
||||
async def on_panel_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "panel")
|
||||
|
||||
for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
|
||||
"notify:chat.message.update", "notify:chat.message.recall",
|
||||
"notify:chat.message.delete"):
|
||||
client.on(ev, self._build_notify_handler(ev))
|
||||
|
||||
socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
|
||||
socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
|
||||
|
||||
try:
|
||||
self._socket = client
|
||||
await client.connect(
|
||||
socket_url, transports=["websocket"], socketio_path=socket_path,
|
||||
auth={"token": self.config.claw_token},
|
||||
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect Mochat websocket: {e}")
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
return False
|
||||
|
||||
def _build_notify_handler(self, event_name: str):
|
||||
async def handler(payload: Any) -> None:
|
||||
if event_name == "notify:chat.inbox.append":
|
||||
await self._handle_notify_inbox_append(payload)
|
||||
elif event_name.startswith("notify:chat.message."):
|
||||
await self._handle_notify_chat_message(payload)
|
||||
return handler
|
||||
|
||||
# ---- subscribe ---------------------------------------------------------
|
||||
|
||||
async def _subscribe_all(self) -> bool:
|
||||
ok = await self._subscribe_sessions(sorted(self._session_set))
|
||||
ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
|
||||
if self._auto_discover_sessions or self._auto_discover_panels:
|
||||
await self._refresh_targets(subscribe_new=True)
|
||||
return ok
|
||||
|
||||
async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
|
||||
if not session_ids:
|
||||
return True
|
||||
for sid in session_ids:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
ack = await self._socket_call("com.claw.im.subscribeSessions", {
|
||||
"sessionIds": session_ids, "cursors": self._session_cursor,
|
||||
"limit": self.config.watch_limit,
|
||||
})
|
||||
if not ack.get("result"):
|
||||
logger.error(f"Mochat subscribeSessions failed: {ack.get('message', 'unknown error')}")
|
||||
return False
|
||||
|
||||
data = ack.get("data")
|
||||
items: list[dict[str, Any]] = []
|
||||
if isinstance(data, list):
|
||||
items = [i for i in data if isinstance(i, dict)]
|
||||
elif isinstance(data, dict):
|
||||
sessions = data.get("sessions")
|
||||
if isinstance(sessions, list):
|
||||
items = [i for i in sessions if isinstance(i, dict)]
|
||||
elif "sessionId" in data:
|
||||
items = [data]
|
||||
for p in items:
|
||||
await self._handle_watch_payload(p, "session")
|
||||
return True
|
||||
|
||||
async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
|
||||
if not self._auto_discover_panels and not panel_ids:
|
||||
return True
|
||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||
if not ack.get("result"):
|
||||
logger.error(f"Mochat subscribePanels failed: {ack.get('message', 'unknown error')}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._socket:
|
||||
return {"result": False, "message": "socket not connected"}
|
||||
try:
|
||||
raw = await self._socket.call(event_name, payload, timeout=10)
|
||||
except Exception as e:
|
||||
return {"result": False, "message": str(e)}
|
||||
return raw if isinstance(raw, dict) else {"result": True, "data": raw}
|
||||
|
||||
# ---- refresh / discovery -----------------------------------------------
|
||||
|
||||
async def _refresh_loop(self) -> None:
|
||||
interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running:
|
||||
await asyncio.sleep(interval_s)
|
||||
try:
|
||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat refresh failed: {e}")
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_targets(self, subscribe_new: bool) -> None:
|
||||
if self._auto_discover_sessions:
|
||||
await self._refresh_sessions_directory(subscribe_new)
|
||||
if self._auto_discover_panels:
|
||||
await self._refresh_panels(subscribe_new)
|
||||
|
||||
async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/sessions/list", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat listSessions failed: {e}")
|
||||
return
|
||||
|
||||
sessions = response.get("sessions")
|
||||
if not isinstance(sessions, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for s in sessions:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
sid = _str_field(s, "sessionId")
|
||||
if not sid:
|
||||
continue
|
||||
if sid not in self._session_set:
|
||||
self._session_set.add(sid)
|
||||
new_ids.append(sid)
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
cid = _str_field(s, "converseId")
|
||||
if cid:
|
||||
self._session_by_converse[cid] = sid
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_sessions(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_panels(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/groups/get", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat getWorkspaceGroup failed: {e}")
|
||||
return
|
||||
|
||||
raw_panels = response.get("panels")
|
||||
if not isinstance(raw_panels, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for p in raw_panels:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
pt = p.get("type")
|
||||
if isinstance(pt, int) and pt != 0:
|
||||
continue
|
||||
pid = _str_field(p, "id", "_id")
|
||||
if pid and pid not in self._panel_set:
|
||||
self._panel_set.add(pid)
|
||||
new_ids.append(pid)
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_panels(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
# ---- fallback workers --------------------------------------------------
|
||||
|
||||
async def _ensure_fallback_workers(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._fallback_mode = True
|
||||
for sid in sorted(self._session_set):
|
||||
t = self._session_fallback_tasks.get(sid)
|
||||
if not t or t.done():
|
||||
self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
|
||||
for pid in sorted(self._panel_set):
|
||||
t = self._panel_fallback_tasks.get(pid)
|
||||
if not t or t.done():
|
||||
self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
|
||||
|
||||
async def _stop_fallback_workers(self) -> None:
|
||||
self._fallback_mode = False
|
||||
tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._session_fallback_tasks.clear()
|
||||
self._panel_fallback_tasks.clear()
|
||||
|
||||
async def _session_watch_worker(self, session_id: str) -> None:
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
payload = await self._post_json("/api/claw/sessions/watch", {
|
||||
"sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
|
||||
"timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
|
||||
})
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat watch fallback error ({session_id}): {e}")
|
||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||
|
||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||
sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
resp = await self._post_json("/api/claw/groups/panels/messages", {
|
||||
"panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
|
||||
})
|
||||
msgs = resp.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
for m in reversed(msgs):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(m.get("messageId") or ""),
|
||||
author=str(m.get("author") or ""),
|
||||
content=m.get("content"),
|
||||
meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
|
||||
converse_id=panel_id, timestamp=m.get("createdAt"),
|
||||
author_info=m.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Mochat panel polling error ({panel_id}): {e}")
|
||||
await asyncio.sleep(sleep_s)
|
||||
|
||||
# ---- inbound event processing ------------------------------------------
|
||||
|
||||
async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
target_id = _str_field(payload, "sessionId")
|
||||
if not target_id:
|
||||
return
|
||||
|
||||
lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
|
||||
async with lock:
|
||||
prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
|
||||
pc = payload.get("cursor")
|
||||
if target_kind == "session" and isinstance(pc, int) and pc >= 0:
|
||||
self._mark_session_cursor(target_id, pc)
|
||||
|
||||
raw_events = payload.get("events")
|
||||
if not isinstance(raw_events, list):
|
||||
return
|
||||
if target_kind == "session" and target_id in self._cold_sessions:
|
||||
self._cold_sessions.discard(target_id)
|
||||
return
|
||||
|
||||
for event in raw_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
seq = event.get("seq")
|
||||
if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
|
||||
self._mark_session_cursor(target_id, seq)
|
||||
if event.get("type") == "message.add":
|
||||
await self._process_inbound_event(target_id, event, target_kind)
|
||||
|
||||
async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
|
||||
payload = event.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
author = _str_field(payload, "author")
|
||||
if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
|
||||
return
|
||||
if not self.is_allowed(author):
|
||||
return
|
||||
|
||||
message_id = _str_field(payload, "messageId")
|
||||
seen_key = f"{target_kind}:{target_id}"
|
||||
if message_id and self._remember_message_id(seen_key, message_id):
|
||||
return
|
||||
|
||||
raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
|
||||
ai = _safe_dict(payload.get("authorInfo"))
|
||||
sender_name = _str_field(ai, "nickname", "email")
|
||||
sender_username = _str_field(ai, "agentId")
|
||||
|
||||
group_id = _str_field(payload, "groupId")
|
||||
is_group = bool(group_id)
|
||||
was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
|
||||
require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
|
||||
use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
|
||||
|
||||
if require_mention and not was_mentioned and not use_delay:
|
||||
return
|
||||
|
||||
entry = MochatBufferedEntry(
|
||||
raw_body=raw_body, author=author, sender_name=sender_name,
|
||||
sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
|
||||
message_id=message_id, group_id=group_id,
|
||||
)
|
||||
|
||||
if use_delay:
|
||||
delay_key = seen_key
|
||||
if was_mentioned:
|
||||
await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
|
||||
else:
|
||||
await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
|
||||
return
|
||||
|
||||
await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
|
||||
|
||||
# ---- dedup / buffering -------------------------------------------------
|
||||
|
||||
def _remember_message_id(self, key: str, message_id: str) -> bool:
|
||||
seen_set = self._seen_set.setdefault(key, set())
|
||||
seen_queue = self._seen_queue.setdefault(key, deque())
|
||||
if message_id in seen_set:
|
||||
return True
|
||||
seen_set.add(message_id)
|
||||
seen_queue.append(message_id)
|
||||
while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
|
||||
seen_set.discard(seen_queue.popleft())
|
||||
return False
|
||||
|
||||
async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
state.entries.append(entry)
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
|
||||
|
||||
async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
|
||||
await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
|
||||
await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
|
||||
|
||||
async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
if entry:
|
||||
state.entries.append(entry)
|
||||
current = asyncio.current_task()
|
||||
if state.timer and state.timer is not current:
|
||||
state.timer.cancel()
|
||||
state.timer = None
|
||||
entries = state.entries[:]
|
||||
state.entries.clear()
|
||||
if entries:
|
||||
await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
|
||||
|
||||
async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
|
||||
if not entries:
|
||||
return
|
||||
last = entries[-1]
|
||||
is_group = bool(last.group_id)
|
||||
body = build_buffered_body(entries, is_group) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=last.author, chat_id=target_id, content=body,
|
||||
metadata={
|
||||
"message_id": last.message_id, "timestamp": last.timestamp,
|
||||
"is_group": is_group, "group_id": last.group_id,
|
||||
"sender_name": last.sender_name, "sender_username": last.sender_username,
|
||||
"target_kind": target_kind, "was_mentioned": was_mentioned,
|
||||
"buffered_count": len(entries),
|
||||
},
|
||||
)
|
||||
|
||||
async def _cancel_delay_timers(self) -> None:
|
||||
for state in self._delay_states.values():
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
self._delay_states.clear()
|
||||
|
||||
# ---- notify handlers ---------------------------------------------------
|
||||
|
||||
async def _handle_notify_chat_message(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
group_id = _str_field(payload, "groupId")
|
||||
panel_id = _str_field(payload, "converseId", "panelId")
|
||||
if not group_id or not panel_id:
|
||||
return
|
||||
if self._panel_set and panel_id not in self._panel_set:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(payload.get("_id") or payload.get("messageId") or ""),
|
||||
author=str(payload.get("author") or ""),
|
||||
content=payload.get("content"), meta=payload.get("meta"),
|
||||
group_id=group_id, converse_id=panel_id,
|
||||
timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
|
||||
async def _handle_notify_inbox_append(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict) or payload.get("type") != "message":
|
||||
return
|
||||
detail = payload.get("payload")
|
||||
if not isinstance(detail, dict):
|
||||
return
|
||||
if _str_field(detail, "groupId"):
|
||||
return
|
||||
converse_id = _str_field(detail, "converseId")
|
||||
if not converse_id:
|
||||
return
|
||||
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
await self._refresh_sessions_directory(self._ws_ready)
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(detail.get("messageId") or payload.get("_id") or ""),
|
||||
author=str(detail.get("messageAuthor") or ""),
|
||||
content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
|
||||
meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
|
||||
group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
|
||||
)
|
||||
await self._process_inbound_event(session_id, evt, "session")
|
||||
|
||||
# ---- cursor persistence ------------------------------------------------
|
||||
|
||||
def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
|
||||
if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
|
||||
return
|
||||
self._session_cursor[session_id] = cursor
|
||||
if not self._cursor_save_task or self._cursor_save_task.done():
|
||||
self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
|
||||
|
||||
async def _save_cursor_debounced(self) -> None:
|
||||
await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
|
||||
await self._save_session_cursors()
|
||||
|
||||
async def _load_session_cursors(self) -> None:
|
||||
if not self._cursor_path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read Mochat cursor file: {e}")
|
||||
return
|
||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||
if isinstance(cursors, dict):
|
||||
for sid, cur in cursors.items():
|
||||
if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
|
||||
self._session_cursor[sid] = cur
|
||||
|
||||
async def _save_session_cursors(self) -> None:
|
||||
try:
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cursor_path.write_text(json.dumps({
|
||||
"schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
|
||||
"cursors": self._session_cursor,
|
||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save Mochat cursor file: {e}")
|
||||
|
||||
# ---- HTTP helpers ------------------------------------------------------
|
||||
|
||||
async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._http:
|
||||
raise RuntimeError("Mochat HTTP client not initialized")
|
||||
url = f"{self.config.base_url.strip().rstrip('/')}{path}"
|
||||
response = await self._http.post(url, headers={
|
||||
"Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
|
||||
}, json=payload)
|
||||
if not response.is_success:
|
||||
raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
|
||||
try:
|
||||
parsed = response.json()
|
||||
except Exception:
|
||||
parsed = response.text
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
|
||||
if parsed["code"] != 200:
|
||||
msg = str(parsed.get("message") or parsed.get("name") or "request failed")
|
||||
raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
|
||||
data = parsed.get("data")
|
||||
return data if isinstance(data, dict) else {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
async def _api_send(self, path: str, id_key: str, id_val: str,
|
||||
content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
|
||||
"""Unified send helper for session and panel messages."""
|
||||
body: dict[str, Any] = {id_key: id_val, "content": content}
|
||||
if reply_to:
|
||||
body["replyTo"] = reply_to
|
||||
if group_id:
|
||||
body["groupId"] = group_id
|
||||
return await self._post_json(path, body)
|
||||
|
||||
@staticmethod
|
||||
def _read_group_id(metadata: dict[str, Any]) -> str | None:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
value = metadata.get("group_id") or metadata.get("groupId")
|
||||
return value.strip() if isinstance(value, str) and value.strip() else None
|
||||
@ -75,12 +75,15 @@ class QQChannel(BaseChannel):
|
||||
logger.info("QQ bot started (C2C private message)")
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
"""Run the bot connection."""
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.error(f"QQ auth failed, check AppID/Secret at q.qq.com: {e}")
|
||||
self._running = False
|
||||
"""Run the bot connection with auto-reconnect."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.warning(f"QQ bot error: {e}")
|
||||
if self._running:
|
||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the QQ bot."""
|
||||
|
||||
@ -4,20 +4,16 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, Update
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
@ -94,7 +90,7 @@ class TelegramChannel(BaseChannel):
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
BotCommand("reset", "Reset conversation history"),
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
@ -103,12 +99,10 @@ class TelegramChannel(BaseChannel):
|
||||
config: TelegramConfig,
|
||||
bus: MessageBus,
|
||||
groq_api_key: str = "",
|
||||
session_manager: SessionManager | None = None,
|
||||
):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self.session_manager = session_manager
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
@ -121,16 +115,18 @@ class TelegramChannel(BaseChannel):
|
||||
|
||||
self._running = True
|
||||
|
||||
# Build the application
|
||||
builder = Application.builder().token(self.config.token)
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
if self.config.proxy:
|
||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("reset", self._on_reset))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._forward_command))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
@ -226,40 +222,15 @@ class TelegramChannel(BaseChannel):
|
||||
"Type /help to see available commands."
|
||||
)
|
||||
|
||||
async def _on_reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /reset command — clear conversation history."""
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
chat_id = str(update.message.chat_id)
|
||||
session_key = f"{self.name}:{chat_id}"
|
||||
|
||||
if self.session_manager is None:
|
||||
logger.warning("/reset called but session_manager is not available")
|
||||
await update.message.reply_text("⚠️ Session management is not available.")
|
||||
return
|
||||
|
||||
session = self.session_manager.get_or_create(session_key)
|
||||
msg_count = len(session.messages)
|
||||
session.clear()
|
||||
self.session_manager.save(session)
|
||||
|
||||
logger.info(f"Session reset for {session_key} (cleared {msg_count} messages)")
|
||||
await update.message.reply_text("🔄 Conversation history cleared. Let's start fresh!")
|
||||
|
||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /help command — show available commands."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
help_text = (
|
||||
"🐈 <b>nanobot commands</b>\n\n"
|
||||
"/start — Start the bot\n"
|
||||
"/reset — Reset conversation history\n"
|
||||
"/help — Show this help message\n\n"
|
||||
"Just send me a text message to chat!"
|
||||
await self._handle_message(
|
||||
sender_id=str(update.effective_user.id),
|
||||
chat_id=str(update.message.chat_id),
|
||||
content=update.message.text,
|
||||
)
|
||||
await update.message.reply_text(help_text, parse_mode="HTML")
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
@ -386,6 +357,10 @@ class TelegramChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.debug(f"Typing indicator stopped for {chat_id}: {e}")
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error(f"Telegram error: {context.error}")
|
||||
|
||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||
"""Get file extension based on media type."""
|
||||
if mime_type:
|
||||
|
||||
@ -42,6 +42,9 @@ class WhatsAppChannel(BaseChannel):
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
self._ws = ws
|
||||
# Send auth token if configured
|
||||
if self.config.bridge_token:
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
@ -11,10 +10,14 @@ import sys
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
|
||||
from nanobot import __version__, __logo__
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
@ -28,13 +31,10 @@ console = Console()
|
||||
EXIT_COMMANDS = {"exit", "quit", "/exit", "/quit", ":q"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lightweight CLI input: readline for arrow keys / history, termios for flush
|
||||
# CLI input: prompt_toolkit for editing, paste, history, and display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_READLINE = None
|
||||
_HISTORY_FILE: Path | None = None
|
||||
_HISTORY_HOOK_REGISTERED = False
|
||||
_USING_LIBEDIT = False
|
||||
_PROMPT_SESSION: PromptSession | None = None
|
||||
_SAVED_TERM_ATTRS = None # original termios settings, restored on exit
|
||||
|
||||
|
||||
@ -65,15 +65,6 @@ def _flush_pending_tty_input() -> None:
|
||||
return
|
||||
|
||||
|
||||
def _save_history() -> None:
|
||||
if _READLINE is None or _HISTORY_FILE is None:
|
||||
return
|
||||
try:
|
||||
_READLINE.write_history_file(str(_HISTORY_FILE))
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def _restore_terminal() -> None:
|
||||
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
||||
if _SAVED_TERM_ATTRS is None:
|
||||
@ -85,11 +76,11 @@ def _restore_terminal() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _enable_line_editing() -> None:
|
||||
"""Enable readline for arrow keys, line editing, and persistent history."""
|
||||
global _READLINE, _HISTORY_FILE, _HISTORY_HOOK_REGISTERED, _USING_LIBEDIT, _SAVED_TERM_ATTRS
|
||||
def _init_prompt_session() -> None:
|
||||
"""Create the prompt_toolkit session with persistent file history."""
|
||||
global _PROMPT_SESSION, _SAVED_TERM_ATTRS
|
||||
|
||||
# Save terminal state before readline touches it
|
||||
# Save terminal state so we can restore it on exit
|
||||
try:
|
||||
import termios
|
||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||
@ -98,43 +89,12 @@ def _enable_line_editing() -> None:
|
||||
|
||||
history_file = Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
_HISTORY_FILE = history_file
|
||||
|
||||
try:
|
||||
import readline
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
_READLINE = readline
|
||||
_USING_LIBEDIT = "libedit" in (readline.__doc__ or "").lower()
|
||||
|
||||
try:
|
||||
if _USING_LIBEDIT:
|
||||
readline.parse_and_bind("bind ^I rl_complete")
|
||||
else:
|
||||
readline.parse_and_bind("tab: complete")
|
||||
readline.parse_and_bind("set editing-mode emacs")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
readline.read_history_file(str(history_file))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not _HISTORY_HOOK_REGISTERED:
|
||||
atexit.register(_save_history)
|
||||
_HISTORY_HOOK_REGISTERED = True
|
||||
|
||||
|
||||
def _prompt_text() -> str:
|
||||
"""Build a readline-friendly colored prompt."""
|
||||
if _READLINE is None:
|
||||
return "You: "
|
||||
# libedit on macOS does not honor GNU readline non-printing markers.
|
||||
if _USING_LIBEDIT:
|
||||
return "\033[1;34mYou:\033[0m "
|
||||
return "\001\033[1;34m\002You:\001\033[0m\002 "
|
||||
_PROMPT_SESSION = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
enable_open_in_editor=False,
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
)
|
||||
|
||||
|
||||
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||
@ -142,15 +102,8 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||
content = response or ""
|
||||
body = Markdown(content) if render_markdown else Text(content)
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
body,
|
||||
title=f"{__logo__} nanobot",
|
||||
title_align="left",
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
console.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||
console.print(body)
|
||||
console.print()
|
||||
|
||||
|
||||
@ -160,13 +113,25 @@ def _is_exit_command(command: str) -> bool:
|
||||
|
||||
|
||||
async def _read_interactive_input_async() -> str:
|
||||
"""Read user input with arrow keys and history (runs input() in a thread)."""
|
||||
"""Read user input using prompt_toolkit (handles paste, history, display).
|
||||
|
||||
prompt_toolkit natively handles:
|
||||
- Multiline paste (bracketed paste mode)
|
||||
- History navigation (up/down arrows)
|
||||
- Clean display (no ghost characters or artifacts)
|
||||
"""
|
||||
if _PROMPT_SESSION is None:
|
||||
raise RuntimeError("Call _init_prompt_session() first")
|
||||
try:
|
||||
return await asyncio.to_thread(input, _prompt_text())
|
||||
with patch_stdout():
|
||||
return await _PROMPT_SESSION.prompt_async(
|
||||
HTML("<b fg='ansiblue'>You:</b> "),
|
||||
)
|
||||
except EOFError as exc:
|
||||
raise KeyboardInterrupt from exc
|
||||
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
console.print(f"{__logo__} nanobot v{__version__}")
|
||||
@ -191,7 +156,7 @@ def main(
|
||||
@app.command()
|
||||
def onboard():
|
||||
"""Initialize nanobot configuration and workspace."""
|
||||
from nanobot.config.loader import get_config_path, save_config
|
||||
from nanobot.config.loader import get_config_path, load_config, save_config
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import get_workspace_path
|
||||
|
||||
@ -199,17 +164,26 @@ def onboard():
|
||||
|
||||
if config_path.exists():
|
||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||
if not typer.confirm("Overwrite?"):
|
||||
raise typer.Exit()
|
||||
|
||||
# Create default config
|
||||
config = Config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
|
||||
if typer.confirm("Overwrite?"):
|
||||
config = Config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
|
||||
else:
|
||||
config = load_config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
|
||||
else:
|
||||
save_config(Config())
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
|
||||
# Create workspace
|
||||
workspace = get_workspace_path()
|
||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||
|
||||
if not workspace.exists():
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||
|
||||
# Create default bootstrap files
|
||||
_create_workspace_templates(workspace)
|
||||
@ -236,7 +210,7 @@ You are a helpful AI assistant. Be concise, accurate, and friendly.
|
||||
- Always explain what you're doing before taking actions
|
||||
- Ask for clarification when the request is ambiguous
|
||||
- Use tools to help accomplish tasks
|
||||
- Remember important information in your memory files
|
||||
- Remember important information in memory/MEMORY.md; past events are logged in memory/HISTORY.md
|
||||
""",
|
||||
"SOUL.md": """# Soul
|
||||
|
||||
@ -294,6 +268,15 @@ This file stores important information that should persist across sessions.
|
||||
(Things to remember)
|
||||
""")
|
||||
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
||||
|
||||
history_file = memory_dir / "HISTORY.md"
|
||||
if not history_file.exists():
|
||||
history_file.write_text("")
|
||||
console.print(" [dim]Created memory/HISTORY.md[/dim]")
|
||||
|
||||
# Create skills directory for custom user skills
|
||||
skills_dir = workspace / "skills"
|
||||
skills_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def _make_provider(config: Config):
|
||||
@ -367,12 +350,16 @@ def gateway(
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
@ -407,7 +394,7 @@ def gateway(
|
||||
)
|
||||
|
||||
# Create channel manager
|
||||
channels = ChannelManager(config, bus, session_manager=session_manager)
|
||||
channels = ChannelManager(config, bus)
|
||||
|
||||
if channels.enabled_channels:
|
||||
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
||||
@ -430,6 +417,8 @@ def gateway(
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\nShutting down...")
|
||||
finally:
|
||||
await agent.close_mcp()
|
||||
heartbeat.stop()
|
||||
cron.stop()
|
||||
agent.stop()
|
||||
@ -448,7 +437,7 @@ def gateway(
|
||||
@app.command()
|
||||
def agent(
|
||||
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
|
||||
session_id: str = typer.Option("cli:default", "--session", "-s", help="Session ID"),
|
||||
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
|
||||
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||
):
|
||||
@ -472,9 +461,15 @@ def agent(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
)
|
||||
|
||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
||||
@ -482,6 +477,7 @@ def agent(
|
||||
if logs:
|
||||
from contextlib import nullcontext
|
||||
return nullcontext()
|
||||
# Animated spinner is safe to use with prompt_toolkit input handling
|
||||
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
||||
|
||||
if message:
|
||||
@ -490,17 +486,15 @@ def agent(
|
||||
with _thinking_ctx():
|
||||
response = await agent_loop.process_direct(message, session_id)
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_once())
|
||||
else:
|
||||
# Interactive mode
|
||||
_enable_line_editing()
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||
|
||||
# input() runs in a worker thread that can't be cancelled.
|
||||
# Without this handler, asyncio.run() would hang waiting for it.
|
||||
def _exit_on_sigint(signum, frame):
|
||||
_save_history()
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
os._exit(0)
|
||||
@ -508,33 +502,33 @@ def agent(
|
||||
signal.signal(signal.SIGINT, _exit_on_sigint)
|
||||
|
||||
async def run_interactive():
|
||||
while True:
|
||||
try:
|
||||
_flush_pending_tty_input()
|
||||
user_input = await _read_interactive_input_async()
|
||||
command = user_input.strip()
|
||||
if not command:
|
||||
continue
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
_flush_pending_tty_input()
|
||||
user_input = await _read_interactive_input_async()
|
||||
command = user_input.strip()
|
||||
if not command:
|
||||
continue
|
||||
|
||||
if _is_exit_command(command):
|
||||
_save_history()
|
||||
if _is_exit_command(command):
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
|
||||
with _thinking_ctx():
|
||||
response = await agent_loop.process_direct(user_input, session_id)
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
except KeyboardInterrupt:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
|
||||
with _thinking_ctx():
|
||||
response = await agent_loop.process_direct(user_input, session_id)
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
except KeyboardInterrupt:
|
||||
_save_history()
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
except EOFError:
|
||||
_save_history()
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
except EOFError:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
finally:
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_interactive())
|
||||
|
||||
@ -574,6 +568,24 @@ def channels_status():
|
||||
"✓" if dc.enabled else "✗",
|
||||
dc.gateway_url
|
||||
)
|
||||
|
||||
# Feishu
|
||||
fs = config.channels.feishu
|
||||
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Feishu",
|
||||
"✓" if fs.enabled else "✗",
|
||||
fs_config
|
||||
)
|
||||
|
||||
# Mochat
|
||||
mc = config.channels.mochat
|
||||
mc_base = mc.base_url or "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Mochat",
|
||||
"✓" if mc.enabled else "✗",
|
||||
mc_base
|
||||
)
|
||||
|
||||
# Telegram
|
||||
tg = config.channels.telegram
|
||||
@ -658,14 +670,20 @@ def _get_bridge_dir() -> Path:
|
||||
def channels_login():
|
||||
"""Link device via QR code."""
|
||||
import subprocess
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
config = load_config()
|
||||
bridge_dir = _get_bridge_dir()
|
||||
|
||||
console.print(f"{__logo__} Starting bridge...")
|
||||
console.print("Scan the QR code to connect.\n")
|
||||
|
||||
env = {**os.environ}
|
||||
if config.channels.whatsapp.bridge_token:
|
||||
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
||||
|
||||
try:
|
||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True)
|
||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]Bridge failed: {e}[/red]")
|
||||
except FileNotFoundError:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Configuration schema using Pydantic."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ class WhatsAppConfig(BaseModel):
|
||||
"""WhatsApp channel configuration."""
|
||||
enabled: bool = False
|
||||
bridge_url: str = "ws://localhost:3001"
|
||||
bridge_token: str = "" # Shared token for bridge auth (optional, recommended)
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
|
||||
|
||||
|
||||
@ -77,6 +78,42 @@ class EmailConfig(BaseModel):
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
|
||||
|
||||
|
||||
class MochatMentionConfig(BaseModel):
|
||||
"""Mochat mention behavior configuration."""
|
||||
require_in_groups: bool = False
|
||||
|
||||
|
||||
class MochatGroupRule(BaseModel):
|
||||
"""Mochat per-group mention requirement."""
|
||||
require_mention: bool = False
|
||||
|
||||
|
||||
class MochatConfig(BaseModel):
|
||||
"""Mochat channel configuration."""
|
||||
enabled: bool = False
|
||||
base_url: str = "https://mochat.io"
|
||||
socket_url: str = ""
|
||||
socket_path: str = "/socket.io"
|
||||
socket_disable_msgpack: bool = False
|
||||
socket_reconnect_delay_ms: int = 1000
|
||||
socket_max_reconnect_delay_ms: int = 10000
|
||||
socket_connect_timeout_ms: int = 10000
|
||||
refresh_interval_ms: int = 30000
|
||||
watch_timeout_ms: int = 25000
|
||||
watch_limit: int = 100
|
||||
retry_delay_ms: int = 500
|
||||
max_retry_attempts: int = 0 # 0 means unlimited retries
|
||||
claw_token: str = ""
|
||||
agent_user_id: str = ""
|
||||
sessions: list[str] = Field(default_factory=list)
|
||||
panels: list[str] = Field(default_factory=list)
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
|
||||
groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
|
||||
reply_delay_mode: str = "non-mention" # off | non-mention
|
||||
reply_delay_ms: int = 120000
|
||||
|
||||
|
||||
class SlackDMConfig(BaseModel):
|
||||
"""Slack DM policy configuration."""
|
||||
enabled: bool = True
|
||||
@ -92,7 +129,7 @@ class SlackConfig(BaseModel):
|
||||
bot_token: str = "" # xoxb-...
|
||||
app_token: str = "" # xapp-...
|
||||
user_token_read_only: bool = True
|
||||
group_policy: str = "open" # "open", "mention", "allowlist"
|
||||
group_policy: str = "mention" # "mention", "open", "allowlist"
|
||||
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
||||
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
||||
|
||||
@ -111,6 +148,7 @@ class ChannelsConfig(BaseModel):
|
||||
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
||||
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
||||
feishu: FeishuConfig = Field(default_factory=FeishuConfig)
|
||||
mochat: MochatConfig = Field(default_factory=MochatConfig)
|
||||
dingtalk: DingTalkConfig = Field(default_factory=DingTalkConfig)
|
||||
email: EmailConfig = Field(default_factory=EmailConfig)
|
||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||
@ -124,6 +162,7 @@ class AgentDefaults(BaseModel):
|
||||
max_tokens: int = 8192
|
||||
temperature: float = 0.7
|
||||
max_tool_iterations: int = 20
|
||||
memory_window: int = 50
|
||||
|
||||
|
||||
class AgentsConfig(BaseModel):
|
||||
@ -140,6 +179,7 @@ class ProviderConfig(BaseModel):
|
||||
|
||||
class ProvidersConfig(BaseModel):
|
||||
"""Configuration for LLM providers."""
|
||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
@ -150,6 +190,7 @@ class ProvidersConfig(BaseModel):
|
||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) # AiHubMix API gateway
|
||||
|
||||
@ -176,11 +217,20 @@ class ExecToolConfig(BaseModel):
|
||||
timeout: int = 60
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""MCP server connection configuration (stdio or HTTP)."""
|
||||
command: str = "" # Stdio: command to run (e.g. "npx")
|
||||
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
||||
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||
url: str = "" # HTTP: streamable HTTP endpoint URL
|
||||
|
||||
|
||||
class ToolsConfig(BaseModel):
|
||||
"""Tools configuration."""
|
||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
@ -248,6 +298,7 @@ class Config(BaseSettings):
|
||||
return spec.default_api_base
|
||||
return None
|
||||
|
||||
class Config:
|
||||
env_prefix = "NANOBOT_"
|
||||
env_nested_delimiter = "__"
|
||||
model_config = ConfigDict(
|
||||
env_prefix="NANOBOT_",
|
||||
env_nested_delimiter="__"
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
@ -30,9 +31,13 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
||||
if schedule.kind == "cron" and schedule.expr:
|
||||
try:
|
||||
from croniter import croniter
|
||||
cron = croniter(schedule.expr, time.time())
|
||||
next_time = cron.get_next()
|
||||
return int(next_time * 1000)
|
||||
from zoneinfo import ZoneInfo
|
||||
base_time = time.time()
|
||||
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
||||
base_dt = datetime.fromtimestamp(base_time, tz=tz)
|
||||
cron = croniter(schedule.expr, base_dt)
|
||||
next_dt = cron.get_next(datetime)
|
||||
return int(next_dt.timestamp() * 1000)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@ -15,7 +16,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LLM provider using LiteLLM for multi-provider support.
|
||||
|
||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, and many other providers through
|
||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through
|
||||
a unified interface. Provider-specific logic is driven by the registry
|
||||
(see providers/registry.py) — no if-elif chains needed here.
|
||||
"""
|
||||
@ -125,6 +126,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
model = self._resolve_model(model or self.default_model)
|
||||
|
||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@ -135,6 +140,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
# Pass api_key directly — more reliable than env vars alone
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
@ -168,10 +177,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {"raw": args}
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tc.id,
|
||||
|
||||
@ -66,6 +66,20 @@ class ProviderSpec:
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
|
||||
# === Custom (user-provided OpenAI-compatible endpoint) =================
|
||||
# No auto-detection — only activates when user explicitly configures "custom".
|
||||
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="Custom",
|
||||
litellm_prefix="openai",
|
||||
skip_prefixes=("openai/",),
|
||||
is_gateway=True,
|
||||
strip_model_prefix=True,
|
||||
),
|
||||
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
|
||||
@ -265,6 +279,25 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
),
|
||||
),
|
||||
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
|
||||
@ -15,15 +15,20 @@ from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
class Session:
|
||||
"""
|
||||
A conversation session.
|
||||
|
||||
|
||||
Stores messages in JSONL format for easy reading and persistence.
|
||||
|
||||
Important: Messages are append-only for LLM cache efficiency.
|
||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||
but does NOT modify the messages list or get_history() output.
|
||||
"""
|
||||
|
||||
|
||||
key: str # channel:chat_id
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||
"""Add a message to the session."""
|
||||
@ -36,35 +41,24 @@ class Session:
|
||||
self.messages.append(msg)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def get_history(self, max_messages: int = 50) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get message history for LLM context.
|
||||
|
||||
Args:
|
||||
max_messages: Maximum messages to return.
|
||||
|
||||
Returns:
|
||||
List of messages in LLM format.
|
||||
"""
|
||||
# Get recent messages
|
||||
recent = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
|
||||
|
||||
# Convert to LLM format (just role and content)
|
||||
return [{"role": m["role"], "content": m["content"]} for m in recent]
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Get recent messages in LLM format (role + content only)."""
|
||||
return [{"role": m["role"], "content": m["content"]} for m in self.messages[-max_messages:]]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages in the session."""
|
||||
"""Clear all messages and reset session to initial state."""
|
||||
self.messages = []
|
||||
self.last_consolidated = 0
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages conversation sessions.
|
||||
|
||||
|
||||
Sessions are stored as JSONL files in the sessions directory.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions")
|
||||
@ -85,11 +79,9 @@ class SessionManager:
|
||||
Returns:
|
||||
The session.
|
||||
"""
|
||||
# Check cache
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
|
||||
# Try to load from disk
|
||||
session = self._load(key)
|
||||
if session is None:
|
||||
session = Session(key=key)
|
||||
@ -100,34 +92,37 @@ class SessionManager:
|
||||
def _load(self, key: str) -> Session | None:
|
||||
"""Load a session from disk."""
|
||||
path = self._get_session_path(key)
|
||||
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
data = json.loads(line)
|
||||
|
||||
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
|
||||
|
||||
return Session(
|
||||
key=key,
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load session {key}: {e}")
|
||||
@ -136,42 +131,24 @@ class SessionManager:
|
||||
def save(self, session: Session) -> None:
|
||||
"""Save a session to disk."""
|
||||
path = self._get_session_path(session.key)
|
||||
|
||||
|
||||
with open(path, "w") as f:
|
||||
# Write metadata first
|
||||
metadata_line = {
|
||||
"_type": "metadata",
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
"metadata": session.metadata
|
||||
"metadata": session.metadata,
|
||||
"last_consolidated": session.last_consolidated
|
||||
}
|
||||
f.write(json.dumps(metadata_line) + "\n")
|
||||
|
||||
# Write messages
|
||||
for msg in session.messages:
|
||||
f.write(json.dumps(msg) + "\n")
|
||||
|
||||
|
||||
self._cache[session.key] = session
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a session.
|
||||
|
||||
Args:
|
||||
key: Session key.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
# Remove from cache
|
||||
def invalidate(self, key: str) -> None:
|
||||
"""Remove a session from the in-memory cache."""
|
||||
self._cache.pop(key, None)
|
||||
|
||||
# Remove file
|
||||
path = self._get_session_path(key)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_sessions(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
|
||||
@ -7,10 +7,11 @@ description: Schedule reminders and recurring tasks.
|
||||
|
||||
Use the `cron` tool to schedule reminders or recurring tasks.
|
||||
|
||||
## Two Modes
|
||||
## Three Modes
|
||||
|
||||
1. **Reminder** - message is sent directly to user
|
||||
2. **Task** - message is a task description, agent executes and sends result
|
||||
3. **One-time** - runs once at a specific time, then auto-deletes
|
||||
|
||||
## Examples
|
||||
|
||||
@ -24,6 +25,11 @@ Dynamic task (agent executes each time):
|
||||
cron(action="add", message="Check HKUDS/nanobot GitHub stars and report", every_seconds=600)
|
||||
```
|
||||
|
||||
One-time scheduled task (compute ISO datetime from current time):
|
||||
```
|
||||
cron(action="add", message="Remind me about the meeting", at="<ISO datetime>")
|
||||
```
|
||||
|
||||
List/remove:
|
||||
```
|
||||
cron(action="list")
|
||||
@ -38,3 +44,4 @@ cron(action="remove", job_id="abc123")
|
||||
| every hour | every_seconds: 3600 |
|
||||
| every day at 8am | cron_expr: "0 8 * * *" |
|
||||
| weekdays at 5pm | cron_expr: "0 17 * * 1-5" |
|
||||
| at a specific time | at: ISO datetime string (compute from current time) |
|
||||
|
||||
31
nanobot/skills/memory/SKILL.md
Normal file
31
nanobot/skills/memory/SKILL.md
Normal file
@ -0,0 +1,31 @@
|
||||
---
|
||||
name: memory
|
||||
description: Two-layer memory system with grep-based recall.
|
||||
always: true
|
||||
---
|
||||
|
||||
# Memory
|
||||
|
||||
## Structure
|
||||
|
||||
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
|
||||
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep.
|
||||
|
||||
## Search Past Events
|
||||
|
||||
```bash
|
||||
grep -i "keyword" memory/HISTORY.md
|
||||
```
|
||||
|
||||
Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md`
|
||||
|
||||
## When to Update MEMORY.md
|
||||
|
||||
Write important facts immediately using `edit_file` or `write_file`:
|
||||
- User preferences ("I prefer dark mode")
|
||||
- Project context ("The API uses OAuth2")
|
||||
- Relationships ("Alice is the project lead")
|
||||
|
||||
## Auto-consolidation
|
||||
|
||||
Old conversations are automatically summarized and appended to HISTORY.md when the session grows large. Long-term facts are extracted to MEMORY.md. You don't need to manage this.
|
||||
@ -37,23 +37,12 @@ def get_sessions_path() -> Path:
|
||||
return ensure_dir(get_data_path() / "sessions")
|
||||
|
||||
|
||||
def get_memory_path(workspace: Path | None = None) -> Path:
|
||||
"""Get the memory directory within the workspace."""
|
||||
ws = workspace or get_workspace_path()
|
||||
return ensure_dir(ws / "memory")
|
||||
|
||||
|
||||
def get_skills_path(workspace: Path | None = None) -> Path:
|
||||
"""Get the skills directory within the workspace."""
|
||||
ws = workspace or get_workspace_path()
|
||||
return ensure_dir(ws / "skills")
|
||||
|
||||
|
||||
def today_date() -> str:
|
||||
"""Get today's date in YYYY-MM-DD format."""
|
||||
return datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def timestamp() -> str:
|
||||
"""Get current timestamp in ISO format."""
|
||||
return datetime.now().isoformat()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "nanobot-ai"
|
||||
version = "0.1.3.post5"
|
||||
version = "0.1.3.post7"
|
||||
description = "A lightweight personal AI assistant framework"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "MIT"}
|
||||
@ -33,8 +33,14 @@ dependencies = [
|
||||
"python-telegram-bot[socks]>=21.0",
|
||||
"lark-oapi>=1.0.0",
|
||||
"socksio>=1.0.0",
|
||||
"python-socketio>=5.11.0",
|
||||
"msgpack>=1.0.8",
|
||||
"slack-sdk>=3.26.0",
|
||||
"qq-botpy>=1.0.0",
|
||||
"python-socks[asyncio]>=2.4.0",
|
||||
"prompt-toolkit>=3.0.0",
|
||||
"mcp>=1.0.0",
|
||||
"json-repair>=0.30.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
59
tests/test_cli_input.py
Normal file
59
tests/test_cli_input.py
Normal file
@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
|
||||
from nanobot.cli import commands
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_session():
|
||||
"""Mock the global prompt session."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.prompt_async = AsyncMock()
|
||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||
patch("nanobot.cli.commands.patch_stdout"):
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_returns_input(mock_prompt_session):
|
||||
"""Test that _read_interactive_input_async returns the user input from prompt_session."""
|
||||
mock_prompt_session.prompt_async.return_value = "hello world"
|
||||
|
||||
result = await commands._read_interactive_input_async()
|
||||
|
||||
assert result == "hello world"
|
||||
mock_prompt_session.prompt_async.assert_called_once()
|
||||
args, _ = mock_prompt_session.prompt_async.call_args
|
||||
assert isinstance(args[0], HTML) # Verify HTML prompt is used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_handles_eof(mock_prompt_session):
|
||||
"""Test that EOFError converts to KeyboardInterrupt."""
|
||||
mock_prompt_session.prompt_async.side_effect = EOFError()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await commands._read_interactive_input_async()
|
||||
|
||||
|
||||
def test_init_prompt_session_creates_session():
|
||||
"""Test that _init_prompt_session initializes the global session."""
|
||||
# Ensure global is None before test
|
||||
commands._PROMPT_SESSION = None
|
||||
|
||||
with patch("nanobot.cli.commands.PromptSession") as MockSession, \
|
||||
patch("nanobot.cli.commands.FileHistory") as MockHistory, \
|
||||
patch("pathlib.Path.home") as mock_home:
|
||||
|
||||
mock_home.return_value = MagicMock()
|
||||
|
||||
commands._init_prompt_session()
|
||||
|
||||
assert commands._PROMPT_SESSION is not None
|
||||
MockSession.assert_called_once()
|
||||
_, kwargs = MockSession.call_args
|
||||
assert kwargs["multiline"] is False
|
||||
assert kwargs["enable_open_in_editor"] is False
|
||||
92
tests/test_commands.py
Normal file
92
tests/test_commands.py
Normal file
@ -0,0 +1,92 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.utils.helpers.get_workspace_path") as mock_ws:
|
||||
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
base_dir.mkdir()
|
||||
|
||||
config_file = base_dir / "config.json"
|
||||
workspace_dir = base_dir / "workspace"
|
||||
|
||||
mock_cp.return_value = config_file
|
||||
mock_ws.return_value = workspace_dir
|
||||
mock_sc.side_effect = lambda config: config_file.write_text("{}")
|
||||
|
||||
yield config_file, workspace_dir
|
||||
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
|
||||
|
||||
def test_onboard_fresh_install(mock_paths):
|
||||
"""No existing config — should create from scratch."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
|
||||
result = runner.invoke(app, ["onboard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created config" in result.stdout
|
||||
assert "Created workspace" in result.stdout
|
||||
assert "nanobot is ready" in result.stdout
|
||||
assert config_file.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
assert (workspace_dir / "memory" / "MEMORY.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_refresh(mock_paths):
|
||||
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "existing values preserved" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_overwrite(mock_paths):
|
||||
"""Config exists, user confirms overwrite — should reset to defaults."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="y\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "Config reset to defaults" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_existing_workspace_safe_create(mock_paths):
|
||||
"""Workspace exists — should not recreate, but still add missing templates."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
workspace_dir.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created workspace" not in result.stdout
|
||||
assert "Created AGENTS.md" in result.stdout
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
477
tests/test_consolidate_offset.py
Normal file
477
tests/test_consolidate_offset.py
Normal file
@ -0,0 +1,477 @@
|
||||
"""Test session management with cache-friendly message handling."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
# Test constants
|
||||
MEMORY_WINDOW = 50
|
||||
KEEP_COUNT = MEMORY_WINDOW // 2 # 25
|
||||
|
||||
|
||||
def create_session_with_messages(key: str, count: int, role: str = "user") -> Session:
|
||||
"""Create a session and add the specified number of messages.
|
||||
|
||||
Args:
|
||||
key: Session identifier
|
||||
count: Number of messages to add
|
||||
role: Message role (default: "user")
|
||||
|
||||
Returns:
|
||||
Session with the specified messages
|
||||
"""
|
||||
session = Session(key=key)
|
||||
for i in range(count):
|
||||
session.add_message(role, f"msg{i}")
|
||||
return session
|
||||
|
||||
|
||||
def assert_messages_content(messages: list, start_index: int, end_index: int) -> None:
|
||||
"""Assert that messages contain expected content from start to end index.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
start_index: Expected first message index
|
||||
end_index: Expected last message index
|
||||
"""
|
||||
assert len(messages) > 0
|
||||
assert messages[0]["content"] == f"msg{start_index}"
|
||||
assert messages[-1]["content"] == f"msg{end_index}"
|
||||
|
||||
|
||||
def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list:
|
||||
"""Extract messages that would be consolidated using the standard slice logic.
|
||||
|
||||
Args:
|
||||
session: The session containing messages
|
||||
last_consolidated: Index of last consolidated message
|
||||
keep_count: Number of recent messages to keep
|
||||
|
||||
Returns:
|
||||
List of messages that would be consolidated
|
||||
"""
|
||||
return session.messages[last_consolidated:-keep_count]
|
||||
|
||||
|
||||
class TestSessionLastConsolidated:
|
||||
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||
|
||||
def test_initial_last_consolidated_zero(self) -> None:
|
||||
"""Test that new session starts with last_consolidated=0."""
|
||||
session = Session(key="test:initial")
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||
"""Test that last_consolidated persists across save/load."""
|
||||
manager = SessionManager(Path(tmp_path))
|
||||
session1 = create_session_with_messages("test:persist", 20)
|
||||
session1.last_consolidated = 15
|
||||
manager.save(session1)
|
||||
|
||||
session2 = manager.get_or_create("test:persist")
|
||||
assert session2.last_consolidated == 15
|
||||
assert len(session2.messages) == 20
|
||||
|
||||
def test_clear_resets_last_consolidated(self) -> None:
|
||||
"""Test that clear() resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
session.last_consolidated = 5
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
|
||||
class TestSessionImmutableHistory:
|
||||
"""Test Session message immutability for cache efficiency."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test that new session has empty messages list."""
|
||||
session = Session(key="test:initial")
|
||||
assert len(session.messages) == 0
|
||||
|
||||
def test_add_messages_appends_only(self) -> None:
|
||||
"""Test that adding messages only appends, never modifies."""
|
||||
session = Session(key="test:preserve")
|
||||
session.add_message("user", "msg1")
|
||||
session.add_message("assistant", "resp1")
|
||||
session.add_message("user", "msg2")
|
||||
assert len(session.messages) == 3
|
||||
assert session.messages[0]["content"] == "msg1"
|
||||
|
||||
def test_get_history_returns_most_recent(self) -> None:
|
||||
"""Test get_history returns the most recent messages."""
|
||||
session = Session(key="test:history")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
|
||||
history = session.get_history(max_messages=6)
|
||||
assert len(history) == 6
|
||||
assert history[0]["content"] == "msg7"
|
||||
assert history[-1]["content"] == "resp9"
|
||||
|
||||
def test_get_history_with_all_messages(self) -> None:
|
||||
"""Test get_history with max_messages larger than actual."""
|
||||
session = create_session_with_messages("test:all", 5)
|
||||
history = session.get_history(max_messages=100)
|
||||
assert len(history) == 5
|
||||
assert history[0]["content"] == "msg0"
|
||||
|
||||
def test_get_history_stable_for_same_session(self) -> None:
|
||||
"""Test that get_history returns same content for same max_messages."""
|
||||
session = create_session_with_messages("test:stable", 20)
|
||||
history1 = session.get_history(max_messages=10)
|
||||
history2 = session.get_history(max_messages=10)
|
||||
assert history1 == history2
|
||||
|
||||
def test_messages_list_never_modified(self) -> None:
|
||||
"""Test that messages list is never modified after creation."""
|
||||
session = create_session_with_messages("test:immutable", 5)
|
||||
original_len = len(session.messages)
|
||||
|
||||
session.get_history(max_messages=2)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
for _ in range(10):
|
||||
session.get_history(max_messages=3)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
"""Test Session persistence and reload."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_manager(self, tmp_path):
|
||||
return SessionManager(Path(tmp_path))
|
||||
|
||||
def test_persistence_roundtrip(self, temp_manager):
|
||||
"""Test that messages persist across save/load."""
|
||||
session1 = create_session_with_messages("test:persistence", 20)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:persistence")
|
||||
assert len(session2.messages) == 20
|
||||
assert session2.messages[0]["content"] == "msg0"
|
||||
assert session2.messages[-1]["content"] == "msg19"
|
||||
|
||||
def test_get_history_after_reload(self, temp_manager):
|
||||
"""Test that get_history works correctly after reload."""
|
||||
session1 = create_session_with_messages("test:reload", 30)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:reload")
|
||||
history = session2.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
assert history[0]["content"] == "msg20"
|
||||
assert history[-1]["content"] == "msg29"
|
||||
|
||||
def test_clear_resets_session(self, temp_manager):
|
||||
"""Test that clear() properly resets session."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
assert len(session.messages) == 10
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
|
||||
|
||||
class TestConsolidationTriggerConditions:
|
||||
"""Test consolidation trigger conditions and logic."""
|
||||
|
||||
def test_consolidation_needed_when_messages_exceed_window(self):
|
||||
"""Test consolidation logic: should trigger when messages > memory_window."""
|
||||
session = create_session_with_messages("test:trigger", 60)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
|
||||
assert total_messages > MEMORY_WINDOW
|
||||
assert messages_to_process > 0
|
||||
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT
|
||||
assert expected_consolidate_count == 35
|
||||
|
||||
def test_consolidation_skipped_when_within_keep_count(self):
|
||||
"""Test consolidation skipped when total messages <= keep_count."""
|
||||
session = create_session_with_messages("test:skip", 20)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
assert total_messages <= KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_consolidation_skipped_when_no_new_messages(self):
|
||||
"""Test consolidation skipped when messages_to_process <= 0."""
|
||||
session = create_session_with_messages("test:already_consolidated", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add a few more messages
|
||||
for i in range(40, 42):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process > 0
|
||||
|
||||
# Simulate last_consolidated catching up
|
||||
session.last_consolidated = total_messages - KEEP_COUNT
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestLastConsolidatedEdgeCases:
|
||||
"""Test last_consolidated edge cases and data corruption scenarios."""
|
||||
|
||||
def test_last_consolidated_exceeds_message_count(self):
|
||||
"""Test behavior when last_consolidated > len(messages) (data corruption)."""
|
||||
session = create_session_with_messages("test:corruption", 10)
|
||||
session.last_consolidated = 20
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process <= 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 5)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_last_consolidated_negative_value(self):
|
||||
"""Test behavior with negative last_consolidated (invalid state)."""
|
||||
session = create_session_with_messages("test:negative", 10)
|
||||
session.last_consolidated = -5
|
||||
|
||||
keep_count = 3
|
||||
old_messages = get_old_messages(session, session.last_consolidated, keep_count)
|
||||
|
||||
# messages[-5:-3] with 10 messages gives indices 5,6
|
||||
assert len(old_messages) == 2
|
||||
assert old_messages[0]["content"] == "msg5"
|
||||
assert old_messages[-1]["content"] == "msg6"
|
||||
|
||||
def test_messages_added_after_consolidation(self):
|
||||
"""Test correct behavior when new messages arrive after consolidation."""
|
||||
session = create_session_with_messages("test:new_messages", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add new messages after consolidation
|
||||
for i in range(40, 50):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated
|
||||
|
||||
assert len(old_messages) == expected_consolidate_count
|
||||
assert_messages_content(old_messages, 15, 24)
|
||||
|
||||
def test_slice_behavior_when_indices_overlap(self):
|
||||
"""Test slice behavior when last_consolidated >= total - keep_count."""
|
||||
session = create_session_with_messages("test:overlap", 30)
|
||||
session.last_consolidated = 12
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 20)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestArchiveAllMode:
|
||||
"""Test archive_all mode (used by /new command)."""
|
||||
|
||||
def test_archive_all_consolidates_everything(self):
|
||||
"""Test archive_all=True consolidates all messages."""
|
||||
session = create_session_with_messages("test:archive_all", 50)
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
assert len(old_messages) == 50
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_archive_all_resets_last_consolidated(self):
|
||||
"""Test that archive_all mode resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:reset", 40)
|
||||
session.last_consolidated = 15
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
session.last_consolidated = 0
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
assert len(session.messages) == 40
|
||||
|
||||
def test_archive_all_vs_normal_consolidation(self):
|
||||
"""Test difference between archive_all and normal consolidation."""
|
||||
# Normal consolidation
|
||||
session1 = create_session_with_messages("test:normal", 60)
|
||||
session1.last_consolidated = len(session1.messages) - KEEP_COUNT
|
||||
|
||||
# archive_all mode
|
||||
session2 = create_session_with_messages("test:all", 60)
|
||||
session2.last_consolidated = 0
|
||||
|
||||
assert session1.last_consolidated == 35
|
||||
assert len(session1.messages) == 60
|
||||
assert session2.last_consolidated == 0
|
||||
assert len(session2.messages) == 60
|
||||
|
||||
|
||||
class TestCacheImmutability:
|
||||
"""Test that consolidation doesn't modify session.messages (cache safety)."""
|
||||
|
||||
def test_consolidation_does_not_modify_messages_list(self):
|
||||
"""Test that consolidation leaves messages list unchanged."""
|
||||
session = create_session_with_messages("test:immutable", 50)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_len = len(session.messages)
|
||||
session.last_consolidated = original_len - KEEP_COUNT
|
||||
|
||||
assert len(session.messages) == original_len
|
||||
assert session.messages == original_messages
|
||||
|
||||
def test_get_history_does_not_modify_messages(self):
|
||||
"""Test that get_history doesn't modify messages list."""
|
||||
session = create_session_with_messages("test:history_immutable", 40)
|
||||
original_messages = [m.copy() for m in session.messages]
|
||||
|
||||
for _ in range(5):
|
||||
history = session.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
|
||||
assert len(session.messages) == 40
|
||||
for i, msg in enumerate(session.messages):
|
||||
assert msg["content"] == original_messages[i]["content"]
|
||||
|
||||
def test_consolidation_only_updates_last_consolidated(self):
|
||||
"""Test that consolidation only updates last_consolidated field."""
|
||||
session = create_session_with_messages("test:field_only", 60)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_key = session.key
|
||||
original_metadata = session.metadata.copy()
|
||||
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT
|
||||
|
||||
assert session.messages == original_messages
|
||||
assert session.key == original_key
|
||||
assert session.metadata == original_metadata
|
||||
assert session.last_consolidated == 35
|
||||
|
||||
|
||||
class TestSliceLogic:
|
||||
"""Test the slice logic: messages[last_consolidated:-keep_count]."""
|
||||
|
||||
def test_slice_extracts_correct_range(self):
|
||||
"""Test that slice extracts the correct message range."""
|
||||
session = create_session_with_messages("test:slice", 60)
|
||||
|
||||
old_messages = get_old_messages(session, 0, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 35
|
||||
assert_messages_content(old_messages, 0, 34)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 35, 59)
|
||||
|
||||
def test_slice_with_partial_consolidation(self):
|
||||
"""Test slice when some messages already consolidated."""
|
||||
session = create_session_with_messages("test:partial", 70)
|
||||
|
||||
last_consolidated = 30
|
||||
old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 15
|
||||
assert_messages_content(old_messages, 30, 44)
|
||||
|
||||
def test_slice_with_various_keep_counts(self):
|
||||
"""Test slice behavior with different keep_count values."""
|
||||
session = create_session_with_messages("test:keep_counts", 50)
|
||||
|
||||
test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)]
|
||||
|
||||
for keep_count, expected_count in test_cases:
|
||||
old_messages = session.messages[0:-keep_count]
|
||||
assert len(old_messages) == expected_count
|
||||
|
||||
def test_slice_when_keep_count_exceeds_messages(self):
|
||||
"""Test slice when keep_count > len(messages)."""
|
||||
session = create_session_with_messages("test:exceed", 10)
|
||||
|
||||
old_messages = session.messages[0:-20]
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestEmptyAndBoundarySessions:
|
||||
"""Test empty sessions and boundary conditions."""
|
||||
|
||||
def test_empty_session_consolidation(self):
|
||||
"""Test consolidation behavior with empty session."""
|
||||
session = Session(key="test:empty")
|
||||
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
messages_to_process = len(session.messages) - session.last_consolidated
|
||||
assert messages_to_process == 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_single_message_session(self):
|
||||
"""Test consolidation with single message."""
|
||||
session = Session(key="test:single")
|
||||
session.add_message("user", "only message")
|
||||
|
||||
assert len(session.messages) == 1
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_exactly_keep_count_messages(self):
|
||||
"""Test session with exactly keep_count messages."""
|
||||
session = create_session_with_messages("test:exact", KEEP_COUNT)
|
||||
|
||||
assert len(session.messages) == KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_just_over_keep_count(self):
|
||||
"""Test session with one message over keep_count."""
|
||||
session = create_session_with_messages("test:over", KEEP_COUNT + 1)
|
||||
|
||||
assert len(session.messages) == 26
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 1
|
||||
assert old_messages[0]["content"] == "msg0"
|
||||
|
||||
def test_very_large_session(self):
|
||||
"""Test consolidation with very large message count."""
|
||||
session = create_session_with_messages("test:large", 1000)
|
||||
|
||||
assert len(session.messages) == 1000
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 975
|
||||
assert_messages_content(old_messages, 0, 974)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 975, 999)
|
||||
|
||||
def test_session_with_gaps_in_consolidation(self):
|
||||
"""Test session with potential gaps in consolidation history."""
|
||||
session = create_session_with_messages("test:gaps", 50)
|
||||
session.last_consolidated = 10
|
||||
|
||||
# Add more messages
|
||||
for i in range(50, 60):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
|
||||
expected_count = 60 - KEEP_COUNT - 10
|
||||
assert len(old_messages) == expected_count
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
@ -20,8 +20,8 @@ You have access to:
|
||||
|
||||
## Memory
|
||||
|
||||
- Use `memory/` directory for daily notes
|
||||
- Use `MEMORY.md` for long-term information
|
||||
- `memory/MEMORY.md` — long-term facts (preferences, context, relationships)
|
||||
- `memory/HISTORY.md` — append-only event log, search with grep to recall past events
|
||||
|
||||
## Scheduled Reminders
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user