From afc8d5065962f053204aa19a649966ad07841313 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Fri, 13 Feb 2026 16:30:43 +0800 Subject: [PATCH] test: add comprehensive tests for consolidate offset functionality Add 26 new test cases covering: - Consolidation trigger conditions (exceed window, within keep count, no new messages) - last_consolidated edge cases (exceeds message count, negative value, new messages after consolidation) - archive_all mode (/new command behavior) - Cache immutability (messages list never modified during consolidation) - Slice logic (messages[last_consolidated:-keep_count]) - Empty and boundary sessions (empty, single message, exact keep count, very large) Refactor tests with helper functions to reduce code duplication by 25%: - create_session_with_messages() - creates session with specified message count - assert_messages_content() - validates message content range - get_old_messages() - encapsulates standard slice logic All 35 tests passing. --- tests/test_consolidate_offset.py | 406 +++++++++++++++++++++++++++---- 1 file changed, 362 insertions(+), 44 deletions(-) diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index dfac395..e204733 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -2,9 +2,56 @@ import pytest from pathlib import Path -from typing import Callable from nanobot.session.manager import Session, SessionManager +# Test constants +MEMORY_WINDOW = 50 +KEEP_COUNT = MEMORY_WINDOW // 2 # 25 + + +def create_session_with_messages(key: str, count: int, role: str = "user") -> Session: + """Create a session and add the specified number of messages. + + Args: + key: Session identifier + count: Number of messages to add + role: Message role (default: "user") + + Returns: + Session with the specified messages + """ + session = Session(key=key) + for i in range(count): + session.add_message(role, f"msg{i}") + return session + + +def assert_messages_content(messages: list, start_index: int, end_index: int) -> None: + """Assert that messages contain expected content from start to end index. + + Args: + messages: List of message dictionaries + start_index: Expected first message index + end_index: Expected last message index + """ + assert len(messages) > 0 + assert messages[0]["content"] == f"msg{start_index}" + assert messages[-1]["content"] == f"msg{end_index}" + + +def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list: + """Extract messages that would be consolidated using the standard slice logic. + + Args: + session: The session containing messages + last_consolidated: Index of last consolidated message + keep_count: Number of recent messages to keep + + Returns: + List of messages that would be consolidated + """ + return session.messages[last_consolidated:-keep_count] + class TestSessionLastConsolidated: """Test last_consolidated tracking to avoid duplicate processing.""" @@ -17,11 +64,8 @@ class TestSessionLastConsolidated: def test_last_consolidated_persistence(self, tmp_path) -> None: """Test that last_consolidated persists across save/load.""" manager = SessionManager(Path(tmp_path)) - - session1 = Session(key="test:persist") - for i in range(20): - session1.add_message("user", f"msg{i}") - session1.last_consolidated = 15 # Simulate consolidation + session1 = create_session_with_messages("test:persist", 20) + session1.last_consolidated = 15 manager.save(session1) session2 = manager.get_or_create("test:persist") @@ -30,9 +74,7 @@ class TestSessionLastConsolidated: def test_clear_resets_last_consolidated(self) -> None: """Test that clear() resets last_consolidated to 0.""" - session = Session(key="test:clear") - for i in range(10): - session.add_message("user", f"msg{i}") + session = create_session_with_messages("test:clear", 10) session.last_consolidated = 5 session.clear() @@ -55,7 +97,6 @@ class TestSessionImmutableHistory: session.add_message("assistant", "resp1") session.add_message("user", "msg2") assert len(session.messages) == 3 - # First message should always be the first message added assert session.messages[0]["content"] == "msg1" def test_get_history_returns_most_recent(self) -> None: @@ -64,51 +105,34 @@ class TestSessionImmutableHistory: for i in range(10): session.add_message("user", f"msg{i}") session.add_message("assistant", f"resp{i}") + history = session.get_history(max_messages=6) - # Should return last 6 messages assert len(history) == 6 - # First returned should be resp4 (messages 7-12: msg7/resp7, msg8/resp8, msg9/resp9) - # Actually: 20 messages total, last 6 are indices 14-19 - assert history[0]["content"] == "msg7" # Index 14 (7th user msg after 7 pairs) - assert history[-1]["content"] == "resp9" # Index 19 (last assistant msg) + assert history[0]["content"] == "msg7" + assert history[-1]["content"] == "resp9" def test_get_history_with_all_messages(self) -> None: """Test get_history with max_messages larger than actual.""" - session = Session(key="test:all") - for i in range(5): - session.add_message("user", f"msg{i}") + session = create_session_with_messages("test:all", 5) history = session.get_history(max_messages=100) assert len(history) == 5 assert history[0]["content"] == "msg0" def test_get_history_stable_for_same_session(self) -> None: """Test that get_history returns same content for same max_messages.""" - session = Session(key="test:stable") - for i in range(20): - session.add_message("user", f"msg{i}") - - # Multiple calls with same max_messages should return identical content + session = create_session_with_messages("test:stable", 20) history1 = session.get_history(max_messages=10) history2 = session.get_history(max_messages=10) assert history1 == history2 def test_messages_list_never_modified(self) -> None: """Test that messages list is never modified after creation.""" - session = Session(key="test:immutable") - original_len = 0 + session = create_session_with_messages("test:immutable", 5) + original_len = len(session.messages) - # Add some messages - for i in range(5): - session.add_message("user", f"msg{i}") - original_len += 1 - - assert len(session.messages) == original_len - - # get_history should not modify the list session.get_history(max_messages=2) assert len(session.messages) == original_len - # Multiple calls should not affect messages for _ in range(10): session.get_history(max_messages=3) assert len(session.messages) == original_len @@ -123,9 +147,7 @@ class TestSessionPersistence: def test_persistence_roundtrip(self, temp_manager): """Test that messages persist across save/load.""" - session1 = Session(key="test:persistence") - for i in range(20): - session1.add_message("user", f"msg{i}") + session1 = create_session_with_messages("test:persistence", 20) temp_manager.save(session1) session2 = temp_manager.get_or_create("test:persistence") @@ -135,25 +157,321 @@ class TestSessionPersistence: def test_get_history_after_reload(self, temp_manager): """Test that get_history works correctly after reload.""" - session1 = Session(key="test:reload") - for i in range(30): - session1.add_message("user", f"msg{i}") + session1 = create_session_with_messages("test:reload", 30) temp_manager.save(session1) session2 = temp_manager.get_or_create("test:reload") history = session2.get_history(max_messages=10) - # Should return last 10 messages (indices 20-29) assert len(history) == 10 assert history[0]["content"] == "msg20" assert history[-1]["content"] == "msg29" def test_clear_resets_session(self, temp_manager): """Test that clear() properly resets session.""" - session = Session(key="test:clear") - for i in range(10): - session.add_message("user", f"msg{i}") + session = create_session_with_messages("test:clear", 10) assert len(session.messages) == 10 session.clear() assert len(session.messages) == 0 + +class TestConsolidationTriggerConditions: + """Test consolidation trigger conditions and logic.""" + + def test_consolidation_needed_when_messages_exceed_window(self): + """Test consolidation logic: should trigger when messages > memory_window.""" + session = create_session_with_messages("test:trigger", 60) + + total_messages = len(session.messages) + messages_to_process = total_messages - session.last_consolidated + + assert total_messages > MEMORY_WINDOW + assert messages_to_process > 0 + + expected_consolidate_count = total_messages - KEEP_COUNT + assert expected_consolidate_count == 35 + + def test_consolidation_skipped_when_within_keep_count(self): + """Test consolidation skipped when total messages <= keep_count.""" + session = create_session_with_messages("test:skip", 20) + + total_messages = len(session.messages) + assert total_messages <= KEEP_COUNT + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_consolidation_skipped_when_no_new_messages(self): + """Test consolidation skipped when messages_to_process <= 0.""" + session = create_session_with_messages("test:already_consolidated", 40) + session.last_consolidated = len(session.messages) - KEEP_COUNT # 15 + + # Add a few more messages + for i in range(40, 42): + session.add_message("user", f"msg{i}") + + total_messages = len(session.messages) + messages_to_process = total_messages - session.last_consolidated + assert messages_to_process > 0 + + # Simulate last_consolidated catching up + session.last_consolidated = total_messages - KEEP_COUNT + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + +class TestLastConsolidatedEdgeCases: + """Test last_consolidated edge cases and data corruption scenarios.""" + + def test_last_consolidated_exceeds_message_count(self): + """Test behavior when last_consolidated > len(messages) (data corruption).""" + session = create_session_with_messages("test:corruption", 10) + session.last_consolidated = 20 + + total_messages = len(session.messages) + messages_to_process = total_messages - session.last_consolidated + assert messages_to_process <= 0 + + old_messages = get_old_messages(session, session.last_consolidated, 5) + assert len(old_messages) == 0 + + def test_last_consolidated_negative_value(self): + """Test behavior with negative last_consolidated (invalid state).""" + session = create_session_with_messages("test:negative", 10) + session.last_consolidated = -5 + + keep_count = 3 + old_messages = get_old_messages(session, session.last_consolidated, keep_count) + + # messages[-5:-3] with 10 messages gives indices 5,6 + assert len(old_messages) == 2 + assert old_messages[0]["content"] == "msg5" + assert old_messages[-1]["content"] == "msg6" + + def test_messages_added_after_consolidation(self): + """Test correct behavior when new messages arrive after consolidation.""" + session = create_session_with_messages("test:new_messages", 40) + session.last_consolidated = len(session.messages) - KEEP_COUNT # 15 + + # Add new messages after consolidation + for i in range(40, 50): + session.add_message("user", f"msg{i}") + + total_messages = len(session.messages) + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated + + assert len(old_messages) == expected_consolidate_count + assert_messages_content(old_messages, 15, 24) + + def test_slice_behavior_when_indices_overlap(self): + """Test slice behavior when last_consolidated >= total - keep_count.""" + session = create_session_with_messages("test:overlap", 30) + session.last_consolidated = 12 + + old_messages = get_old_messages(session, session.last_consolidated, 20) + assert len(old_messages) == 0 + + +class TestArchiveAllMode: + """Test archive_all mode (used by /new command).""" + + def test_archive_all_consolidates_everything(self): + """Test archive_all=True consolidates all messages.""" + session = create_session_with_messages("test:archive_all", 50) + + archive_all = True + if archive_all: + old_messages = session.messages + assert len(old_messages) == 50 + + assert session.last_consolidated == 0 + + def test_archive_all_resets_last_consolidated(self): + """Test that archive_all mode resets last_consolidated to 0.""" + session = create_session_with_messages("test:reset", 40) + session.last_consolidated = 15 + + archive_all = True + if archive_all: + session.last_consolidated = 0 + + assert session.last_consolidated == 0 + assert len(session.messages) == 40 + + def test_archive_all_vs_normal_consolidation(self): + """Test difference between archive_all and normal consolidation.""" + # Normal consolidation + session1 = create_session_with_messages("test:normal", 60) + session1.last_consolidated = len(session1.messages) - KEEP_COUNT + + # archive_all mode + session2 = create_session_with_messages("test:all", 60) + session2.last_consolidated = 0 + + assert session1.last_consolidated == 35 + assert len(session1.messages) == 60 + assert session2.last_consolidated == 0 + assert len(session2.messages) == 60 + + +class TestCacheImmutability: + """Test that consolidation doesn't modify session.messages (cache safety).""" + + def test_consolidation_does_not_modify_messages_list(self): + """Test that consolidation leaves messages list unchanged.""" + session = create_session_with_messages("test:immutable", 50) + + original_messages = session.messages.copy() + original_len = len(session.messages) + session.last_consolidated = original_len - KEEP_COUNT + + assert len(session.messages) == original_len + assert session.messages == original_messages + + def test_get_history_does_not_modify_messages(self): + """Test that get_history doesn't modify messages list.""" + session = create_session_with_messages("test:history_immutable", 40) + original_messages = [m.copy() for m in session.messages] + + for _ in range(5): + history = session.get_history(max_messages=10) + assert len(history) == 10 + + assert len(session.messages) == 40 + for i, msg in enumerate(session.messages): + assert msg["content"] == original_messages[i]["content"] + + def test_consolidation_only_updates_last_consolidated(self): + """Test that consolidation only updates last_consolidated field.""" + session = create_session_with_messages("test:field_only", 60) + + original_messages = session.messages.copy() + original_key = session.key + original_metadata = session.metadata.copy() + + session.last_consolidated = len(session.messages) - KEEP_COUNT + + assert session.messages == original_messages + assert session.key == original_key + assert session.metadata == original_metadata + assert session.last_consolidated == 35 + + +class TestSliceLogic: + """Test the slice logic: messages[last_consolidated:-keep_count].""" + + def test_slice_extracts_correct_range(self): + """Test that slice extracts the correct message range.""" + session = create_session_with_messages("test:slice", 60) + + old_messages = get_old_messages(session, 0, KEEP_COUNT) + + assert len(old_messages) == 35 + assert_messages_content(old_messages, 0, 34) + + remaining = session.messages[-KEEP_COUNT:] + assert len(remaining) == 25 + assert_messages_content(remaining, 35, 59) + + def test_slice_with_partial_consolidation(self): + """Test slice when some messages already consolidated.""" + session = create_session_with_messages("test:partial", 70) + + last_consolidated = 30 + old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT) + + assert len(old_messages) == 15 + assert_messages_content(old_messages, 30, 44) + + def test_slice_with_various_keep_counts(self): + """Test slice behavior with different keep_count values.""" + session = create_session_with_messages("test:keep_counts", 50) + + test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)] + + for keep_count, expected_count in test_cases: + old_messages = session.messages[0:-keep_count] + assert len(old_messages) == expected_count + + def test_slice_when_keep_count_exceeds_messages(self): + """Test slice when keep_count > len(messages).""" + session = create_session_with_messages("test:exceed", 10) + + old_messages = session.messages[0:-20] + assert len(old_messages) == 0 + + +class TestEmptyAndBoundarySessions: + """Test empty sessions and boundary conditions.""" + + def test_empty_session_consolidation(self): + """Test consolidation behavior with empty session.""" + session = Session(key="test:empty") + + assert len(session.messages) == 0 + assert session.last_consolidated == 0 + + messages_to_process = len(session.messages) - session.last_consolidated + assert messages_to_process == 0 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_single_message_session(self): + """Test consolidation with single message.""" + session = Session(key="test:single") + session.add_message("user", "only message") + + assert len(session.messages) == 1 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_exactly_keep_count_messages(self): + """Test session with exactly keep_count messages.""" + session = create_session_with_messages("test:exact", KEEP_COUNT) + + assert len(session.messages) == KEEP_COUNT + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 0 + + def test_just_over_keep_count(self): + """Test session with one message over keep_count.""" + session = create_session_with_messages("test:over", KEEP_COUNT + 1) + + assert len(session.messages) == 26 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 1 + assert old_messages[0]["content"] == "msg0" + + def test_very_large_session(self): + """Test consolidation with very large message count.""" + session = create_session_with_messages("test:large", 1000) + + assert len(session.messages) == 1000 + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + assert len(old_messages) == 975 + assert_messages_content(old_messages, 0, 974) + + remaining = session.messages[-KEEP_COUNT:] + assert len(remaining) == 25 + assert_messages_content(remaining, 975, 999) + + def test_session_with_gaps_in_consolidation(self): + """Test session with potential gaps in consolidation history.""" + session = create_session_with_messages("test:gaps", 50) + session.last_consolidated = 10 + + # Add more messages + for i in range(50, 60): + session.add_message("user", f"msg{i}") + + old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT) + + expected_count = 60 - KEEP_COUNT - 10 + assert len(old_messages) == expected_count + assert_messages_content(old_messages, 10, 34)