POTE/tests/test_price_loader.py
ilia 204cd0e75b Initial commit: POTE Phase 1 complete
- PR1: Project scaffold, DB models, price loader
- PR2: Congressional trade ingestion (House Stock Watcher)
- PR3: Security enrichment + deployment infrastructure
- 37 passing tests, 87%+ coverage
- Docker + Proxmox deployment ready
- Complete documentation
- Works 100% offline with fixtures
2025-12-14 20:45:34 -05:00

223 lines
7.0 KiB
Python

"""
Tests for price loader.
"""
from datetime import date
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from sqlalchemy import select
from pote.db.models import Price, Security
from pote.ingestion.prices import PriceLoader
@pytest.fixture
def price_loader(test_db_session):
"""Create a PriceLoader instance with test session."""
return PriceLoader(test_db_session)
def test_get_or_create_security_new(price_loader, test_db_session):
"""Test creating a new security."""
security = price_loader._get_or_create_security("MSFT")
assert security.id is not None
assert security.ticker == "MSFT"
assert security.asset_type == "stock"
# Verify it's in the database
stmt = select(Security).where(Security.ticker == "MSFT")
db_security = test_db_session.scalars(stmt).first()
assert db_security is not None
assert db_security.id == security.id
def test_get_or_create_security_existing(price_loader, test_db_session, sample_security):
"""Test getting an existing security."""
security = price_loader._get_or_create_security("AAPL")
assert security.id == sample_security.id
assert security.ticker == "AAPL"
# Verify no duplicate was created
stmt = select(Security).where(Security.ticker == "AAPL")
count = len(test_db_session.scalars(stmt).all())
assert count == 1
def test_store_prices(price_loader, test_db_session, sample_security):
"""Test storing price data."""
df = pd.DataFrame(
{
"date": [date(2024, 1, 1), date(2024, 1, 2), date(2024, 1, 3)],
"open": [100.0, 101.0, 102.0],
"high": [105.0, 106.0, 107.0],
"low": [99.0, 100.0, 101.0],
"close": [103.0, 104.0, 105.0],
"volume": [1000000, 1100000, 1200000],
}
)
count = price_loader._store_prices(sample_security.id, df)
assert count == 3
# Verify prices in database
stmt = select(Price).where(Price.security_id == sample_security.id).order_by(Price.date)
prices = test_db_session.scalars(stmt).all()
assert len(prices) == 3
assert prices[0].date == date(2024, 1, 1)
assert prices[0].close == Decimal("103.0")
assert prices[2].volume == 1200000
def test_store_prices_upsert(price_loader, test_db_session, sample_security):
"""Test that storing prices twice performs upsert (update on conflict)."""
df1 = pd.DataFrame(
{
"date": [date(2024, 1, 1)],
"open": [100.0],
"high": [105.0],
"low": [99.0],
"close": [103.0],
"volume": [1000000],
}
)
count1 = price_loader._store_prices(sample_security.id, df1)
assert count1 == 1
# Store again with updated values
df2 = pd.DataFrame(
{
"date": [date(2024, 1, 1)],
"open": [100.5],
"high": [106.0],
"low": [99.5],
"close": [104.0],
"volume": [1100000],
}
)
count2 = price_loader._store_prices(sample_security.id, df2)
assert count2 == 1
# Verify only one price exists with updated values
stmt = select(Price).where(Price.security_id == sample_security.id)
prices = test_db_session.scalars(stmt).all()
assert len(prices) == 1
assert prices[0].close == Decimal("104.0")
assert prices[0].volume == 1100000
def test_get_missing_date_range_start_no_data(price_loader, test_db_session, sample_security):
"""Test finding missing date range when no data exists."""
start = date(2024, 1, 1)
end = date(2024, 1, 31)
missing_start = price_loader._get_missing_date_range_start(sample_security.id, start, end)
assert missing_start == start
def test_get_missing_date_range_start_partial_data(price_loader, test_db_session, sample_security):
"""Test finding missing date range when partial data exists."""
# Add prices for first week of January
df = pd.DataFrame(
{
"date": [date(2024, 1, d) for d in range(1, 8)],
"close": [100.0 + d for d in range(7)],
}
)
price_loader._store_prices(sample_security.id, df)
start = date(2024, 1, 1)
end = date(2024, 1, 31)
missing_start = price_loader._get_missing_date_range_start(sample_security.id, start, end)
# Should start from day after last existing (Jan 8)
assert missing_start == date(2024, 1, 8)
@patch("pote.ingestion.prices.yf.Ticker")
def test_fetch_and_store_prices_integration(mock_ticker, price_loader, test_db_session):
"""Test the full fetch_and_store_prices flow with mocked yfinance."""
# Mock yfinance response
mock_hist_df = pd.DataFrame(
{
"Date": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
"Open": [100.0, 101.0, 102.0],
"High": [105.0, 106.0, 107.0],
"Low": [99.0, 100.0, 101.0],
"Close": [103.0, 104.0, 105.0],
"Volume": [1000000, 1100000, 1200000],
}
).set_index("Date")
mock_ticker_instance = MagicMock()
mock_ticker_instance.history.return_value = mock_hist_df
mock_ticker.return_value = mock_ticker_instance
# Fetch and store
count = price_loader.fetch_and_store_prices(
"TSLA",
start_date=date(2024, 1, 1),
end_date=date(2024, 1, 3),
)
assert count == 3
# Verify security was created
stmt = select(Security).where(Security.ticker == "TSLA")
security = test_db_session.scalars(stmt).first()
assert security is not None
# Verify prices were stored
stmt = select(Price).where(Price.security_id == security.id).order_by(Price.date)
prices = test_db_session.scalars(stmt).all()
assert len(prices) == 3
assert prices[0].close == Decimal("103.0")
assert prices[2].close == Decimal("105.0")
@patch("pote.ingestion.prices.yf.Ticker")
def test_fetch_and_store_prices_idempotent(mock_ticker, price_loader, test_db_session):
"""Test that re-fetching doesn't duplicate data."""
mock_hist_df = pd.DataFrame(
{
"Date": pd.to_datetime(["2024-01-01"]),
"Open": [100.0],
"High": [105.0],
"Low": [99.0],
"Close": [103.0],
"Volume": [1000000],
}
).set_index("Date")
mock_ticker_instance = MagicMock()
mock_ticker_instance.history.return_value = mock_hist_df
mock_ticker.return_value = mock_ticker_instance
# Fetch twice
count1 = price_loader.fetch_and_store_prices("TSLA", date(2024, 1, 1), date(2024, 1, 1))
count2 = price_loader.fetch_and_store_prices("TSLA", date(2024, 1, 1), date(2024, 1, 1))
# First call should insert, second should skip (no missing dates)
assert count1 == 1
assert count2 == 0 # No missing data
# Verify only one price record exists
stmt = select(Security).where(Security.ticker == "TSLA")
security = test_db_session.scalars(stmt).first()
stmt = select(Price).where(Price.security_id == security.id)
prices = test_db_session.scalars(stmt).all()
assert len(prices) == 1