fix: history messages should not be change[kvcache]
This commit is contained in:
parent
32c9431191
commit
740294fd74
1
.gitignore
vendored
1
.gitignore
vendored
@ -17,5 +17,4 @@ docs/
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
poetry.lock
|
poetry.lock
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
tests/
|
|
||||||
botpy.log
|
botpy.log
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ from nanobot.session.manager import SessionManager
|
|||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""
|
"""
|
||||||
The agent loop is the core processing engine.
|
The agent loop is the core processing engine.
|
||||||
|
|
||||||
It:
|
It:
|
||||||
1. Receives messages from the bus
|
1. Receives messages from the bus
|
||||||
2. Builds context with history, memory, skills
|
2. Builds context with history, memory, skills
|
||||||
@ -34,7 +35,7 @@ class AgentLoop:
|
|||||||
4. Executes tool calls
|
4. Executes tool calls
|
||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
@ -61,8 +62,10 @@ class AgentLoop:
|
|||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.cron_service = cron_service
|
self.cron_service = cron_service
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
|
||||||
self.context = ContextBuilder(workspace)
|
self.context = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
# Initialize session manager
|
||||||
self.sessions = session_manager or SessionManager(workspace)
|
self.sessions = session_manager or SessionManager(workspace)
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
self.subagents = SubagentManager(
|
self.subagents = SubagentManager(
|
||||||
@ -110,11 +113,81 @@ class AgentLoop:
|
|||||||
if self.cron_service:
|
if self.cron_service:
|
||||||
self.tools.register(CronTool(self.cron_service))
|
self.tools.register(CronTool(self.cron_service))
|
||||||
|
|
||||||
|
def _set_tool_context(self, channel: str, chat_id: str) -> None:
|
||||||
|
"""Update context for all tools that need routing info."""
|
||||||
|
if message_tool := self.tools.get("message"):
|
||||||
|
if isinstance(message_tool, MessageTool):
|
||||||
|
message_tool.set_context(channel, chat_id)
|
||||||
|
|
||||||
|
if spawn_tool := self.tools.get("spawn"):
|
||||||
|
if isinstance(spawn_tool, SpawnTool):
|
||||||
|
spawn_tool.set_context(channel, chat_id)
|
||||||
|
|
||||||
|
if cron_tool := self.tools.get("cron"):
|
||||||
|
if isinstance(cron_tool, CronTool):
|
||||||
|
cron_tool.set_context(channel, chat_id)
|
||||||
|
|
||||||
|
async def _run_agent_loop(self, initial_messages: list[dict]) -> tuple[str | None, list[str]]:
|
||||||
|
"""
|
||||||
|
Run the agent iteration loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_messages: Starting messages for the LLM conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (final_content, list_of_tools_used).
|
||||||
|
"""
|
||||||
|
messages = initial_messages
|
||||||
|
iteration = 0
|
||||||
|
final_content = None
|
||||||
|
tools_used: list[str] = []
|
||||||
|
|
||||||
|
while iteration < self.max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
response = await self.provider.chat(
|
||||||
|
messages=messages,
|
||||||
|
tools=self.tools.get_definitions(),
|
||||||
|
model=self.model
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.has_tool_calls:
|
||||||
|
tool_call_dicts = [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for tc in response.tool_calls
|
||||||
|
]
|
||||||
|
messages = self.context.add_assistant_message(
|
||||||
|
messages, response.content, tool_call_dicts,
|
||||||
|
reasoning_content=response.reasoning_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_call in response.tool_calls:
|
||||||
|
tools_used.append(tool_call.name)
|
||||||
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||||
|
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||||
|
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||||
|
messages = self.context.add_tool_result(
|
||||||
|
messages, tool_call.id, tool_call.name, result
|
||||||
|
)
|
||||||
|
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||||
|
else:
|
||||||
|
final_content = response.content
|
||||||
|
break
|
||||||
|
|
||||||
|
return final_content, tools_used
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, processing messages from the bus."""
|
"""Run the agent loop, processing messages from the bus."""
|
||||||
self._running = True
|
self._running = True
|
||||||
logger.info("Agent loop started")
|
logger.info("Agent loop started")
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
# Wait for next message
|
# Wait for next message
|
||||||
@ -173,8 +246,10 @@ class AgentLoop:
|
|||||||
await self._consolidate_memory(session, archive_all=True)
|
await self._consolidate_memory(session, archive_all=True)
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
# Clear cache to force reload from disk on next request
|
||||||
|
self.sessions._cache.pop(session.key, None)
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 New session started. Memory consolidated.")
|
content="New session started. Memory consolidated.")
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||||
@ -182,79 +257,22 @@ class AgentLoop:
|
|||||||
# Consolidate memory before processing if session is too large
|
# Consolidate memory before processing if session is too large
|
||||||
if len(session.messages) > self.memory_window:
|
if len(session.messages) > self.memory_window:
|
||||||
await self._consolidate_memory(session)
|
await self._consolidate_memory(session)
|
||||||
|
|
||||||
# Update tool contexts
|
# Update tool contexts
|
||||||
message_tool = self.tools.get("message")
|
self._set_tool_context(msg.channel, msg.chat_id)
|
||||||
if isinstance(message_tool, MessageTool):
|
|
||||||
message_tool.set_context(msg.channel, msg.chat_id)
|
# Build initial messages
|
||||||
|
initial_messages = self.context.build_messages(
|
||||||
spawn_tool = self.tools.get("spawn")
|
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
|
||||||
spawn_tool.set_context(msg.channel, msg.chat_id)
|
|
||||||
|
|
||||||
cron_tool = self.tools.get("cron")
|
|
||||||
if isinstance(cron_tool, CronTool):
|
|
||||||
cron_tool.set_context(msg.channel, msg.chat_id)
|
|
||||||
|
|
||||||
# Build initial messages (use get_history for LLM-formatted messages)
|
|
||||||
messages = self.context.build_messages(
|
|
||||||
history=session.get_history(),
|
history=session.get_history(),
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Agent loop
|
# Run agent loop
|
||||||
iteration = 0
|
final_content, tools_used = await self._run_agent_loop(initial_messages)
|
||||||
final_content = None
|
|
||||||
tools_used: list[str] = []
|
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
|
||||||
iteration += 1
|
|
||||||
|
|
||||||
# Call LLM
|
|
||||||
response = await self.provider.chat(
|
|
||||||
messages=messages,
|
|
||||||
tools=self.tools.get_definitions(),
|
|
||||||
model=self.model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle tool calls
|
|
||||||
if response.has_tool_calls:
|
|
||||||
# Add assistant message with tool calls
|
|
||||||
tool_call_dicts = [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.name,
|
|
||||||
"arguments": json.dumps(tc.arguments) # Must be JSON string
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tc in response.tool_calls
|
|
||||||
]
|
|
||||||
messages = self.context.add_assistant_message(
|
|
||||||
messages, response.content, tool_call_dicts,
|
|
||||||
reasoning_content=response.reasoning_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute tools
|
|
||||||
for tool_call in response.tool_calls:
|
|
||||||
tools_used.append(tool_call.name)
|
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
|
||||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
|
||||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
|
||||||
messages = self.context.add_tool_result(
|
|
||||||
messages, tool_call.id, tool_call.name, result
|
|
||||||
)
|
|
||||||
# Interleaved CoT: reflect before next action
|
|
||||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
|
||||||
else:
|
|
||||||
# No tool calls, we're done
|
|
||||||
final_content = response.content
|
|
||||||
break
|
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
@ -297,71 +315,21 @@ class AgentLoop:
|
|||||||
# Use the origin session for context
|
# Use the origin session for context
|
||||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
session_key = f"{origin_channel}:{origin_chat_id}"
|
||||||
session = self.sessions.get_or_create(session_key)
|
session = self.sessions.get_or_create(session_key)
|
||||||
|
|
||||||
# Update tool contexts
|
# Update tool contexts
|
||||||
message_tool = self.tools.get("message")
|
self._set_tool_context(origin_channel, origin_chat_id)
|
||||||
if isinstance(message_tool, MessageTool):
|
|
||||||
message_tool.set_context(origin_channel, origin_chat_id)
|
|
||||||
|
|
||||||
spawn_tool = self.tools.get("spawn")
|
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
|
||||||
spawn_tool.set_context(origin_channel, origin_chat_id)
|
|
||||||
|
|
||||||
cron_tool = self.tools.get("cron")
|
|
||||||
if isinstance(cron_tool, CronTool):
|
|
||||||
cron_tool.set_context(origin_channel, origin_chat_id)
|
|
||||||
|
|
||||||
# Build messages with the announce content
|
# Build messages with the announce content
|
||||||
messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=session.get_history(),
|
history=session.get_history(),
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
channel=origin_channel,
|
channel=origin_channel,
|
||||||
chat_id=origin_chat_id,
|
chat_id=origin_chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Agent loop (limited for announce handling)
|
# Run agent loop
|
||||||
iteration = 0
|
final_content, _ = await self._run_agent_loop(initial_messages)
|
||||||
final_content = None
|
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
|
||||||
iteration += 1
|
|
||||||
|
|
||||||
response = await self.provider.chat(
|
|
||||||
messages=messages,
|
|
||||||
tools=self.tools.get_definitions(),
|
|
||||||
model=self.model
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.has_tool_calls:
|
|
||||||
tool_call_dicts = [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.name,
|
|
||||||
"arguments": json.dumps(tc.arguments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tc in response.tool_calls
|
|
||||||
]
|
|
||||||
messages = self.context.add_assistant_message(
|
|
||||||
messages, response.content, tool_call_dicts,
|
|
||||||
reasoning_content=response.reasoning_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in response.tool_calls:
|
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
|
||||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
|
||||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
|
||||||
messages = self.context.add_tool_result(
|
|
||||||
messages, tool_call.id, tool_call.name, result
|
|
||||||
)
|
|
||||||
# Interleaved CoT: reflect before next action
|
|
||||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
|
||||||
else:
|
|
||||||
final_content = response.content
|
|
||||||
break
|
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "Background task completed."
|
final_content = "Background task completed."
|
||||||
|
|
||||||
@ -377,19 +345,39 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md, then trim session."""
|
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
||||||
if not session.messages:
|
|
||||||
return
|
Args:
|
||||||
|
archive_all: If True, clear all messages and reset session (for /new command).
|
||||||
|
If False, only write to files without modifying session.
|
||||||
|
"""
|
||||||
memory = MemoryStore(self.workspace)
|
memory = MemoryStore(self.workspace)
|
||||||
|
|
||||||
|
# Handle /new command: clear session and consolidate everything
|
||||||
if archive_all:
|
if archive_all:
|
||||||
old_messages = session.messages
|
old_messages = session.messages # All messages
|
||||||
keep_count = 0
|
keep_count = 0 # Clear everything
|
||||||
|
logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived")
|
||||||
else:
|
else:
|
||||||
keep_count = min(10, max(2, self.memory_window // 2))
|
# Normal consolidation: only write files, keep session intact
|
||||||
old_messages = session.messages[:-keep_count]
|
keep_count = self.memory_window // 2
|
||||||
if not old_messages:
|
|
||||||
return
|
# Check if consolidation is needed
|
||||||
logger.info(f"Memory consolidation started: {len(session.messages)} messages, archiving {len(old_messages)}, keeping {keep_count}")
|
if len(session.messages) <= keep_count:
|
||||||
|
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use last_consolidated to avoid re-processing messages
|
||||||
|
messages_to_process = len(session.messages) - session.last_consolidated
|
||||||
|
if messages_to_process <= 0:
|
||||||
|
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get messages to consolidate (from last_consolidated to keep_count from end)
|
||||||
|
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||||
|
if not old_messages:
|
||||||
|
return
|
||||||
|
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")
|
||||||
|
|
||||||
# Format messages for LLM (include tool names when available)
|
# Format messages for LLM (include tool names when available)
|
||||||
lines = []
|
lines = []
|
||||||
@ -434,9 +422,18 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
if update != current_memory:
|
if update != current_memory:
|
||||||
memory.write_long_term(update)
|
memory.write_long_term(update)
|
||||||
|
|
||||||
session.messages = session.messages[-keep_count:] if keep_count else []
|
# Update last_consolidated to track what's been processed
|
||||||
self.sessions.save(session)
|
if archive_all:
|
||||||
logger.info(f"Memory consolidation done, session trimmed to {len(session.messages)} messages")
|
# /new command: reset to 0 after clearing
|
||||||
|
session.last_consolidated = 0
|
||||||
|
else:
|
||||||
|
# Normal: mark up to (total - keep_count) as consolidated
|
||||||
|
session.last_consolidated = len(session.messages) - keep_count
|
||||||
|
|
||||||
|
# Key: We do NOT modify session.messages (append-only for cache)
|
||||||
|
# The consolidation is only for human-readable files (MEMORY.md/HISTORY.md)
|
||||||
|
# LLM cache remains intact because the messages list is unchanged
|
||||||
|
logger.info(f"Memory consolidation done: {len(session.messages)} total messages (unchanged), last_consolidated={session.last_consolidated}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Memory consolidation failed: {e}")
|
logger.error(f"Memory consolidation failed: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -15,15 +15,20 @@ from nanobot.utils.helpers import ensure_dir, safe_filename
|
|||||||
class Session:
|
class Session:
|
||||||
"""
|
"""
|
||||||
A conversation session.
|
A conversation session.
|
||||||
|
|
||||||
Stores messages in JSONL format for easy reading and persistence.
|
Stores messages in JSONL format for easy reading and persistence.
|
||||||
|
|
||||||
|
Important: Messages are append-only for LLM cache efficiency.
|
||||||
|
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||||
|
but does NOT modify the messages list or get_history() output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key: str # channel:chat_id
|
key: str # channel:chat_id
|
||||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
updated_at: datetime = field(default_factory=datetime.now)
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||||
|
|
||||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||||
"""Add a message to the session."""
|
"""Add a message to the session."""
|
||||||
@ -39,32 +44,36 @@ class Session:
|
|||||||
def get_history(self, max_messages: int = 50) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 50) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get message history for LLM context.
|
Get message history for LLM context.
|
||||||
|
|
||||||
|
Messages are returned in append-only order for cache efficiency.
|
||||||
|
Only the most recent max_messages are returned, but the order
|
||||||
|
is always stable for the same max_messages value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_messages: Maximum messages to return.
|
max_messages: Maximum messages to return (most recent).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of messages in LLM format.
|
List of messages in LLM format (role and content only).
|
||||||
"""
|
"""
|
||||||
# Get recent messages
|
recent = self.messages[-max_messages:]
|
||||||
recent = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
|
|
||||||
|
|
||||||
# Convert to LLM format (just role and content)
|
# Convert to LLM format (just role and content)
|
||||||
return [{"role": m["role"], "content": m["content"]} for m in recent]
|
return [{"role": m["role"], "content": m["content"]} for m in recent]
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all messages in the session."""
|
"""Clear all messages and reset session to initial state."""
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
self.last_consolidated = 0
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""
|
"""
|
||||||
Manages conversation sessions.
|
Manages conversation sessions.
|
||||||
|
|
||||||
Sessions are stored as JSONL files in the sessions directory.
|
Sessions are stored as JSONL files in the sessions directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions")
|
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions")
|
||||||
@ -100,34 +109,37 @@ class SessionManager:
|
|||||||
def _load(self, key: str) -> Session | None:
|
def _load(self, key: str) -> Session | None:
|
||||||
"""Load a session from disk."""
|
"""Load a session from disk."""
|
||||||
path = self._get_session_path(key)
|
path = self._get_session_path(key)
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = []
|
messages = []
|
||||||
metadata = {}
|
metadata = {}
|
||||||
created_at = None
|
created_at = None
|
||||||
|
last_consolidated = 0
|
||||||
|
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
|
|
||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||||
|
last_consolidated = data.get("last_consolidated", 0)
|
||||||
else:
|
else:
|
||||||
messages.append(data)
|
messages.append(data)
|
||||||
|
|
||||||
return Session(
|
return Session(
|
||||||
key=key,
|
key=key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
created_at=created_at or datetime.now(),
|
created_at=created_at or datetime.now(),
|
||||||
metadata=metadata
|
metadata=metadata,
|
||||||
|
last_consolidated=last_consolidated
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load session {key}: {e}")
|
logger.warning(f"Failed to load session {key}: {e}")
|
||||||
@ -136,43 +148,24 @@ class SessionManager:
|
|||||||
def save(self, session: Session) -> None:
|
def save(self, session: Session) -> None:
|
||||||
"""Save a session to disk."""
|
"""Save a session to disk."""
|
||||||
path = self._get_session_path(session.key)
|
path = self._get_session_path(session.key)
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
# Write metadata first
|
# Write metadata first
|
||||||
metadata_line = {
|
metadata_line = {
|
||||||
"_type": "metadata",
|
"_type": "metadata",
|
||||||
"created_at": session.created_at.isoformat(),
|
"created_at": session.created_at.isoformat(),
|
||||||
"updated_at": session.updated_at.isoformat(),
|
"updated_at": session.updated_at.isoformat(),
|
||||||
"metadata": session.metadata
|
"metadata": session.metadata,
|
||||||
|
"last_consolidated": session.last_consolidated
|
||||||
}
|
}
|
||||||
f.write(json.dumps(metadata_line) + "\n")
|
f.write(json.dumps(metadata_line) + "\n")
|
||||||
|
|
||||||
# Write messages
|
# Write messages
|
||||||
for msg in session.messages:
|
for msg in session.messages:
|
||||||
f.write(json.dumps(msg) + "\n")
|
f.write(json.dumps(msg) + "\n")
|
||||||
|
|
||||||
self._cache[session.key] = session
|
self._cache[session.key] = session
|
||||||
|
|
||||||
def delete(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Session key.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found.
|
|
||||||
"""
|
|
||||||
# Remove from cache
|
|
||||||
self._cache.pop(key, None)
|
|
||||||
|
|
||||||
# Remove file
|
|
||||||
path = self._get_session_path(key)
|
|
||||||
if path.exists():
|
|
||||||
path.unlink()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_sessions(self) -> list[dict[str, Any]]:
|
def list_sessions(self) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
List all sessions.
|
List all sessions.
|
||||||
|
|||||||
@ -12,7 +12,8 @@ def mock_prompt_session():
|
|||||||
"""Mock the global prompt session."""
|
"""Mock the global prompt session."""
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session.prompt_async = AsyncMock()
|
mock_session.prompt_async = AsyncMock()
|
||||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session):
|
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||||
|
patch("nanobot.cli.commands.patch_stdout"):
|
||||||
yield mock_session
|
yield mock_session
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
159
tests/test_consolidate_offset.py
Normal file
159
tests/test_consolidate_offset.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
"""Test session management with cache-friendly message handling."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionLastConsolidated:
|
||||||
|
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||||
|
|
||||||
|
def test_initial_last_consolidated_zero(self) -> None:
|
||||||
|
"""Test that new session starts with last_consolidated=0."""
|
||||||
|
session = Session(key="test:initial")
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||||
|
"""Test that last_consolidated persists across save/load."""
|
||||||
|
manager = SessionManager(Path(tmp_path))
|
||||||
|
|
||||||
|
session1 = Session(key="test:persist")
|
||||||
|
for i in range(20):
|
||||||
|
session1.add_message("user", f"msg{i}")
|
||||||
|
session1.last_consolidated = 15 # Simulate consolidation
|
||||||
|
manager.save(session1)
|
||||||
|
|
||||||
|
session2 = manager.get_or_create("test:persist")
|
||||||
|
assert session2.last_consolidated == 15
|
||||||
|
assert len(session2.messages) == 20
|
||||||
|
|
||||||
|
def test_clear_resets_last_consolidated(self) -> None:
|
||||||
|
"""Test that clear() resets last_consolidated to 0."""
|
||||||
|
session = Session(key="test:clear")
|
||||||
|
for i in range(10):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.last_consolidated = 5
|
||||||
|
|
||||||
|
session.clear()
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionImmutableHistory:
|
||||||
|
"""Test Session message immutability for cache efficiency."""
|
||||||
|
|
||||||
|
def test_initial_state(self) -> None:
|
||||||
|
"""Test that new session has empty messages list."""
|
||||||
|
session = Session(key="test:initial")
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
|
||||||
|
def test_add_messages_appends_only(self) -> None:
|
||||||
|
"""Test that adding messages only appends, never modifies."""
|
||||||
|
session = Session(key="test:preserve")
|
||||||
|
session.add_message("user", "msg1")
|
||||||
|
session.add_message("assistant", "resp1")
|
||||||
|
session.add_message("user", "msg2")
|
||||||
|
assert len(session.messages) == 3
|
||||||
|
# First message should always be the first message added
|
||||||
|
assert session.messages[0]["content"] == "msg1"
|
||||||
|
|
||||||
|
def test_get_history_returns_most_recent(self) -> None:
|
||||||
|
"""Test get_history returns the most recent messages."""
|
||||||
|
session = Session(key="test:history")
|
||||||
|
for i in range(10):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
history = session.get_history(max_messages=6)
|
||||||
|
# Should return last 6 messages
|
||||||
|
assert len(history) == 6
|
||||||
|
# First returned should be resp4 (messages 7-12: msg7/resp7, msg8/resp8, msg9/resp9)
|
||||||
|
# Actually: 20 messages total, last 6 are indices 14-19
|
||||||
|
assert history[0]["content"] == "msg7" # Index 14 (7th user msg after 7 pairs)
|
||||||
|
assert history[-1]["content"] == "resp9" # Index 19 (last assistant msg)
|
||||||
|
|
||||||
|
def test_get_history_with_all_messages(self) -> None:
|
||||||
|
"""Test get_history with max_messages larger than actual."""
|
||||||
|
session = Session(key="test:all")
|
||||||
|
for i in range(5):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
history = session.get_history(max_messages=100)
|
||||||
|
assert len(history) == 5
|
||||||
|
assert history[0]["content"] == "msg0"
|
||||||
|
|
||||||
|
def test_get_history_stable_for_same_session(self) -> None:
|
||||||
|
"""Test that get_history returns same content for same max_messages."""
|
||||||
|
session = Session(key="test:stable")
|
||||||
|
for i in range(20):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
|
||||||
|
# Multiple calls with same max_messages should return identical content
|
||||||
|
history1 = session.get_history(max_messages=10)
|
||||||
|
history2 = session.get_history(max_messages=10)
|
||||||
|
assert history1 == history2
|
||||||
|
|
||||||
|
def test_messages_list_never_modified(self) -> None:
|
||||||
|
"""Test that messages list is never modified after creation."""
|
||||||
|
session = Session(key="test:immutable")
|
||||||
|
original_len = 0
|
||||||
|
|
||||||
|
# Add some messages
|
||||||
|
for i in range(5):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
original_len += 1
|
||||||
|
|
||||||
|
assert len(session.messages) == original_len
|
||||||
|
|
||||||
|
# get_history should not modify the list
|
||||||
|
session.get_history(max_messages=2)
|
||||||
|
assert len(session.messages) == original_len
|
||||||
|
|
||||||
|
# Multiple calls should not affect messages
|
||||||
|
for _ in range(10):
|
||||||
|
session.get_history(max_messages=3)
|
||||||
|
assert len(session.messages) == original_len
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionPersistence:
|
||||||
|
"""Test Session persistence and reload."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_manager(self, tmp_path):
|
||||||
|
return SessionManager(Path(tmp_path))
|
||||||
|
|
||||||
|
def test_persistence_roundtrip(self, temp_manager):
|
||||||
|
"""Test that messages persist across save/load."""
|
||||||
|
session1 = Session(key="test:persistence")
|
||||||
|
for i in range(20):
|
||||||
|
session1.add_message("user", f"msg{i}")
|
||||||
|
temp_manager.save(session1)
|
||||||
|
|
||||||
|
session2 = temp_manager.get_or_create("test:persistence")
|
||||||
|
assert len(session2.messages) == 20
|
||||||
|
assert session2.messages[0]["content"] == "msg0"
|
||||||
|
assert session2.messages[-1]["content"] == "msg19"
|
||||||
|
|
||||||
|
def test_get_history_after_reload(self, temp_manager):
|
||||||
|
"""Test that get_history works correctly after reload."""
|
||||||
|
session1 = Session(key="test:reload")
|
||||||
|
for i in range(30):
|
||||||
|
session1.add_message("user", f"msg{i}")
|
||||||
|
temp_manager.save(session1)
|
||||||
|
|
||||||
|
session2 = temp_manager.get_or_create("test:reload")
|
||||||
|
history = session2.get_history(max_messages=10)
|
||||||
|
# Should return last 10 messages (indices 20-29)
|
||||||
|
assert len(history) == 10
|
||||||
|
assert history[0]["content"] == "msg20"
|
||||||
|
assert history[-1]["content"] == "msg29"
|
||||||
|
|
||||||
|
def test_clear_resets_session(self, temp_manager):
|
||||||
|
"""Test that clear() properly resets session."""
|
||||||
|
session = Session(key="test:clear")
|
||||||
|
for i in range(10):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
assert len(session.messages) == 10
|
||||||
|
|
||||||
|
session.clear()
|
||||||
|
assert len(session.messages) == 0
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user