- PR1: Project scaffold, DB models, price loader - PR2: Congressional trade ingestion (House Stock Watcher) - PR3: Security enrichment + deployment infrastructure - 37 passing tests, 87%+ coverage - Docker + Proxmox deployment ready - Complete documentation - Works 100% offline with fixtures
223 lines
7.0 KiB
Python
223 lines
7.0 KiB
Python
"""
|
|
Tests for price loader.
|
|
"""
|
|
|
|
from datetime import date
|
|
from decimal import Decimal
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from pote.db.models import Price, Security
|
|
from pote.ingestion.prices import PriceLoader
|
|
|
|
|
|
@pytest.fixture
|
|
def price_loader(test_db_session):
|
|
"""Create a PriceLoader instance with test session."""
|
|
return PriceLoader(test_db_session)
|
|
|
|
|
|
def test_get_or_create_security_new(price_loader, test_db_session):
|
|
"""Test creating a new security."""
|
|
security = price_loader._get_or_create_security("MSFT")
|
|
|
|
assert security.id is not None
|
|
assert security.ticker == "MSFT"
|
|
assert security.asset_type == "stock"
|
|
|
|
# Verify it's in the database
|
|
stmt = select(Security).where(Security.ticker == "MSFT")
|
|
db_security = test_db_session.scalars(stmt).first()
|
|
assert db_security is not None
|
|
assert db_security.id == security.id
|
|
|
|
|
|
def test_get_or_create_security_existing(price_loader, test_db_session, sample_security):
|
|
"""Test getting an existing security."""
|
|
security = price_loader._get_or_create_security("AAPL")
|
|
|
|
assert security.id == sample_security.id
|
|
assert security.ticker == "AAPL"
|
|
|
|
# Verify no duplicate was created
|
|
stmt = select(Security).where(Security.ticker == "AAPL")
|
|
count = len(test_db_session.scalars(stmt).all())
|
|
assert count == 1
|
|
|
|
|
|
def test_store_prices(price_loader, test_db_session, sample_security):
|
|
"""Test storing price data."""
|
|
df = pd.DataFrame(
|
|
{
|
|
"date": [date(2024, 1, 1), date(2024, 1, 2), date(2024, 1, 3)],
|
|
"open": [100.0, 101.0, 102.0],
|
|
"high": [105.0, 106.0, 107.0],
|
|
"low": [99.0, 100.0, 101.0],
|
|
"close": [103.0, 104.0, 105.0],
|
|
"volume": [1000000, 1100000, 1200000],
|
|
}
|
|
)
|
|
|
|
count = price_loader._store_prices(sample_security.id, df)
|
|
|
|
assert count == 3
|
|
|
|
# Verify prices in database
|
|
stmt = select(Price).where(Price.security_id == sample_security.id).order_by(Price.date)
|
|
prices = test_db_session.scalars(stmt).all()
|
|
|
|
assert len(prices) == 3
|
|
assert prices[0].date == date(2024, 1, 1)
|
|
assert prices[0].close == Decimal("103.0")
|
|
assert prices[2].volume == 1200000
|
|
|
|
|
|
def test_store_prices_upsert(price_loader, test_db_session, sample_security):
|
|
"""Test that storing prices twice performs upsert (update on conflict)."""
|
|
df1 = pd.DataFrame(
|
|
{
|
|
"date": [date(2024, 1, 1)],
|
|
"open": [100.0],
|
|
"high": [105.0],
|
|
"low": [99.0],
|
|
"close": [103.0],
|
|
"volume": [1000000],
|
|
}
|
|
)
|
|
|
|
count1 = price_loader._store_prices(sample_security.id, df1)
|
|
assert count1 == 1
|
|
|
|
# Store again with updated values
|
|
df2 = pd.DataFrame(
|
|
{
|
|
"date": [date(2024, 1, 1)],
|
|
"open": [100.5],
|
|
"high": [106.0],
|
|
"low": [99.5],
|
|
"close": [104.0],
|
|
"volume": [1100000],
|
|
}
|
|
)
|
|
|
|
count2 = price_loader._store_prices(sample_security.id, df2)
|
|
assert count2 == 1
|
|
|
|
# Verify only one price exists with updated values
|
|
stmt = select(Price).where(Price.security_id == sample_security.id)
|
|
prices = test_db_session.scalars(stmt).all()
|
|
|
|
assert len(prices) == 1
|
|
assert prices[0].close == Decimal("104.0")
|
|
assert prices[0].volume == 1100000
|
|
|
|
|
|
def test_get_missing_date_range_start_no_data(price_loader, test_db_session, sample_security):
|
|
"""Test finding missing date range when no data exists."""
|
|
start = date(2024, 1, 1)
|
|
end = date(2024, 1, 31)
|
|
|
|
missing_start = price_loader._get_missing_date_range_start(sample_security.id, start, end)
|
|
|
|
assert missing_start == start
|
|
|
|
|
|
def test_get_missing_date_range_start_partial_data(price_loader, test_db_session, sample_security):
|
|
"""Test finding missing date range when partial data exists."""
|
|
# Add prices for first week of January
|
|
df = pd.DataFrame(
|
|
{
|
|
"date": [date(2024, 1, d) for d in range(1, 8)],
|
|
"close": [100.0 + d for d in range(7)],
|
|
}
|
|
)
|
|
price_loader._store_prices(sample_security.id, df)
|
|
|
|
start = date(2024, 1, 1)
|
|
end = date(2024, 1, 31)
|
|
|
|
missing_start = price_loader._get_missing_date_range_start(sample_security.id, start, end)
|
|
|
|
# Should start from day after last existing (Jan 8)
|
|
assert missing_start == date(2024, 1, 8)
|
|
|
|
|
|
@patch("pote.ingestion.prices.yf.Ticker")
|
|
def test_fetch_and_store_prices_integration(mock_ticker, price_loader, test_db_session):
|
|
"""Test the full fetch_and_store_prices flow with mocked yfinance."""
|
|
# Mock yfinance response
|
|
mock_hist_df = pd.DataFrame(
|
|
{
|
|
"Date": pd.to_datetime(["2024-01-01", "2024-01-02", "2024-01-03"]),
|
|
"Open": [100.0, 101.0, 102.0],
|
|
"High": [105.0, 106.0, 107.0],
|
|
"Low": [99.0, 100.0, 101.0],
|
|
"Close": [103.0, 104.0, 105.0],
|
|
"Volume": [1000000, 1100000, 1200000],
|
|
}
|
|
).set_index("Date")
|
|
|
|
mock_ticker_instance = MagicMock()
|
|
mock_ticker_instance.history.return_value = mock_hist_df
|
|
mock_ticker.return_value = mock_ticker_instance
|
|
|
|
# Fetch and store
|
|
count = price_loader.fetch_and_store_prices(
|
|
"TSLA",
|
|
start_date=date(2024, 1, 1),
|
|
end_date=date(2024, 1, 3),
|
|
)
|
|
|
|
assert count == 3
|
|
|
|
# Verify security was created
|
|
stmt = select(Security).where(Security.ticker == "TSLA")
|
|
security = test_db_session.scalars(stmt).first()
|
|
assert security is not None
|
|
|
|
# Verify prices were stored
|
|
stmt = select(Price).where(Price.security_id == security.id).order_by(Price.date)
|
|
prices = test_db_session.scalars(stmt).all()
|
|
|
|
assert len(prices) == 3
|
|
assert prices[0].close == Decimal("103.0")
|
|
assert prices[2].close == Decimal("105.0")
|
|
|
|
|
|
@patch("pote.ingestion.prices.yf.Ticker")
|
|
def test_fetch_and_store_prices_idempotent(mock_ticker, price_loader, test_db_session):
|
|
"""Test that re-fetching doesn't duplicate data."""
|
|
mock_hist_df = pd.DataFrame(
|
|
{
|
|
"Date": pd.to_datetime(["2024-01-01"]),
|
|
"Open": [100.0],
|
|
"High": [105.0],
|
|
"Low": [99.0],
|
|
"Close": [103.0],
|
|
"Volume": [1000000],
|
|
}
|
|
).set_index("Date")
|
|
|
|
mock_ticker_instance = MagicMock()
|
|
mock_ticker_instance.history.return_value = mock_hist_df
|
|
mock_ticker.return_value = mock_ticker_instance
|
|
|
|
# Fetch twice
|
|
count1 = price_loader.fetch_and_store_prices("TSLA", date(2024, 1, 1), date(2024, 1, 1))
|
|
count2 = price_loader.fetch_and_store_prices("TSLA", date(2024, 1, 1), date(2024, 1, 1))
|
|
|
|
# First call should insert, second should skip (no missing dates)
|
|
assert count1 == 1
|
|
assert count2 == 0 # No missing data
|
|
|
|
# Verify only one price record exists
|
|
stmt = select(Security).where(Security.ticker == "TSLA")
|
|
security = test_db_session.scalars(stmt).first()
|
|
|
|
stmt = select(Price).where(Price.security_id == security.id)
|
|
prices = test_db_session.scalars(stmt).all()
|
|
assert len(prices) == 1
|