""" Tests for trade loader (ETL). """ import json from datetime import date from decimal import Decimal from pathlib import Path from sqlalchemy import select from pote.db.models import Official, Trade from pote.ingestion.trade_loader import TradeLoader def test_ingest_transactions_from_fixture(test_db_session): """Test ingesting transactions from fixture file.""" # Load fixture fixture_path = Path(__file__).parent / "fixtures" / "sample_house_watcher.json" with open(fixture_path) as f: transactions = json.load(f) # Ingest loader = TradeLoader(test_db_session) counts = loader.ingest_transactions(transactions) # Verify counts assert counts["officials"] >= 3 # Nancy, Josh, Tommy, Dan assert counts["securities"] >= 4 # NVDA, MSFT, AAPL, TSLA, GOOGL assert counts["trades"] == 5 # Verify data in DB stmt = select(Official).where(Official.name == "Nancy Pelosi") pelosi = test_db_session.scalars(stmt).first() assert pelosi is not None assert pelosi.chamber == "House" assert pelosi.party == "Democrat" # Verify trades stmt = select(Trade).where(Trade.official_id == pelosi.id) pelosi_trades = test_db_session.scalars(stmt).all() assert len(pelosi_trades) == 2 # NVDA and GOOGL # Check one trade in detail nvda_trade = [t for t in pelosi_trades if t.security.ticker == "NVDA"][0] assert nvda_trade.transaction_date == date(2024, 1, 15) assert nvda_trade.filing_date == date(2024, 2, 1) assert nvda_trade.side == "buy" assert nvda_trade.value_min == Decimal("1001") assert nvda_trade.value_max == Decimal("15000") def test_ingest_duplicate_transaction(test_db_session): """Test that duplicate transactions are not created.""" loader = TradeLoader(test_db_session) transaction = { "representative": "Test Official", "ticker": "AAPL", "transaction_date": "01/15/2024", "disclosure_date": "02/01/2024", "transaction": "Purchase", "amount": "$1,001 - $15,000", "house": "House", "party": "Independent", } # Ingest once counts1 = loader.ingest_transactions([transaction]) assert counts1["trades"] == 1 # Ingest again (should detect duplicate) counts2 = loader.ingest_transactions([transaction]) assert counts2["trades"] == 0 # No new trade created # Verify only one trade in DB stmt = select(Trade) trades = test_db_session.scalars(stmt).all() assert len(trades) == 1 def test_ingest_transaction_missing_ticker(test_db_session): """Test that transactions without tickers are skipped.""" loader = TradeLoader(test_db_session) transaction = { "representative": "Test Official", "ticker": "", # Missing ticker "transaction_date": "01/15/2024", "disclosure_date": "02/01/2024", "transaction": "Purchase", "amount": "$1,001 - $15,000", "house": "House", "party": "Independent", } counts = loader.ingest_transactions([transaction]) assert counts["trades"] == 0 def test_get_or_create_official_senate(test_db_session): """Test creating a Senate official.""" loader = TradeLoader(test_db_session) transaction = { "representative": "Test Senator", "ticker": "AAPL", "transaction_date": "01/15/2024", "disclosure_date": "02/01/2024", "transaction": "Purchase", "amount": "$1,001 - $15,000", "house": "Senate", "party": "Republican", } loader.ingest_transactions([transaction]) stmt = select(Official).where(Official.name == "Test Senator") official = test_db_session.scalars(stmt).first() assert official is not None assert official.chamber == "Senate" assert official.party == "Republican" def test_multiple_trades_same_official(test_db_session): """Test multiple trades for the same official.""" loader = TradeLoader(test_db_session) transactions = [ { "representative": "Jane Doe", "ticker": "AAPL", "transaction_date": "01/10/2024", "disclosure_date": "01/25/2024", "transaction": "Purchase", "amount": "$1,001 - $15,000", "house": "House", "party": "Democrat", }, { "representative": "Jane Doe", "ticker": "MSFT", "transaction_date": "01/15/2024", "disclosure_date": "01/30/2024", "transaction": "Sale", "amount": "$15,001 - $50,000", "house": "House", "party": "Democrat", }, ] counts = loader.ingest_transactions(transactions) assert counts["officials"] == 1 # Only one official created assert counts["trades"] == 2 stmt = select(Official).where(Official.name == "Jane Doe") official = test_db_session.scalars(stmt).first() stmt = select(Trade).where(Trade.official_id == official.id) trades = test_db_session.scalars(stmt).all() assert len(trades) == 2