"""Tests for pattern detection module.""" import pytest from datetime import date, datetime, timedelta, timezone from decimal import Decimal from pote.monitoring.pattern_detector import PatternDetector from pote.db.models import Official, Security, Trade, MarketAlert @pytest.fixture def multiple_officials_with_patterns(test_db_session): """Create multiple officials with different timing patterns.""" session = test_db_session # Create officials pelosi = Official(name="Nancy Pelosi", chamber="House", party="Democrat", state="CA") tuberville = Official(name="Tommy Tuberville", chamber="Senate", party="Republican", state="AL") clean_trader = Official(name="Clean Trader", chamber="House", party="Independent", state="TX") session.add_all([pelosi, tuberville, clean_trader]) session.flush() # Create securities nvda = Security(ticker="NVDA", name="NVIDIA", sector="Technology") msft = Security(ticker="MSFT", name="Microsoft", sector="Technology") xom = Security(ticker="XOM", name="Exxon", sector="Energy") session.add_all([nvda, msft, xom]) session.flush() # Pelosi - Suspicious pattern (trades with alerts) for i in range(5): trade_date = date(2024, 1, 15) + timedelta(days=i*30) # Create trade trade = Trade( official_id=pelosi.id, security_id=nvda.id, source="test", transaction_date=trade_date, side="buy", value_min=Decimal("15001"), value_max=Decimal("50000"), ) session.add(trade) session.flush() # Create alerts BEFORE trade (suspicious) for j in range(2): alert = MarketAlert( ticker="NVDA", alert_type="unusual_volume", timestamp=datetime.combine( trade_date - timedelta(days=3+j), datetime.min.time() ).replace(tzinfo=timezone.utc), severity=7 + j, ) session.add(alert) # Tuberville - Mixed pattern for i in range(4): trade_date = date(2024, 2, 1) + timedelta(days=i*30) trade = Trade( official_id=tuberville.id, security_id=msft.id, source="test", transaction_date=trade_date, side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.flush() # Only first 2 trades have alerts if i < 2: alert = MarketAlert( ticker="MSFT", alert_type="price_spike", timestamp=datetime.combine( trade_date - timedelta(days=5), datetime.min.time() ).replace(tzinfo=timezone.utc), severity=6, ) session.add(alert) # Clean trader - No suspicious activity for i in range(3): trade_date = date(2024, 3, 1) + timedelta(days=i*30) trade = Trade( official_id=clean_trader.id, security_id=xom.id, source="test", transaction_date=trade_date, side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() return { "officials": [pelosi, tuberville, clean_trader], "securities": [nvda, msft, xom], } def test_rank_officials_by_timing(test_db_session, multiple_officials_with_patterns): """Test ranking officials by timing scores.""" session = test_db_session detector = PatternDetector(session) rankings = detector.rank_officials_by_timing(lookback_days=3650, min_trades=3) assert len(rankings) >= 2 # At least 2 officials with 3+ trades # Rankings should be sorted by avg_timing_score (descending) for i in range(len(rankings) - 1): assert rankings[i]["avg_timing_score"] >= rankings[i + 1]["avg_timing_score"] # Check required fields for ranking in rankings: assert "name" in ranking assert "party" in ranking assert "chamber" in ranking assert "trade_count" in ranking assert "avg_timing_score" in ranking assert "suspicious_rate" in ranking def test_identify_repeat_offenders(test_db_session, multiple_officials_with_patterns): """Test identifying repeat offenders.""" session = test_db_session detector = PatternDetector(session) # Set low threshold to catch Pelosi (who has 100% suspicious rate) offenders = detector.identify_repeat_offenders( lookback_days=3650, min_suspicious_rate=0.7 # 70%+ ) # Should find at least Pelosi (all trades with alerts) assert isinstance(offenders, list) # All offenders should have high suspicious rates for offender in offenders: assert offender["suspicious_rate"] >= 70 def test_analyze_ticker_patterns(test_db_session, multiple_officials_with_patterns): """Test ticker pattern analysis.""" session = test_db_session detector = PatternDetector(session) ticker_patterns = detector.analyze_ticker_patterns( lookback_days=3650, min_trades=3 ) assert isinstance(ticker_patterns, list) assert len(ticker_patterns) >= 1 # At least NVDA should qualify # Check sorting for i in range(len(ticker_patterns) - 1): assert ticker_patterns[i]["avg_timing_score"] >= ticker_patterns[i + 1]["avg_timing_score"] # Check fields for pattern in ticker_patterns: assert "ticker" in pattern assert "trade_count" in pattern assert "avg_timing_score" in pattern assert "suspicious_rate" in pattern def test_get_sector_timing_analysis(test_db_session, multiple_officials_with_patterns): """Test sector timing analysis.""" session = test_db_session detector = PatternDetector(session) sector_stats = detector.get_sector_timing_analysis(lookback_days=3650) assert isinstance(sector_stats, dict) assert len(sector_stats) >= 2 # Technology and Energy # Check Technology sector (should have alerts) if "Technology" in sector_stats: tech = sector_stats["Technology"] assert tech["trade_count"] >= 9 # 5 NVDA + 4 MSFT assert "avg_timing_score" in tech assert "alert_rate" in tech assert "suspicious_rate" in tech def test_get_party_comparison(test_db_session, multiple_officials_with_patterns): """Test party comparison analysis.""" session = test_db_session detector = PatternDetector(session) party_stats = detector.get_party_comparison(lookback_days=3650) assert isinstance(party_stats, dict) assert len(party_stats) >= 2 # Democrat, Republican, Independent # Check that we have data for each party for party, stats in party_stats.items(): assert "official_count" in stats assert "total_trades" in stats assert "avg_timing_score" in stats assert "suspicious_rate" in stats def test_generate_pattern_report(test_db_session, multiple_officials_with_patterns): """Test comprehensive pattern report generation.""" session = test_db_session detector = PatternDetector(session) report = detector.generate_pattern_report(lookback_days=3650) # Check report structure assert "period_days" in report assert "summary" in report assert "top_suspicious_officials" in report assert "repeat_offenders" in report assert "suspicious_tickers" in report assert "sector_analysis" in report assert "party_comparison" in report # Check summary summary = report["summary"] assert summary["total_officials_analyzed"] >= 2 assert "avg_timing_score" in summary # Check that lists are populated assert len(report["top_suspicious_officials"]) >= 2 assert isinstance(report["suspicious_tickers"], list) def test_rank_officials_min_trades_filter(test_db_session, multiple_officials_with_patterns): """Test that min_trades filter works correctly.""" session = test_db_session detector = PatternDetector(session) # With min_trades=5, should only get Pelosi rankings_high = detector.rank_officials_by_timing(lookback_days=3650, min_trades=5) # With min_trades=3, should get at least 2 officials rankings_low = detector.rank_officials_by_timing(lookback_days=3650, min_trades=3) assert len(rankings_low) >= len(rankings_high) # All officials should meet min_trades requirement for ranking in rankings_high: assert ranking["trade_count"] >= 5 def test_empty_data_handling(test_db_session): """Test handling of empty dataset.""" session = test_db_session detector = PatternDetector(session) # With no data, should return empty results rankings = detector.rank_officials_by_timing(lookback_days=30, min_trades=1) assert rankings == [] offenders = detector.identify_repeat_offenders(lookback_days=30) assert offenders == [] tickers = detector.analyze_ticker_patterns(lookback_days=30) assert tickers == [] sectors = detector.get_sector_timing_analysis(lookback_days=30) assert sectors == {} def test_ranking_score_accuracy(test_db_session, multiple_officials_with_patterns): """Test that rankings accurately reflect timing patterns.""" session = test_db_session detector = PatternDetector(session) rankings = detector.rank_officials_by_timing(lookback_days=3650, min_trades=3) # Find Pelosi and Clean Trader pelosi_rank = next((r for r in rankings if "Pelosi" in r["name"]), None) clean_rank = next((r for r in rankings if "Clean" in r["name"]), None) if pelosi_rank and clean_rank: # Pelosi (with alerts) should have higher score than clean trader (no alerts) assert pelosi_rank["avg_timing_score"] > clean_rank["avg_timing_score"] assert pelosi_rank["trades_with_alerts"] > clean_rank["trades_with_alerts"] def test_sector_stats_accuracy(test_db_session, multiple_officials_with_patterns): """Test sector statistics are calculated correctly.""" session = test_db_session detector = PatternDetector(session) sector_stats = detector.get_sector_timing_analysis(lookback_days=3650) # Energy should have clean pattern (no alerts) if "Energy" in sector_stats: energy = sector_stats["Energy"] assert energy["suspicious_count"] == 0 assert energy["alert_rate"] == 0.0 def test_party_stats_completeness(test_db_session, multiple_officials_with_patterns): """Test party statistics completeness.""" session = test_db_session detector = PatternDetector(session) party_stats = detector.get_party_comparison(lookback_days=3650) # Check Democrats (Pelosi) if "Democrat" in party_stats: dem = party_stats["Democrat"] assert dem["official_count"] >= 1 assert dem["total_trades"] >= 5 # Pelosi has 5 trades assert dem["total_suspicious"] > 0 # Pelosi has suspicious trades