Merge branch 'main' into pr-151
This commit is contained in:
commit
9e5f7348fe
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,8 +14,9 @@ docs/
|
|||||||
*.pywz
|
*.pywz
|
||||||
*.pyzz
|
*.pyzz
|
||||||
.venv/
|
.venv/
|
||||||
|
venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
poetry.lock
|
poetry.lock
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
tests/
|
|
||||||
botpy.log
|
botpy.log
|
||||||
|
tests/
|
||||||
|
|||||||
267
README.md
267
README.md
@ -12,14 +12,19 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [Clawdbot](https://github.com/openclaw/openclaw)
|
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw)
|
||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,510 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,663 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-02-14** 🔌 nanobot now supports MCP! See [MCP section](#mcp-model-context-protocol) for details.
|
||||||
|
- **2026-02-13** 🎉 Released v0.1.3.post7 — includes security hardening and multiple improvements. All users are recommended to upgrade to the latest version. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
|
||||||
|
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
|
||||||
|
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
|
||||||
|
- **2026-02-10** 🎉 Released v0.1.3.post6 with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
||||||
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
||||||
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
||||||
- **2026-02-07** 🚀 Released v0.1.3.post5 with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details.
|
- **2026-02-07** 🚀 Released v0.1.3.post5 with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details.
|
||||||
@ -94,7 +99,7 @@ pip install nanobot-ai
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Set your API key in `~/.nanobot/config.json`.
|
> Set your API key in `~/.nanobot/config.json`.
|
||||||
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (Global) · [DashScope](https://dashscope.console.aliyun.com) (Qwen) · [Brave Search](https://brave.com/search/api/) (optional, for web search)
|
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (Global) · [Brave Search](https://brave.com/search/api/) (optional, for web search)
|
||||||
|
|
||||||
**1. Initialize**
|
**1. Initialize**
|
||||||
|
|
||||||
@ -104,14 +109,22 @@ nanobot onboard
|
|||||||
|
|
||||||
**2. Configure** (`~/.nanobot/config.json`)
|
**2. Configure** (`~/.nanobot/config.json`)
|
||||||
|
|
||||||
For OpenRouter - recommended for global users:
|
Add or merge these **two parts** into your config (other options have defaults).
|
||||||
|
|
||||||
|
*Set your API key* (e.g. OpenRouter, recommended for global users):
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"providers": {
|
"providers": {
|
||||||
"openrouter": {
|
"openrouter": {
|
||||||
"apiKey": "sk-or-v1-xxx"
|
"apiKey": "sk-or-v1-xxx"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
*Set your model*:
|
||||||
|
```json
|
||||||
|
{
|
||||||
"agents": {
|
"agents": {
|
||||||
"defaults": {
|
"defaults": {
|
||||||
"model": "anthropic/claude-opus-4-5"
|
"model": "anthropic/claude-opus-4-5"
|
||||||
@ -123,51 +136,14 @@ For OpenRouter - recommended for global users:
|
|||||||
**3. Chat**
|
**3. Chat**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
nanobot agent -m "What is 2+2?"
|
nanobot agent
|
||||||
```
|
```
|
||||||
|
|
||||||
That's it! You have a working AI assistant in 2 minutes.
|
That's it! You have a working AI assistant in 2 minutes.
|
||||||
|
|
||||||
## 🖥️ Local Models (vLLM)
|
|
||||||
|
|
||||||
Run nanobot with your own local models using vLLM or any OpenAI-compatible server.
|
|
||||||
|
|
||||||
**1. Start your vLLM server**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
**2. Configure** (`~/.nanobot/config.json`)
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"providers": {
|
|
||||||
"vllm": {
|
|
||||||
"apiKey": "dummy",
|
|
||||||
"apiBase": "http://localhost:8000/v1"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"agents": {
|
|
||||||
"defaults": {
|
|
||||||
"model": "meta-llama/Llama-3.1-8B-Instruct"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**3. Chat**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
nanobot agent -m "Hello from my local LLM!"
|
|
||||||
```
|
|
||||||
|
|
||||||
> [!TIP]
|
|
||||||
> The `apiKey` can be any non-empty string for local servers that don't require authentication.
|
|
||||||
|
|
||||||
## 💬 Chat Apps
|
## 💬 Chat Apps
|
||||||
|
|
||||||
Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, DingTalk, Slack, Email, or QQ — anytime, anywhere.
|
Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, Mochat, DingTalk, Slack, Email, or QQ — anytime, anywhere.
|
||||||
|
|
||||||
| Channel | Setup |
|
| Channel | Setup |
|
||||||
|---------|-------|
|
|---------|-------|
|
||||||
@ -175,6 +151,7 @@ Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, DingTalk, Slac
|
|||||||
| **Discord** | Easy (bot token + intents) |
|
| **Discord** | Easy (bot token + intents) |
|
||||||
| **WhatsApp** | Medium (scan QR) |
|
| **WhatsApp** | Medium (scan QR) |
|
||||||
| **Feishu** | Medium (app credentials) |
|
| **Feishu** | Medium (app credentials) |
|
||||||
|
| **Mochat** | Medium (claw token + websocket) |
|
||||||
| **DingTalk** | Medium (app credentials) |
|
| **DingTalk** | Medium (app credentials) |
|
||||||
| **Slack** | Medium (bot + app tokens) |
|
| **Slack** | Medium (bot + app tokens) |
|
||||||
| **Email** | Medium (IMAP/SMTP credentials) |
|
| **Email** | Medium (IMAP/SMTP credentials) |
|
||||||
@ -214,6 +191,63 @@ nanobot gateway
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Mochat (Claw IM)</b></summary>
|
||||||
|
|
||||||
|
Uses **Socket.IO WebSocket** by default, with HTTP polling fallback.
|
||||||
|
|
||||||
|
**1. Ask nanobot to set up Mochat for you**
|
||||||
|
|
||||||
|
Simply send this message to nanobot (replace `xxx@xxx` with your real email):
|
||||||
|
|
||||||
|
```
|
||||||
|
Read https://raw.githubusercontent.com/HKUDS/MoChat/refs/heads/main/skills/nanobot/skill.md and register on MoChat. My Email account is xxx@xxx Bind me as your owner and DM me on MoChat.
|
||||||
|
```
|
||||||
|
|
||||||
|
nanobot will automatically register, configure `~/.nanobot/config.json`, and connect to Mochat.
|
||||||
|
|
||||||
|
**2. Restart gateway**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
That's it — nanobot handles the rest!
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Manual configuration (advanced)</summary>
|
||||||
|
|
||||||
|
If you prefer to configure manually, add the following to `~/.nanobot/config.json`:
|
||||||
|
|
||||||
|
> Keep `claw_token` private. It should only be sent in `X-Claw-Token` header to your Mochat API endpoint.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"mochat": {
|
||||||
|
"enabled": true,
|
||||||
|
"base_url": "https://mochat.io",
|
||||||
|
"socket_url": "https://mochat.io",
|
||||||
|
"socket_path": "/socket.io",
|
||||||
|
"claw_token": "claw_xxx",
|
||||||
|
"agent_user_id": "6982abcdef",
|
||||||
|
"sessions": ["*"],
|
||||||
|
"panels": ["*"],
|
||||||
|
"reply_delay_mode": "non-mention",
|
||||||
|
"reply_delay_ms": 120000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Discord</b></summary>
|
<summary><b>Discord</b></summary>
|
||||||
|
|
||||||
@ -428,13 +462,17 @@ nanobot gateway
|
|||||||
Uses **Socket Mode** — no public URL required.
|
Uses **Socket Mode** — no public URL required.
|
||||||
|
|
||||||
**1. Create a Slack app**
|
**1. Create a Slack app**
|
||||||
- Go to [Slack API](https://api.slack.com/apps) → Create New App
|
- Go to [Slack API](https://api.slack.com/apps) → **Create New App** → "From scratch"
|
||||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
|
- Pick a name and select your workspace
|
||||||
- Install to your workspace and copy the **Bot Token** (`xoxb-...`)
|
|
||||||
- **Socket Mode**: Enable it and generate an **App-Level Token** (`xapp-...`) with `connections:write` scope
|
|
||||||
- **Event Subscriptions**: Subscribe to `message.im`, `message.channels`, `app_mention`
|
|
||||||
|
|
||||||
**2. Configure**
|
**2. Configure the app**
|
||||||
|
- **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`)
|
||||||
|
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
|
||||||
|
- **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes
|
||||||
|
- **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"**
|
||||||
|
- **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`)
|
||||||
|
|
||||||
|
**3. Configure nanobot**
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@ -449,15 +487,18 @@ Uses **Socket Mode** — no public URL required.
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
> `groupPolicy`: `"mention"` (respond only when @mentioned), `"open"` (respond to all messages), or `"allowlist"` (restrict to specific channels).
|
**4. Run**
|
||||||
> DM policy defaults to open. Set `"dm": {"enabled": false}` to disable DMs.
|
|
||||||
|
|
||||||
**3. Run**
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
nanobot gateway
|
nanobot gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
|
DM the bot directly or @mention it in a channel — it should respond!
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> - `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all channel messages), or `"allowlist"` (restrict to specific channels).
|
||||||
|
> - DM policy defaults to open. Set `"dm": {"enabled": false}` to disable DMs.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@ -507,6 +548,17 @@ nanobot gateway
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 🌐 Agent Social Network
|
||||||
|
|
||||||
|
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
|
||||||
|
|
||||||
|
| Platform | How to Join (send this message to your bot) |
|
||||||
|
|----------|-------------|
|
||||||
|
| [**Moltbook**](https://www.moltbook.com/) | `Read https://moltbook.com/skill.md and follow the instructions to join Moltbook` |
|
||||||
|
| [**ClawdChat**](https://clawdchat.ai/) | `Read https://clawdchat.ai/skill.md and follow the instructions to join ClawdChat` |
|
||||||
|
|
||||||
|
Simply send the command above to your nanobot (via CLI or any chat channel), and it will handle the rest.
|
||||||
|
|
||||||
## ⚙️ Configuration
|
## ⚙️ Configuration
|
||||||
|
|
||||||
Config file: `~/.nanobot/config.json`
|
Config file: `~/.nanobot/config.json`
|
||||||
@ -516,21 +568,86 @@ Config file: `~/.nanobot/config.json`
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||||
|
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
|
| `custom` | Any OpenAI-compatible endpoint | — |
|
||||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||||
|
| `minimax` | LLM (MiniMax direct) | [platform.minimax.io](https://platform.minimax.io) |
|
||||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
|
||||||
|
|
||||||
|
If your provider is not listed above but exposes an **OpenAI-compatible API** (e.g. Together AI, Fireworks, Azure OpenAI, self-hosted endpoints), use the `custom` provider:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"custom": {
|
||||||
|
"apiKey": "your-api-key",
|
||||||
|
"apiBase": "https://api.your-provider.com/v1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"model": "your-model-name"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> The `custom` provider routes through LiteLLM's OpenAI-compatible path. It works with any endpoint that follows the OpenAI chat completions API format. The model name is passed directly to the endpoint without any prefix.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
||||||
|
|
||||||
|
Run your own model with vLLM or any OpenAI-compatible server, then add to config:
|
||||||
|
|
||||||
|
**1. Start the server** (example):
|
||||||
|
```bash
|
||||||
|
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||||
|
|
||||||
|
*Provider (key can be any non-empty string for local):*
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"vllm": {
|
||||||
|
"apiKey": "dummy",
|
||||||
|
"apiBase": "http://localhost:8000/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
*Model:*
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Adding a New Provider (Developer Guide)</b></summary>
|
<summary><b>Adding a New Provider (Developer Guide)</b></summary>
|
||||||
|
|
||||||
@ -576,8 +693,43 @@ That's it! Environment variables, model prefixing, config matching, and `nanobot
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
### MCP (Model Context Protocol)
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The config format is compatible with Claude Desktop / Cursor. You can copy MCP server configs directly from any MCP server's README.
|
||||||
|
|
||||||
|
nanobot supports [MCP](https://modelcontextprotocol.io/) — connect external tool servers and use them as native agent tools.
|
||||||
|
|
||||||
|
Add MCP servers to your `config.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": {
|
||||||
|
"mcpServers": {
|
||||||
|
"filesystem": {
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Two transport modes are supported:
|
||||||
|
|
||||||
|
| Mode | Config | Example |
|
||||||
|
|------|--------|---------|
|
||||||
|
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
|
||||||
|
| **HTTP** | `url` | Remote endpoint (`https://mcp.example.com/sse`) |
|
||||||
|
|
||||||
|
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
||||||
|
|
||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
@ -637,7 +789,7 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot onboard
|
|||||||
# Edit config on host to add API keys
|
# Edit config on host to add API keys
|
||||||
vim ~/.nanobot/config.json
|
vim ~/.nanobot/config.json
|
||||||
|
|
||||||
# Run gateway (connects to Telegram/WhatsApp)
|
# Run gateway (connects to enabled channels, e.g. Telegram/Discord/Mochat)
|
||||||
docker run -v ~/.nanobot:/root/.nanobot -p 18790:18790 nanobot gateway
|
docker run -v ~/.nanobot:/root/.nanobot -p 18790:18790 nanobot gateway
|
||||||
|
|
||||||
# Or run a single command
|
# Or run a single command
|
||||||
@ -657,7 +809,7 @@ nanobot/
|
|||||||
│ ├── subagent.py # Background task execution
|
│ ├── subagent.py # Background task execution
|
||||||
│ └── tools/ # Built-in tools (incl. spawn)
|
│ └── tools/ # Built-in tools (incl. spawn)
|
||||||
├── skills/ # 🎯 Bundled skills (github, weather, tmux...)
|
├── skills/ # 🎯 Bundled skills (github, weather, tmux...)
|
||||||
├── channels/ # 📱 WhatsApp integration
|
├── channels/ # 📱 Chat channel integrations
|
||||||
├── bus/ # 🚌 Message routing
|
├── bus/ # 🚌 Message routing
|
||||||
├── cron/ # ⏰ Scheduled tasks
|
├── cron/ # ⏰ Scheduled tasks
|
||||||
├── heartbeat/ # 💓 Proactive wake-up
|
├── heartbeat/ # 💓 Proactive wake-up
|
||||||
@ -673,7 +825,6 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
|
|||||||
|
|
||||||
**Roadmap** — Pick an item and [open a PR](https://github.com/HKUDS/nanobot/pulls)!
|
**Roadmap** — Pick an item and [open a PR](https://github.com/HKUDS/nanobot/pulls)!
|
||||||
|
|
||||||
- [x] **Voice Transcription** — Support for Groq Whisper (Issue #13)
|
|
||||||
- [ ] **Multi-modal** — See and hear (images, voice, video)
|
- [ ] **Multi-modal** — See and hear (images, voice, video)
|
||||||
- [ ] **Long-term memory** — Never forget important context
|
- [ ] **Long-term memory** — Never forget important context
|
||||||
- [ ] **Better reasoning** — Multi-step planning and reflection
|
- [ ] **Better reasoning** — Multi-step planning and reflection
|
||||||
@ -683,7 +834,7 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
|
|||||||
### Contributors
|
### Contributors
|
||||||
|
|
||||||
<a href="https://github.com/HKUDS/nanobot/graphs/contributors">
|
<a href="https://github.com/HKUDS/nanobot/graphs/contributors">
|
||||||
<img src="https://contrib.rocks/image?repo=HKUDS/nanobot&max=100&columns=12" />
|
<img src="https://contrib.rocks/image?repo=HKUDS/nanobot&max=100&columns=12&updated=20260210" alt="Contributors" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -95,8 +95,8 @@ File operations have path traversal protection, but:
|
|||||||
- Consider using a firewall to restrict outbound connections if needed
|
- Consider using a firewall to restrict outbound connections if needed
|
||||||
|
|
||||||
**WhatsApp Bridge:**
|
**WhatsApp Bridge:**
|
||||||
- The bridge runs on `localhost:3001` by default
|
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||||
- If exposing to network, use proper authentication and TLS
|
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||||
|
|
||||||
### 6. Dependency Security
|
### 6. Dependency Security
|
||||||
@ -224,7 +224,7 @@ If you suspect a security breach:
|
|||||||
✅ **Secure Communication**
|
✅ **Secure Communication**
|
||||||
- HTTPS for all external API calls
|
- HTTPS for all external API calls
|
||||||
- TLS for Telegram API
|
- TLS for Telegram API
|
||||||
- WebSocket security for WhatsApp bridge
|
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||||
|
|
||||||
## Known Limitations
|
## Known Limitations
|
||||||
|
|
||||||
|
|||||||
@ -25,11 +25,12 @@ import { join } from 'path';
|
|||||||
|
|
||||||
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
||||||
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
||||||
|
const TOKEN = process.env.BRIDGE_TOKEN || undefined;
|
||||||
|
|
||||||
console.log('🐈 nanobot WhatsApp Bridge');
|
console.log('🐈 nanobot WhatsApp Bridge');
|
||||||
console.log('========================\n');
|
console.log('========================\n');
|
||||||
|
|
||||||
const server = new BridgeServer(PORT, AUTH_DIR);
|
const server = new BridgeServer(PORT, AUTH_DIR, TOKEN);
|
||||||
|
|
||||||
// Handle graceful shutdown
|
// Handle graceful shutdown
|
||||||
process.on('SIGINT', async () => {
|
process.on('SIGINT', async () => {
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
/**
|
/**
|
||||||
* WebSocket server for Python-Node.js bridge communication.
|
* WebSocket server for Python-Node.js bridge communication.
|
||||||
|
* Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { WebSocketServer, WebSocket } from 'ws';
|
import { WebSocketServer, WebSocket } from 'ws';
|
||||||
@ -21,12 +22,13 @@ export class BridgeServer {
|
|||||||
private wa: WhatsAppClient | null = null;
|
private wa: WhatsAppClient | null = null;
|
||||||
private clients: Set<WebSocket> = new Set();
|
private clients: Set<WebSocket> = new Set();
|
||||||
|
|
||||||
constructor(private port: number, private authDir: string) {}
|
constructor(private port: number, private authDir: string, private token?: string) {}
|
||||||
|
|
||||||
async start(): Promise<void> {
|
async start(): Promise<void> {
|
||||||
// Create WebSocket server
|
// Bind to localhost only — never expose to external network
|
||||||
this.wss = new WebSocketServer({ port: this.port });
|
this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port });
|
||||||
console.log(`🌉 Bridge server listening on ws://localhost:${this.port}`);
|
console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`);
|
||||||
|
if (this.token) console.log('🔒 Token authentication enabled');
|
||||||
|
|
||||||
// Initialize WhatsApp client
|
// Initialize WhatsApp client
|
||||||
this.wa = new WhatsAppClient({
|
this.wa = new WhatsAppClient({
|
||||||
@ -38,35 +40,58 @@ export class BridgeServer {
|
|||||||
|
|
||||||
// Handle WebSocket connections
|
// Handle WebSocket connections
|
||||||
this.wss.on('connection', (ws) => {
|
this.wss.on('connection', (ws) => {
|
||||||
console.log('🔗 Python client connected');
|
if (this.token) {
|
||||||
this.clients.add(ws);
|
// Require auth handshake as first message
|
||||||
|
const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000);
|
||||||
ws.on('message', async (data) => {
|
ws.once('message', (data) => {
|
||||||
try {
|
clearTimeout(timeout);
|
||||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
try {
|
||||||
await this.handleCommand(cmd);
|
const msg = JSON.parse(data.toString());
|
||||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
if (msg.type === 'auth' && msg.token === this.token) {
|
||||||
} catch (error) {
|
console.log('🔗 Python client authenticated');
|
||||||
console.error('Error handling command:', error);
|
this.setupClient(ws);
|
||||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
} else {
|
||||||
}
|
ws.close(4003, 'Invalid token');
|
||||||
});
|
}
|
||||||
|
} catch {
|
||||||
ws.on('close', () => {
|
ws.close(4003, 'Invalid auth message');
|
||||||
console.log('🔌 Python client disconnected');
|
}
|
||||||
this.clients.delete(ws);
|
});
|
||||||
});
|
} else {
|
||||||
|
console.log('🔗 Python client connected');
|
||||||
ws.on('error', (error) => {
|
this.setupClient(ws);
|
||||||
console.error('WebSocket error:', error);
|
}
|
||||||
this.clients.delete(ws);
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Connect to WhatsApp
|
// Connect to WhatsApp
|
||||||
await this.wa.connect();
|
await this.wa.connect();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private setupClient(ws: WebSocket): void {
|
||||||
|
this.clients.add(ws);
|
||||||
|
|
||||||
|
ws.on('message', async (data) => {
|
||||||
|
try {
|
||||||
|
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||||
|
await this.handleCommand(cmd);
|
||||||
|
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error handling command:', error);
|
||||||
|
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ws.on('close', () => {
|
||||||
|
console.log('🔌 Python client disconnected');
|
||||||
|
this.clients.delete(ws);
|
||||||
|
});
|
||||||
|
|
||||||
|
ws.on('error', (error) => {
|
||||||
|
console.error('WebSocket error:', error);
|
||||||
|
this.clients.delete(ws);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private async handleCommand(cmd: SendCommand): Promise<void> {
|
private async handleCommand(cmd: SendCommand): Promise<void> {
|
||||||
if (cmd.type === 'send' && this.wa) {
|
if (cmd.type === 'send' && this.wa) {
|
||||||
await this.wa.sendMessage(cmd.to, cmd.text);
|
await this.wa.sendMessage(cmd.to, cmd.text);
|
||||||
|
|||||||
@ -73,7 +73,9 @@ Skills with available="false" need dependencies installed first - you can try in
|
|||||||
def _get_identity(self) -> str:
|
def _get_identity(self) -> str:
|
||||||
"""Get the core identity section."""
|
"""Get the core identity section."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import time as _time
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||||
|
tz = _time.strftime("%Z") or "UTC"
|
||||||
workspace_path = str(self.workspace.expanduser().resolve())
|
workspace_path = str(self.workspace.expanduser().resolve())
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||||
@ -88,23 +90,24 @@ You are nanobot, a helpful AI assistant. You have access to tools that allow you
|
|||||||
- Spawn subagents for complex background tasks
|
- Spawn subagents for complex background tasks
|
||||||
|
|
||||||
## Current Time
|
## Current Time
|
||||||
{now}
|
{now} ({tz})
|
||||||
|
|
||||||
## Runtime
|
## Runtime
|
||||||
{runtime}
|
{runtime}
|
||||||
|
|
||||||
## Workspace
|
## Workspace
|
||||||
Your workspace is at: {workspace_path}
|
Your workspace is at: {workspace_path}
|
||||||
- Memory files: {workspace_path}/memory/MEMORY.md
|
- Long-term memory: {workspace_path}/memory/MEMORY.md
|
||||||
- Daily notes: {workspace_path}/memory/YYYY-MM-DD.md
|
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
|
||||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||||
|
|
||||||
IMPORTANT: When responding to direct questions or conversations, reply directly with your text response.
|
IMPORTANT: When responding to direct questions or conversations, reply directly with your text response.
|
||||||
Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp).
|
Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp).
|
||||||
For normal conversation, just respond with text - do not call the message tool.
|
For normal conversation, just respond with text - do not call the message tool.
|
||||||
|
|
||||||
Always be helpful, accurate, and concise. When using tools, explain what you're doing.
|
Always be helpful, accurate, and concise. When using tools, think step by step: what you know, what you need, and why you chose this tool.
|
||||||
When remembering something, write to {workspace_path}/memory/MEMORY.md"""
|
When remembering something important, write to {workspace_path}/memory/MEMORY.md
|
||||||
|
To recall past events, grep {workspace_path}/memory/HISTORY.md"""
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
"""Agent loop: the core processing engine."""
|
"""Agent loop: the core processing engine."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
import json
|
import json
|
||||||
|
import json_repair
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -18,14 +20,15 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.agent.tools.spawn import SpawnTool
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
|
||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""
|
"""
|
||||||
The agent loop is the core processing engine.
|
The agent loop is the core processing engine.
|
||||||
|
|
||||||
It:
|
It:
|
||||||
1. Receives messages from the bus
|
1. Receives messages from the bus
|
||||||
2. Builds context with history, memory, skills
|
2. Builds context with history, memory, skills
|
||||||
@ -33,7 +36,7 @@ class AgentLoop:
|
|||||||
4. Executes tool calls
|
4. Executes tool calls
|
||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
@ -41,11 +44,15 @@ class AgentLoop:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 20,
|
max_iterations: int = 20,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
memory_window: int = 50,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
cron_service: "CronService | None" = None,
|
cron_service: "CronService | None" = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
session_manager: SessionManager | None = None,
|
session_manager: SessionManager | None = None,
|
||||||
|
mcp_servers: dict | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
@ -54,11 +61,14 @@ class AgentLoop:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.memory_window = memory_window
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.cron_service = cron_service
|
self.cron_service = cron_service
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
|
||||||
self.context = ContextBuilder(workspace)
|
self.context = ContextBuilder(workspace)
|
||||||
self.sessions = session_manager or SessionManager(workspace)
|
self.sessions = session_manager or SessionManager(workspace)
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
@ -67,12 +77,17 @@ class AgentLoop:
|
|||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
brave_api_key=brave_api_key,
|
brave_api_key=brave_api_key,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
restrict_to_workspace=restrict_to_workspace,
|
restrict_to_workspace=restrict_to_workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._mcp_servers = mcp_servers or {}
|
||||||
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
|
self._mcp_connected = False
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@ -107,107 +122,64 @@ class AgentLoop:
|
|||||||
if self.cron_service:
|
if self.cron_service:
|
||||||
self.tools.register(CronTool(self.cron_service))
|
self.tools.register(CronTool(self.cron_service))
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def _connect_mcp(self) -> None:
|
||||||
"""Run the agent loop, processing messages from the bus."""
|
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||||
self._running = True
|
if self._mcp_connected or not self._mcp_servers:
|
||||||
logger.info("Agent loop started")
|
return
|
||||||
|
self._mcp_connected = True
|
||||||
while self._running:
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||||
try:
|
self._mcp_stack = AsyncExitStack()
|
||||||
# Wait for next message
|
await self._mcp_stack.__aenter__()
|
||||||
msg = await asyncio.wait_for(
|
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||||
self.bus.consume_inbound(),
|
|
||||||
timeout=1.0
|
def _set_tool_context(self, channel: str, chat_id: str) -> None:
|
||||||
)
|
"""Update context for all tools that need routing info."""
|
||||||
|
if message_tool := self.tools.get("message"):
|
||||||
# Process it
|
if isinstance(message_tool, MessageTool):
|
||||||
try:
|
message_tool.set_context(channel, chat_id)
|
||||||
response = await self._process_message(msg)
|
|
||||||
if response:
|
if spawn_tool := self.tools.get("spawn"):
|
||||||
await self.bus.publish_outbound(response)
|
if isinstance(spawn_tool, SpawnTool):
|
||||||
except Exception as e:
|
spawn_tool.set_context(channel, chat_id)
|
||||||
logger.error(f"Error processing message: {e}")
|
|
||||||
# Send error response
|
if cron_tool := self.tools.get("cron"):
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
if isinstance(cron_tool, CronTool):
|
||||||
channel=msg.channel,
|
cron_tool.set_context(channel, chat_id)
|
||||||
chat_id=msg.chat_id,
|
|
||||||
content=f"Sorry, I encountered an error: {str(e)}"
|
async def _run_agent_loop(self, initial_messages: list[dict]) -> tuple[str | None, list[str]]:
|
||||||
))
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the agent loop."""
|
|
||||||
self._running = False
|
|
||||||
logger.info("Agent loop stopping")
|
|
||||||
|
|
||||||
async def _process_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
|
||||||
"""
|
"""
|
||||||
Process a single inbound message.
|
Run the agent iteration loop.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The inbound message to process.
|
initial_messages: Starting messages for the LLM conversation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response message, or None if no response needed.
|
Tuple of (final_content, list_of_tools_used).
|
||||||
"""
|
"""
|
||||||
# Handle system messages (subagent announces)
|
messages = initial_messages
|
||||||
# The chat_id contains the original "channel:chat_id" to route back to
|
|
||||||
if msg.channel == "system":
|
|
||||||
return await self._process_system_message(msg)
|
|
||||||
|
|
||||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
|
||||||
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
|
||||||
|
|
||||||
# Get or create session
|
|
||||||
session = self.sessions.get_or_create(msg.session_key)
|
|
||||||
|
|
||||||
# Update tool contexts
|
|
||||||
message_tool = self.tools.get("message")
|
|
||||||
if isinstance(message_tool, MessageTool):
|
|
||||||
message_tool.set_context(msg.channel, msg.chat_id)
|
|
||||||
|
|
||||||
spawn_tool = self.tools.get("spawn")
|
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
|
||||||
spawn_tool.set_context(msg.channel, msg.chat_id)
|
|
||||||
|
|
||||||
cron_tool = self.tools.get("cron")
|
|
||||||
if isinstance(cron_tool, CronTool):
|
|
||||||
cron_tool.set_context(msg.channel, msg.chat_id)
|
|
||||||
|
|
||||||
# Build initial messages (use get_history for LLM-formatted messages)
|
|
||||||
messages = self.context.build_messages(
|
|
||||||
history=session.get_history(),
|
|
||||||
current_message=msg.content,
|
|
||||||
media=msg.media if msg.media else None,
|
|
||||||
channel=msg.channel,
|
|
||||||
chat_id=msg.chat_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Agent loop
|
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
|
tools_used: list[str] = []
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
while iteration < self.max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
# Call LLM
|
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tools.get_definitions(),
|
tools=self.tools.get_definitions(),
|
||||||
model=self.model
|
model=self.model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
# Add assistant message with tool calls
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
{
|
{
|
||||||
"id": tc.id,
|
"id": tc.id,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc.name,
|
"name": tc.name,
|
||||||
"arguments": json.dumps(tc.arguments) # Must be JSON string
|
"arguments": json.dumps(tc.arguments)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
@ -216,30 +188,126 @@ class AgentLoop:
|
|||||||
messages, response.content, tool_call_dicts,
|
messages, response.content, tool_call_dicts,
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute tools
|
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
|
tools_used.append(tool_call.name)
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||||
messages = self.context.add_tool_result(
|
messages = self.context.add_tool_result(
|
||||||
messages, tool_call.id, tool_call.name, result
|
messages, tool_call.id, tool_call.name, result
|
||||||
)
|
)
|
||||||
|
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||||
else:
|
else:
|
||||||
# No tool calls, we're done
|
|
||||||
final_content = response.content
|
final_content = response.content
|
||||||
break
|
break
|
||||||
|
|
||||||
|
return final_content, tools_used
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""Run the agent loop, processing messages from the bus."""
|
||||||
|
self._running = True
|
||||||
|
await self._connect_mcp()
|
||||||
|
logger.info("Agent loop started")
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(
|
||||||
|
self.bus.consume_inbound(),
|
||||||
|
timeout=1.0
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = await self._process_message(msg)
|
||||||
|
if response:
|
||||||
|
await self.bus.publish_outbound(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message: {e}")
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=f"Sorry, I encountered an error: {str(e)}"
|
||||||
|
))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def close_mcp(self) -> None:
|
||||||
|
"""Close MCP connections."""
|
||||||
|
if self._mcp_stack:
|
||||||
|
try:
|
||||||
|
await self._mcp_stack.aclose()
|
||||||
|
except (RuntimeError, BaseExceptionGroup):
|
||||||
|
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||||
|
self._mcp_stack = None
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the agent loop."""
|
||||||
|
self._running = False
|
||||||
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
|
async def _process_message(self, msg: InboundMessage, session_key: str | None = None) -> OutboundMessage | None:
|
||||||
|
"""
|
||||||
|
Process a single inbound message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: The inbound message to process.
|
||||||
|
session_key: Override session key (used by process_direct).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response message, or None if no response needed.
|
||||||
|
"""
|
||||||
|
# System messages route back via chat_id ("channel:chat_id")
|
||||||
|
if msg.channel == "system":
|
||||||
|
return await self._process_system_message(msg)
|
||||||
|
|
||||||
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||||
|
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
||||||
|
|
||||||
|
key = session_key or msg.session_key
|
||||||
|
session = self.sessions.get_or_create(key)
|
||||||
|
|
||||||
|
# Handle slash commands
|
||||||
|
cmd = msg.content.strip().lower()
|
||||||
|
if cmd == "/new":
|
||||||
|
# Capture messages before clearing (avoid race condition with background task)
|
||||||
|
messages_to_archive = session.messages.copy()
|
||||||
|
session.clear()
|
||||||
|
self.sessions.save(session)
|
||||||
|
self.sessions.invalidate(session.key)
|
||||||
|
|
||||||
|
async def _consolidate_and_cleanup():
|
||||||
|
temp_session = Session(key=session.key)
|
||||||
|
temp_session.messages = messages_to_archive
|
||||||
|
await self._consolidate_memory(temp_session, archive_all=True)
|
||||||
|
|
||||||
|
asyncio.create_task(_consolidate_and_cleanup())
|
||||||
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="New session started. Memory consolidation in progress.")
|
||||||
|
if cmd == "/help":
|
||||||
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||||
|
|
||||||
|
if len(session.messages) > self.memory_window:
|
||||||
|
asyncio.create_task(self._consolidate_memory(session))
|
||||||
|
|
||||||
|
self._set_tool_context(msg.channel, msg.chat_id)
|
||||||
|
initial_messages = self.context.build_messages(
|
||||||
|
history=session.get_history(max_messages=self.memory_window),
|
||||||
|
current_message=msg.content,
|
||||||
|
media=msg.media if msg.media else None,
|
||||||
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
)
|
||||||
|
final_content, tools_used = await self._run_agent_loop(initial_messages)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
# Log response preview
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
||||||
|
|
||||||
# Save to session
|
|
||||||
session.add_message("user", msg.content)
|
session.add_message("user", msg.content)
|
||||||
session.add_message("assistant", final_content)
|
session.add_message("assistant", final_content,
|
||||||
|
tools_used=tools_used if tools_used else None)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
@ -268,76 +336,20 @@ class AgentLoop:
|
|||||||
origin_channel = "cli"
|
origin_channel = "cli"
|
||||||
origin_chat_id = msg.chat_id
|
origin_chat_id = msg.chat_id
|
||||||
|
|
||||||
# Use the origin session for context
|
|
||||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
session_key = f"{origin_channel}:{origin_chat_id}"
|
||||||
session = self.sessions.get_or_create(session_key)
|
session = self.sessions.get_or_create(session_key)
|
||||||
|
self._set_tool_context(origin_channel, origin_chat_id)
|
||||||
# Update tool contexts
|
initial_messages = self.context.build_messages(
|
||||||
message_tool = self.tools.get("message")
|
history=session.get_history(max_messages=self.memory_window),
|
||||||
if isinstance(message_tool, MessageTool):
|
|
||||||
message_tool.set_context(origin_channel, origin_chat_id)
|
|
||||||
|
|
||||||
spawn_tool = self.tools.get("spawn")
|
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
|
||||||
spawn_tool.set_context(origin_channel, origin_chat_id)
|
|
||||||
|
|
||||||
cron_tool = self.tools.get("cron")
|
|
||||||
if isinstance(cron_tool, CronTool):
|
|
||||||
cron_tool.set_context(origin_channel, origin_chat_id)
|
|
||||||
|
|
||||||
# Build messages with the announce content
|
|
||||||
messages = self.context.build_messages(
|
|
||||||
history=session.get_history(),
|
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
channel=origin_channel,
|
channel=origin_channel,
|
||||||
chat_id=origin_chat_id,
|
chat_id=origin_chat_id,
|
||||||
)
|
)
|
||||||
|
final_content, _ = await self._run_agent_loop(initial_messages)
|
||||||
# Agent loop (limited for announce handling)
|
|
||||||
iteration = 0
|
|
||||||
final_content = None
|
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
|
||||||
iteration += 1
|
|
||||||
|
|
||||||
response = await self.provider.chat(
|
|
||||||
messages=messages,
|
|
||||||
tools=self.tools.get_definitions(),
|
|
||||||
model=self.model
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.has_tool_calls:
|
|
||||||
tool_call_dicts = [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.name,
|
|
||||||
"arguments": json.dumps(tc.arguments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tc in response.tool_calls
|
|
||||||
]
|
|
||||||
messages = self.context.add_assistant_message(
|
|
||||||
messages, response.content, tool_call_dicts,
|
|
||||||
reasoning_content=response.reasoning_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in response.tool_calls:
|
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
|
||||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
|
||||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
|
||||||
messages = self.context.add_tool_result(
|
|
||||||
messages, tool_call.id, tool_call.name, result
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
final_content = response.content
|
|
||||||
break
|
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "Background task completed."
|
final_content = "Background task completed."
|
||||||
|
|
||||||
# Save to session (mark as system message in history)
|
|
||||||
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
||||||
session.add_message("assistant", final_content)
|
session.add_message("assistant", final_content)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@ -348,6 +360,91 @@ class AgentLoop:
|
|||||||
content=final_content
|
content=final_content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||||
|
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
archive_all: If True, clear all messages and reset session (for /new command).
|
||||||
|
If False, only write to files without modifying session.
|
||||||
|
"""
|
||||||
|
memory = MemoryStore(self.workspace)
|
||||||
|
|
||||||
|
if archive_all:
|
||||||
|
old_messages = session.messages
|
||||||
|
keep_count = 0
|
||||||
|
logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived")
|
||||||
|
else:
|
||||||
|
keep_count = self.memory_window // 2
|
||||||
|
if len(session.messages) <= keep_count:
|
||||||
|
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
|
||||||
|
return
|
||||||
|
|
||||||
|
messages_to_process = len(session.messages) - session.last_consolidated
|
||||||
|
if messages_to_process <= 0:
|
||||||
|
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
|
||||||
|
return
|
||||||
|
|
||||||
|
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||||
|
if not old_messages:
|
||||||
|
return
|
||||||
|
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for m in old_messages:
|
||||||
|
if not m.get("content"):
|
||||||
|
continue
|
||||||
|
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||||
|
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
||||||
|
conversation = "\n".join(lines)
|
||||||
|
current_memory = memory.read_long_term()
|
||||||
|
|
||||||
|
prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys:
|
||||||
|
|
||||||
|
1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.
|
||||||
|
|
||||||
|
2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.
|
||||||
|
|
||||||
|
## Current Long-term Memory
|
||||||
|
{current_memory or "(empty)"}
|
||||||
|
|
||||||
|
## Conversation to Process
|
||||||
|
{conversation}
|
||||||
|
|
||||||
|
Respond with ONLY valid JSON, no markdown fences."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.provider.chat(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
text = (response.content or "").strip()
|
||||||
|
if not text:
|
||||||
|
logger.warning("Memory consolidation: LLM returned empty response, skipping")
|
||||||
|
return
|
||||||
|
if text.startswith("```"):
|
||||||
|
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||||
|
result = json_repair.loads(text)
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if entry := result.get("history_entry"):
|
||||||
|
memory.append_history(entry)
|
||||||
|
if update := result.get("memory_update"):
|
||||||
|
if update != current_memory:
|
||||||
|
memory.write_long_term(update)
|
||||||
|
|
||||||
|
if archive_all:
|
||||||
|
session.last_consolidated = 0
|
||||||
|
else:
|
||||||
|
session.last_consolidated = len(session.messages) - keep_count
|
||||||
|
logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Memory consolidation failed: {e}")
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
@ -360,13 +457,14 @@ class AgentLoop:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The message content.
|
content: The message content.
|
||||||
session_key: Session identifier.
|
session_key: Session identifier (overrides channel:chat_id for session lookup).
|
||||||
channel: Source channel (for context).
|
channel: Source channel (for tool context routing).
|
||||||
chat_id: Source chat ID (for context).
|
chat_id: Source chat ID (for tool context routing).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The agent's response.
|
The agent's response.
|
||||||
"""
|
"""
|
||||||
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel=channel,
|
channel=channel,
|
||||||
sender_id="user",
|
sender_id="user",
|
||||||
@ -374,5 +472,5 @@ class AgentLoop:
|
|||||||
content=content
|
content=content
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self._process_message(msg)
|
response = await self._process_message(msg, session_key=session_key)
|
||||||
return response.content if response else ""
|
return response.content if response else ""
|
||||||
|
|||||||
@ -1,109 +1,30 @@
|
|||||||
"""Memory system for persistent agent memory."""
|
"""Memory system for persistent agent memory."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from nanobot.utils.helpers import ensure_dir, today_date
|
from nanobot.utils.helpers import ensure_dir
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
Memory system for the agent.
|
|
||||||
|
|
||||||
Supports daily notes (memory/YYYY-MM-DD.md) and long-term memory (MEMORY.md).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
|
||||||
self.memory_dir = ensure_dir(workspace / "memory")
|
self.memory_dir = ensure_dir(workspace / "memory")
|
||||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||||
|
self.history_file = self.memory_dir / "HISTORY.md"
|
||||||
def get_today_file(self) -> Path:
|
|
||||||
"""Get path to today's memory file."""
|
|
||||||
return self.memory_dir / f"{today_date()}.md"
|
|
||||||
|
|
||||||
def read_today(self) -> str:
|
|
||||||
"""Read today's memory notes."""
|
|
||||||
today_file = self.get_today_file()
|
|
||||||
if today_file.exists():
|
|
||||||
return today_file.read_text(encoding="utf-8")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def append_today(self, content: str) -> None:
|
|
||||||
"""Append content to today's memory notes."""
|
|
||||||
today_file = self.get_today_file()
|
|
||||||
|
|
||||||
if today_file.exists():
|
|
||||||
existing = today_file.read_text(encoding="utf-8")
|
|
||||||
content = existing + "\n" + content
|
|
||||||
else:
|
|
||||||
# Add header for new day
|
|
||||||
header = f"# {today_date()}\n\n"
|
|
||||||
content = header + content
|
|
||||||
|
|
||||||
today_file.write_text(content, encoding="utf-8")
|
|
||||||
|
|
||||||
def read_long_term(self) -> str:
|
def read_long_term(self) -> str:
|
||||||
"""Read long-term memory (MEMORY.md)."""
|
|
||||||
if self.memory_file.exists():
|
if self.memory_file.exists():
|
||||||
return self.memory_file.read_text(encoding="utf-8")
|
return self.memory_file.read_text(encoding="utf-8")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def write_long_term(self, content: str) -> None:
|
def write_long_term(self, content: str) -> None:
|
||||||
"""Write to long-term memory (MEMORY.md)."""
|
|
||||||
self.memory_file.write_text(content, encoding="utf-8")
|
self.memory_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
def get_recent_memories(self, days: int = 7) -> str:
|
def append_history(self, entry: str) -> None:
|
||||||
"""
|
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||||
Get memories from the last N days.
|
f.write(entry.rstrip() + "\n\n")
|
||||||
|
|
||||||
Args:
|
|
||||||
days: Number of days to look back.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Combined memory content.
|
|
||||||
"""
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
memories = []
|
|
||||||
today = datetime.now().date()
|
|
||||||
|
|
||||||
for i in range(days):
|
|
||||||
date = today - timedelta(days=i)
|
|
||||||
date_str = date.strftime("%Y-%m-%d")
|
|
||||||
file_path = self.memory_dir / f"{date_str}.md"
|
|
||||||
|
|
||||||
if file_path.exists():
|
|
||||||
content = file_path.read_text(encoding="utf-8")
|
|
||||||
memories.append(content)
|
|
||||||
|
|
||||||
return "\n\n---\n\n".join(memories)
|
|
||||||
|
|
||||||
def list_memory_files(self) -> list[Path]:
|
|
||||||
"""List all memory files sorted by date (newest first)."""
|
|
||||||
if not self.memory_dir.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
files = list(self.memory_dir.glob("????-??-??.md"))
|
|
||||||
return sorted(files, reverse=True)
|
|
||||||
|
|
||||||
def get_memory_context(self) -> str:
|
def get_memory_context(self) -> str:
|
||||||
"""
|
|
||||||
Get memory context for the agent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted memory context including long-term and recent memories.
|
|
||||||
"""
|
|
||||||
parts = []
|
|
||||||
|
|
||||||
# Long-term memory
|
|
||||||
long_term = self.read_long_term()
|
long_term = self.read_long_term()
|
||||||
if long_term:
|
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||||
parts.append("## Long-term Memory\n" + long_term)
|
|
||||||
|
|
||||||
# Today's notes
|
|
||||||
today = self.read_today()
|
|
||||||
if today:
|
|
||||||
parts.append("## Today's Notes\n" + today)
|
|
||||||
|
|
||||||
return "\n\n".join(parts) if parts else ""
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from nanobot.bus.events import InboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, ListDirTool
|
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
||||||
|
|
||||||
@ -32,6 +32,8 @@ class SubagentManager:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 4096,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
@ -41,6 +43,8 @@ class SubagentManager:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
@ -101,6 +105,7 @@ class SubagentManager:
|
|||||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||||
tools.register(ReadFileTool(allowed_dir=allowed_dir))
|
tools.register(ReadFileTool(allowed_dir=allowed_dir))
|
||||||
tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
||||||
|
tools.register(EditFileTool(allowed_dir=allowed_dir))
|
||||||
tools.register(ListDirTool(allowed_dir=allowed_dir))
|
tools.register(ListDirTool(allowed_dir=allowed_dir))
|
||||||
tools.register(ExecTool(
|
tools.register(ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
@ -129,6 +134,8 @@ class SubagentManager:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools.get_definitions(),
|
tools=tools.get_definitions(),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
@ -210,12 +217,17 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
|||||||
|
|
||||||
def _build_subagent_prompt(self, task: str) -> str:
|
def _build_subagent_prompt(self, task: str) -> str:
|
||||||
"""Build a focused system prompt for the subagent."""
|
"""Build a focused system prompt for the subagent."""
|
||||||
|
from datetime import datetime
|
||||||
|
import time as _time
|
||||||
|
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||||
|
tz = _time.strftime("%Z") or "UTC"
|
||||||
|
|
||||||
return f"""# Subagent
|
return f"""# Subagent
|
||||||
|
|
||||||
You are a subagent spawned by the main agent to complete a specific task.
|
## Current Time
|
||||||
|
{now} ({tz})
|
||||||
|
|
||||||
## Your Task
|
You are a subagent spawned by the main agent to complete a specific task.
|
||||||
{task}
|
|
||||||
|
|
||||||
## Rules
|
## Rules
|
||||||
1. Stay focused - complete only the assigned task, nothing else
|
1. Stay focused - complete only the assigned task, nothing else
|
||||||
@ -236,6 +248,7 @@ You are a subagent spawned by the main agent to complete a specific task.
|
|||||||
|
|
||||||
## Workspace
|
## Workspace
|
||||||
Your workspace is at: {self.workspace}
|
Your workspace is at: {self.workspace}
|
||||||
|
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
||||||
|
|
||||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
When you have completed the task, provide a clear summary of your findings or actions."""
|
||||||
|
|
||||||
|
|||||||
@ -50,6 +50,10 @@ class CronTool(Tool):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
||||||
},
|
},
|
||||||
|
"at": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
|
||||||
|
},
|
||||||
"job_id": {
|
"job_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Job ID (for remove)"
|
"description": "Job ID (for remove)"
|
||||||
@ -64,30 +68,38 @@ class CronTool(Tool):
|
|||||||
message: str = "",
|
message: str = "",
|
||||||
every_seconds: int | None = None,
|
every_seconds: int | None = None,
|
||||||
cron_expr: str | None = None,
|
cron_expr: str | None = None,
|
||||||
|
at: str | None = None,
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
if action == "add":
|
if action == "add":
|
||||||
return self._add_job(message, every_seconds, cron_expr)
|
return self._add_job(message, every_seconds, cron_expr, at)
|
||||||
elif action == "list":
|
elif action == "list":
|
||||||
return self._list_jobs()
|
return self._list_jobs()
|
||||||
elif action == "remove":
|
elif action == "remove":
|
||||||
return self._remove_job(job_id)
|
return self._remove_job(job_id)
|
||||||
return f"Unknown action: {action}"
|
return f"Unknown action: {action}"
|
||||||
|
|
||||||
def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None) -> str:
|
def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None, at: str | None) -> str:
|
||||||
if not message:
|
if not message:
|
||||||
return "Error: message is required for add"
|
return "Error: message is required for add"
|
||||||
if not self._channel or not self._chat_id:
|
if not self._channel or not self._chat_id:
|
||||||
return "Error: no session context (channel/chat_id)"
|
return "Error: no session context (channel/chat_id)"
|
||||||
|
|
||||||
# Build schedule
|
# Build schedule
|
||||||
|
delete_after = False
|
||||||
if every_seconds:
|
if every_seconds:
|
||||||
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
||||||
elif cron_expr:
|
elif cron_expr:
|
||||||
schedule = CronSchedule(kind="cron", expr=cron_expr)
|
schedule = CronSchedule(kind="cron", expr=cron_expr)
|
||||||
|
elif at:
|
||||||
|
from datetime import datetime
|
||||||
|
dt = datetime.fromisoformat(at)
|
||||||
|
at_ms = int(dt.timestamp() * 1000)
|
||||||
|
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||||
|
delete_after = True
|
||||||
else:
|
else:
|
||||||
return "Error: either every_seconds or cron_expr is required"
|
return "Error: either every_seconds, cron_expr, or at is required"
|
||||||
|
|
||||||
job = self._cron.add_job(
|
job = self._cron.add_job(
|
||||||
name=message[:30],
|
name=message[:30],
|
||||||
@ -96,6 +108,7 @@ class CronTool(Tool):
|
|||||||
deliver=True,
|
deliver=True,
|
||||||
channel=self._channel,
|
channel=self._channel,
|
||||||
to=self._chat_id,
|
to=self._chat_id,
|
||||||
|
delete_after_run=delete_after,
|
||||||
)
|
)
|
||||||
return f"Created job '{job.name}' (id: {job.id})"
|
return f"Created job '{job.name}' (id: {job.id})"
|
||||||
|
|
||||||
|
|||||||
80
nanobot/agent/tools/mcp.py
Normal file
80
nanobot/agent/tools/mcp.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
||||||
|
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolWrapper(Tool):
|
||||||
|
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||||
|
|
||||||
|
def __init__(self, session, server_name: str, tool_def):
|
||||||
|
self._session = session
|
||||||
|
self._original_name = tool_def.name
|
||||||
|
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||||
|
self._description = tool_def.description or tool_def.name
|
||||||
|
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return self._parameters
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
from mcp import types
|
||||||
|
result = await self._session.call_tool(self._original_name, arguments=kwargs)
|
||||||
|
parts = []
|
||||||
|
for block in result.content:
|
||||||
|
if isinstance(block, types.TextContent):
|
||||||
|
parts.append(block.text)
|
||||||
|
else:
|
||||||
|
parts.append(str(block))
|
||||||
|
return "\n".join(parts) or "(no output)"
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_mcp_servers(
|
||||||
|
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||||
|
) -> None:
|
||||||
|
"""Connect to configured MCP servers and register their tools."""
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
||||||
|
for name, cfg in mcp_servers.items():
|
||||||
|
try:
|
||||||
|
if cfg.command:
|
||||||
|
params = StdioServerParameters(
|
||||||
|
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||||
|
)
|
||||||
|
read, write = await stack.enter_async_context(stdio_client(params))
|
||||||
|
elif cfg.url:
|
||||||
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
|
read, write, _ = await stack.enter_async_context(
|
||||||
|
streamable_http_client(cfg.url)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"MCP server '{name}': no command or url configured, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
session = await stack.enter_async_context(ClientSession(read, write))
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
tools = await session.list_tools()
|
||||||
|
for tool_def in tools.tools:
|
||||||
|
wrapper = MCPToolWrapper(session, name, tool_def)
|
||||||
|
registry.register(wrapper)
|
||||||
|
logger.debug(f"MCP: registered tool '{wrapper.name}' from server '{name}'")
|
||||||
|
|
||||||
|
logger.info(f"MCP server '{name}': connected, {len(tools.tools)} tools registered")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MCP server '{name}': failed to connect: {e}")
|
||||||
@ -128,14 +128,17 @@ class ExecTool(Tool):
|
|||||||
cwd_path = Path(cwd).resolve()
|
cwd_path = Path(cwd).resolve()
|
||||||
|
|
||||||
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
||||||
posix_paths = re.findall(r"/[^\s\"']+", cmd)
|
# Only match absolute paths — avoid false positives on relative
|
||||||
|
# paths like ".venv/bin/python" where "/bin/python" would be
|
||||||
|
# incorrectly extracted by the old pattern.
|
||||||
|
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
|
||||||
|
|
||||||
for raw in win_paths + posix_paths:
|
for raw in win_paths + posix_paths:
|
||||||
try:
|
try:
|
||||||
p = Path(raw).resolve()
|
p = Path(raw.strip()).resolve()
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
if cwd_path not in p.parents and p != cwd_path:
|
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -137,8 +137,15 @@ class DingTalkChannel(BaseChannel):
|
|||||||
|
|
||||||
logger.info("DingTalk bot started with Stream Mode")
|
logger.info("DingTalk bot started with Stream Mode")
|
||||||
|
|
||||||
# client.start() is an async infinite loop handling the websocket connection
|
# Reconnect loop: restart stream if SDK exits or crashes
|
||||||
await self._client.start()
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self._client.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"DingTalk stream error: {e}")
|
||||||
|
if self._running:
|
||||||
|
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to start DingTalk channel: {e}")
|
logger.exception(f"Failed to start DingTalk channel: {e}")
|
||||||
|
|||||||
@ -39,6 +39,53 @@ MSG_TYPE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_post_text(content_json: dict) -> str:
|
||||||
|
"""Extract plain text from Feishu post (rich text) message content.
|
||||||
|
|
||||||
|
Supports two formats:
|
||||||
|
1. Direct format: {"title": "...", "content": [...]}
|
||||||
|
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||||
|
"""
|
||||||
|
def extract_from_lang(lang_content: dict) -> str | None:
|
||||||
|
if not isinstance(lang_content, dict):
|
||||||
|
return None
|
||||||
|
title = lang_content.get("title", "")
|
||||||
|
content_blocks = lang_content.get("content", [])
|
||||||
|
if not isinstance(content_blocks, list):
|
||||||
|
return None
|
||||||
|
text_parts = []
|
||||||
|
if title:
|
||||||
|
text_parts.append(title)
|
||||||
|
for block in content_blocks:
|
||||||
|
if not isinstance(block, list):
|
||||||
|
continue
|
||||||
|
for element in block:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
tag = element.get("tag")
|
||||||
|
if tag == "text":
|
||||||
|
text_parts.append(element.get("text", ""))
|
||||||
|
elif tag == "a":
|
||||||
|
text_parts.append(element.get("text", ""))
|
||||||
|
elif tag == "at":
|
||||||
|
text_parts.append(f"@{element.get('user_name', 'user')}")
|
||||||
|
return " ".join(text_parts).strip() if text_parts else None
|
||||||
|
|
||||||
|
# Try direct format first
|
||||||
|
if "content" in content_json:
|
||||||
|
result = extract_from_lang(content_json)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Try localized format
|
||||||
|
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||||
|
lang_content = content_json.get(lang_key)
|
||||||
|
result = extract_from_lang(lang_content)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class FeishuChannel(BaseChannel):
|
class FeishuChannel(BaseChannel):
|
||||||
"""
|
"""
|
||||||
Feishu/Lark channel using WebSocket long connection.
|
Feishu/Lark channel using WebSocket long connection.
|
||||||
@ -98,12 +145,15 @@ class FeishuChannel(BaseChannel):
|
|||||||
log_level=lark.LogLevel.INFO
|
log_level=lark.LogLevel.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start WebSocket client in a separate thread
|
# Start WebSocket client in a separate thread with reconnect loop
|
||||||
def run_ws():
|
def run_ws():
|
||||||
try:
|
while self._running:
|
||||||
self._ws_client.start()
|
try:
|
||||||
except Exception as e:
|
self._ws_client.start()
|
||||||
logger.error(f"Feishu WebSocket error: {e}")
|
except Exception as e:
|
||||||
|
logger.warning(f"Feishu WebSocket error: {e}")
|
||||||
|
if self._running:
|
||||||
|
import time; time.sleep(5)
|
||||||
|
|
||||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||||
self._ws_thread.start()
|
self._ws_thread.start()
|
||||||
@ -163,6 +213,10 @@ class FeishuChannel(BaseChannel):
|
|||||||
re.MULTILINE,
|
re.MULTILINE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||||
|
|
||||||
|
_CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_md_table(table_text: str) -> dict | None:
|
def _parse_md_table(table_text: str) -> dict | None:
|
||||||
"""Parse a markdown table into a Feishu table element."""
|
"""Parse a markdown table into a Feishu table element."""
|
||||||
@ -182,17 +236,52 @@ class FeishuChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _build_card_elements(self, content: str) -> list[dict]:
|
def _build_card_elements(self, content: str) -> list[dict]:
|
||||||
"""Split content into markdown + table elements for Feishu card."""
|
"""Split content into div/markdown + table elements for Feishu card."""
|
||||||
elements, last_end = [], 0
|
elements, last_end = [], 0
|
||||||
for m in self._TABLE_RE.finditer(content):
|
for m in self._TABLE_RE.finditer(content):
|
||||||
before = content[last_end:m.start()].strip()
|
before = content[last_end:m.start()]
|
||||||
if before:
|
if before.strip():
|
||||||
elements.append({"tag": "markdown", "content": before})
|
elements.extend(self._split_headings(before))
|
||||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||||
last_end = m.end()
|
last_end = m.end()
|
||||||
remaining = content[last_end:].strip()
|
remaining = content[last_end:]
|
||||||
|
if remaining.strip():
|
||||||
|
elements.extend(self._split_headings(remaining))
|
||||||
|
return elements or [{"tag": "markdown", "content": content}]
|
||||||
|
|
||||||
|
def _split_headings(self, content: str) -> list[dict]:
|
||||||
|
"""Split content by headings, converting headings to div elements."""
|
||||||
|
protected = content
|
||||||
|
code_blocks = []
|
||||||
|
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||||
|
code_blocks.append(m.group(1))
|
||||||
|
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||||
|
|
||||||
|
elements = []
|
||||||
|
last_end = 0
|
||||||
|
for m in self._HEADING_RE.finditer(protected):
|
||||||
|
before = protected[last_end:m.start()].strip()
|
||||||
|
if before:
|
||||||
|
elements.append({"tag": "markdown", "content": before})
|
||||||
|
level = len(m.group(1))
|
||||||
|
text = m.group(2).strip()
|
||||||
|
elements.append({
|
||||||
|
"tag": "div",
|
||||||
|
"text": {
|
||||||
|
"tag": "lark_md",
|
||||||
|
"content": f"**{text}**",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
last_end = m.end()
|
||||||
|
remaining = protected[last_end:].strip()
|
||||||
if remaining:
|
if remaining:
|
||||||
elements.append({"tag": "markdown", "content": remaining})
|
elements.append({"tag": "markdown", "content": remaining})
|
||||||
|
|
||||||
|
for i, cb in enumerate(code_blocks):
|
||||||
|
for el in elements:
|
||||||
|
if el.get("tag") == "markdown":
|
||||||
|
el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb)
|
||||||
|
|
||||||
return elements or [{"tag": "markdown", "content": content}]
|
return elements or [{"tag": "markdown", "content": content}]
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
@ -284,6 +373,12 @@ class FeishuChannel(BaseChannel):
|
|||||||
content = json.loads(message.content).get("text", "")
|
content = json.loads(message.content).get("text", "")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
content = message.content or ""
|
content = message.content or ""
|
||||||
|
elif msg_type == "post":
|
||||||
|
try:
|
||||||
|
content_json = json.loads(message.content)
|
||||||
|
content = _extract_post_text(content_json)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
content = message.content or ""
|
||||||
else:
|
else:
|
||||||
content = MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]")
|
content = MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]")
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -12,9 +12,6 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from nanobot.session.manager import SessionManager
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
"""
|
"""
|
||||||
@ -26,10 +23,9 @@ class ChannelManager:
|
|||||||
- Route outbound messages
|
- Route outbound messages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Config, bus: MessageBus, session_manager: "SessionManager | None" = None):
|
def __init__(self, config: Config, bus: MessageBus):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.session_manager = session_manager
|
|
||||||
self.channels: dict[str, BaseChannel] = {}
|
self.channels: dict[str, BaseChannel] = {}
|
||||||
self._dispatch_task: asyncio.Task | None = None
|
self._dispatch_task: asyncio.Task | None = None
|
||||||
|
|
||||||
@ -46,7 +42,6 @@ class ChannelManager:
|
|||||||
self.config.channels.telegram,
|
self.config.channels.telegram,
|
||||||
self.bus,
|
self.bus,
|
||||||
groq_api_key=self.config.providers.groq.api_key,
|
groq_api_key=self.config.providers.groq.api_key,
|
||||||
session_manager=self.session_manager,
|
|
||||||
)
|
)
|
||||||
logger.info("Telegram channel enabled")
|
logger.info("Telegram channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -85,6 +80,18 @@ class ChannelManager:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Feishu channel not available: {e}")
|
logger.warning(f"Feishu channel not available: {e}")
|
||||||
|
|
||||||
|
# Mochat channel
|
||||||
|
if self.config.channels.mochat.enabled:
|
||||||
|
try:
|
||||||
|
from nanobot.channels.mochat import MochatChannel
|
||||||
|
|
||||||
|
self.channels["mochat"] = MochatChannel(
|
||||||
|
self.config.channels.mochat, self.bus
|
||||||
|
)
|
||||||
|
logger.info("Mochat channel enabled")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Mochat channel not available: {e}")
|
||||||
|
|
||||||
# DingTalk channel
|
# DingTalk channel
|
||||||
if self.config.channels.dingtalk.enabled:
|
if self.config.channels.dingtalk.enabled:
|
||||||
try:
|
try:
|
||||||
|
|||||||
895
nanobot/channels/mochat.py
Normal file
895
nanobot/channels/mochat.py
Normal file
@ -0,0 +1,895 @@
|
|||||||
|
"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.config.schema import MochatConfig
|
||||||
|
from nanobot.utils.helpers import get_data_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import socketio
|
||||||
|
SOCKETIO_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
socketio = None
|
||||||
|
SOCKETIO_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import msgpack # noqa: F401
|
||||||
|
MSGPACK_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MSGPACK_AVAILABLE = False
|
||||||
|
|
||||||
|
MAX_SEEN_MESSAGE_IDS = 2000
|
||||||
|
CURSOR_SAVE_DEBOUNCE_S = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data classes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MochatBufferedEntry:
|
||||||
|
"""Buffered inbound entry for delayed dispatch."""
|
||||||
|
raw_body: str
|
||||||
|
author: str
|
||||||
|
sender_name: str = ""
|
||||||
|
sender_username: str = ""
|
||||||
|
timestamp: int | None = None
|
||||||
|
message_id: str = ""
|
||||||
|
group_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DelayState:
|
||||||
|
"""Per-target delayed message state."""
|
||||||
|
entries: list[MochatBufferedEntry] = field(default_factory=list)
|
||||||
|
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||||
|
timer: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MochatTarget:
|
||||||
|
"""Outbound target resolution result."""
|
||||||
|
id: str
|
||||||
|
is_panel: bool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pure helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _safe_dict(value: Any) -> dict:
|
||||||
|
"""Return *value* if it's a dict, else empty dict."""
|
||||||
|
return value if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _str_field(src: dict, *keys: str) -> str:
|
||||||
|
"""Return the first non-empty str value found for *keys*, stripped."""
|
||||||
|
for k in keys:
|
||||||
|
v = src.get(k)
|
||||||
|
if isinstance(v, str) and v.strip():
|
||||||
|
return v.strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_synthetic_event(
|
||||||
|
message_id: str, author: str, content: Any,
|
||||||
|
meta: Any, group_id: str, converse_id: str,
|
||||||
|
timestamp: Any = None, *, author_info: Any = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a synthetic ``message.add`` event dict."""
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"messageId": message_id, "author": author,
|
||||||
|
"content": content, "meta": _safe_dict(meta),
|
||||||
|
"groupId": group_id, "converseId": converse_id,
|
||||||
|
}
|
||||||
|
if author_info is not None:
|
||||||
|
payload["authorInfo"] = _safe_dict(author_info)
|
||||||
|
return {
|
||||||
|
"type": "message.add",
|
||||||
|
"timestamp": timestamp or datetime.utcnow().isoformat(),
|
||||||
|
"payload": payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_mochat_content(content: Any) -> str:
|
||||||
|
"""Normalize content payload to text."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content.strip()
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return json.dumps(content, ensure_ascii=False)
|
||||||
|
except TypeError:
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_mochat_target(raw: str) -> MochatTarget:
|
||||||
|
"""Resolve id and target kind from user-provided target string."""
|
||||||
|
trimmed = (raw or "").strip()
|
||||||
|
if not trimmed:
|
||||||
|
return MochatTarget(id="", is_panel=False)
|
||||||
|
|
||||||
|
lowered = trimmed.lower()
|
||||||
|
cleaned, forced_panel = trimmed, False
|
||||||
|
for prefix in ("mochat:", "group:", "channel:", "panel:"):
|
||||||
|
if lowered.startswith(prefix):
|
||||||
|
cleaned = trimmed[len(prefix):].strip()
|
||||||
|
forced_panel = prefix in {"group:", "channel:", "panel:"}
|
||||||
|
break
|
||||||
|
|
||||||
|
if not cleaned:
|
||||||
|
return MochatTarget(id="", is_panel=False)
|
||||||
|
return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_mention_ids(value: Any) -> list[str]:
|
||||||
|
"""Extract mention ids from heterogeneous mention payload."""
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return []
|
||||||
|
ids: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str):
|
||||||
|
if item.strip():
|
||||||
|
ids.append(item.strip())
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
for key in ("id", "userId", "_id"):
|
||||||
|
candidate = item.get(key)
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
ids.append(candidate.strip())
|
||||||
|
break
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
|
||||||
|
"""Resolve mention state from payload metadata and text fallback."""
|
||||||
|
meta = payload.get("meta")
|
||||||
|
if isinstance(meta, dict):
|
||||||
|
if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
|
||||||
|
return True
|
||||||
|
for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
|
||||||
|
if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
|
||||||
|
return True
|
||||||
|
if not agent_user_id:
|
||||||
|
return False
|
||||||
|
content = payload.get("content")
|
||||||
|
if not isinstance(content, str) or not content:
|
||||||
|
return False
|
||||||
|
return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
|
||||||
|
"""Resolve mention requirement for group/panel conversations."""
|
||||||
|
groups = config.groups or {}
|
||||||
|
for key in (group_id, session_id, "*"):
|
||||||
|
if key and key in groups:
|
||||||
|
return bool(groups[key].require_mention)
|
||||||
|
return bool(config.mention.require_in_groups)
|
||||||
|
|
||||||
|
|
||||||
|
def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
|
||||||
|
"""Build text body from one or more buffered entries."""
|
||||||
|
if not entries:
|
||||||
|
return ""
|
||||||
|
if len(entries) == 1:
|
||||||
|
return entries[0].raw_body
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
if not entry.raw_body:
|
||||||
|
continue
|
||||||
|
if is_group:
|
||||||
|
label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
|
||||||
|
if label:
|
||||||
|
lines.append(f"{label}: {entry.raw_body}")
|
||||||
|
continue
|
||||||
|
lines.append(entry.raw_body)
|
||||||
|
return "\n".join(lines).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_timestamp(value: Any) -> int | None:
|
||||||
|
"""Parse event timestamp to epoch milliseconds."""
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Channel
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class MochatChannel(BaseChannel):
|
||||||
|
"""Mochat channel using socket.io with fallback polling workers."""
|
||||||
|
|
||||||
|
name = "mochat"
|
||||||
|
|
||||||
|
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self.config: MochatConfig = config
|
||||||
|
self._http: httpx.AsyncClient | None = None
|
||||||
|
self._socket: Any = None
|
||||||
|
self._ws_connected = self._ws_ready = False
|
||||||
|
|
||||||
|
self._state_dir = get_data_path() / "mochat"
|
||||||
|
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||||
|
self._session_cursor: dict[str, int] = {}
|
||||||
|
self._cursor_save_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
self._session_set: set[str] = set()
|
||||||
|
self._panel_set: set[str] = set()
|
||||||
|
self._auto_discover_sessions = self._auto_discover_panels = False
|
||||||
|
|
||||||
|
self._cold_sessions: set[str] = set()
|
||||||
|
self._session_by_converse: dict[str, str] = {}
|
||||||
|
|
||||||
|
self._seen_set: dict[str, set[str]] = {}
|
||||||
|
self._seen_queue: dict[str, deque[str]] = {}
|
||||||
|
self._delay_states: dict[str, DelayState] = {}
|
||||||
|
|
||||||
|
self._fallback_mode = False
|
||||||
|
self._session_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._refresh_task: asyncio.Task | None = None
|
||||||
|
self._target_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
# ---- lifecycle ---------------------------------------------------------
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start Mochat channel workers and websocket connection."""
|
||||||
|
if not self.config.claw_token:
|
||||||
|
logger.error("Mochat claw_token not configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._http = httpx.AsyncClient(timeout=30.0)
|
||||||
|
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
await self._load_session_cursors()
|
||||||
|
self._seed_targets_from_config()
|
||||||
|
await self._refresh_targets(subscribe_new=False)
|
||||||
|
|
||||||
|
if not await self._start_socket_client():
|
||||||
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
|
self._refresh_task = asyncio.create_task(self._refresh_loop())
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop all workers and clean up resources."""
|
||||||
|
self._running = False
|
||||||
|
if self._refresh_task:
|
||||||
|
self._refresh_task.cancel()
|
||||||
|
self._refresh_task = None
|
||||||
|
|
||||||
|
await self._stop_fallback_workers()
|
||||||
|
await self._cancel_delay_timers()
|
||||||
|
|
||||||
|
if self._socket:
|
||||||
|
try:
|
||||||
|
await self._socket.disconnect()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._socket = None
|
||||||
|
|
||||||
|
if self._cursor_save_task:
|
||||||
|
self._cursor_save_task.cancel()
|
||||||
|
self._cursor_save_task = None
|
||||||
|
await self._save_session_cursors()
|
||||||
|
|
||||||
|
if self._http:
|
||||||
|
await self._http.aclose()
|
||||||
|
self._http = None
|
||||||
|
self._ws_connected = self._ws_ready = False
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
"""Send outbound message to session or panel."""
|
||||||
|
if not self.config.claw_token:
|
||||||
|
logger.warning("Mochat claw_token missing, skip send")
|
||||||
|
return
|
||||||
|
|
||||||
|
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||||
|
if msg.media:
|
||||||
|
parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
|
||||||
|
content = "\n".join(parts).strip()
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
|
||||||
|
target = resolve_mochat_target(msg.chat_id)
|
||||||
|
if not target.id:
|
||||||
|
logger.warning("Mochat outbound target is empty")
|
||||||
|
return
|
||||||
|
|
||||||
|
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||||
|
try:
|
||||||
|
if is_panel:
|
||||||
|
await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
|
||||||
|
content, msg.reply_to, self._read_group_id(msg.metadata))
|
||||||
|
else:
|
||||||
|
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||||
|
content, msg.reply_to)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send Mochat message: {e}")
|
||||||
|
|
||||||
|
# ---- config / init helpers ---------------------------------------------
|
||||||
|
|
||||||
|
def _seed_targets_from_config(self) -> None:
|
||||||
|
sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
|
||||||
|
panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
|
||||||
|
self._session_set.update(sessions)
|
||||||
|
self._panel_set.update(panels)
|
||||||
|
for sid in sessions:
|
||||||
|
if sid not in self._session_cursor:
|
||||||
|
self._cold_sessions.add(sid)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
|
||||||
|
cleaned = [str(v).strip() for v in values if str(v).strip()]
|
||||||
|
return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
|
||||||
|
|
||||||
|
# ---- websocket ---------------------------------------------------------
|
||||||
|
|
||||||
|
async def _start_socket_client(self) -> bool:
|
||||||
|
if not SOCKETIO_AVAILABLE:
|
||||||
|
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
||||||
|
return False
|
||||||
|
|
||||||
|
serializer = "default"
|
||||||
|
if not self.config.socket_disable_msgpack:
|
||||||
|
if MSGPACK_AVAILABLE:
|
||||||
|
serializer = "msgpack"
|
||||||
|
else:
|
||||||
|
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||||
|
|
||||||
|
client = socketio.AsyncClient(
|
||||||
|
reconnection=True,
|
||||||
|
reconnection_attempts=self.config.max_retry_attempts or None,
|
||||||
|
reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
|
||||||
|
reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
|
||||||
|
logger=False, engineio_logger=False, serializer=serializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def connect() -> None:
|
||||||
|
self._ws_connected, self._ws_ready = True, False
|
||||||
|
logger.info("Mochat websocket connected")
|
||||||
|
subscribed = await self._subscribe_all()
|
||||||
|
self._ws_ready = subscribed
|
||||||
|
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def disconnect() -> None:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
self._ws_connected = self._ws_ready = False
|
||||||
|
logger.warning("Mochat websocket disconnected")
|
||||||
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
|
@client.event
|
||||||
|
async def connect_error(data: Any) -> None:
|
||||||
|
logger.error(f"Mochat websocket connect error: {data}")
|
||||||
|
|
||||||
|
@client.on("claw.session.events")
|
||||||
|
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||||
|
await self._handle_watch_payload(payload, "session")
|
||||||
|
|
||||||
|
@client.on("claw.panel.events")
|
||||||
|
async def on_panel_events(payload: dict[str, Any]) -> None:
|
||||||
|
await self._handle_watch_payload(payload, "panel")
|
||||||
|
|
||||||
|
for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
|
||||||
|
"notify:chat.message.update", "notify:chat.message.recall",
|
||||||
|
"notify:chat.message.delete"):
|
||||||
|
client.on(ev, self._build_notify_handler(ev))
|
||||||
|
|
||||||
|
socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
|
||||||
|
socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._socket = client
|
||||||
|
await client.connect(
|
||||||
|
socket_url, transports=["websocket"], socketio_path=socket_path,
|
||||||
|
auth={"token": self.config.claw_token},
|
||||||
|
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect Mochat websocket: {e}")
|
||||||
|
try:
|
||||||
|
await client.disconnect()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._socket = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _build_notify_handler(self, event_name: str):
|
||||||
|
async def handler(payload: Any) -> None:
|
||||||
|
if event_name == "notify:chat.inbox.append":
|
||||||
|
await self._handle_notify_inbox_append(payload)
|
||||||
|
elif event_name.startswith("notify:chat.message."):
|
||||||
|
await self._handle_notify_chat_message(payload)
|
||||||
|
return handler
|
||||||
|
|
||||||
|
# ---- subscribe ---------------------------------------------------------
|
||||||
|
|
||||||
|
async def _subscribe_all(self) -> bool:
|
||||||
|
ok = await self._subscribe_sessions(sorted(self._session_set))
|
||||||
|
ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
|
||||||
|
if self._auto_discover_sessions or self._auto_discover_panels:
|
||||||
|
await self._refresh_targets(subscribe_new=True)
|
||||||
|
return ok
|
||||||
|
|
||||||
|
async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
|
||||||
|
if not session_ids:
|
||||||
|
return True
|
||||||
|
for sid in session_ids:
|
||||||
|
if sid not in self._session_cursor:
|
||||||
|
self._cold_sessions.add(sid)
|
||||||
|
|
||||||
|
ack = await self._socket_call("com.claw.im.subscribeSessions", {
|
||||||
|
"sessionIds": session_ids, "cursors": self._session_cursor,
|
||||||
|
"limit": self.config.watch_limit,
|
||||||
|
})
|
||||||
|
if not ack.get("result"):
|
||||||
|
logger.error(f"Mochat subscribeSessions failed: {ack.get('message', 'unknown error')}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
data = ack.get("data")
|
||||||
|
items: list[dict[str, Any]] = []
|
||||||
|
if isinstance(data, list):
|
||||||
|
items = [i for i in data if isinstance(i, dict)]
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
sessions = data.get("sessions")
|
||||||
|
if isinstance(sessions, list):
|
||||||
|
items = [i for i in sessions if isinstance(i, dict)]
|
||||||
|
elif "sessionId" in data:
|
||||||
|
items = [data]
|
||||||
|
for p in items:
|
||||||
|
await self._handle_watch_payload(p, "session")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
|
||||||
|
if not self._auto_discover_panels and not panel_ids:
|
||||||
|
return True
|
||||||
|
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||||
|
if not ack.get("result"):
|
||||||
|
logger.error(f"Mochat subscribePanels failed: {ack.get('message', 'unknown error')}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
if not self._socket:
|
||||||
|
return {"result": False, "message": "socket not connected"}
|
||||||
|
try:
|
||||||
|
raw = await self._socket.call(event_name, payload, timeout=10)
|
||||||
|
except Exception as e:
|
||||||
|
return {"result": False, "message": str(e)}
|
||||||
|
return raw if isinstance(raw, dict) else {"result": True, "data": raw}
|
||||||
|
|
||||||
|
# ---- refresh / discovery -----------------------------------------------
|
||||||
|
|
||||||
|
async def _refresh_loop(self) -> None:
|
||||||
|
interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(interval_s)
|
||||||
|
try:
|
||||||
|
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Mochat refresh failed: {e}")
|
||||||
|
if self._fallback_mode:
|
||||||
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
|
async def _refresh_targets(self, subscribe_new: bool) -> None:
|
||||||
|
if self._auto_discover_sessions:
|
||||||
|
await self._refresh_sessions_directory(subscribe_new)
|
||||||
|
if self._auto_discover_panels:
|
||||||
|
await self._refresh_panels(subscribe_new)
|
||||||
|
|
||||||
|
async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
|
||||||
|
try:
|
||||||
|
response = await self._post_json("/api/claw/sessions/list", {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Mochat listSessions failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
sessions = response.get("sessions")
|
||||||
|
if not isinstance(sessions, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_ids: list[str] = []
|
||||||
|
for s in sessions:
|
||||||
|
if not isinstance(s, dict):
|
||||||
|
continue
|
||||||
|
sid = _str_field(s, "sessionId")
|
||||||
|
if not sid:
|
||||||
|
continue
|
||||||
|
if sid not in self._session_set:
|
||||||
|
self._session_set.add(sid)
|
||||||
|
new_ids.append(sid)
|
||||||
|
if sid not in self._session_cursor:
|
||||||
|
self._cold_sessions.add(sid)
|
||||||
|
cid = _str_field(s, "converseId")
|
||||||
|
if cid:
|
||||||
|
self._session_by_converse[cid] = sid
|
||||||
|
|
||||||
|
if not new_ids:
|
||||||
|
return
|
||||||
|
if self._ws_ready and subscribe_new:
|
||||||
|
await self._subscribe_sessions(new_ids)
|
||||||
|
if self._fallback_mode:
|
||||||
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
|
async def _refresh_panels(self, subscribe_new: bool) -> None:
|
||||||
|
try:
|
||||||
|
response = await self._post_json("/api/claw/groups/get", {})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Mochat getWorkspaceGroup failed: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_panels = response.get("panels")
|
||||||
|
if not isinstance(raw_panels, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_ids: list[str] = []
|
||||||
|
for p in raw_panels:
|
||||||
|
if not isinstance(p, dict):
|
||||||
|
continue
|
||||||
|
pt = p.get("type")
|
||||||
|
if isinstance(pt, int) and pt != 0:
|
||||||
|
continue
|
||||||
|
pid = _str_field(p, "id", "_id")
|
||||||
|
if pid and pid not in self._panel_set:
|
||||||
|
self._panel_set.add(pid)
|
||||||
|
new_ids.append(pid)
|
||||||
|
|
||||||
|
if not new_ids:
|
||||||
|
return
|
||||||
|
if self._ws_ready and subscribe_new:
|
||||||
|
await self._subscribe_panels(new_ids)
|
||||||
|
if self._fallback_mode:
|
||||||
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
|
# ---- fallback workers --------------------------------------------------
|
||||||
|
|
||||||
|
async def _ensure_fallback_workers(self) -> None:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
self._fallback_mode = True
|
||||||
|
for sid in sorted(self._session_set):
|
||||||
|
t = self._session_fallback_tasks.get(sid)
|
||||||
|
if not t or t.done():
|
||||||
|
self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
|
||||||
|
for pid in sorted(self._panel_set):
|
||||||
|
t = self._panel_fallback_tasks.get(pid)
|
||||||
|
if not t or t.done():
|
||||||
|
self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
|
||||||
|
|
||||||
|
async def _stop_fallback_workers(self) -> None:
|
||||||
|
self._fallback_mode = False
|
||||||
|
tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
self._session_fallback_tasks.clear()
|
||||||
|
self._panel_fallback_tasks.clear()
|
||||||
|
|
||||||
|
async def _session_watch_worker(self, session_id: str) -> None:
|
||||||
|
while self._running and self._fallback_mode:
|
||||||
|
try:
|
||||||
|
payload = await self._post_json("/api/claw/sessions/watch", {
|
||||||
|
"sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
|
||||||
|
"timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
|
||||||
|
})
|
||||||
|
await self._handle_watch_payload(payload, "session")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Mochat watch fallback error ({session_id}): {e}")
|
||||||
|
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||||
|
|
||||||
|
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||||
|
sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||||
|
while self._running and self._fallback_mode:
|
||||||
|
try:
|
||||||
|
resp = await self._post_json("/api/claw/groups/panels/messages", {
|
||||||
|
"panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
|
||||||
|
})
|
||||||
|
msgs = resp.get("messages")
|
||||||
|
if isinstance(msgs, list):
|
||||||
|
for m in reversed(msgs):
|
||||||
|
if not isinstance(m, dict):
|
||||||
|
continue
|
||||||
|
evt = _make_synthetic_event(
|
||||||
|
message_id=str(m.get("messageId") or ""),
|
||||||
|
author=str(m.get("author") or ""),
|
||||||
|
content=m.get("content"),
|
||||||
|
meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
|
||||||
|
converse_id=panel_id, timestamp=m.get("createdAt"),
|
||||||
|
author_info=m.get("authorInfo"),
|
||||||
|
)
|
||||||
|
await self._process_inbound_event(panel_id, evt, "panel")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Mochat panel polling error ({panel_id}): {e}")
|
||||||
|
await asyncio.sleep(sleep_s)
|
||||||
|
|
||||||
|
# ---- inbound event processing ------------------------------------------
|
||||||
|
|
||||||
|
async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return
|
||||||
|
target_id = _str_field(payload, "sessionId")
|
||||||
|
if not target_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
|
||||||
|
async with lock:
|
||||||
|
prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
|
||||||
|
pc = payload.get("cursor")
|
||||||
|
if target_kind == "session" and isinstance(pc, int) and pc >= 0:
|
||||||
|
self._mark_session_cursor(target_id, pc)
|
||||||
|
|
||||||
|
raw_events = payload.get("events")
|
||||||
|
if not isinstance(raw_events, list):
|
||||||
|
return
|
||||||
|
if target_kind == "session" and target_id in self._cold_sessions:
|
||||||
|
self._cold_sessions.discard(target_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
for event in raw_events:
|
||||||
|
if not isinstance(event, dict):
|
||||||
|
continue
|
||||||
|
seq = event.get("seq")
|
||||||
|
if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
|
||||||
|
self._mark_session_cursor(target_id, seq)
|
||||||
|
if event.get("type") == "message.add":
|
||||||
|
await self._process_inbound_event(target_id, event, target_kind)
|
||||||
|
|
||||||
|
async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
|
||||||
|
payload = event.get("payload")
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
author = _str_field(payload, "author")
|
||||||
|
if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
|
||||||
|
return
|
||||||
|
if not self.is_allowed(author):
|
||||||
|
return
|
||||||
|
|
||||||
|
message_id = _str_field(payload, "messageId")
|
||||||
|
seen_key = f"{target_kind}:{target_id}"
|
||||||
|
if message_id and self._remember_message_id(seen_key, message_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
|
||||||
|
ai = _safe_dict(payload.get("authorInfo"))
|
||||||
|
sender_name = _str_field(ai, "nickname", "email")
|
||||||
|
sender_username = _str_field(ai, "agentId")
|
||||||
|
|
||||||
|
group_id = _str_field(payload, "groupId")
|
||||||
|
is_group = bool(group_id)
|
||||||
|
was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
|
||||||
|
require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
|
||||||
|
use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
|
||||||
|
|
||||||
|
if require_mention and not was_mentioned and not use_delay:
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = MochatBufferedEntry(
|
||||||
|
raw_body=raw_body, author=author, sender_name=sender_name,
|
||||||
|
sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
|
||||||
|
message_id=message_id, group_id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_delay:
|
||||||
|
delay_key = seen_key
|
||||||
|
if was_mentioned:
|
||||||
|
await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
|
||||||
|
else:
|
||||||
|
await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
|
||||||
|
|
||||||
|
# ---- dedup / buffering -------------------------------------------------
|
||||||
|
|
||||||
|
def _remember_message_id(self, key: str, message_id: str) -> bool:
|
||||||
|
seen_set = self._seen_set.setdefault(key, set())
|
||||||
|
seen_queue = self._seen_queue.setdefault(key, deque())
|
||||||
|
if message_id in seen_set:
|
||||||
|
return True
|
||||||
|
seen_set.add(message_id)
|
||||||
|
seen_queue.append(message_id)
|
||||||
|
while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
|
||||||
|
seen_set.discard(seen_queue.popleft())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
|
||||||
|
state = self._delay_states.setdefault(key, DelayState())
|
||||||
|
async with state.lock:
|
||||||
|
state.entries.append(entry)
|
||||||
|
if state.timer:
|
||||||
|
state.timer.cancel()
|
||||||
|
state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
|
||||||
|
|
||||||
|
async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
|
||||||
|
await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
|
||||||
|
await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
|
||||||
|
|
||||||
|
async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
|
||||||
|
state = self._delay_states.setdefault(key, DelayState())
|
||||||
|
async with state.lock:
|
||||||
|
if entry:
|
||||||
|
state.entries.append(entry)
|
||||||
|
current = asyncio.current_task()
|
||||||
|
if state.timer and state.timer is not current:
|
||||||
|
state.timer.cancel()
|
||||||
|
state.timer = None
|
||||||
|
entries = state.entries[:]
|
||||||
|
state.entries.clear()
|
||||||
|
if entries:
|
||||||
|
await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
|
||||||
|
|
||||||
|
async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
|
||||||
|
if not entries:
|
||||||
|
return
|
||||||
|
last = entries[-1]
|
||||||
|
is_group = bool(last.group_id)
|
||||||
|
body = build_buffered_body(entries, is_group) or "[empty message]"
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=last.author, chat_id=target_id, content=body,
|
||||||
|
metadata={
|
||||||
|
"message_id": last.message_id, "timestamp": last.timestamp,
|
||||||
|
"is_group": is_group, "group_id": last.group_id,
|
||||||
|
"sender_name": last.sender_name, "sender_username": last.sender_username,
|
||||||
|
"target_kind": target_kind, "was_mentioned": was_mentioned,
|
||||||
|
"buffered_count": len(entries),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cancel_delay_timers(self) -> None:
|
||||||
|
for state in self._delay_states.values():
|
||||||
|
if state.timer:
|
||||||
|
state.timer.cancel()
|
||||||
|
self._delay_states.clear()
|
||||||
|
|
||||||
|
# ---- notify handlers ---------------------------------------------------
|
||||||
|
|
||||||
|
async def _handle_notify_chat_message(self, payload: Any) -> None:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return
|
||||||
|
group_id = _str_field(payload, "groupId")
|
||||||
|
panel_id = _str_field(payload, "converseId", "panelId")
|
||||||
|
if not group_id or not panel_id:
|
||||||
|
return
|
||||||
|
if self._panel_set and panel_id not in self._panel_set:
|
||||||
|
return
|
||||||
|
|
||||||
|
evt = _make_synthetic_event(
|
||||||
|
message_id=str(payload.get("_id") or payload.get("messageId") or ""),
|
||||||
|
author=str(payload.get("author") or ""),
|
||||||
|
content=payload.get("content"), meta=payload.get("meta"),
|
||||||
|
group_id=group_id, converse_id=panel_id,
|
||||||
|
timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
|
||||||
|
)
|
||||||
|
await self._process_inbound_event(panel_id, evt, "panel")
|
||||||
|
|
||||||
|
async def _handle_notify_inbox_append(self, payload: Any) -> None:
|
||||||
|
if not isinstance(payload, dict) or payload.get("type") != "message":
|
||||||
|
return
|
||||||
|
detail = payload.get("payload")
|
||||||
|
if not isinstance(detail, dict):
|
||||||
|
return
|
||||||
|
if _str_field(detail, "groupId"):
|
||||||
|
return
|
||||||
|
converse_id = _str_field(detail, "converseId")
|
||||||
|
if not converse_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = self._session_by_converse.get(converse_id)
|
||||||
|
if not session_id:
|
||||||
|
await self._refresh_sessions_directory(self._ws_ready)
|
||||||
|
session_id = self._session_by_converse.get(converse_id)
|
||||||
|
if not session_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
evt = _make_synthetic_event(
|
||||||
|
message_id=str(detail.get("messageId") or payload.get("_id") or ""),
|
||||||
|
author=str(detail.get("messageAuthor") or ""),
|
||||||
|
content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
|
||||||
|
meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
|
||||||
|
group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
|
||||||
|
)
|
||||||
|
await self._process_inbound_event(session_id, evt, "session")
|
||||||
|
|
||||||
|
# ---- cursor persistence ------------------------------------------------
|
||||||
|
|
||||||
|
def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
|
||||||
|
if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
|
||||||
|
return
|
||||||
|
self._session_cursor[session_id] = cursor
|
||||||
|
if not self._cursor_save_task or self._cursor_save_task.done():
|
||||||
|
self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
|
||||||
|
|
||||||
|
async def _save_cursor_debounced(self) -> None:
|
||||||
|
await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
|
||||||
|
await self._save_session_cursors()
|
||||||
|
|
||||||
|
async def _load_session_cursors(self) -> None:
|
||||||
|
if not self._cursor_path.exists():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read Mochat cursor file: {e}")
|
||||||
|
return
|
||||||
|
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||||
|
if isinstance(cursors, dict):
|
||||||
|
for sid, cur in cursors.items():
|
||||||
|
if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
|
||||||
|
self._session_cursor[sid] = cur
|
||||||
|
|
||||||
|
async def _save_session_cursors(self) -> None:
|
||||||
|
try:
|
||||||
|
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._cursor_path.write_text(json.dumps({
|
||||||
|
"schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
|
||||||
|
"cursors": self._session_cursor,
|
||||||
|
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save Mochat cursor file: {e}")
|
||||||
|
|
||||||
|
# ---- HTTP helpers ------------------------------------------------------
|
||||||
|
|
||||||
|
async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
if not self._http:
|
||||||
|
raise RuntimeError("Mochat HTTP client not initialized")
|
||||||
|
url = f"{self.config.base_url.strip().rstrip('/')}{path}"
|
||||||
|
response = await self._http.post(url, headers={
|
||||||
|
"Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
|
||||||
|
}, json=payload)
|
||||||
|
if not response.is_success:
|
||||||
|
raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
|
||||||
|
try:
|
||||||
|
parsed = response.json()
|
||||||
|
except Exception:
|
||||||
|
parsed = response.text
|
||||||
|
if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
|
||||||
|
if parsed["code"] != 200:
|
||||||
|
msg = str(parsed.get("message") or parsed.get("name") or "request failed")
|
||||||
|
raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
|
||||||
|
data = parsed.get("data")
|
||||||
|
return data if isinstance(data, dict) else {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
|
||||||
|
async def _api_send(self, path: str, id_key: str, id_val: str,
|
||||||
|
content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
|
||||||
|
"""Unified send helper for session and panel messages."""
|
||||||
|
body: dict[str, Any] = {id_key: id_val, "content": content}
|
||||||
|
if reply_to:
|
||||||
|
body["replyTo"] = reply_to
|
||||||
|
if group_id:
|
||||||
|
body["groupId"] = group_id
|
||||||
|
return await self._post_json(path, body)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_group_id(metadata: dict[str, Any]) -> str | None:
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
return None
|
||||||
|
value = metadata.get("group_id") or metadata.get("groupId")
|
||||||
|
return value.strip() if isinstance(value, str) and value.strip() else None
|
||||||
@ -75,12 +75,15 @@ class QQChannel(BaseChannel):
|
|||||||
logger.info("QQ bot started (C2C private message)")
|
logger.info("QQ bot started (C2C private message)")
|
||||||
|
|
||||||
async def _run_bot(self) -> None:
|
async def _run_bot(self) -> None:
|
||||||
"""Run the bot connection."""
|
"""Run the bot connection with auto-reconnect."""
|
||||||
try:
|
while self._running:
|
||||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
try:
|
||||||
except Exception as e:
|
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||||
logger.error(f"QQ auth failed, check AppID/Secret at q.qq.com: {e}")
|
except Exception as e:
|
||||||
self._running = False
|
logger.warning(f"QQ bot error: {e}")
|
||||||
|
if self._running:
|
||||||
|
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the QQ bot."""
|
"""Stop the QQ bot."""
|
||||||
|
|||||||
@ -4,20 +4,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from telegram import BotCommand, Update
|
from telegram import BotCommand, Update
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||||
|
from telegram.request import HTTPXRequest
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import TelegramConfig
|
from nanobot.config.schema import TelegramConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from nanobot.session.manager import SessionManager
|
|
||||||
|
|
||||||
|
|
||||||
def _markdown_to_telegram_html(text: str) -> str:
|
def _markdown_to_telegram_html(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
@ -94,7 +90,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Commands registered with Telegram's command menu
|
# Commands registered with Telegram's command menu
|
||||||
BOT_COMMANDS = [
|
BOT_COMMANDS = [
|
||||||
BotCommand("start", "Start the bot"),
|
BotCommand("start", "Start the bot"),
|
||||||
BotCommand("reset", "Reset conversation history"),
|
BotCommand("new", "Start a new conversation"),
|
||||||
BotCommand("help", "Show available commands"),
|
BotCommand("help", "Show available commands"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -103,12 +99,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
config: TelegramConfig,
|
config: TelegramConfig,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
groq_api_key: str = "",
|
groq_api_key: str = "",
|
||||||
session_manager: SessionManager | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.config: TelegramConfig = config
|
self.config: TelegramConfig = config
|
||||||
self.groq_api_key = groq_api_key
|
self.groq_api_key = groq_api_key
|
||||||
self.session_manager = session_manager
|
|
||||||
self._app: Application | None = None
|
self._app: Application | None = None
|
||||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
@ -121,16 +115,18 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
# Build the application
|
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||||
builder = Application.builder().token(self.config.token)
|
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||||
|
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||||
if self.config.proxy:
|
if self.config.proxy:
|
||||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||||
self._app = builder.build()
|
self._app = builder.build()
|
||||||
|
self._app.add_error_handler(self._on_error)
|
||||||
|
|
||||||
# Add command handlers
|
# Add command handlers
|
||||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||||
self._app.add_handler(CommandHandler("reset", self._on_reset))
|
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
self._app.add_handler(CommandHandler("help", self._forward_command))
|
||||||
|
|
||||||
# Add message handler for text, photos, voice, documents
|
# Add message handler for text, photos, voice, documents
|
||||||
self._app.add_handler(
|
self._app.add_handler(
|
||||||
@ -226,40 +222,15 @@ class TelegramChannel(BaseChannel):
|
|||||||
"Type /help to see available commands."
|
"Type /help to see available commands."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle /reset command — clear conversation history."""
|
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
return
|
return
|
||||||
|
await self._handle_message(
|
||||||
chat_id = str(update.message.chat_id)
|
sender_id=str(update.effective_user.id),
|
||||||
session_key = f"{self.name}:{chat_id}"
|
chat_id=str(update.message.chat_id),
|
||||||
|
content=update.message.text,
|
||||||
if self.session_manager is None:
|
|
||||||
logger.warning("/reset called but session_manager is not available")
|
|
||||||
await update.message.reply_text("⚠️ Session management is not available.")
|
|
||||||
return
|
|
||||||
|
|
||||||
session = self.session_manager.get_or_create(session_key)
|
|
||||||
msg_count = len(session.messages)
|
|
||||||
session.clear()
|
|
||||||
self.session_manager.save(session)
|
|
||||||
|
|
||||||
logger.info(f"Session reset for {session_key} (cleared {msg_count} messages)")
|
|
||||||
await update.message.reply_text("🔄 Conversation history cleared. Let's start fresh!")
|
|
||||||
|
|
||||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
|
||||||
"""Handle /help command — show available commands."""
|
|
||||||
if not update.message:
|
|
||||||
return
|
|
||||||
|
|
||||||
help_text = (
|
|
||||||
"🐈 <b>nanobot commands</b>\n\n"
|
|
||||||
"/start — Start the bot\n"
|
|
||||||
"/reset — Reset conversation history\n"
|
|
||||||
"/help — Show this help message\n\n"
|
|
||||||
"Just send me a text message to chat!"
|
|
||||||
)
|
)
|
||||||
await update.message.reply_text(help_text, parse_mode="HTML")
|
|
||||||
|
|
||||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||||
@ -386,6 +357,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Typing indicator stopped for {chat_id}: {e}")
|
logger.debug(f"Typing indicator stopped for {chat_id}: {e}")
|
||||||
|
|
||||||
|
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
"""Log polling / handler errors instead of silently swallowing them."""
|
||||||
|
logger.error(f"Telegram error: {context.error}")
|
||||||
|
|
||||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||||
"""Get file extension based on media type."""
|
"""Get file extension based on media type."""
|
||||||
if mime_type:
|
if mime_type:
|
||||||
|
|||||||
@ -42,6 +42,9 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
async with websockets.connect(bridge_url) as ws:
|
async with websockets.connect(bridge_url) as ws:
|
||||||
self._ws = ws
|
self._ws = ws
|
||||||
|
# Send auth token if configured
|
||||||
|
if self.config.bridge_token:
|
||||||
|
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||||
self._connected = True
|
self._connected = True
|
||||||
logger.info("Connected to WhatsApp bridge")
|
logger.info("Connected to WhatsApp bridge")
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""CLI commands for nanobot."""
|
"""CLI commands for nanobot."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -11,10 +10,14 @@ import sys
|
|||||||
import typer
|
import typer
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.formatted_text import HTML
|
||||||
|
from prompt_toolkit.history import FileHistory
|
||||||
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
|
|
||||||
from nanobot import __version__, __logo__
|
from nanobot import __version__, __logo__
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
@ -28,13 +31,10 @@ console = Console()
|
|||||||
EXIT_COMMANDS = {"exit", "quit", "/exit", "/quit", ":q"}
|
EXIT_COMMANDS = {"exit", "quit", "/exit", "/quit", ":q"}
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Lightweight CLI input: readline for arrow keys / history, termios for flush
|
# CLI input: prompt_toolkit for editing, paste, history, and display
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_READLINE = None
|
_PROMPT_SESSION: PromptSession | None = None
|
||||||
_HISTORY_FILE: Path | None = None
|
|
||||||
_HISTORY_HOOK_REGISTERED = False
|
|
||||||
_USING_LIBEDIT = False
|
|
||||||
_SAVED_TERM_ATTRS = None # original termios settings, restored on exit
|
_SAVED_TERM_ATTRS = None # original termios settings, restored on exit
|
||||||
|
|
||||||
|
|
||||||
@ -65,15 +65,6 @@ def _flush_pending_tty_input() -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def _save_history() -> None:
|
|
||||||
if _READLINE is None or _HISTORY_FILE is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
_READLINE.write_history_file(str(_HISTORY_FILE))
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def _restore_terminal() -> None:
|
def _restore_terminal() -> None:
|
||||||
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
||||||
if _SAVED_TERM_ATTRS is None:
|
if _SAVED_TERM_ATTRS is None:
|
||||||
@ -85,11 +76,11 @@ def _restore_terminal() -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _enable_line_editing() -> None:
|
def _init_prompt_session() -> None:
|
||||||
"""Enable readline for arrow keys, line editing, and persistent history."""
|
"""Create the prompt_toolkit session with persistent file history."""
|
||||||
global _READLINE, _HISTORY_FILE, _HISTORY_HOOK_REGISTERED, _USING_LIBEDIT, _SAVED_TERM_ATTRS
|
global _PROMPT_SESSION, _SAVED_TERM_ATTRS
|
||||||
|
|
||||||
# Save terminal state before readline touches it
|
# Save terminal state so we can restore it on exit
|
||||||
try:
|
try:
|
||||||
import termios
|
import termios
|
||||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||||
@ -98,43 +89,12 @@ def _enable_line_editing() -> None:
|
|||||||
|
|
||||||
history_file = Path.home() / ".nanobot" / "history" / "cli_history"
|
history_file = Path.home() / ".nanobot" / "history" / "cli_history"
|
||||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
_HISTORY_FILE = history_file
|
|
||||||
|
|
||||||
try:
|
_PROMPT_SESSION = PromptSession(
|
||||||
import readline
|
history=FileHistory(str(history_file)),
|
||||||
except ImportError:
|
enable_open_in_editor=False,
|
||||||
return
|
multiline=False, # Enter submits (single line mode)
|
||||||
|
)
|
||||||
_READLINE = readline
|
|
||||||
_USING_LIBEDIT = "libedit" in (readline.__doc__ or "").lower()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if _USING_LIBEDIT:
|
|
||||||
readline.parse_and_bind("bind ^I rl_complete")
|
|
||||||
else:
|
|
||||||
readline.parse_and_bind("tab: complete")
|
|
||||||
readline.parse_and_bind("set editing-mode emacs")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
readline.read_history_file(str(history_file))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not _HISTORY_HOOK_REGISTERED:
|
|
||||||
atexit.register(_save_history)
|
|
||||||
_HISTORY_HOOK_REGISTERED = True
|
|
||||||
|
|
||||||
|
|
||||||
def _prompt_text() -> str:
|
|
||||||
"""Build a readline-friendly colored prompt."""
|
|
||||||
if _READLINE is None:
|
|
||||||
return "You: "
|
|
||||||
# libedit on macOS does not honor GNU readline non-printing markers.
|
|
||||||
if _USING_LIBEDIT:
|
|
||||||
return "\033[1;34mYou:\033[0m "
|
|
||||||
return "\001\033[1;34m\002You:\001\033[0m\002 "
|
|
||||||
|
|
||||||
|
|
||||||
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||||
@ -142,15 +102,8 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
|
|||||||
content = response or ""
|
content = response or ""
|
||||||
body = Markdown(content) if render_markdown else Text(content)
|
body = Markdown(content) if render_markdown else Text(content)
|
||||||
console.print()
|
console.print()
|
||||||
console.print(
|
console.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||||
Panel(
|
console.print(body)
|
||||||
body,
|
|
||||||
title=f"{__logo__} nanobot",
|
|
||||||
title_align="left",
|
|
||||||
border_style="cyan",
|
|
||||||
padding=(0, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
@ -160,13 +113,25 @@ def _is_exit_command(command: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def _read_interactive_input_async() -> str:
|
async def _read_interactive_input_async() -> str:
|
||||||
"""Read user input with arrow keys and history (runs input() in a thread)."""
|
"""Read user input using prompt_toolkit (handles paste, history, display).
|
||||||
|
|
||||||
|
prompt_toolkit natively handles:
|
||||||
|
- Multiline paste (bracketed paste mode)
|
||||||
|
- History navigation (up/down arrows)
|
||||||
|
- Clean display (no ghost characters or artifacts)
|
||||||
|
"""
|
||||||
|
if _PROMPT_SESSION is None:
|
||||||
|
raise RuntimeError("Call _init_prompt_session() first")
|
||||||
try:
|
try:
|
||||||
return await asyncio.to_thread(input, _prompt_text())
|
with patch_stdout():
|
||||||
|
return await _PROMPT_SESSION.prompt_async(
|
||||||
|
HTML("<b fg='ansiblue'>You:</b> "),
|
||||||
|
)
|
||||||
except EOFError as exc:
|
except EOFError as exc:
|
||||||
raise KeyboardInterrupt from exc
|
raise KeyboardInterrupt from exc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def version_callback(value: bool):
|
def version_callback(value: bool):
|
||||||
if value:
|
if value:
|
||||||
console.print(f"{__logo__} nanobot v{__version__}")
|
console.print(f"{__logo__} nanobot v{__version__}")
|
||||||
@ -191,7 +156,7 @@ def main(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def onboard():
|
def onboard():
|
||||||
"""Initialize nanobot configuration and workspace."""
|
"""Initialize nanobot configuration and workspace."""
|
||||||
from nanobot.config.loader import get_config_path, save_config
|
from nanobot.config.loader import get_config_path, load_config, save_config
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.utils.helpers import get_workspace_path
|
from nanobot.utils.helpers import get_workspace_path
|
||||||
|
|
||||||
@ -199,17 +164,26 @@ def onboard():
|
|||||||
|
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||||
if not typer.confirm("Overwrite?"):
|
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||||
raise typer.Exit()
|
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
|
||||||
|
if typer.confirm("Overwrite?"):
|
||||||
# Create default config
|
config = Config()
|
||||||
config = Config()
|
save_config(config)
|
||||||
save_config(config)
|
console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
|
||||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
else:
|
||||||
|
config = load_config()
|
||||||
|
save_config(config)
|
||||||
|
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
|
||||||
|
else:
|
||||||
|
save_config(Config())
|
||||||
|
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||||
|
|
||||||
# Create workspace
|
# Create workspace
|
||||||
workspace = get_workspace_path()
|
workspace = get_workspace_path()
|
||||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
|
||||||
|
if not workspace.exists():
|
||||||
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||||
|
|
||||||
# Create default bootstrap files
|
# Create default bootstrap files
|
||||||
_create_workspace_templates(workspace)
|
_create_workspace_templates(workspace)
|
||||||
@ -236,7 +210,7 @@ You are a helpful AI assistant. Be concise, accurate, and friendly.
|
|||||||
- Always explain what you're doing before taking actions
|
- Always explain what you're doing before taking actions
|
||||||
- Ask for clarification when the request is ambiguous
|
- Ask for clarification when the request is ambiguous
|
||||||
- Use tools to help accomplish tasks
|
- Use tools to help accomplish tasks
|
||||||
- Remember important information in your memory files
|
- Remember important information in memory/MEMORY.md; past events are logged in memory/HISTORY.md
|
||||||
""",
|
""",
|
||||||
"SOUL.md": """# Soul
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
@ -294,6 +268,15 @@ This file stores important information that should persist across sessions.
|
|||||||
(Things to remember)
|
(Things to remember)
|
||||||
""")
|
""")
|
||||||
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
||||||
|
|
||||||
|
history_file = memory_dir / "HISTORY.md"
|
||||||
|
if not history_file.exists():
|
||||||
|
history_file.write_text("")
|
||||||
|
console.print(" [dim]Created memory/HISTORY.md[/dim]")
|
||||||
|
|
||||||
|
# Create skills directory for custom user skills
|
||||||
|
skills_dir = workspace / "skills"
|
||||||
|
skills_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
@ -367,12 +350,16 @@ def gateway(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
|
temperature=config.agents.defaults.temperature,
|
||||||
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
|
memory_window=config.agents.defaults.memory_window,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
|
mcp_servers=config.tools.mcp_servers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set cron callback (needs agent)
|
# Set cron callback (needs agent)
|
||||||
@ -407,7 +394,7 @@ def gateway(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create channel manager
|
# Create channel manager
|
||||||
channels = ChannelManager(config, bus, session_manager=session_manager)
|
channels = ChannelManager(config, bus)
|
||||||
|
|
||||||
if channels.enabled_channels:
|
if channels.enabled_channels:
|
||||||
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
||||||
@ -430,6 +417,8 @@ def gateway(
|
|||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\nShutting down...")
|
console.print("\nShutting down...")
|
||||||
|
finally:
|
||||||
|
await agent.close_mcp()
|
||||||
heartbeat.stop()
|
heartbeat.stop()
|
||||||
cron.stop()
|
cron.stop()
|
||||||
agent.stop()
|
agent.stop()
|
||||||
@ -448,7 +437,7 @@ def gateway(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def agent(
|
def agent(
|
||||||
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
|
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
|
||||||
session_id: str = typer.Option("cli:default", "--session", "-s", help="Session ID"),
|
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
|
||||||
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
||||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||||
):
|
):
|
||||||
@ -472,9 +461,15 @@ def agent(
|
|||||||
bus=bus,
|
bus=bus,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
|
model=config.agents.defaults.model,
|
||||||
|
temperature=config.agents.defaults.temperature,
|
||||||
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
|
memory_window=config.agents.defaults.memory_window,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
mcp_servers=config.tools.mcp_servers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
# Show spinner when logs are off (no output to miss); skip when logs are on
|
||||||
@ -482,6 +477,7 @@ def agent(
|
|||||||
if logs:
|
if logs:
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
# Animated spinner is safe to use with prompt_toolkit input handling
|
||||||
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
@ -490,17 +486,15 @@ def agent(
|
|||||||
with _thinking_ctx():
|
with _thinking_ctx():
|
||||||
response = await agent_loop.process_direct(message, session_id)
|
response = await agent_loop.process_direct(message, session_id)
|
||||||
_print_agent_response(response, render_markdown=markdown)
|
_print_agent_response(response, render_markdown=markdown)
|
||||||
|
await agent_loop.close_mcp()
|
||||||
|
|
||||||
asyncio.run(run_once())
|
asyncio.run(run_once())
|
||||||
else:
|
else:
|
||||||
# Interactive mode
|
# Interactive mode
|
||||||
_enable_line_editing()
|
_init_prompt_session()
|
||||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||||
|
|
||||||
# input() runs in a worker thread that can't be cancelled.
|
|
||||||
# Without this handler, asyncio.run() would hang waiting for it.
|
|
||||||
def _exit_on_sigint(signum, frame):
|
def _exit_on_sigint(signum, frame):
|
||||||
_save_history()
|
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
@ -508,33 +502,33 @@ def agent(
|
|||||||
signal.signal(signal.SIGINT, _exit_on_sigint)
|
signal.signal(signal.SIGINT, _exit_on_sigint)
|
||||||
|
|
||||||
async def run_interactive():
|
async def run_interactive():
|
||||||
while True:
|
try:
|
||||||
try:
|
while True:
|
||||||
_flush_pending_tty_input()
|
try:
|
||||||
user_input = await _read_interactive_input_async()
|
_flush_pending_tty_input()
|
||||||
command = user_input.strip()
|
user_input = await _read_interactive_input_async()
|
||||||
if not command:
|
command = user_input.strip()
|
||||||
continue
|
if not command:
|
||||||
|
continue
|
||||||
|
|
||||||
if _is_exit_command(command):
|
if _is_exit_command(command):
|
||||||
_save_history()
|
_restore_terminal()
|
||||||
|
console.print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
with _thinking_ctx():
|
||||||
|
response = await agent_loop.process_direct(user_input, session_id)
|
||||||
|
_print_agent_response(response, render_markdown=markdown)
|
||||||
|
except KeyboardInterrupt:
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
except EOFError:
|
||||||
with _thinking_ctx():
|
_restore_terminal()
|
||||||
response = await agent_loop.process_direct(user_input, session_id)
|
console.print("\nGoodbye!")
|
||||||
_print_agent_response(response, render_markdown=markdown)
|
break
|
||||||
except KeyboardInterrupt:
|
finally:
|
||||||
_save_history()
|
await agent_loop.close_mcp()
|
||||||
_restore_terminal()
|
|
||||||
console.print("\nGoodbye!")
|
|
||||||
break
|
|
||||||
except EOFError:
|
|
||||||
_save_history()
|
|
||||||
_restore_terminal()
|
|
||||||
console.print("\nGoodbye!")
|
|
||||||
break
|
|
||||||
|
|
||||||
asyncio.run(run_interactive())
|
asyncio.run(run_interactive())
|
||||||
|
|
||||||
@ -574,6 +568,24 @@ def channels_status():
|
|||||||
"✓" if dc.enabled else "✗",
|
"✓" if dc.enabled else "✗",
|
||||||
dc.gateway_url
|
dc.gateway_url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Feishu
|
||||||
|
fs = config.channels.feishu
|
||||||
|
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"Feishu",
|
||||||
|
"✓" if fs.enabled else "✗",
|
||||||
|
fs_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mochat
|
||||||
|
mc = config.channels.mochat
|
||||||
|
mc_base = mc.base_url or "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"Mochat",
|
||||||
|
"✓" if mc.enabled else "✗",
|
||||||
|
mc_base
|
||||||
|
)
|
||||||
|
|
||||||
# Telegram
|
# Telegram
|
||||||
tg = config.channels.telegram
|
tg = config.channels.telegram
|
||||||
@ -658,14 +670,20 @@ def _get_bridge_dir() -> Path:
|
|||||||
def channels_login():
|
def channels_login():
|
||||||
"""Link device via QR code."""
|
"""Link device via QR code."""
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from nanobot.config.loader import load_config
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
bridge_dir = _get_bridge_dir()
|
bridge_dir = _get_bridge_dir()
|
||||||
|
|
||||||
console.print(f"{__logo__} Starting bridge...")
|
console.print(f"{__logo__} Starting bridge...")
|
||||||
console.print("Scan the QR code to connect.\n")
|
console.print("Scan the QR code to connect.\n")
|
||||||
|
|
||||||
|
env = {**os.environ}
|
||||||
|
if config.channels.whatsapp.bridge_token:
|
||||||
|
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True)
|
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
console.print(f"[red]Bridge failed: {e}[/red]")
|
console.print(f"[red]Bridge failed: {e}[/red]")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Configuration schema using Pydantic."""
|
"""Configuration schema using Pydantic."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
@ -9,6 +9,7 @@ class WhatsAppConfig(BaseModel):
|
|||||||
"""WhatsApp channel configuration."""
|
"""WhatsApp channel configuration."""
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
bridge_url: str = "ws://localhost:3001"
|
bridge_url: str = "ws://localhost:3001"
|
||||||
|
bridge_token: str = "" # Shared token for bridge auth (optional, recommended)
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
|
allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +78,42 @@ class EmailConfig(BaseModel):
|
|||||||
allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
|
allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
|
||||||
|
|
||||||
|
|
||||||
|
class MochatMentionConfig(BaseModel):
|
||||||
|
"""Mochat mention behavior configuration."""
|
||||||
|
require_in_groups: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MochatGroupRule(BaseModel):
|
||||||
|
"""Mochat per-group mention requirement."""
|
||||||
|
require_mention: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MochatConfig(BaseModel):
|
||||||
|
"""Mochat channel configuration."""
|
||||||
|
enabled: bool = False
|
||||||
|
base_url: str = "https://mochat.io"
|
||||||
|
socket_url: str = ""
|
||||||
|
socket_path: str = "/socket.io"
|
||||||
|
socket_disable_msgpack: bool = False
|
||||||
|
socket_reconnect_delay_ms: int = 1000
|
||||||
|
socket_max_reconnect_delay_ms: int = 10000
|
||||||
|
socket_connect_timeout_ms: int = 10000
|
||||||
|
refresh_interval_ms: int = 30000
|
||||||
|
watch_timeout_ms: int = 25000
|
||||||
|
watch_limit: int = 100
|
||||||
|
retry_delay_ms: int = 500
|
||||||
|
max_retry_attempts: int = 0 # 0 means unlimited retries
|
||||||
|
claw_token: str = ""
|
||||||
|
agent_user_id: str = ""
|
||||||
|
sessions: list[str] = Field(default_factory=list)
|
||||||
|
panels: list[str] = Field(default_factory=list)
|
||||||
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
|
mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
|
||||||
|
groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
|
||||||
|
reply_delay_mode: str = "non-mention" # off | non-mention
|
||||||
|
reply_delay_ms: int = 120000
|
||||||
|
|
||||||
|
|
||||||
class SlackDMConfig(BaseModel):
|
class SlackDMConfig(BaseModel):
|
||||||
"""Slack DM policy configuration."""
|
"""Slack DM policy configuration."""
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
@ -92,7 +129,7 @@ class SlackConfig(BaseModel):
|
|||||||
bot_token: str = "" # xoxb-...
|
bot_token: str = "" # xoxb-...
|
||||||
app_token: str = "" # xapp-...
|
app_token: str = "" # xapp-...
|
||||||
user_token_read_only: bool = True
|
user_token_read_only: bool = True
|
||||||
group_policy: str = "open" # "open", "mention", "allowlist"
|
group_policy: str = "mention" # "mention", "open", "allowlist"
|
||||||
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
||||||
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
||||||
|
|
||||||
@ -111,6 +148,7 @@ class ChannelsConfig(BaseModel):
|
|||||||
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
||||||
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
||||||
feishu: FeishuConfig = Field(default_factory=FeishuConfig)
|
feishu: FeishuConfig = Field(default_factory=FeishuConfig)
|
||||||
|
mochat: MochatConfig = Field(default_factory=MochatConfig)
|
||||||
dingtalk: DingTalkConfig = Field(default_factory=DingTalkConfig)
|
dingtalk: DingTalkConfig = Field(default_factory=DingTalkConfig)
|
||||||
email: EmailConfig = Field(default_factory=EmailConfig)
|
email: EmailConfig = Field(default_factory=EmailConfig)
|
||||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||||
@ -124,6 +162,7 @@ class AgentDefaults(BaseModel):
|
|||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
max_tool_iterations: int = 20
|
max_tool_iterations: int = 20
|
||||||
|
memory_window: int = 50
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfig(BaseModel):
|
class AgentsConfig(BaseModel):
|
||||||
@ -140,6 +179,7 @@ class ProviderConfig(BaseModel):
|
|||||||
|
|
||||||
class ProvidersConfig(BaseModel):
|
class ProvidersConfig(BaseModel):
|
||||||
"""Configuration for LLM providers."""
|
"""Configuration for LLM providers."""
|
||||||
|
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
@ -150,6 +190,7 @@ class ProvidersConfig(BaseModel):
|
|||||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) # AiHubMix API gateway
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) # AiHubMix API gateway
|
||||||
|
|
||||||
@ -176,11 +217,20 @@ class ExecToolConfig(BaseModel):
|
|||||||
timeout: int = 60
|
timeout: int = 60
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerConfig(BaseModel):
|
||||||
|
"""MCP server connection configuration (stdio or HTTP)."""
|
||||||
|
command: str = "" # Stdio: command to run (e.g. "npx")
|
||||||
|
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
||||||
|
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||||
|
url: str = "" # HTTP: streamable HTTP endpoint URL
|
||||||
|
|
||||||
|
|
||||||
class ToolsConfig(BaseModel):
|
class ToolsConfig(BaseModel):
|
||||||
"""Tools configuration."""
|
"""Tools configuration."""
|
||||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||||
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||||
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
|
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
|
||||||
|
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseSettings):
|
class Config(BaseSettings):
|
||||||
@ -248,6 +298,7 @@ class Config(BaseSettings):
|
|||||||
return spec.default_api_base
|
return spec.default_api_base
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
env_prefix = "NANOBOT_"
|
env_prefix="NANOBOT_",
|
||||||
env_nested_delimiter = "__"
|
env_nested_delimiter="__"
|
||||||
|
)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
@ -30,9 +31,13 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
|||||||
if schedule.kind == "cron" and schedule.expr:
|
if schedule.kind == "cron" and schedule.expr:
|
||||||
try:
|
try:
|
||||||
from croniter import croniter
|
from croniter import croniter
|
||||||
cron = croniter(schedule.expr, time.time())
|
from zoneinfo import ZoneInfo
|
||||||
next_time = cron.get_next()
|
base_time = time.time()
|
||||||
return int(next_time * 1000)
|
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
||||||
|
base_dt = datetime.fromtimestamp(base_time, tz=tz)
|
||||||
|
cron = croniter(schedule.expr, base_dt)
|
||||||
|
next_dt = cron.get_next(datetime)
|
||||||
|
return int(next_dt.timestamp() * 1000)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""LiteLLM provider implementation for multi-provider support."""
|
"""LiteLLM provider implementation for multi-provider support."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import json_repair
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -15,7 +16,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"""
|
"""
|
||||||
LLM provider using LiteLLM for multi-provider support.
|
LLM provider using LiteLLM for multi-provider support.
|
||||||
|
|
||||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, and many other providers through
|
Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through
|
||||||
a unified interface. Provider-specific logic is driven by the registry
|
a unified interface. Provider-specific logic is driven by the registry
|
||||||
(see providers/registry.py) — no if-elif chains needed here.
|
(see providers/registry.py) — no if-elif chains needed here.
|
||||||
"""
|
"""
|
||||||
@ -125,6 +126,10 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"""
|
"""
|
||||||
model = self._resolve_model(model or self.default_model)
|
model = self._resolve_model(model or self.default_model)
|
||||||
|
|
||||||
|
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||||
|
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||||
|
max_tokens = max(1, max_tokens)
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@ -135,6 +140,10 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||||
self._apply_model_overrides(model, kwargs)
|
self._apply_model_overrides(model, kwargs)
|
||||||
|
|
||||||
|
# Pass api_key directly — more reliable than env vars alone
|
||||||
|
if self.api_key:
|
||||||
|
kwargs["api_key"] = self.api_key
|
||||||
|
|
||||||
# Pass api_base for custom endpoints
|
# Pass api_base for custom endpoints
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
kwargs["api_base"] = self.api_base
|
kwargs["api_base"] = self.api_base
|
||||||
@ -168,10 +177,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
# Parse arguments from JSON string if needed
|
# Parse arguments from JSON string if needed
|
||||||
args = tc.function.arguments
|
args = tc.function.arguments
|
||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
try:
|
args = json_repair.loads(args)
|
||||||
args = json.loads(args)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
args = {"raw": args}
|
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=tc.id,
|
id=tc.id,
|
||||||
|
|||||||
@ -66,6 +66,20 @@ class ProviderSpec:
|
|||||||
|
|
||||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||||
|
|
||||||
|
# === Custom (user-provided OpenAI-compatible endpoint) =================
|
||||||
|
# No auto-detection — only activates when user explicitly configures "custom".
|
||||||
|
|
||||||
|
ProviderSpec(
|
||||||
|
name="custom",
|
||||||
|
keywords=(),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="Custom",
|
||||||
|
litellm_prefix="openai",
|
||||||
|
skip_prefixes=("openai/",),
|
||||||
|
is_gateway=True,
|
||||||
|
strip_model_prefix=True,
|
||||||
|
),
|
||||||
|
|
||||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||||
# Gateways can route any model, so they win in fallback.
|
# Gateways can route any model, so they win in fallback.
|
||||||
|
|
||||||
@ -265,6 +279,25 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||||
|
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||||
|
ProviderSpec(
|
||||||
|
name="minimax",
|
||||||
|
keywords=("minimax",),
|
||||||
|
env_key="MINIMAX_API_KEY",
|
||||||
|
display_name="MiniMax",
|
||||||
|
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||||
|
skip_prefixes=("minimax/", "openrouter/"),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=False,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="",
|
||||||
|
default_api_base="https://api.minimax.io/v1",
|
||||||
|
strip_model_prefix=False,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||||
|
|
||||||
# vLLM / any OpenAI-compatible local server.
|
# vLLM / any OpenAI-compatible local server.
|
||||||
|
|||||||
@ -15,15 +15,20 @@ from nanobot.utils.helpers import ensure_dir, safe_filename
|
|||||||
class Session:
|
class Session:
|
||||||
"""
|
"""
|
||||||
A conversation session.
|
A conversation session.
|
||||||
|
|
||||||
Stores messages in JSONL format for easy reading and persistence.
|
Stores messages in JSONL format for easy reading and persistence.
|
||||||
|
|
||||||
|
Important: Messages are append-only for LLM cache efficiency.
|
||||||
|
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||||
|
but does NOT modify the messages list or get_history() output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key: str # channel:chat_id
|
key: str # channel:chat_id
|
||||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
updated_at: datetime = field(default_factory=datetime.now)
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||||
|
|
||||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||||
"""Add a message to the session."""
|
"""Add a message to the session."""
|
||||||
@ -36,35 +41,24 @@ class Session:
|
|||||||
self.messages.append(msg)
|
self.messages.append(msg)
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 50) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Get recent messages in LLM format (role + content only)."""
|
||||||
Get message history for LLM context.
|
return [{"role": m["role"], "content": m["content"]} for m in self.messages[-max_messages:]]
|
||||||
|
|
||||||
Args:
|
|
||||||
max_messages: Maximum messages to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages in LLM format.
|
|
||||||
"""
|
|
||||||
# Get recent messages
|
|
||||||
recent = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
|
|
||||||
|
|
||||||
# Convert to LLM format (just role and content)
|
|
||||||
return [{"role": m["role"], "content": m["content"]} for m in recent]
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all messages in the session."""
|
"""Clear all messages and reset session to initial state."""
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
self.last_consolidated = 0
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""
|
"""
|
||||||
Manages conversation sessions.
|
Manages conversation sessions.
|
||||||
|
|
||||||
Sessions are stored as JSONL files in the sessions directory.
|
Sessions are stored as JSONL files in the sessions directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions")
|
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions")
|
||||||
@ -85,11 +79,9 @@ class SessionManager:
|
|||||||
Returns:
|
Returns:
|
||||||
The session.
|
The session.
|
||||||
"""
|
"""
|
||||||
# Check cache
|
|
||||||
if key in self._cache:
|
if key in self._cache:
|
||||||
return self._cache[key]
|
return self._cache[key]
|
||||||
|
|
||||||
# Try to load from disk
|
|
||||||
session = self._load(key)
|
session = self._load(key)
|
||||||
if session is None:
|
if session is None:
|
||||||
session = Session(key=key)
|
session = Session(key=key)
|
||||||
@ -100,34 +92,37 @@ class SessionManager:
|
|||||||
def _load(self, key: str) -> Session | None:
|
def _load(self, key: str) -> Session | None:
|
||||||
"""Load a session from disk."""
|
"""Load a session from disk."""
|
||||||
path = self._get_session_path(key)
|
path = self._get_session_path(key)
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = []
|
messages = []
|
||||||
metadata = {}
|
metadata = {}
|
||||||
created_at = None
|
created_at = None
|
||||||
|
last_consolidated = 0
|
||||||
|
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
|
|
||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||||
|
last_consolidated = data.get("last_consolidated", 0)
|
||||||
else:
|
else:
|
||||||
messages.append(data)
|
messages.append(data)
|
||||||
|
|
||||||
return Session(
|
return Session(
|
||||||
key=key,
|
key=key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
created_at=created_at or datetime.now(),
|
created_at=created_at or datetime.now(),
|
||||||
metadata=metadata
|
metadata=metadata,
|
||||||
|
last_consolidated=last_consolidated
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load session {key}: {e}")
|
logger.warning(f"Failed to load session {key}: {e}")
|
||||||
@ -136,42 +131,24 @@ class SessionManager:
|
|||||||
def save(self, session: Session) -> None:
|
def save(self, session: Session) -> None:
|
||||||
"""Save a session to disk."""
|
"""Save a session to disk."""
|
||||||
path = self._get_session_path(session.key)
|
path = self._get_session_path(session.key)
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
# Write metadata first
|
|
||||||
metadata_line = {
|
metadata_line = {
|
||||||
"_type": "metadata",
|
"_type": "metadata",
|
||||||
"created_at": session.created_at.isoformat(),
|
"created_at": session.created_at.isoformat(),
|
||||||
"updated_at": session.updated_at.isoformat(),
|
"updated_at": session.updated_at.isoformat(),
|
||||||
"metadata": session.metadata
|
"metadata": session.metadata,
|
||||||
|
"last_consolidated": session.last_consolidated
|
||||||
}
|
}
|
||||||
f.write(json.dumps(metadata_line) + "\n")
|
f.write(json.dumps(metadata_line) + "\n")
|
||||||
|
|
||||||
# Write messages
|
|
||||||
for msg in session.messages:
|
for msg in session.messages:
|
||||||
f.write(json.dumps(msg) + "\n")
|
f.write(json.dumps(msg) + "\n")
|
||||||
|
|
||||||
self._cache[session.key] = session
|
self._cache[session.key] = session
|
||||||
|
|
||||||
def delete(self, key: str) -> bool:
|
def invalidate(self, key: str) -> None:
|
||||||
"""
|
"""Remove a session from the in-memory cache."""
|
||||||
Delete a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Session key.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found.
|
|
||||||
"""
|
|
||||||
# Remove from cache
|
|
||||||
self._cache.pop(key, None)
|
self._cache.pop(key, None)
|
||||||
|
|
||||||
# Remove file
|
|
||||||
path = self._get_session_path(key)
|
|
||||||
if path.exists():
|
|
||||||
path.unlink()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_sessions(self) -> list[dict[str, Any]]:
|
def list_sessions(self) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,10 +7,11 @@ description: Schedule reminders and recurring tasks.
|
|||||||
|
|
||||||
Use the `cron` tool to schedule reminders or recurring tasks.
|
Use the `cron` tool to schedule reminders or recurring tasks.
|
||||||
|
|
||||||
## Two Modes
|
## Three Modes
|
||||||
|
|
||||||
1. **Reminder** - message is sent directly to user
|
1. **Reminder** - message is sent directly to user
|
||||||
2. **Task** - message is a task description, agent executes and sends result
|
2. **Task** - message is a task description, agent executes and sends result
|
||||||
|
3. **One-time** - runs once at a specific time, then auto-deletes
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
@ -24,6 +25,11 @@ Dynamic task (agent executes each time):
|
|||||||
cron(action="add", message="Check HKUDS/nanobot GitHub stars and report", every_seconds=600)
|
cron(action="add", message="Check HKUDS/nanobot GitHub stars and report", every_seconds=600)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
One-time scheduled task (compute ISO datetime from current time):
|
||||||
|
```
|
||||||
|
cron(action="add", message="Remind me about the meeting", at="<ISO datetime>")
|
||||||
|
```
|
||||||
|
|
||||||
List/remove:
|
List/remove:
|
||||||
```
|
```
|
||||||
cron(action="list")
|
cron(action="list")
|
||||||
@ -38,3 +44,4 @@ cron(action="remove", job_id="abc123")
|
|||||||
| every hour | every_seconds: 3600 |
|
| every hour | every_seconds: 3600 |
|
||||||
| every day at 8am | cron_expr: "0 8 * * *" |
|
| every day at 8am | cron_expr: "0 8 * * *" |
|
||||||
| weekdays at 5pm | cron_expr: "0 17 * * 1-5" |
|
| weekdays at 5pm | cron_expr: "0 17 * * 1-5" |
|
||||||
|
| at a specific time | at: ISO datetime string (compute from current time) |
|
||||||
|
|||||||
31
nanobot/skills/memory/SKILL.md
Normal file
31
nanobot/skills/memory/SKILL.md
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
---
|
||||||
|
name: memory
|
||||||
|
description: Two-layer memory system with grep-based recall.
|
||||||
|
always: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# Memory
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
|
||||||
|
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep.
|
||||||
|
|
||||||
|
## Search Past Events
|
||||||
|
|
||||||
|
```bash
|
||||||
|
grep -i "keyword" memory/HISTORY.md
|
||||||
|
```
|
||||||
|
|
||||||
|
Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md`
|
||||||
|
|
||||||
|
## When to Update MEMORY.md
|
||||||
|
|
||||||
|
Write important facts immediately using `edit_file` or `write_file`:
|
||||||
|
- User preferences ("I prefer dark mode")
|
||||||
|
- Project context ("The API uses OAuth2")
|
||||||
|
- Relationships ("Alice is the project lead")
|
||||||
|
|
||||||
|
## Auto-consolidation
|
||||||
|
|
||||||
|
Old conversations are automatically summarized and appended to HISTORY.md when the session grows large. Long-term facts are extracted to MEMORY.md. You don't need to manage this.
|
||||||
@ -37,23 +37,12 @@ def get_sessions_path() -> Path:
|
|||||||
return ensure_dir(get_data_path() / "sessions")
|
return ensure_dir(get_data_path() / "sessions")
|
||||||
|
|
||||||
|
|
||||||
def get_memory_path(workspace: Path | None = None) -> Path:
|
|
||||||
"""Get the memory directory within the workspace."""
|
|
||||||
ws = workspace or get_workspace_path()
|
|
||||||
return ensure_dir(ws / "memory")
|
|
||||||
|
|
||||||
|
|
||||||
def get_skills_path(workspace: Path | None = None) -> Path:
|
def get_skills_path(workspace: Path | None = None) -> Path:
|
||||||
"""Get the skills directory within the workspace."""
|
"""Get the skills directory within the workspace."""
|
||||||
ws = workspace or get_workspace_path()
|
ws = workspace or get_workspace_path()
|
||||||
return ensure_dir(ws / "skills")
|
return ensure_dir(ws / "skills")
|
||||||
|
|
||||||
|
|
||||||
def today_date() -> str:
|
|
||||||
"""Get today's date in YYYY-MM-DD format."""
|
|
||||||
return datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
|
|
||||||
def timestamp() -> str:
|
def timestamp() -> str:
|
||||||
"""Get current timestamp in ISO format."""
|
"""Get current timestamp in ISO format."""
|
||||||
return datetime.now().isoformat()
|
return datetime.now().isoformat()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "nanobot-ai"
|
name = "nanobot-ai"
|
||||||
version = "0.1.3.post5"
|
version = "0.1.3.post7"
|
||||||
description = "A lightweight personal AI assistant framework"
|
description = "A lightweight personal AI assistant framework"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
@ -33,8 +33,14 @@ dependencies = [
|
|||||||
"python-telegram-bot[socks]>=21.0",
|
"python-telegram-bot[socks]>=21.0",
|
||||||
"lark-oapi>=1.0.0",
|
"lark-oapi>=1.0.0",
|
||||||
"socksio>=1.0.0",
|
"socksio>=1.0.0",
|
||||||
|
"python-socketio>=5.11.0",
|
||||||
|
"msgpack>=1.0.8",
|
||||||
"slack-sdk>=3.26.0",
|
"slack-sdk>=3.26.0",
|
||||||
"qq-botpy>=1.0.0",
|
"qq-botpy>=1.0.0",
|
||||||
|
"python-socks[asyncio]>=2.4.0",
|
||||||
|
"prompt-toolkit>=3.0.0",
|
||||||
|
"mcp>=1.0.0",
|
||||||
|
"json-repair>=0.30.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
59
tests/test_cli_input.py
Normal file
59
tests/test_cli_input.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from prompt_toolkit.formatted_text import HTML
|
||||||
|
|
||||||
|
from nanobot.cli import commands
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prompt_session():
|
||||||
|
"""Mock the global prompt session."""
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session.prompt_async = AsyncMock()
|
||||||
|
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||||
|
patch("nanobot.cli.commands.patch_stdout"):
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_interactive_input_async_returns_input(mock_prompt_session):
|
||||||
|
"""Test that _read_interactive_input_async returns the user input from prompt_session."""
|
||||||
|
mock_prompt_session.prompt_async.return_value = "hello world"
|
||||||
|
|
||||||
|
result = await commands._read_interactive_input_async()
|
||||||
|
|
||||||
|
assert result == "hello world"
|
||||||
|
mock_prompt_session.prompt_async.assert_called_once()
|
||||||
|
args, _ = mock_prompt_session.prompt_async.call_args
|
||||||
|
assert isinstance(args[0], HTML) # Verify HTML prompt is used
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_interactive_input_async_handles_eof(mock_prompt_session):
|
||||||
|
"""Test that EOFError converts to KeyboardInterrupt."""
|
||||||
|
mock_prompt_session.prompt_async.side_effect = EOFError()
|
||||||
|
|
||||||
|
with pytest.raises(KeyboardInterrupt):
|
||||||
|
await commands._read_interactive_input_async()
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_prompt_session_creates_session():
|
||||||
|
"""Test that _init_prompt_session initializes the global session."""
|
||||||
|
# Ensure global is None before test
|
||||||
|
commands._PROMPT_SESSION = None
|
||||||
|
|
||||||
|
with patch("nanobot.cli.commands.PromptSession") as MockSession, \
|
||||||
|
patch("nanobot.cli.commands.FileHistory") as MockHistory, \
|
||||||
|
patch("pathlib.Path.home") as mock_home:
|
||||||
|
|
||||||
|
mock_home.return_value = MagicMock()
|
||||||
|
|
||||||
|
commands._init_prompt_session()
|
||||||
|
|
||||||
|
assert commands._PROMPT_SESSION is not None
|
||||||
|
MockSession.assert_called_once()
|
||||||
|
_, kwargs = MockSession.call_args
|
||||||
|
assert kwargs["multiline"] is False
|
||||||
|
assert kwargs["enable_open_in_editor"] is False
|
||||||
92
tests/test_commands.py
Normal file
92
tests/test_commands.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from nanobot.cli.commands import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_paths():
|
||||||
|
"""Mock config/workspace paths for test isolation."""
|
||||||
|
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||||
|
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||||
|
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||||
|
patch("nanobot.utils.helpers.get_workspace_path") as mock_ws:
|
||||||
|
|
||||||
|
base_dir = Path("./test_onboard_data")
|
||||||
|
if base_dir.exists():
|
||||||
|
shutil.rmtree(base_dir)
|
||||||
|
base_dir.mkdir()
|
||||||
|
|
||||||
|
config_file = base_dir / "config.json"
|
||||||
|
workspace_dir = base_dir / "workspace"
|
||||||
|
|
||||||
|
mock_cp.return_value = config_file
|
||||||
|
mock_ws.return_value = workspace_dir
|
||||||
|
mock_sc.side_effect = lambda config: config_file.write_text("{}")
|
||||||
|
|
||||||
|
yield config_file, workspace_dir
|
||||||
|
|
||||||
|
if base_dir.exists():
|
||||||
|
shutil.rmtree(base_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_onboard_fresh_install(mock_paths):
|
||||||
|
"""No existing config — should create from scratch."""
|
||||||
|
config_file, workspace_dir = mock_paths
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["onboard"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Created config" in result.stdout
|
||||||
|
assert "Created workspace" in result.stdout
|
||||||
|
assert "nanobot is ready" in result.stdout
|
||||||
|
assert config_file.exists()
|
||||||
|
assert (workspace_dir / "AGENTS.md").exists()
|
||||||
|
assert (workspace_dir / "memory" / "MEMORY.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_onboard_existing_config_refresh(mock_paths):
|
||||||
|
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
|
||||||
|
config_file, workspace_dir = mock_paths
|
||||||
|
config_file.write_text('{"existing": true}')
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Config already exists" in result.stdout
|
||||||
|
assert "existing values preserved" in result.stdout
|
||||||
|
assert workspace_dir.exists()
|
||||||
|
assert (workspace_dir / "AGENTS.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_onboard_existing_config_overwrite(mock_paths):
|
||||||
|
"""Config exists, user confirms overwrite — should reset to defaults."""
|
||||||
|
config_file, workspace_dir = mock_paths
|
||||||
|
config_file.write_text('{"existing": true}')
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["onboard"], input="y\n")
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Config already exists" in result.stdout
|
||||||
|
assert "Config reset to defaults" in result.stdout
|
||||||
|
assert workspace_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_onboard_existing_workspace_safe_create(mock_paths):
|
||||||
|
"""Workspace exists — should not recreate, but still add missing templates."""
|
||||||
|
config_file, workspace_dir = mock_paths
|
||||||
|
workspace_dir.mkdir(parents=True)
|
||||||
|
config_file.write_text("{}")
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Created workspace" not in result.stdout
|
||||||
|
assert "Created AGENTS.md" in result.stdout
|
||||||
|
assert (workspace_dir / "AGENTS.md").exists()
|
||||||
477
tests/test_consolidate_offset.py
Normal file
477
tests/test_consolidate_offset.py
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
"""Test session management with cache-friendly message handling."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
# Test constants
|
||||||
|
MEMORY_WINDOW = 50
|
||||||
|
KEEP_COUNT = MEMORY_WINDOW // 2 # 25
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_with_messages(key: str, count: int, role: str = "user") -> Session:
|
||||||
|
"""Create a session and add the specified number of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Session identifier
|
||||||
|
count: Number of messages to add
|
||||||
|
role: Message role (default: "user")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session with the specified messages
|
||||||
|
"""
|
||||||
|
session = Session(key=key)
|
||||||
|
for i in range(count):
|
||||||
|
session.add_message(role, f"msg{i}")
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def assert_messages_content(messages: list, start_index: int, end_index: int) -> None:
|
||||||
|
"""Assert that messages contain expected content from start to end index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
start_index: Expected first message index
|
||||||
|
end_index: Expected last message index
|
||||||
|
"""
|
||||||
|
assert len(messages) > 0
|
||||||
|
assert messages[0]["content"] == f"msg{start_index}"
|
||||||
|
assert messages[-1]["content"] == f"msg{end_index}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list:
|
||||||
|
"""Extract messages that would be consolidated using the standard slice logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: The session containing messages
|
||||||
|
last_consolidated: Index of last consolidated message
|
||||||
|
keep_count: Number of recent messages to keep
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages that would be consolidated
|
||||||
|
"""
|
||||||
|
return session.messages[last_consolidated:-keep_count]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionLastConsolidated:
|
||||||
|
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||||
|
|
||||||
|
def test_initial_last_consolidated_zero(self) -> None:
|
||||||
|
"""Test that new session starts with last_consolidated=0."""
|
||||||
|
session = Session(key="test:initial")
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||||
|
"""Test that last_consolidated persists across save/load."""
|
||||||
|
manager = SessionManager(Path(tmp_path))
|
||||||
|
session1 = create_session_with_messages("test:persist", 20)
|
||||||
|
session1.last_consolidated = 15
|
||||||
|
manager.save(session1)
|
||||||
|
|
||||||
|
session2 = manager.get_or_create("test:persist")
|
||||||
|
assert session2.last_consolidated == 15
|
||||||
|
assert len(session2.messages) == 20
|
||||||
|
|
||||||
|
def test_clear_resets_last_consolidated(self) -> None:
|
||||||
|
"""Test that clear() resets last_consolidated to 0."""
|
||||||
|
session = create_session_with_messages("test:clear", 10)
|
||||||
|
session.last_consolidated = 5
|
||||||
|
|
||||||
|
session.clear()
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionImmutableHistory:
|
||||||
|
"""Test Session message immutability for cache efficiency."""
|
||||||
|
|
||||||
|
def test_initial_state(self) -> None:
|
||||||
|
"""Test that new session has empty messages list."""
|
||||||
|
session = Session(key="test:initial")
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
|
||||||
|
def test_add_messages_appends_only(self) -> None:
|
||||||
|
"""Test that adding messages only appends, never modifies."""
|
||||||
|
session = Session(key="test:preserve")
|
||||||
|
session.add_message("user", "msg1")
|
||||||
|
session.add_message("assistant", "resp1")
|
||||||
|
session.add_message("user", "msg2")
|
||||||
|
assert len(session.messages) == 3
|
||||||
|
assert session.messages[0]["content"] == "msg1"
|
||||||
|
|
||||||
|
def test_get_history_returns_most_recent(self) -> None:
|
||||||
|
"""Test get_history returns the most recent messages."""
|
||||||
|
session = Session(key="test:history")
|
||||||
|
for i in range(10):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
|
||||||
|
history = session.get_history(max_messages=6)
|
||||||
|
assert len(history) == 6
|
||||||
|
assert history[0]["content"] == "msg7"
|
||||||
|
assert history[-1]["content"] == "resp9"
|
||||||
|
|
||||||
|
def test_get_history_with_all_messages(self) -> None:
|
||||||
|
"""Test get_history with max_messages larger than actual."""
|
||||||
|
session = create_session_with_messages("test:all", 5)
|
||||||
|
history = session.get_history(max_messages=100)
|
||||||
|
assert len(history) == 5
|
||||||
|
assert history[0]["content"] == "msg0"
|
||||||
|
|
||||||
|
def test_get_history_stable_for_same_session(self) -> None:
|
||||||
|
"""Test that get_history returns same content for same max_messages."""
|
||||||
|
session = create_session_with_messages("test:stable", 20)
|
||||||
|
history1 = session.get_history(max_messages=10)
|
||||||
|
history2 = session.get_history(max_messages=10)
|
||||||
|
assert history1 == history2
|
||||||
|
|
||||||
|
def test_messages_list_never_modified(self) -> None:
|
||||||
|
"""Test that messages list is never modified after creation."""
|
||||||
|
session = create_session_with_messages("test:immutable", 5)
|
||||||
|
original_len = len(session.messages)
|
||||||
|
|
||||||
|
session.get_history(max_messages=2)
|
||||||
|
assert len(session.messages) == original_len
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
session.get_history(max_messages=3)
|
||||||
|
assert len(session.messages) == original_len
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionPersistence:
|
||||||
|
"""Test Session persistence and reload."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_manager(self, tmp_path):
|
||||||
|
return SessionManager(Path(tmp_path))
|
||||||
|
|
||||||
|
def test_persistence_roundtrip(self, temp_manager):
|
||||||
|
"""Test that messages persist across save/load."""
|
||||||
|
session1 = create_session_with_messages("test:persistence", 20)
|
||||||
|
temp_manager.save(session1)
|
||||||
|
|
||||||
|
session2 = temp_manager.get_or_create("test:persistence")
|
||||||
|
assert len(session2.messages) == 20
|
||||||
|
assert session2.messages[0]["content"] == "msg0"
|
||||||
|
assert session2.messages[-1]["content"] == "msg19"
|
||||||
|
|
||||||
|
def test_get_history_after_reload(self, temp_manager):
|
||||||
|
"""Test that get_history works correctly after reload."""
|
||||||
|
session1 = create_session_with_messages("test:reload", 30)
|
||||||
|
temp_manager.save(session1)
|
||||||
|
|
||||||
|
session2 = temp_manager.get_or_create("test:reload")
|
||||||
|
history = session2.get_history(max_messages=10)
|
||||||
|
assert len(history) == 10
|
||||||
|
assert history[0]["content"] == "msg20"
|
||||||
|
assert history[-1]["content"] == "msg29"
|
||||||
|
|
||||||
|
def test_clear_resets_session(self, temp_manager):
|
||||||
|
"""Test that clear() properly resets session."""
|
||||||
|
session = create_session_with_messages("test:clear", 10)
|
||||||
|
assert len(session.messages) == 10
|
||||||
|
|
||||||
|
session.clear()
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsolidationTriggerConditions:
|
||||||
|
"""Test consolidation trigger conditions and logic."""
|
||||||
|
|
||||||
|
def test_consolidation_needed_when_messages_exceed_window(self):
|
||||||
|
"""Test consolidation logic: should trigger when messages > memory_window."""
|
||||||
|
session = create_session_with_messages("test:trigger", 60)
|
||||||
|
|
||||||
|
total_messages = len(session.messages)
|
||||||
|
messages_to_process = total_messages - session.last_consolidated
|
||||||
|
|
||||||
|
assert total_messages > MEMORY_WINDOW
|
||||||
|
assert messages_to_process > 0
|
||||||
|
|
||||||
|
expected_consolidate_count = total_messages - KEEP_COUNT
|
||||||
|
assert expected_consolidate_count == 35
|
||||||
|
|
||||||
|
def test_consolidation_skipped_when_within_keep_count(self):
|
||||||
|
"""Test consolidation skipped when total messages <= keep_count."""
|
||||||
|
session = create_session_with_messages("test:skip", 20)
|
||||||
|
|
||||||
|
total_messages = len(session.messages)
|
||||||
|
assert total_messages <= KEEP_COUNT
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
def test_consolidation_skipped_when_no_new_messages(self):
|
||||||
|
"""Test consolidation skipped when messages_to_process <= 0."""
|
||||||
|
session = create_session_with_messages("test:already_consolidated", 40)
|
||||||
|
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||||
|
|
||||||
|
# Add a few more messages
|
||||||
|
for i in range(40, 42):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
|
||||||
|
total_messages = len(session.messages)
|
||||||
|
messages_to_process = total_messages - session.last_consolidated
|
||||||
|
assert messages_to_process > 0
|
||||||
|
|
||||||
|
# Simulate last_consolidated catching up
|
||||||
|
session.last_consolidated = total_messages - KEEP_COUNT
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestLastConsolidatedEdgeCases:
|
||||||
|
"""Test last_consolidated edge cases and data corruption scenarios."""
|
||||||
|
|
||||||
|
def test_last_consolidated_exceeds_message_count(self):
|
||||||
|
"""Test behavior when last_consolidated > len(messages) (data corruption)."""
|
||||||
|
session = create_session_with_messages("test:corruption", 10)
|
||||||
|
session.last_consolidated = 20
|
||||||
|
|
||||||
|
total_messages = len(session.messages)
|
||||||
|
messages_to_process = total_messages - session.last_consolidated
|
||||||
|
assert messages_to_process <= 0
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, 5)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
def test_last_consolidated_negative_value(self):
|
||||||
|
"""Test behavior with negative last_consolidated (invalid state)."""
|
||||||
|
session = create_session_with_messages("test:negative", 10)
|
||||||
|
session.last_consolidated = -5
|
||||||
|
|
||||||
|
keep_count = 3
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, keep_count)
|
||||||
|
|
||||||
|
# messages[-5:-3] with 10 messages gives indices 5,6
|
||||||
|
assert len(old_messages) == 2
|
||||||
|
assert old_messages[0]["content"] == "msg5"
|
||||||
|
assert old_messages[-1]["content"] == "msg6"
|
||||||
|
|
||||||
|
def test_messages_added_after_consolidation(self):
|
||||||
|
"""Test correct behavior when new messages arrive after consolidation."""
|
||||||
|
session = create_session_with_messages("test:new_messages", 40)
|
||||||
|
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||||
|
|
||||||
|
# Add new messages after consolidation
|
||||||
|
for i in range(40, 50):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
|
||||||
|
total_messages = len(session.messages)
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated
|
||||||
|
|
||||||
|
assert len(old_messages) == expected_consolidate_count
|
||||||
|
assert_messages_content(old_messages, 15, 24)
|
||||||
|
|
||||||
|
def test_slice_behavior_when_indices_overlap(self):
|
||||||
|
"""Test slice behavior when last_consolidated >= total - keep_count."""
|
||||||
|
session = create_session_with_messages("test:overlap", 30)
|
||||||
|
session.last_consolidated = 12
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, 20)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestArchiveAllMode:
|
||||||
|
"""Test archive_all mode (used by /new command)."""
|
||||||
|
|
||||||
|
def test_archive_all_consolidates_everything(self):
|
||||||
|
"""Test archive_all=True consolidates all messages."""
|
||||||
|
session = create_session_with_messages("test:archive_all", 50)
|
||||||
|
|
||||||
|
archive_all = True
|
||||||
|
if archive_all:
|
||||||
|
old_messages = session.messages
|
||||||
|
assert len(old_messages) == 50
|
||||||
|
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
def test_archive_all_resets_last_consolidated(self):
|
||||||
|
"""Test that archive_all mode resets last_consolidated to 0."""
|
||||||
|
session = create_session_with_messages("test:reset", 40)
|
||||||
|
session.last_consolidated = 15
|
||||||
|
|
||||||
|
archive_all = True
|
||||||
|
if archive_all:
|
||||||
|
session.last_consolidated = 0
|
||||||
|
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
assert len(session.messages) == 40
|
||||||
|
|
||||||
|
def test_archive_all_vs_normal_consolidation(self):
|
||||||
|
"""Test difference between archive_all and normal consolidation."""
|
||||||
|
# Normal consolidation
|
||||||
|
session1 = create_session_with_messages("test:normal", 60)
|
||||||
|
session1.last_consolidated = len(session1.messages) - KEEP_COUNT
|
||||||
|
|
||||||
|
# archive_all mode
|
||||||
|
session2 = create_session_with_messages("test:all", 60)
|
||||||
|
session2.last_consolidated = 0
|
||||||
|
|
||||||
|
assert session1.last_consolidated == 35
|
||||||
|
assert len(session1.messages) == 60
|
||||||
|
assert session2.last_consolidated == 0
|
||||||
|
assert len(session2.messages) == 60
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheImmutability:
|
||||||
|
"""Test that consolidation doesn't modify session.messages (cache safety)."""
|
||||||
|
|
||||||
|
def test_consolidation_does_not_modify_messages_list(self):
|
||||||
|
"""Test that consolidation leaves messages list unchanged."""
|
||||||
|
session = create_session_with_messages("test:immutable", 50)
|
||||||
|
|
||||||
|
original_messages = session.messages.copy()
|
||||||
|
original_len = len(session.messages)
|
||||||
|
session.last_consolidated = original_len - KEEP_COUNT
|
||||||
|
|
||||||
|
assert len(session.messages) == original_len
|
||||||
|
assert session.messages == original_messages
|
||||||
|
|
||||||
|
def test_get_history_does_not_modify_messages(self):
|
||||||
|
"""Test that get_history doesn't modify messages list."""
|
||||||
|
session = create_session_with_messages("test:history_immutable", 40)
|
||||||
|
original_messages = [m.copy() for m in session.messages]
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
history = session.get_history(max_messages=10)
|
||||||
|
assert len(history) == 10
|
||||||
|
|
||||||
|
assert len(session.messages) == 40
|
||||||
|
for i, msg in enumerate(session.messages):
|
||||||
|
assert msg["content"] == original_messages[i]["content"]
|
||||||
|
|
||||||
|
def test_consolidation_only_updates_last_consolidated(self):
|
||||||
|
"""Test that consolidation only updates last_consolidated field."""
|
||||||
|
session = create_session_with_messages("test:field_only", 60)
|
||||||
|
|
||||||
|
original_messages = session.messages.copy()
|
||||||
|
original_key = session.key
|
||||||
|
original_metadata = session.metadata.copy()
|
||||||
|
|
||||||
|
session.last_consolidated = len(session.messages) - KEEP_COUNT
|
||||||
|
|
||||||
|
assert session.messages == original_messages
|
||||||
|
assert session.key == original_key
|
||||||
|
assert session.metadata == original_metadata
|
||||||
|
assert session.last_consolidated == 35
|
||||||
|
|
||||||
|
|
||||||
|
class TestSliceLogic:
|
||||||
|
"""Test the slice logic: messages[last_consolidated:-keep_count]."""
|
||||||
|
|
||||||
|
def test_slice_extracts_correct_range(self):
|
||||||
|
"""Test that slice extracts the correct message range."""
|
||||||
|
session = create_session_with_messages("test:slice", 60)
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, 0, KEEP_COUNT)
|
||||||
|
|
||||||
|
assert len(old_messages) == 35
|
||||||
|
assert_messages_content(old_messages, 0, 34)
|
||||||
|
|
||||||
|
remaining = session.messages[-KEEP_COUNT:]
|
||||||
|
assert len(remaining) == 25
|
||||||
|
assert_messages_content(remaining, 35, 59)
|
||||||
|
|
||||||
|
def test_slice_with_partial_consolidation(self):
|
||||||
|
"""Test slice when some messages already consolidated."""
|
||||||
|
session = create_session_with_messages("test:partial", 70)
|
||||||
|
|
||||||
|
last_consolidated = 30
|
||||||
|
old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT)
|
||||||
|
|
||||||
|
assert len(old_messages) == 15
|
||||||
|
assert_messages_content(old_messages, 30, 44)
|
||||||
|
|
||||||
|
def test_slice_with_various_keep_counts(self):
|
||||||
|
"""Test slice behavior with different keep_count values."""
|
||||||
|
session = create_session_with_messages("test:keep_counts", 50)
|
||||||
|
|
||||||
|
test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)]
|
||||||
|
|
||||||
|
for keep_count, expected_count in test_cases:
|
||||||
|
old_messages = session.messages[0:-keep_count]
|
||||||
|
assert len(old_messages) == expected_count
|
||||||
|
|
||||||
|
def test_slice_when_keep_count_exceeds_messages(self):
|
||||||
|
"""Test slice when keep_count > len(messages)."""
|
||||||
|
session = create_session_with_messages("test:exceed", 10)
|
||||||
|
|
||||||
|
old_messages = session.messages[0:-20]
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmptyAndBoundarySessions:
|
||||||
|
"""Test empty sessions and boundary conditions."""
|
||||||
|
|
||||||
|
def test_empty_session_consolidation(self):
|
||||||
|
"""Test consolidation behavior with empty session."""
|
||||||
|
session = Session(key="test:empty")
|
||||||
|
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
messages_to_process = len(session.messages) - session.last_consolidated
|
||||||
|
assert messages_to_process == 0
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
def test_single_message_session(self):
|
||||||
|
"""Test consolidation with single message."""
|
||||||
|
session = Session(key="test:single")
|
||||||
|
session.add_message("user", "only message")
|
||||||
|
|
||||||
|
assert len(session.messages) == 1
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
def test_exactly_keep_count_messages(self):
|
||||||
|
"""Test session with exactly keep_count messages."""
|
||||||
|
session = create_session_with_messages("test:exact", KEEP_COUNT)
|
||||||
|
|
||||||
|
assert len(session.messages) == KEEP_COUNT
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 0
|
||||||
|
|
||||||
|
def test_just_over_keep_count(self):
|
||||||
|
"""Test session with one message over keep_count."""
|
||||||
|
session = create_session_with_messages("test:over", KEEP_COUNT + 1)
|
||||||
|
|
||||||
|
assert len(session.messages) == 26
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 1
|
||||||
|
assert old_messages[0]["content"] == "msg0"
|
||||||
|
|
||||||
|
def test_very_large_session(self):
|
||||||
|
"""Test consolidation with very large message count."""
|
||||||
|
session = create_session_with_messages("test:large", 1000)
|
||||||
|
|
||||||
|
assert len(session.messages) == 1000
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
assert len(old_messages) == 975
|
||||||
|
assert_messages_content(old_messages, 0, 974)
|
||||||
|
|
||||||
|
remaining = session.messages[-KEEP_COUNT:]
|
||||||
|
assert len(remaining) == 25
|
||||||
|
assert_messages_content(remaining, 975, 999)
|
||||||
|
|
||||||
|
def test_session_with_gaps_in_consolidation(self):
|
||||||
|
"""Test session with potential gaps in consolidation history."""
|
||||||
|
session = create_session_with_messages("test:gaps", 50)
|
||||||
|
session.last_consolidated = 10
|
||||||
|
|
||||||
|
# Add more messages
|
||||||
|
for i in range(50, 60):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
|
||||||
|
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||||
|
|
||||||
|
expected_count = 60 - KEEP_COUNT - 10
|
||||||
|
assert len(old_messages) == expected_count
|
||||||
|
assert_messages_content(old_messages, 10, 34)
|
||||||
@ -20,8 +20,8 @@ You have access to:
|
|||||||
|
|
||||||
## Memory
|
## Memory
|
||||||
|
|
||||||
- Use `memory/` directory for daily notes
|
- `memory/MEMORY.md` — long-term facts (preferences, context, relationships)
|
||||||
- Use `MEMORY.md` for long-term information
|
- `memory/HISTORY.md` — append-only event log, search with grep to recall past events
|
||||||
|
|
||||||
## Scheduled Reminders
|
## Scheduled Reminders
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user