""" Tests for price loader. """ from datetime import date from decimal import Decimal from unittest.mock import MagicMock, patch import pandas as pd import pytest from sqlalchemy import select from pote.db.models import Price, Security from pote.ingestion.prices import PriceLoader @pytest.fixture def price_loader(test_db_session): """Create a PriceLoader instance with test session.""" return PriceLoader(test_db_session) def test_get_or_create_security_new(price_loader, test_db_session): """Test creating a new security.""" security = price_loader._get_or_create_security("MSFT") assert security.id is not None assert security.ticker == "MSFT" assert security.asset_type == "stock" # Verify it's in the database stmt = select(Security).where(Security.ticker == "MSFT") db_security = test_db_session.scalars(stmt).first() assert db_security is not None assert db_security.id == security.id def test_get_or_create_security_existing(price_loader, test_db_session, sample_security): """Test getting an existing security.""" security = price_loader._get_or_create_security("AAPL") assert security.id == sample_security.id assert security.ticker == "AAPL" # Verify no duplicate was created stmt = select(Security).where(Security.ticker == "AAPL") count = len(test_db_session.scalars(stmt).all()) assert count == 1 def test_store_prices(price_loader, test_db_session, sample_security): """Test storing price data.""" df = pd.DataFrame( { "date": [date(2024, 1, 1), date(2024, 1, 2), date(2024, 1, 3)], "open": [100.0, 101.0, 102.0], "high": [105.0, 106.0, 107.0], "low": [99.0, 100.0, 101.0], "close": [103.0, 104.0, 105.0], "volume": [1000000, 1100000, 1200000], } ) count = price_loader._store_prices(sample_security.id, df) assert count == 3 # Verify prices in database stmt = select(Price).where(Price.security_id == sample_security.id).order_by(Price.date) prices = test_db_session.scalars(stmt).all() assert len(prices) == 3 assert prices[0].date == date(2024, 1, 1) assert prices[0].close == Decimal("103.0") assert prices[2].volume == 1200000 def test_store_prices_upsert(price_loader, test_db_session, sample_security): """Test that storing prices twice performs upsert (update on conflict).""" df1 = pd.DataFrame( { "date": [date(2024, 1, 1)], "open": [100.0], "high": [105.0], "low": [99.0], "close": [103.0], "volume": [1000000], } ) count1 = price_loader._store_prices(sample_security.id, df1) assert count1 == 1 # Store again with updated values df2 = pd.DataFrame( { "date": [date(2024, 1, 1)], "open": [100.5], "high": [106.0], "low": [99.5], "close": [104.0], "volume": [1100000], } ) count2 = price_loader._store_prices(sample_security.id, df2) assert count2 == 1 # Verify only one price exists with updated values stmt = select(Price).where(Price.security_id == sample_security.id) prices = test_db_session.scalars(stmt).all() assert len(prices) == 1 assert prices[0].close == Decimal("104.0") assert prices[0].volume == 1100000 def test_get_missing_date_range_start_no_data(price_loader, test_db_session, sample_security): """Test finding missing date range when no data exists.""" start = date(2024, 1, 1) end = date(2024, 1, 31) missing_start = price_loader._get_missing_date_range_start(sample_security.id, start, end) assert missing_start == start def test_get_missing_date_range_start_partial_data(price_loader, test_db_session, sample_security): """Test finding missing date range when partial data exists.""" # Add prices for first week of January df = pd.DataFrame( { "date": [date(2024, 1, d) for d in range(1, 8)], "close": [100.0 + d for d in range(7)], } ) price_loader._store_prices(sample_security.id, df) start = date(2024, 1, 1) end = date(2024, 1, 31) missing_start = price_loader._get_missing_date_range_start(sample_security.id, start, end) # Should start from day after last existing (Jan 8) assert missing_start == date(2024, 1, 8) @patch("pote.ingestion.prices.yf.Ticker") def test_fetch_and_store_prices_integration(mock_ticker, price_loader, test_db_session): """Test the full fetch_and_store_prices flow with mocked yfinance.""" # Mock yfinance response mock_hist_df = pd.DataFrame( { "Date": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]), "Open": [100.0, 101.0, 102.0], "High": [105.0, 106.0, 107.0], "Low": [99.0, 100.0, 101.0], "Close": [103.0, 104.0, 105.0], "Volume": [1000000, 1100000, 1200000], } ).set_index("Date") mock_ticker_instance = MagicMock() mock_ticker_instance.history.return_value = mock_hist_df mock_ticker.return_value = mock_ticker_instance # Fetch and store count = price_loader.fetch_and_store_prices( "TSLA", start_date=date(2024, 1, 1), end_date=date(2024, 1, 3), ) assert count == 3 # Verify security was created stmt = select(Security).where(Security.ticker == "TSLA") security = test_db_session.scalars(stmt).first() assert security is not None # Verify prices were stored stmt = select(Price).where(Price.security_id == security.id).order_by(Price.date) prices = test_db_session.scalars(stmt).all() assert len(prices) == 3 assert prices[0].close == Decimal("103.0") assert prices[2].close == Decimal("105.0") @patch("pote.ingestion.prices.yf.Ticker") def test_fetch_and_store_prices_idempotent(mock_ticker, price_loader, test_db_session): """Test that re-fetching doesn't duplicate data.""" mock_hist_df = pd.DataFrame( { "Date": pd.to_datetime(["2024-01-01"]), "Open": [100.0], "High": [105.0], "Low": [99.0], "Close": [103.0], "Volume": [1000000], } ).set_index("Date") mock_ticker_instance = MagicMock() mock_ticker_instance.history.return_value = mock_hist_df mock_ticker.return_value = mock_ticker_instance # Fetch twice count1 = price_loader.fetch_and_store_prices("TSLA", date(2024, 1, 1), date(2024, 1, 1)) count2 = price_loader.fetch_and_store_prices("TSLA", date(2024, 1, 1), date(2024, 1, 1)) # First call should insert, second should skip (no missing dates) assert count1 == 1 assert count2 == 0 # No missing data # Verify only one price record exists stmt = select(Security).where(Security.ticker == "TSLA") security = test_db_session.scalars(stmt).first() stmt = select(Price).where(Price.security_id == security.id) prices = test_db_session.scalars(stmt).all() assert len(prices) == 1