"""Tests for analytics module.""" import pytest from datetime import date, timedelta from decimal import Decimal from pote.analytics.returns import ReturnCalculator from pote.analytics.benchmarks import BenchmarkComparison from pote.analytics.metrics import PerformanceMetrics from pote.db.models import Official, Security, Trade, Price @pytest.fixture def sample_prices(test_db_session, sample_security): """Create sample price data for testing.""" session = test_db_session # Add SPY (benchmark) prices spy = Security(ticker="SPY", name="SPDR S&P 500 ETF") session.add(spy) session.flush() base_date = date(2024, 1, 1) # Create SPY prices for i in range(100): price = Price( security_id=spy.id, date=base_date + timedelta(days=i), open=Decimal("450") + Decimal(i * 0.5), high=Decimal("452") + Decimal(i * 0.5), low=Decimal("449") + Decimal(i * 0.5), close=Decimal("451") + Decimal(i * 0.5), volume=1000000, ) session.add(price) # Create prices for sample_security (AAPL) for i in range(100): price = Price( security_id=sample_security.id, date=base_date + timedelta(days=i), open=Decimal("180") + Decimal(i * 0.3), high=Decimal("182") + Decimal(i * 0.3), low=Decimal("179") + Decimal(i * 0.3), close=Decimal("181") + Decimal(i * 0.3), volume=50000000, ) session.add(price) session.commit() return session def test_return_calculator_basic(test_db_session, sample_official, sample_security, sample_prices): session = test_db_session """Test basic return calculation.""" # Create a trade trade = Trade( official_id=sample_official.id, security_id=sample_security.id, source="test", transaction_date=date(2024, 1, 15), side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() # Calculate return calculator = ReturnCalculator(session) result = calculator.calculate_trade_return(trade, window_days=30) # Should have all required fields assert result is not None assert "ticker" in result assert "return_pct" in result assert "entry_price" in result assert "exit_price" in result def test_return_calculator_sell_trade(test_db_session, sample_official, sample_security, sample_prices): session = test_db_session """Test return calculation for sell trade.""" trade = Trade( official_id=sample_official.id, security_id=sample_security.id, source="test", transaction_date=date(2024, 1, 15), side="sell", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() calculator = ReturnCalculator(session) result = calculator.calculate_trade_return(trade, window_days=30) # For sell trades, returns should be inverted assert result is not None assert result["side"] == "sell" def test_return_calculator_missing_data(test_db_session, sample_official, sample_security): session = test_db_session """Test handling of missing price data.""" trade = Trade( official_id=sample_official.id, security_id=sample_security.id, source="test", transaction_date=date(2024, 1, 15), side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() calculator = ReturnCalculator(session) result = calculator.calculate_trade_return(trade, window_days=30) # Should return None when data unavailable assert result is None def test_benchmark_comparison(test_db_session, sample_official, sample_security, sample_prices): session = test_db_session """Test benchmark comparison.""" # Create trade and SPY security spy = session.query(Security).filter_by(ticker="SPY").first() trade = Trade( official_id=sample_official.id, security_id=spy.id, source="test", transaction_date=date(2024, 1, 15), side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() # Compare to benchmark benchmark = BenchmarkComparison(session) result = benchmark.compare_trade_to_benchmark(trade, window_days=30, benchmark="SPY") assert result is not None assert "trade_return" in result assert "benchmark_return" in result assert "abnormal_return" in result assert "beat_market" in result def test_performance_metrics_official(test_db_session, sample_official, sample_security, sample_prices): session = test_db_session """Test official performance metrics.""" # Create multiple trades spy = session.query(Security).filter_by(ticker="SPY").first() for i in range(3): trade = Trade( official_id=sample_official.id, security_id=spy.id, source="test", transaction_date=date(2024, 1, 10 + i), side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() # Get performance metrics metrics = PerformanceMetrics(session) perf = metrics.official_performance(sample_official.id, window_days=30) assert perf["name"] == sample_official.name assert "total_trades" in perf assert "avg_return" in perf or "message" in perf def test_multiple_windows(test_db_session, sample_official, sample_security, sample_prices): session = test_db_session """Test calculating returns for multiple windows.""" spy = session.query(Security).filter_by(ticker="SPY").first() trade = Trade( official_id=sample_official.id, security_id=spy.id, source="test", transaction_date=date(2024, 1, 15), side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() calculator = ReturnCalculator(session) results = calculator.calculate_multiple_windows(trade, windows=[30, 60, 90]) # Should calculate for all available windows assert isinstance(results, dict) for window in [30, 60, 90]: if window in results: assert results[window]["window_days"] == window def test_sector_analysis(test_db_session, sample_official, sample_prices): session = test_db_session """Test sector analysis.""" # Create securities in different sectors tech = Security(ticker="TECH", name="Tech Corp", sector="Technology") health = Security(ticker="HLTH", name="Health Inc", sector="Healthcare") session.add_all([tech, health]) session.commit() # Create trades for each sector for sec in [tech, health]: trade = Trade( official_id=sample_official.id, security_id=sec.id, source="test", transaction_date=date(2024, 1, 15), side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() metrics = PerformanceMetrics(session) sectors = metrics.sector_analysis(window_days=30) # Should group by sector assert isinstance(sectors, list) def test_timing_analysis(test_db_session, sample_official, sample_security): session = test_db_session """Test disclosure timing analysis.""" # Create trades with disclosure dates for i in range(3): trade = Trade( official_id=sample_official.id, security_id=sample_security.id, source="test", transaction_date=date(2024, 1, i + 1), filing_date=date(2024, 1, i + 15), # 14 day lag side="buy", value_min=Decimal("10000"), value_max=Decimal("50000"), ) session.add(trade) session.commit() metrics = PerformanceMetrics(session) timing = metrics.timing_analysis() assert "avg_disclosure_lag_days" in timing assert timing["avg_disclosure_lag_days"] > 0