fix: history messages should not be change[kvcache]
This commit is contained in:
parent
32c9431191
commit
740294fd74
1
.gitignore
vendored
1
.gitignore
vendored
@ -17,5 +17,4 @@ docs/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
tests/
|
||||
botpy.log
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -26,7 +27,7 @@ from nanobot.session.manager import SessionManager
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
|
||||
|
||||
It:
|
||||
1. Receives messages from the bus
|
||||
2. Builds context with history, memory, skills
|
||||
@ -34,7 +35,7 @@ class AgentLoop:
|
||||
4. Executes tool calls
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
@ -61,8 +62,10 @@ class AgentLoop:
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.cron_service = cron_service
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
|
||||
|
||||
self.context = ContextBuilder(workspace)
|
||||
|
||||
# Initialize session manager
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
self.subagents = SubagentManager(
|
||||
@ -110,11 +113,81 @@ class AgentLoop:
|
||||
if self.cron_service:
|
||||
self.tools.register(CronTool(self.cron_service))
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(channel, chat_id)
|
||||
|
||||
if spawn_tool := self.tools.get("spawn"):
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(channel, chat_id)
|
||||
|
||||
if cron_tool := self.tools.get("cron"):
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(channel, chat_id)
|
||||
|
||||
async def _run_agent_loop(self, initial_messages: list[dict]) -> tuple[str | None, list[str]]:
|
||||
"""
|
||||
Run the agent iteration loop.
|
||||
|
||||
Args:
|
||||
initial_messages: Starting messages for the LLM conversation.
|
||||
|
||||
Returns:
|
||||
Tuple of (final_content, list_of_tools_used).
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||
else:
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
return final_content, tools_used
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, processing messages from the bus."""
|
||||
self._running = True
|
||||
logger.info("Agent loop started")
|
||||
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for next message
|
||||
@ -173,8 +246,10 @@ class AgentLoop:
|
||||
await self._consolidate_memory(session, archive_all=True)
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
# Clear cache to force reload from disk on next request
|
||||
self.sessions._cache.pop(session.key, None)
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 New session started. Memory consolidated.")
|
||||
content="New session started. Memory consolidated.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||
@ -182,79 +257,22 @@ class AgentLoop:
|
||||
# Consolidate memory before processing if session is too large
|
||||
if len(session.messages) > self.memory_window:
|
||||
await self._consolidate_memory(session)
|
||||
|
||||
|
||||
# Update tool contexts
|
||||
message_tool = self.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
spawn_tool = self.tools.get("spawn")
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
cron_tool = self.tools.get("cron")
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
# Build initial messages (use get_history for LLM-formatted messages)
|
||||
messages = self.context.build_messages(
|
||||
self._set_tool_context(msg.channel, msg.chat_id)
|
||||
|
||||
# Build initial messages
|
||||
initial_messages = self.context.build_messages(
|
||||
history=session.get_history(),
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
# Agent loop
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model
|
||||
)
|
||||
|
||||
# Handle tool calls
|
||||
if response.has_tool_calls:
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments) # Must be JSON string
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
# Interleaved CoT: reflect before next action
|
||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||
else:
|
||||
# No tool calls, we're done
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
|
||||
# Run agent loop
|
||||
final_content, tools_used = await self._run_agent_loop(initial_messages)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
@ -297,71 +315,21 @@ class AgentLoop:
|
||||
# Use the origin session for context
|
||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
||||
session = self.sessions.get_or_create(session_key)
|
||||
|
||||
|
||||
# Update tool contexts
|
||||
message_tool = self.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
spawn_tool = self.tools.get("spawn")
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
cron_tool = self.tools.get("cron")
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
self._set_tool_context(origin_channel, origin_chat_id)
|
||||
|
||||
# Build messages with the announce content
|
||||
messages = self.context.build_messages(
|
||||
initial_messages = self.context.build_messages(
|
||||
history=session.get_history(),
|
||||
current_message=msg.content,
|
||||
channel=origin_channel,
|
||||
chat_id=origin_chat_id,
|
||||
)
|
||||
|
||||
# Agent loop (limited for announce handling)
|
||||
iteration = 0
|
||||
final_content = None
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
# Interleaved CoT: reflect before next action
|
||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||
else:
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
|
||||
# Run agent loop
|
||||
final_content, _ = await self._run_agent_loop(initial_messages)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "Background task completed."
|
||||
|
||||
@ -377,19 +345,39 @@ class AgentLoop:
|
||||
)
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md, then trim session."""
|
||||
if not session.messages:
|
||||
return
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
||||
|
||||
Args:
|
||||
archive_all: If True, clear all messages and reset session (for /new command).
|
||||
If False, only write to files without modifying session.
|
||||
"""
|
||||
memory = MemoryStore(self.workspace)
|
||||
|
||||
# Handle /new command: clear session and consolidate everything
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
keep_count = 0
|
||||
old_messages = session.messages # All messages
|
||||
keep_count = 0 # Clear everything
|
||||
logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived")
|
||||
else:
|
||||
keep_count = min(10, max(2, self.memory_window // 2))
|
||||
old_messages = session.messages[:-keep_count]
|
||||
if not old_messages:
|
||||
return
|
||||
logger.info(f"Memory consolidation started: {len(session.messages)} messages, archiving {len(old_messages)}, keeping {keep_count}")
|
||||
# Normal consolidation: only write files, keep session intact
|
||||
keep_count = self.memory_window // 2
|
||||
|
||||
# Check if consolidation is needed
|
||||
if len(session.messages) <= keep_count:
|
||||
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
|
||||
return
|
||||
|
||||
# Use last_consolidated to avoid re-processing messages
|
||||
messages_to_process = len(session.messages) - session.last_consolidated
|
||||
if messages_to_process <= 0:
|
||||
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
|
||||
return
|
||||
|
||||
# Get messages to consolidate (from last_consolidated to keep_count from end)
|
||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||
if not old_messages:
|
||||
return
|
||||
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")
|
||||
|
||||
# Format messages for LLM (include tool names when available)
|
||||
lines = []
|
||||
@ -434,9 +422,18 @@ Respond with ONLY valid JSON, no markdown fences."""
|
||||
if update != current_memory:
|
||||
memory.write_long_term(update)
|
||||
|
||||
session.messages = session.messages[-keep_count:] if keep_count else []
|
||||
self.sessions.save(session)
|
||||
logger.info(f"Memory consolidation done, session trimmed to {len(session.messages)} messages")
|
||||
# Update last_consolidated to track what's been processed
|
||||
if archive_all:
|
||||
# /new command: reset to 0 after clearing
|
||||
session.last_consolidated = 0
|
||||
else:
|
||||
# Normal: mark up to (total - keep_count) as consolidated
|
||||
session.last_consolidated = len(session.messages) - keep_count
|
||||
|
||||
# Key: We do NOT modify session.messages (append-only for cache)
|
||||
# The consolidation is only for human-readable files (MEMORY.md/HISTORY.md)
|
||||
# LLM cache remains intact because the messages list is unchanged
|
||||
logger.info(f"Memory consolidation done: {len(session.messages)} total messages (unchanged), last_consolidated={session.last_consolidated}")
|
||||
except Exception as e:
|
||||
logger.error(f"Memory consolidation failed: {e}")
|
||||
|
||||
|
||||
@ -15,15 +15,20 @@ from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
class Session:
|
||||
"""
|
||||
A conversation session.
|
||||
|
||||
|
||||
Stores messages in JSONL format for easy reading and persistence.
|
||||
|
||||
Important: Messages are append-only for LLM cache efficiency.
|
||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||
but does NOT modify the messages list or get_history() output.
|
||||
"""
|
||||
|
||||
|
||||
key: str # channel:chat_id
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||
"""Add a message to the session."""
|
||||
@ -39,32 +44,36 @@ class Session:
|
||||
def get_history(self, max_messages: int = 50) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get message history for LLM context.
|
||||
|
||||
|
||||
Messages are returned in append-only order for cache efficiency.
|
||||
Only the most recent max_messages are returned, but the order
|
||||
is always stable for the same max_messages value.
|
||||
|
||||
Args:
|
||||
max_messages: Maximum messages to return.
|
||||
|
||||
max_messages: Maximum messages to return (most recent).
|
||||
|
||||
Returns:
|
||||
List of messages in LLM format.
|
||||
List of messages in LLM format (role and content only).
|
||||
"""
|
||||
# Get recent messages
|
||||
recent = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
|
||||
|
||||
recent = self.messages[-max_messages:]
|
||||
|
||||
# Convert to LLM format (just role and content)
|
||||
return [{"role": m["role"], "content": m["content"]} for m in recent]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages in the session."""
|
||||
"""Clear all messages and reset session to initial state."""
|
||||
self.messages = []
|
||||
self.last_consolidated = 0
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages conversation sessions.
|
||||
|
||||
|
||||
Sessions are stored as JSONL files in the sessions directory.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions")
|
||||
@ -100,34 +109,37 @@ class SessionManager:
|
||||
def _load(self, key: str) -> Session | None:
|
||||
"""Load a session from disk."""
|
||||
path = self._get_session_path(key)
|
||||
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
data = json.loads(line)
|
||||
|
||||
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
|
||||
|
||||
return Session(
|
||||
key=key,
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load session {key}: {e}")
|
||||
@ -136,43 +148,24 @@ class SessionManager:
|
||||
def save(self, session: Session) -> None:
|
||||
"""Save a session to disk."""
|
||||
path = self._get_session_path(session.key)
|
||||
|
||||
|
||||
with open(path, "w") as f:
|
||||
# Write metadata first
|
||||
metadata_line = {
|
||||
"_type": "metadata",
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
"metadata": session.metadata
|
||||
"metadata": session.metadata,
|
||||
"last_consolidated": session.last_consolidated
|
||||
}
|
||||
f.write(json.dumps(metadata_line) + "\n")
|
||||
|
||||
|
||||
# Write messages
|
||||
for msg in session.messages:
|
||||
f.write(json.dumps(msg) + "\n")
|
||||
|
||||
|
||||
self._cache[session.key] = session
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a session.
|
||||
|
||||
Args:
|
||||
key: Session key.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
# Remove from cache
|
||||
self._cache.pop(key, None)
|
||||
|
||||
# Remove file
|
||||
path = self._get_session_path(key)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_sessions(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List all sessions.
|
||||
|
||||
@ -12,7 +12,8 @@ def mock_prompt_session():
|
||||
"""Mock the global prompt session."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.prompt_async = AsyncMock()
|
||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session):
|
||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||
patch("nanobot.cli.commands.patch_stdout"):
|
||||
yield mock_session
|
||||
|
||||
|
||||
|
||||
159
tests/test_consolidate_offset.py
Normal file
159
tests/test_consolidate_offset.py
Normal file
@ -0,0 +1,159 @@
|
||||
"""Test session management with cache-friendly message handling."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
class TestSessionLastConsolidated:
|
||||
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||
|
||||
def test_initial_last_consolidated_zero(self) -> None:
|
||||
"""Test that new session starts with last_consolidated=0."""
|
||||
session = Session(key="test:initial")
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||
"""Test that last_consolidated persists across save/load."""
|
||||
manager = SessionManager(Path(tmp_path))
|
||||
|
||||
session1 = Session(key="test:persist")
|
||||
for i in range(20):
|
||||
session1.add_message("user", f"msg{i}")
|
||||
session1.last_consolidated = 15 # Simulate consolidation
|
||||
manager.save(session1)
|
||||
|
||||
session2 = manager.get_or_create("test:persist")
|
||||
assert session2.last_consolidated == 15
|
||||
assert len(session2.messages) == 20
|
||||
|
||||
def test_clear_resets_last_consolidated(self) -> None:
|
||||
"""Test that clear() resets last_consolidated to 0."""
|
||||
session = Session(key="test:clear")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.last_consolidated = 5
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
|
||||
class TestSessionImmutableHistory:
|
||||
"""Test Session message immutability for cache efficiency."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test that new session has empty messages list."""
|
||||
session = Session(key="test:initial")
|
||||
assert len(session.messages) == 0
|
||||
|
||||
def test_add_messages_appends_only(self) -> None:
|
||||
"""Test that adding messages only appends, never modifies."""
|
||||
session = Session(key="test:preserve")
|
||||
session.add_message("user", "msg1")
|
||||
session.add_message("assistant", "resp1")
|
||||
session.add_message("user", "msg2")
|
||||
assert len(session.messages) == 3
|
||||
# First message should always be the first message added
|
||||
assert session.messages[0]["content"] == "msg1"
|
||||
|
||||
def test_get_history_returns_most_recent(self) -> None:
|
||||
"""Test get_history returns the most recent messages."""
|
||||
session = Session(key="test:history")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
history = session.get_history(max_messages=6)
|
||||
# Should return last 6 messages
|
||||
assert len(history) == 6
|
||||
# First returned should be resp4 (messages 7-12: msg7/resp7, msg8/resp8, msg9/resp9)
|
||||
# Actually: 20 messages total, last 6 are indices 14-19
|
||||
assert history[0]["content"] == "msg7" # Index 14 (7th user msg after 7 pairs)
|
||||
assert history[-1]["content"] == "resp9" # Index 19 (last assistant msg)
|
||||
|
||||
def test_get_history_with_all_messages(self) -> None:
|
||||
"""Test get_history with max_messages larger than actual."""
|
||||
session = Session(key="test:all")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
history = session.get_history(max_messages=100)
|
||||
assert len(history) == 5
|
||||
assert history[0]["content"] == "msg0"
|
||||
|
||||
def test_get_history_stable_for_same_session(self) -> None:
|
||||
"""Test that get_history returns same content for same max_messages."""
|
||||
session = Session(key="test:stable")
|
||||
for i in range(20):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
# Multiple calls with same max_messages should return identical content
|
||||
history1 = session.get_history(max_messages=10)
|
||||
history2 = session.get_history(max_messages=10)
|
||||
assert history1 == history2
|
||||
|
||||
def test_messages_list_never_modified(self) -> None:
|
||||
"""Test that messages list is never modified after creation."""
|
||||
session = Session(key="test:immutable")
|
||||
original_len = 0
|
||||
|
||||
# Add some messages
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
original_len += 1
|
||||
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
# get_history should not modify the list
|
||||
session.get_history(max_messages=2)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
# Multiple calls should not affect messages
|
||||
for _ in range(10):
|
||||
session.get_history(max_messages=3)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
"""Test Session persistence and reload."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_manager(self, tmp_path):
|
||||
return SessionManager(Path(tmp_path))
|
||||
|
||||
def test_persistence_roundtrip(self, temp_manager):
|
||||
"""Test that messages persist across save/load."""
|
||||
session1 = Session(key="test:persistence")
|
||||
for i in range(20):
|
||||
session1.add_message("user", f"msg{i}")
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:persistence")
|
||||
assert len(session2.messages) == 20
|
||||
assert session2.messages[0]["content"] == "msg0"
|
||||
assert session2.messages[-1]["content"] == "msg19"
|
||||
|
||||
def test_get_history_after_reload(self, temp_manager):
|
||||
"""Test that get_history works correctly after reload."""
|
||||
session1 = Session(key="test:reload")
|
||||
for i in range(30):
|
||||
session1.add_message("user", f"msg{i}")
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:reload")
|
||||
history = session2.get_history(max_messages=10)
|
||||
# Should return last 10 messages (indices 20-29)
|
||||
assert len(history) == 10
|
||||
assert history[0]["content"] == "msg20"
|
||||
assert history[-1]["content"] == "msg29"
|
||||
|
||||
def test_clear_resets_session(self, temp_manager):
|
||||
"""Test that clear() properly resets session."""
|
||||
session = Session(key="test:clear")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
assert len(session.messages) == 10
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user