"""Tests for market monitoring module.""" import pytest from datetime import date, datetime, timedelta, timezone from decimal import Decimal from pote.monitoring.market_monitor import MarketMonitor from pote.monitoring.alert_manager import AlertManager from pote.db.models import Official, Security, Trade, MarketAlert @pytest.fixture def sample_congressional_trades(test_db_session): """Create sample congressional trades for watchlist building.""" session = test_db_session # Create officials pelosi = Official(name="Nancy Pelosi", chamber="House", party="Democrat", state="CA") tuberville = Official(name="Tommy Tuberville", chamber="Senate", party="Republican", state="AL") session.add_all([pelosi, tuberville]) session.flush() # Create securities nvda = Security(ticker="NVDA", name="NVIDIA Corporation", sector="Technology") msft = Security(ticker="MSFT", name="Microsoft Corporation", sector="Technology") aapl = Security(ticker="AAPL", name="Apple Inc.", sector="Technology") tsla = Security(ticker="TSLA", name="Tesla, Inc.", sector="Automotive") spy = Security(ticker="SPY", name="SPDR S&P 500 ETF", sector="Financial") session.add_all([nvda, msft, aapl, tsla, spy]) session.flush() # Create multiple trades (NVDA is most traded) trades = [ Trade(official_id=pelosi.id, security_id=nvda.id, source="test", transaction_date=date(2024, 1, 15), side="buy", value_min=Decimal("15001"), value_max=Decimal("50000")), Trade(official_id=pelosi.id, security_id=nvda.id, source="test", transaction_date=date(2024, 2, 1), side="buy", value_min=Decimal("15001"), value_max=Decimal("50000")), Trade(official_id=tuberville.id, security_id=nvda.id, source="test", transaction_date=date(2024, 2, 15), side="buy", value_min=Decimal("50001"), value_max=Decimal("100000")), Trade(official_id=pelosi.id, security_id=msft.id, source="test", transaction_date=date(2024, 1, 20), side="sell", value_min=Decimal("15001"), value_max=Decimal("50000")), Trade(official_id=tuberville.id, security_id=aapl.id, source="test", transaction_date=date(2024, 2, 10), side="buy", value_min=Decimal("15001"), value_max=Decimal("50000")), ] session.add_all(trades) session.commit() return { "officials": [pelosi, tuberville], "securities": [nvda, msft, aapl, tsla, spy], "trades": trades, } @pytest.fixture def sample_alerts(test_db_session): """Create sample market alerts.""" session = test_db_session now = datetime.now(timezone.utc) alerts = [ MarketAlert( ticker="NVDA", alert_type="unusual_volume", timestamp=now - timedelta(hours=2), details={"current_volume": 100000000, "avg_volume": 30000000, "multiplier": 3.33}, price=Decimal("495.50"), volume=100000000, change_pct=Decimal("2.5"), severity=7, ), MarketAlert( ticker="NVDA", alert_type="price_spike", timestamp=now - timedelta(hours=1), details={"current_price": 505.00, "prev_price": 495.50, "change_pct": 1.92}, price=Decimal("505.00"), volume=85000000, change_pct=Decimal("5.5"), severity=4, ), MarketAlert( ticker="MSFT", alert_type="high_volatility", timestamp=now - timedelta(hours=3), details={"recent_volatility": 4.5, "avg_volatility": 2.0, "multiplier": 2.25}, price=Decimal("380.25"), volume=50000000, change_pct=Decimal("1.2"), severity=5, ), ] session.add_all(alerts) session.commit() return alerts def test_get_congressional_watchlist(test_db_session, sample_congressional_trades): """Test building watchlist from congressional trades.""" session = test_db_session monitor = MarketMonitor(session) watchlist = monitor.get_congressional_watchlist(limit=10) assert len(watchlist) > 0 assert "NVDA" in watchlist # Most traded assert watchlist[0] == "NVDA" # Should be first (3 trades) def test_check_ticker_basic(test_db_session): """Test basic ticker checking (may not find alerts with real data).""" session = test_db_session monitor = MarketMonitor(session) # This uses real yfinance data, so alerts depend on current market # We test that it doesn't crash alerts = monitor.check_ticker("AAPL", lookback_days=5) assert isinstance(alerts, list) # Each alert should have required fields for alert in alerts: assert "ticker" in alert assert "alert_type" in alert assert "timestamp" in alert assert "severity" in alert def test_scan_watchlist_with_mock(test_db_session, sample_congressional_trades, monkeypatch): """Test scanning watchlist with mocked data.""" session = test_db_session monitor = MarketMonitor(session) # Mock the check_ticker method to return controlled data def mock_check_ticker(ticker, lookback_days=5): if ticker == "NVDA": return [ { "ticker": ticker, "alert_type": "unusual_volume", "timestamp": datetime.now(timezone.utc), "details": {"multiplier": 3.5}, "price": Decimal("500.00"), "volume": 100000000, "change_pct": Decimal("2.5"), "severity": 7, } ] return [] monkeypatch.setattr(monitor, "check_ticker", mock_check_ticker) # Scan with limited watchlist alerts = monitor.scan_watchlist(tickers=["NVDA", "MSFT"], lookback_days=5) assert len(alerts) == 1 assert alerts[0]["ticker"] == "NVDA" assert alerts[0]["alert_type"] == "unusual_volume" def test_save_alerts(test_db_session): """Test saving alerts to database.""" session = test_db_session monitor = MarketMonitor(session) alerts_data = [ { "ticker": "TSLA", "alert_type": "price_spike", "timestamp": datetime.now(timezone.utc), "details": {"change_pct": 7.5}, "price": Decimal("250.00"), "volume": 75000000, "change_pct": Decimal("7.5"), "severity": 8, }, { "ticker": "TSLA", "alert_type": "unusual_volume", "timestamp": datetime.now(timezone.utc), "details": {"multiplier": 4.0}, "price": Decimal("250.00"), "volume": 120000000, "change_pct": Decimal("7.5"), "severity": 9, }, ] saved_count = monitor.save_alerts(alerts_data) assert saved_count == 2 # Verify in database alerts = session.query(MarketAlert).filter_by(ticker="TSLA").all() assert len(alerts) == 2 def test_get_recent_alerts(test_db_session, sample_alerts): """Test querying recent alerts.""" session = test_db_session monitor = MarketMonitor(session) # Get all alerts all_alerts = monitor.get_recent_alerts(days=1) assert len(all_alerts) >= 3 # Filter by ticker nvda_alerts = monitor.get_recent_alerts(ticker="NVDA", days=1) assert len(nvda_alerts) == 2 assert all(a.ticker == "NVDA" for a in nvda_alerts) # Filter by alert type volume_alerts = monitor.get_recent_alerts(alert_type="unusual_volume", days=1) assert len(volume_alerts) == 1 assert volume_alerts[0].alert_type == "unusual_volume" # Filter by severity high_sev_alerts = monitor.get_recent_alerts(min_severity=6, days=1) assert all(a.severity >= 6 for a in high_sev_alerts) def test_get_ticker_alert_summary(test_db_session, sample_alerts): """Test alert summary by ticker.""" session = test_db_session monitor = MarketMonitor(session) summary = monitor.get_ticker_alert_summary(days=1) assert "NVDA" in summary assert "MSFT" in summary nvda_summary = summary["NVDA"] assert nvda_summary["alert_count"] == 2 assert nvda_summary["max_severity"] == 7 assert 4 <= nvda_summary["avg_severity"] <= 7 def test_alert_manager_format_text(test_db_session, sample_alerts): """Test text formatting of alerts.""" session = test_db_session alert_mgr = AlertManager(session) alert = sample_alerts[0] # NVDA unusual volume text = alert_mgr.format_alert_text(alert) assert "NVDA" in text assert "UNUSUAL VOLUME" in text assert "Severity" in text assert "$495.50" in text def test_alert_manager_format_html(test_db_session, sample_alerts): """Test HTML formatting of alerts.""" session = test_db_session alert_mgr = AlertManager(session) alert = sample_alerts[0] html = alert_mgr.format_alert_html(alert) assert "