diff --git a/.gitignore b/.gitignore index 1940c42..3c94338 100644 --- a/.gitignore +++ b/.gitignore @@ -162,7 +162,7 @@ cython_debug/ *.db *.sqlite *.sqlite3 -data/ +/data/ logs/ checkpoints/ models/ diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..4d17877 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,29 @@ +""" +Data layer: market data providers, models, and the local SQLite cache. +""" + +from .cache import CacheError, CacheManager +from .models import OptionChainData, PriceData +from .providers import ( + DataNotAvailable, + DataProviderError, + InvalidSymbol, + MarketDataRequest, + RateLimiter, + YFinanceProvider, + get_default_provider, +) + +__all__ = [ + "PriceData", + "OptionChainData", + "YFinanceProvider", + "RateLimiter", + "MarketDataRequest", + "DataProviderError", + "InvalidSymbol", + "DataNotAvailable", + "get_default_provider", + "CacheManager", + "CacheError", +] diff --git a/src/data/cache.py b/src/data/cache.py new file mode 100644 index 0000000..bdd2d37 --- /dev/null +++ b/src/data/cache.py @@ -0,0 +1,320 @@ +""" +SQLite-backed cache for market data. + +Wraps a data provider, persisting price history and option chains in a local +SQLite database via SQLAlchemy and serving cached results within a TTL. +""" + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, Optional + +import pandas as pd +from sqlalchemy import create_engine, func +from sqlalchemy.orm import sessionmaker + +from .models import OptionChainData, PriceData +from .schema import Base, CacheMetadata, OptionsData, PriceHistory + +_DEFAULT_PRICE_TTL = 24 * 60 * 60 # 24 hours +_DEFAULT_OPTIONS_TTL = 15 * 60 # 15 minutes + + +class CacheError(Exception): + """Raised when a cache operation fails.""" + + +class CacheManager: + """Cache market data in SQLite, refreshing from a provider on miss/expiry.""" + + def __init__(self, database_path=None, provider=None): + if database_path is None: + database_path = self._default_db_path() + self.database_path = Path(database_path) + self.database_path.parent.mkdir(parents=True, exist_ok=True) + + if provider is None: + from .providers import get_default_provider + provider = get_default_provider() + self.provider = provider + + self._engine = create_engine(f"sqlite:///{self.database_path}") + Base.metadata.create_all(self._engine) + self.Session = sessionmaker(bind=self._engine) + + price_ttl, options_ttl = _DEFAULT_PRICE_TTL, _DEFAULT_OPTIONS_TTL + try: + from ..config import config + price_ttl = config.cache.price_data_ttl + options_ttl = config.cache.options_data_ttl + except Exception: + pass + self._ttl_cache = { + "price_data_ttl_seconds": price_ttl, + "options_data_ttl_seconds": options_ttl, + } + + # In-memory hit/miss counters keyed by cache key. + self._stats: Dict[str, Dict[str, int]] = {} + + @staticmethod + def _default_db_path() -> Path: + try: + from ..config import config + return Path(config.cache.cache_db_path) + except Exception: + return Path("data/cache.db") + + # ------------------------------------------------------------------ price + + def get_price_history(self, symbol: str, start: Optional[datetime] = None, + end: Optional[datetime] = None, + interval: str = "1d") -> PriceData: + """Return cached price history when fresh, else refresh from provider.""" + key = f"price:{symbol}:{interval}" + + if self._price_is_fresh(symbol, interval): + self._record(key, "hit") + return self._load_price_data(symbol, interval) + + data = self.provider.get_price_history( + symbol, start=start, end=end, interval=interval + ) + + session = self.Session() + try: + self._delete_price_rows(session, symbol, interval) + self._store_price_data(session, data) + session.commit() + except Exception as exc: + session.rollback() + raise CacheError(f"Failed to cache price data for {symbol}: {exc}") from exc + finally: + session.close() + + self._record(key, "miss") + return data + + def _price_is_fresh(self, symbol: str, interval: str) -> bool: + session = self.Session() + try: + latest = session.query(func.max(PriceHistory.updated_at)).filter( + PriceHistory.symbol == symbol, + PriceHistory.interval == interval, + ).scalar() + finally: + session.close() + + if latest is None: + return False + ttl = self._ttl_cache.get("price_data_ttl_seconds", _DEFAULT_PRICE_TTL) + return (datetime.utcnow() - latest).total_seconds() < ttl + + def _load_price_data(self, symbol: str, interval: str) -> PriceData: + session = self.Session() + try: + rows = session.query(PriceHistory).filter( + PriceHistory.symbol == symbol, + PriceHistory.interval == interval, + ).order_by(PriceHistory.date).all() + finally: + session.close() + + frame = pd.DataFrame( + { + "Open": [r.open for r in rows], + "High": [r.high for r in rows], + "Low": [r.low for r in rows], + "Close": [r.close for r in rows], + "Volume": [r.volume for r in rows], + }, + index=pd.DatetimeIndex([r.date for r in rows], name="Date"), + ) + return PriceData( + symbol=symbol, + data=frame, + start_date=rows[0].date if rows else None, + end_date=rows[-1].date if rows else None, + interval=interval, + source="cache", + ) + + def _store_price_data(self, session, price_data: PriceData) -> None: + now = datetime.utcnow() + for idx, row in price_data.data.iterrows(): + timestamp = idx.to_pydatetime() if hasattr(idx, "to_pydatetime") else idx + session.add(PriceHistory( + symbol=price_data.symbol, + date=timestamp, + open=float(row["Open"]), + high=float(row["High"]), + low=float(row["Low"]), + close=float(row["Close"]), + volume=float(row["Volume"]), + interval=price_data.interval, + source=price_data.source, + updated_at=now, + )) + + def _delete_price_rows(self, session, symbol: str, interval: str) -> None: + session.query(PriceHistory).filter( + PriceHistory.symbol == symbol, + PriceHistory.interval == interval, + ).delete(synchronize_session=False) + + # ---------------------------------------------------------------- options + + def get_options_chain(self, symbol: str, + expiration: Optional[datetime] = None) -> OptionChainData: + """Return cached option chain when fresh, else refresh from provider.""" + key = f"options:{symbol}:{expiration}" + + if self._options_is_fresh(symbol, expiration): + self._record(key, "hit") + return self._load_options_data(symbol, expiration) + + data = self.provider.get_options_chain(symbol, expiration) + + session = self.Session() + try: + self._delete_options_rows(session, symbol, expiration) + self._store_options_data(session, data) + session.commit() + except Exception as exc: + session.rollback() + raise CacheError(f"Failed to cache options for {symbol}: {exc}") from exc + finally: + session.close() + + self._record(key, "miss") + return data + + def _options_is_fresh(self, symbol: str, + expiration: Optional[datetime]) -> bool: + session = self.Session() + try: + query = session.query(func.max(OptionsData.updated_at)).filter( + OptionsData.symbol == symbol, + ) + if expiration is not None: + query = query.filter(OptionsData.expiration == expiration) + latest = query.scalar() + finally: + session.close() + + if latest is None: + return False + ttl = self._ttl_cache.get("options_data_ttl_seconds", _DEFAULT_OPTIONS_TTL) + return (datetime.utcnow() - latest).total_seconds() < ttl + + def _load_options_data(self, symbol: str, + expiration: Optional[datetime]) -> OptionChainData: + session = self.Session() + try: + query = session.query(OptionsData).filter(OptionsData.symbol == symbol) + if expiration is not None: + query = query.filter(OptionsData.expiration == expiration) + rows = query.all() + finally: + session.close() + + calls = [r for r in rows if r.option_type == "call"] + puts = [r for r in rows if r.option_type == "put"] + underlying = rows[0].underlying_price if rows else None + + return OptionChainData( + symbol=symbol, + expiration=expiration, + calls=self._options_frame(calls), + puts=self._options_frame(puts), + underlying_price=underlying, + source="cache", + ) + + @staticmethod + def _options_frame(rows) -> pd.DataFrame: + return pd.DataFrame({ + "strike": [r.strike for r in rows], + "bid": [r.bid for r in rows], + "ask": [r.ask for r in rows], + "volume": [r.volume for r in rows], + "openInterest": [r.open_interest for r in rows], + "impliedVolatility": [r.implied_volatility for r in rows], + }) + + def _store_options_data(self, session, options_data: OptionChainData) -> None: + now = datetime.utcnow() + frames = (("call", options_data.calls), ("put", options_data.puts)) + for option_type, frame in frames: + for _, row in frame.iterrows(): + session.add(OptionsData( + symbol=options_data.symbol, + expiration=options_data.expiration, + option_type=option_type, + strike=float(row["strike"]), + bid=self._opt_float(row, "bid"), + ask=self._opt_float(row, "ask"), + volume=self._opt_float(row, "volume"), + open_interest=self._opt_float(row, "openInterest"), + implied_volatility=self._opt_float(row, "impliedVolatility"), + underlying_price=options_data.underlying_price, + source=options_data.source, + updated_at=now, + )) + + @staticmethod + def _opt_float(row, column: str) -> Optional[float]: + if column in row and pd.notna(row[column]): + return float(row[column]) + return None + + def _delete_options_rows(self, session, symbol: str, + expiration: Optional[datetime]) -> None: + query = session.query(OptionsData).filter(OptionsData.symbol == symbol) + if expiration is not None: + query = query.filter(OptionsData.expiration == expiration) + query.delete(synchronize_session=False) + + # ------------------------------------------------------------ maintenance + + def get_cache_statistics(self) -> dict: + """Return database row counts and per-key hit/miss counters.""" + session = self.Session() + try: + price_records = session.query(func.count(PriceHistory.id)).scalar() or 0 + options_records = session.query(func.count(OptionsData.id)).scalar() or 0 + finally: + session.close() + + stats: dict = { + "database": { + "price_records": price_records, + "options_records": options_records, + }, + } + for key, counters in self._stats.items(): + stats[key] = dict(counters) + return stats + + def cleanup_old_data(self, max_age_days: int = 30) -> int: + """Delete cached rows older than max_age_days. Returns rows removed.""" + cutoff = datetime.utcnow() - timedelta(days=max_age_days) + session = self.Session() + try: + deleted = session.query(PriceHistory).filter( + PriceHistory.updated_at < cutoff + ).delete(synchronize_session=False) + deleted += session.query(OptionsData).filter( + OptionsData.updated_at < cutoff + ).delete(synchronize_session=False) + session.commit() + except Exception as exc: + session.rollback() + raise CacheError(f"Failed to clean up cache: {exc}") from exc + finally: + session.close() + return deleted + + def _record(self, key: str, outcome: str) -> None: + counters = self._stats.setdefault(key, {"hits": 0, "misses": 0}) + counters["hits" if outcome == "hit" else "misses"] += 1 diff --git a/src/data/models.py b/src/data/models.py new file mode 100644 index 0000000..9ce5906 --- /dev/null +++ b/src/data/models.py @@ -0,0 +1,62 @@ +""" +Data models for market data. + +Lightweight containers for the price history and option-chain snapshots +returned by data providers and the cache layer. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +import pandas as pd + + +@dataclass +class PriceData: + """OHLCV price history for a single symbol.""" + + symbol: str + data: pd.DataFrame + start_date: Optional[datetime] = None + end_date: Optional[datetime] = None + interval: str = "1d" + source: str = "yfinance" + + @property + def is_valid(self) -> bool: + """True when the frame holds at least one row with a Close column.""" + return ( + self.data is not None + and not self.data.empty + and "Close" in self.data.columns + ) + + @property + def latest_price(self) -> Optional[float]: + """Most recent close price, or None when no data is present.""" + if not self.is_valid: + return None + return float(self.data["Close"].iloc[-1]) + + +@dataclass +class OptionChainData: + """Calls and puts for a single symbol and expiration.""" + + symbol: str + expiration: Optional[datetime] + calls: pd.DataFrame + puts: pd.DataFrame + underlying_price: Optional[float] = None + source: str = "yfinance" + + @property + def is_valid(self) -> bool: + """True when both the call and put frames hold at least one row.""" + return ( + self.calls is not None + and self.puts is not None + and not self.calls.empty + and not self.puts.empty + ) diff --git a/src/data/providers.py b/src/data/providers.py new file mode 100644 index 0000000..099cd61 --- /dev/null +++ b/src/data/providers.py @@ -0,0 +1,189 @@ +""" +Market data providers. + +A yfinance-backed implementation plus the provider error hierarchy, a sliding +window rate limiter, and a request descriptor used across the data layer. +""" + +import re +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional + +import yfinance as yf + +from .models import OptionChainData, PriceData + + +class DataProviderError(Exception): + """Base error for all data provider failures.""" + + +class InvalidSymbol(DataProviderError): + """Raised when a symbol is malformed or empty.""" + + +class DataNotAvailable(DataProviderError): + """Raised when a provider returns no data for an otherwise valid request.""" + + +# Tickers are letters/digits with optional dot or hyphen (e.g. BRK.B, RDS-A). +_SYMBOL_PATTERN = re.compile(r"^[A-Z0-9.\-]{1,12}$") + + +@dataclass +class MarketDataRequest: + """Description of a price-history request.""" + + symbol: str + start: Optional[datetime] = None + end: Optional[datetime] = None + interval: str = "1d" + + +class RateLimiter: + """Sliding-window rate limiter capping requests per minute.""" + + def __init__(self, max_requests_per_minute: int = 30): + self.max_requests_per_minute = max_requests_per_minute + self.requests: List[float] = [] + self._lock = threading.Lock() + + def wait_if_needed(self) -> None: + """Block until issuing another request stays within the limit.""" + with self._lock: + now = time.monotonic() + self.requests = [t for t in self.requests if t > now - 60.0] + + if len(self.requests) >= self.max_requests_per_minute: + sleep_for = self.requests[0] + 60.0 - now + if sleep_for > 0: + time.sleep(sleep_for) + now = time.monotonic() + self.requests = [t for t in self.requests if t > now - 60.0] + + self.requests.append(now) + + +def _parse_expiration(value: str) -> datetime: + """Parse a yfinance expiration string (YYYY-MM-DD) to a datetime.""" + return datetime.strptime(value, "%Y-%m-%d") + + +class YFinanceProvider: + """Market data provider backed by yfinance.""" + + def __init__(self, rate_limiter: Optional[RateLimiter] = None, + requests_per_minute: int = 30): + self.rate_limiter = rate_limiter or RateLimiter(requests_per_minute) + + def _validate_symbol(self, symbol: Optional[str]) -> str: + """Normalize and validate a ticker symbol. + + Returns the upper-cased, stripped symbol. Raises InvalidSymbol for + empty, non-string, or malformed input. + """ + if not symbol or not isinstance(symbol, str): + raise InvalidSymbol(f"Invalid symbol: {symbol!r}") + cleaned = symbol.strip().upper() + if not _SYMBOL_PATTERN.match(cleaned): + raise InvalidSymbol(f"Invalid symbol: {symbol!r}") + return cleaned + + def get_price_history(self, symbol: str, start: Optional[datetime] = None, + end: Optional[datetime] = None, + interval: str = "1d") -> PriceData: + """Fetch OHLCV history. Raises DataNotAvailable when empty.""" + symbol = self._validate_symbol(symbol) + self.rate_limiter.wait_if_needed() + + ticker = yf.Ticker(symbol) + data = ticker.history(start=start, end=end, interval=interval) + + if data is None or data.empty: + raise DataNotAvailable(f"No price data available for {symbol}") + + return PriceData( + symbol=symbol, + data=data, + start_date=start, + end_date=end, + interval=interval, + source="yfinance", + ) + + def get_current_price(self, symbol: str) -> float: + """Return the latest price, falling back to recent history.""" + symbol = self._validate_symbol(symbol) + self.rate_limiter.wait_if_needed() + + ticker = yf.Ticker(symbol) + last_price = getattr(ticker.fast_info, "last_price", None) + if last_price is not None: + return float(last_price) + + history = ticker.history(period="1d") + if history is None or history.empty: + raise DataNotAvailable(f"No current price available for {symbol}") + return float(history["Close"].iloc[-1]) + + def batch_get_prices(self, symbols: List[str]) -> Dict[str, PriceData]: + """Fetch price history for several symbols, skipping failures.""" + results: Dict[str, PriceData] = {} + for symbol in symbols: + try: + results[symbol] = self.get_price_history(symbol) + except DataProviderError: + continue + return results + + def get_options_chain(self, symbol: str, + expiration: Optional[datetime] = None) -> OptionChainData: + """Fetch the option chain for a symbol and expiration.""" + symbol = self._validate_symbol(symbol) + self.rate_limiter.wait_if_needed() + + ticker = yf.Ticker(symbol) + available = ticker.options + if not available: + raise DataNotAvailable(f"No options available for {symbol}") + + if expiration is None: + expiration = _parse_expiration(available[0]) + expiration_str = available[0] + elif isinstance(expiration, datetime): + expiration_str = expiration.strftime("%Y-%m-%d") + else: + expiration_str = str(expiration) + expiration = _parse_expiration(expiration_str) + + chain = ticker.option_chain(expiration_str) + underlying = getattr(ticker.fast_info, "last_price", None) + + return OptionChainData( + symbol=symbol, + expiration=expiration, + calls=chain.calls, + puts=chain.puts, + underlying_price=float(underlying) if underlying is not None else None, + source="yfinance", + ) + + +_default_provider: Optional[YFinanceProvider] = None + + +def get_default_provider() -> YFinanceProvider: + """Return a process-wide default YFinanceProvider instance.""" + global _default_provider + if _default_provider is None: + requests_per_minute = 30 + try: + from ..config import config + requests_per_minute = config.data.yfinance_requests_per_minute + except Exception: + pass + _default_provider = YFinanceProvider(requests_per_minute=requests_per_minute) + return _default_provider diff --git a/src/data/schema.py b/src/data/schema.py new file mode 100644 index 0000000..f9a93a5 --- /dev/null +++ b/src/data/schema.py @@ -0,0 +1,77 @@ +""" +SQLAlchemy ORM schema for the local market-data cache. + +Three tables back the cache: persisted price history, persisted option-chain +rows, and per-key cache metadata used for statistics. +""" + +from datetime import datetime + +from sqlalchemy import ( + Column, + DateTime, + Float, + Integer, + String, + UniqueConstraint, +) +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class PriceHistory(Base): + """One OHLCV bar for a symbol/interval.""" + + __tablename__ = "price_history" + + id = Column(Integer, primary_key=True, autoincrement=True) + symbol = Column(String(16), nullable=False, index=True) + date = Column(DateTime, nullable=False) + open = Column(Float, nullable=False) + high = Column(Float, nullable=False) + low = Column(Float, nullable=False) + close = Column(Float, nullable=False) + volume = Column(Float, nullable=False, default=0.0) + interval = Column(String(8), nullable=False, default="1d") + source = Column(String(32), nullable=False, default="yfinance") + updated_at = Column(DateTime, nullable=False, default=datetime.utcnow) + + __table_args__ = ( + UniqueConstraint("symbol", "date", "interval", + name="uq_price_symbol_date_interval"), + ) + + +class OptionsData(Base): + """One option contract row (call or put) for a symbol/expiration.""" + + __tablename__ = "options_data" + + id = Column(Integer, primary_key=True, autoincrement=True) + symbol = Column(String(16), nullable=False, index=True) + expiration = Column(DateTime, nullable=True) + option_type = Column(String(4), nullable=False) # 'call' or 'put' + strike = Column(Float, nullable=False) + bid = Column(Float) + ask = Column(Float) + volume = Column(Float) + open_interest = Column(Float) + implied_volatility = Column(Float) + underlying_price = Column(Float) + source = Column(String(32), nullable=False, default="yfinance") + updated_at = Column(DateTime, nullable=False, default=datetime.utcnow) + + +class CacheMetadata(Base): + """Per-key cache bookkeeping (last refresh, hit/miss counters).""" + + __tablename__ = "cache_metadata" + + id = Column(Integer, primary_key=True, autoincrement=True) + cache_key = Column(String(128), nullable=False, unique=True) + symbol = Column(String(16), index=True) + data_type = Column(String(16)) # 'price' or 'options' + hits = Column(Integer, nullable=False, default=0) + misses = Column(Integer, nullable=False, default=0) + last_updated = Column(DateTime, nullable=False, default=datetime.utcnow)