diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..c45f823 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,32 @@ +name: Unit tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install R + run: | + sudo apt-get update + sudo apt-get install -y r-base r-base-dev + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: requirements.txt + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Run unit tests + run: python -m unittest discover -s tests -v diff --git a/.gitignore b/.gitignore index 441ea33..051aa23 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,8 @@ __pycache__/ .env data/*.db data/*.db-* +data/espn_active_players.json +data/ffanalytics_players.json +data/sleeper_players.json .DS_Store .streamlit/secrets.toml diff --git a/Dockerfile b/Dockerfile index 602b1e1..24a1d2d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,18 +4,38 @@ WORKDIR /app ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ - PIP_NO_CACHE_DIR=1 + PIP_NO_CACHE_DIR=1 \ + R_HOME=/usr/lib/R RUN apt-get update \ - && apt-get install -y --no-install-recommends curl \ + && apt-get install -y --no-install-recommends \ + curl \ + git \ + r-base \ + r-base-dev \ + libcurl4-openssl-dev \ + libssl-dev \ + libxml2-dev \ && rm -rf /var/lib/apt/lists/* +# CRAN may not ship ffanalytics for the distro R version; install from GitHub. +RUN Rscript -e '\ + install.packages("remotes", repos = "https://cloud.r-project.org"); \ + remotes::install_github( \ + "FantasyFootballAnalytics/ffanalytics", \ + upgrade = "never", \ + dependencies = TRUE \ + ) \ +' + COPY requirements.txt . RUN pip install -r requirements.txt COPY config.py . COPY api/ api/ COPY rag/ rag/ +COPY draft/ draft/ +COPY sdks/ sdks/ COPY app/ app/ COPY scripts/ scripts/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2a156d6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 jlee733 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 87c1e40..6d4c2c5 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,20 @@ # StatShift -Local-only RAG stack matching the architecture diagram: +Local fantasy football statistics assistant. Query seeded NFL/fantasy data through a read-only API, or chat with a local LLM for open-ended analysis. **Streamlit** → **RAG** → **FastAPI (read-only)** → **SQLite** **RAG** → **Ollama** → **Gemma** +Factual questions (stats, matchups, injuries) are answered from the database. Conversational prompts go to Gemma. + ## Prerequisites - **Docker (recommended):** Docker Desktop or Docker Engine with Compose v2 - **Or local Python 3.11+** and [Ollama](https://ollama.com) with Gemma (`ollama pull gemma2:2b`) -## Run with Docker +The Docker image includes R, [ffanalytics](https://github.com/FantasyFootballAnalytics/ffanalytics), and `rpy2` for the Mock Draft tab — no separate R install when using Compose. -Yes — you interact with Streamlit in your **browser on your machine**. The UI container listens on `0.0.0.0:8501` and Compose publishes it to `localhost:8501`. Buttons, text input, and reruns all work normally; only the server runs inside Docker. +## Run with Docker ```bash docker compose up --build @@ -28,7 +30,7 @@ Optional: API docs at http://localhost:8000/docs # Stop docker compose down -# Reset DB + models +# Reset DB + models (re-seeds fantasy football sample data) docker compose down -v ``` @@ -40,7 +42,7 @@ Services: | `api` | FastAPI read-only API | 8000 | | `ollama` | Gemma inference | 11434 | -Inside the Compose network, the UI talks to `http://api:8000` and `http://ollama:11434`. You do not need to call those from the browser. +Inside the Compose network, the UI talks to `http://api:8000` and `http://ollama:11434`. ## Local setup (without Docker) @@ -73,7 +75,37 @@ chmod +x scripts/run_api.sh scripts/run_ui.sh ./scripts/run_ui.sh ``` -Open http://127.0.0.1:8501 and ask questions about the seeded basketball stats. +Open http://127.0.0.1:8501 and ask about players, weekly matchups, injuries, or PPR scoring from the seeded data. + +If you previously ran an older build with different seed data, delete `data/statshift.db` or run `docker compose down -v` before re-initializing. + +## Mock Draft (Docker Compose) + +With the stack running (`docker compose up --build`), open **http://localhost:8501** → **Mock Draft**. The player pool comes from **ffanalytics** (projections for Standard / Half-PPR / PPR plus ADP). R and ffanalytics are already in the `ui` image. + +The cache lives in the **`statshift_data` volume** at `/app/data/ffanalytics_players.json` (24-hour TTL). It survives container restarts until you run `docker compose down -v`. + +**First visit:** opening Mock Draft or clicking **Refresh player pool** triggers a scrape (often several minutes). **Later visits** use the cache. + +Optional — pre-warm the cache before using the UI: + +```bash +# One-off sync service (stack does not need to be up) +docker compose --profile sync run --rm ffanalytics-sync + +# Or, while ui is running +docker compose exec ui python scripts/sync_ffanalytics_players.py +``` + +After rebuilding the image (`docker compose up --build`), run a sync again if ffanalytics or R packages changed. + +## Tests + +```bash +python -m unittest discover -s tests -v +``` + +CI runs the same suite on pull requests and pushes to `main` (ffanalytics parsing tests do not call R). ## Configuration @@ -96,3 +128,13 @@ API_BASE_URL=http://127.0.0.1:8000 | `GET /categories` | Distinct categories | `POST`, `PUT`, `PATCH`, and `DELETE` return **405** — writes are blocked at the API layer; SQLite is opened in read-only mode for queries. + +## Sample data categories + +| Category | Examples | +|----------|----------| +| `player` | McCaffrey workload, Chase targets, Kelce red-zone usage | +| `team` | Bills pace and scoring | +| `matchup` | Week 12 RB defensive matchups | +| `injury` | Week 15 practice reports | +| `league` | PPR scoring, passing leaders | diff --git a/api/main.py b/api/main.py index 29a8693..03d9c78 100644 --- a/api/main.py +++ b/api/main.py @@ -13,7 +13,7 @@ app = FastAPI( title="StatShift API", - description="Local read-only API over SQLite. Write operations are blocked.", + description="Read-only fantasy football stats API over SQLite. Write operations are blocked.", version="0.1.0", ) diff --git a/app/mock_draft_tab.py b/app/mock_draft_tab.py new file mode 100644 index 0000000..2f549eb --- /dev/null +++ b/app/mock_draft_tab.py @@ -0,0 +1,497 @@ +"""Streamlit UI for mock fantasy drafts.""" + +from __future__ import annotations + +import html + +import streamlit as st +from streamlit_autorefresh import st_autorefresh + +from draft.engine import MockDraftEngine, _snake_team_index +from draft.ffanalytics_loader import cache_is_fresh, load_active_players, load_player_cache +from draft.models import POSITION_COLORS, ROSTER_SLOTS, Position, RankingSource, ScoringFormat +from draft.player_pool import refresh_players +from draft.rpy2_setup import rpy2_session + + +def _init_draft_state() -> None: + if "mock_draft" not in st.session_state: + st.session_state["mock_draft"] = None + if "draft_pool" not in st.session_state: + st.session_state["draft_pool"] = None + + +def _ensure_player_pool(*, force_refresh: bool = False) -> list: + if force_refresh: + with rpy2_session(): + with st.spinner("Loading players…"): + st.session_state["draft_pool"] = refresh_players(force_refresh=True) + elif st.session_state.get("draft_pool") is None: + # Use JSON cache only on tab open — avoid a long R scrape until user clicks Refresh + st.session_state["draft_pool"] = load_player_cache() or [] + return st.session_state["draft_pool"] + + +def _position_colors(position: str) -> tuple[str, str]: + """Return (background, text) hex colors for a position.""" + return POSITION_COLORS.get(position, ("#374151", "#FFFFFF")) + + +def _player_pick_html( + name: str, + position: str, + team: str = "", + *, + highlight: bool = False, +) -> str: + """HTML for a Sleeper-style colored pick cell.""" + bg, fg = _position_colors(position) + border = "2px solid #FBBF24" if highlight else "1px solid rgba(255,255,255,0.15)" + safe_name = html.escape(name) + safe_pos = html.escape(position) + safe_team = html.escape(team) if team else "" + team_line = f'
{safe_pos}' + if safe_team: + team_line += f" · {safe_team}" + team_line += "
" + return ( + f'
' + f'
{safe_name}
' + f"{team_line}
" + ) + + +def _on_clock_html(text: str, *, is_user: bool) -> str: + bg = "#F59E0B" if is_user else "#4B5563" + return ( + f'
' + f"{html.escape(text)}
" + ) + + +def _empty_cell_html(round_num: int) -> str: + return ( + f'
R{round_num}
' + ) + + +def _render_position_legend() -> None: + chips = [] + for pos in ("QB", "RB", "WR", "TE", "K", "DEF"): + bg, fg = _position_colors(pos) + chips.append( + f'{pos}' + ) + st.markdown("".join(chips), unsafe_allow_html=True) + + +def _render_settings( + pool_size: int, +) -> tuple[ScoringFormat, int, int, int, float, RankingSource, str] | None: + st.subheader("League settings") + col1, col2, col3, col4, col5, col6 = st.columns(6) + + with col1: + scoring_label = st.selectbox( + "Scoring", + options=[s.value for s in ScoringFormat], + index=0, + ) + with col2: + ranking_options = [r.value for r in RankingSource] + ranking_label = st.selectbox( + "Rankings", + options=ranking_options, + index=0, + help="Consensus rankings source for CPU picks and best available.", + ) + with col3: + league_size = st.selectbox("League size", options=[8, 10, 12, 14], index=2) + with col4: + draft_slot = st.number_input( + "Your draft slot", + min_value=1, + max_value=int(league_size), + value=min(5, int(league_size)), + step=1, + ) + with col5: + if pool_size: + max_rounds = max(1, pool_size // max(int(league_size), 1)) + else: + max_rounds = 15 + rounds = st.number_input( + "Rounds", + min_value=1, + max_value=max(1, max_rounds), + value=min(15, max(1, max_rounds)), + step=1, + disabled=pool_size == 0, + ) + with col6: + user_team_name = st.text_input( + "Your team name", + value="My Team", + max_chars=25, + ) + + cpu_randomness = st.slider( + "Draft unpredictability", + min_value=0.0, + max_value=1.0, + value=0.35, + step=0.05, + help="Higher values make other teams less predictable.", + ) + + scoring = ScoringFormat(scoring_label) + ranking_source = RankingSource(ranking_label) + if pool_size: + st.caption(f"{pool_size} players · up to {max_rounds} rounds") + return scoring, int(league_size), int(draft_slot), int(rounds), float(cpu_randomness), ranking_source, user_team_name + + +def _render_sleeper_draft_board(draft: MockDraftEngine) -> None: + """Render Sleeper-style grid board (teams as columns, rounds as rows).""" + st.subheader("Draft Board") + _render_position_legend() + + num_teams = draft.settings.league_size + num_rounds = draft.settings.rounds + user_team_idx = draft.settings.draft_slot - 1 + + pick_grid: dict[tuple[int, int], tuple[str, str, str]] = {} + for pick in draft.picks: + pick_grid[(pick.round, pick.team_index)] = ( + pick.player.name, + pick.player.position.value, + pick.player.team, + ) + + current_round = draft.current_round + current_team = draft.current_team_index if not draft.is_complete else -1 + + header_cols = st.columns(num_teams) + for i, col in enumerate(header_cols): + team_name = draft.get_team_name(i) + suffix = " 🏈" if i == user_team_idx else "" + col.markdown( + f'
{html.escape(team_name)}{suffix}
', + unsafe_allow_html=True, + ) + + for round_num in range(1, num_rounds + 1): + row_cols = st.columns(num_teams) + + for pick_in_round in range(1, num_teams + 1): + team_idx = _snake_team_index(round_num, pick_in_round, num_teams) + col = row_cols[team_idx] + + pick_data = pick_grid.get((round_num, team_idx)) + is_current_pick = round_num == current_round and team_idx == current_team + is_user_team = team_idx == user_team_idx + + if pick_data: + name, position, team = pick_data + pick_html = _player_pick_html( + name, + position, + team, + highlight=is_user_team, + ) + col.markdown(pick_html, unsafe_allow_html=True) + elif is_current_pick: + remaining = draft.time_remaining() + if draft.is_user_turn: + text = f"⏱️ YOUR PICK ({remaining}s)" + else: + text = f"⏱️ On Clock ({remaining}s)" + col.markdown( + _on_clock_html(text, is_user=draft.is_user_turn), + unsafe_allow_html=True, + ) + else: + col.markdown(_empty_cell_html(round_num), unsafe_allow_html=True) + + +def _render_user_roster(draft: MockDraftEngine) -> None: + st.subheader("Your roster") + roster = draft.user_roster() + if not roster.picks: + st.caption("No players drafted yet.") + return + + pos_order = {p: i for i, p in enumerate(Position)} + sorted_picks = sorted(roster.picks, key=lambda p: pos_order.get(p.position, 99)) + + for p in sorted_picks: + fpg = round(p.fantasy_points(draft.settings.scoring), 1) + bg, fg = _position_colors(p.position.value) + st.markdown( + f'
' + f'
' + f"{p.position.value}
" + f'
' + f'
{p.name}
' + f'
{p.team} · {fpg} FPG
' + f"
", + unsafe_allow_html=True, + ) + + filled = [] + for slot, need in ROSTER_SLOTS.items(): + if slot == "FLEX": + continue + pos_key = slot + have = sum(1 for p in roster.picks if p.position.value == pos_key) + filled.append(f"{slot}: {have}/{need}") + st.caption("Starters · " + " · ".join(filled)) + + +def _render_pick_timer(draft: MockDraftEngine) -> None: + """Render the countdown timer for the current pick.""" + remaining = draft.time_remaining() + + if remaining <= 10: + st.error(f"⏱️ **{remaining}** seconds remaining!") + elif remaining <= 20: + st.warning(f"⏱️ **{remaining}** seconds remaining") + else: + st.info(f"⏱️ **{remaining}** seconds remaining") + + +def _handle_timer_expiration(draft: MockDraftEngine) -> bool: + """Auto-draft for the team on the clock when their 30s expires. One pick per timeout.""" + if draft.is_timer_expired() and not draft.is_complete: + draft.auto_pick() + if not draft.is_complete: + draft.start_pick_timer() + return True + return False + + +def _render_pick_controls(draft: MockDraftEngine) -> None: + if draft.is_complete: + if draft.pool_exhausted and len(draft.picks) < draft.settings.league_size * draft.requested_rounds: + st.warning( + f"Draft stopped early: player pool ran out after **{len(draft.picks)}** picks " + f"(round {draft.picks[-1].round if draft.picks else 0})." + ) + else: + st.success("Draft complete!") + return + + if _handle_timer_expiration(draft): + st.rerun() + + st.subheader("On the clock") + round_num = draft.current_round + overall = draft.current_overall + team_name = draft.get_team_name(draft.current_team_index) + + _render_pick_timer(draft) + st.caption(f"Each team has **{draft.pick_time_limit}** seconds to pick.") + + if draft.is_user_turn: + st.info(f"Round {round_num} · Pick {overall} — **{team_name} (You)**") + ranked = draft.rank_available()[:25] + if not ranked: + st.error("No players left in the pool.") + return + options = {} + for p in ranked: + rank = p.rank_for_source(draft.ranking_source) + rank_str = f"#{int(rank)}" if rank < 999 else "—" + label = ( + f"{rank_str} {p.name} ({p.position.value}, {p.team}) — " + f"{p.fantasy_points(draft.settings.scoring):.1f} FPG" + ) + options[label] = p + choice = st.selectbox("Select player", options=list(options.keys())) + if st.button("Draft player", type="primary", use_container_width=True): + draft.make_pick(options[choice]) + if not draft.is_complete: + draft.start_pick_timer() + st.rerun() + else: + st.warning( + f"Round {round_num} · Pick {overall} — **{team_name}** is on the clock " + f"({draft.time_remaining()}s left)" + ) + if st.button("Skip to your pick", use_container_width=True): + while not draft.is_complete and not draft.is_user_turn: + draft.auto_pick() + if not draft.is_complete: + draft.start_pick_timer() + st.rerun() + + +def render_mock_draft_tab() -> None: + _init_draft_state() + + # Auto-load players if cache is missing or stale and not already attempted + if "auto_load_attempted" not in st.session_state: + st.session_state["auto_load_attempted"] = False + + if not st.session_state["auto_load_attempted"]: + if not cache_is_fresh(): + st.session_state["auto_load_attempted"] = True + try: + with rpy2_session(): + with st.spinner("Loading player pool…"): + pool_loaded = load_active_players(force_refresh=True) + refresh_players(force_refresh=True) + st.session_state["draft_pool"] = pool_loaded + st.session_state["players_loaded_successfully"] = True + except Exception as exc: + st.error(f"Could not auto-load players: {exc}") + st.session_state["players_loaded_successfully"] = False + else: + # Cache is fresh, just load from cache + st.session_state["auto_load_attempted"] = True + cached = load_player_cache() + if cached: + st.session_state["draft_pool"] = cached + st.session_state["players_loaded_successfully"] = True + + pool = _ensure_player_pool() + pool_size = len(pool) + + load_col, _ = st.columns([1, 3]) + with load_col: + if st.button("Load players", use_container_width=True): + try: + with rpy2_session(): + with st.spinner("Loading player pool…"): + pool_loaded = load_active_players(force_refresh=True) + refresh_players(force_refresh=True) + st.session_state["draft_pool"] = pool_loaded + st.session_state["players_loaded_successfully"] = True + if not pool_loaded: + st.error("No players were returned. Try again in a few minutes.") + st.session_state["players_loaded_successfully"] = False + else: + st.rerun() + except Exception as exc: + st.error(f"Could not load players: {exc}") + st.session_state["players_loaded_successfully"] = False + pool = _ensure_player_pool() + pool_size = len(pool) + + if st.session_state.get("players_loaded_successfully"): + st.success("Players successfully loaded!", icon="✅") + + if pool_size == 0: + st.info("Load players to start a mock draft.") + + settings = _render_settings(pool_size) + if settings is None: + return + + scoring, league_size, draft_slot, rounds, cpu_randomness, ranking_source, user_team_name = settings + + slot_cols = st.columns([1, 1, 2]) + with slot_cols[0]: + if st.button( + "Start new draft", + type="primary", + use_container_width=True, + disabled=pool_size == 0, + ): + try: + st.session_state["mock_draft"] = MockDraftEngine.create( + scoring=scoring, + league_size=league_size, + draft_slot=draft_slot, + rounds=rounds, + pool=pool, + cpu_randomness=cpu_randomness, + ranking_source=ranking_source, + user_team_name=user_team_name, + ) + st.rerun() + except ValueError as exc: + st.error(str(exc)) + with slot_cols[1]: + if st.button("Reset", use_container_width=True): + st.session_state["mock_draft"] = None + st.rerun() + + draft: MockDraftEngine | None = st.session_state.get("mock_draft") + if draft is None: + return + + st_autorefresh(interval=1000, key="draft_timer_refresh") + + meta = draft.settings + st.caption( + f"{meta.scoring.value} · {draft.ranking_source.value} rankings · " + f"{meta.league_size} teams · pick {meta.draft_slot} · " + f"round {draft.current_round} · " + f"#{min(draft.current_overall, draft.total_picks)} of {draft.total_picks}" + ) + + _render_pick_controls(draft) + + st.divider() + + _render_sleeper_draft_board(draft) + + st.divider() + + roster_col, best_col = st.columns([1, 1]) + with roster_col: + _render_user_roster(draft) + with best_col: + st.subheader("Best available") + for player in draft.rank_available()[:8]: + rank = player.rank_for_source(draft.ranking_source) + rank_str = f"#{int(rank)}" if rank < 999 else "—" + st.write( + f"**{player.name}** ({player.position.value}, {player.team}) — " + f"{player.fantasy_points(draft.settings.scoring):.1f} FPG · {rank_str}" + ) + + _render_monte_carlo_lookahead(draft) + + +def _render_monte_carlo_lookahead(draft: MockDraftEngine) -> None: + if draft.is_complete or draft.pool_exhausted or not draft.is_user_turn: + return + + num_sims = 40 + with st.spinner("Calculating pick outlook…"): + probabilities = draft.simulate_availability_at_next_user_pick(num_sims=num_sims) + + ranked = draft.rank_available()[:10] + if not ranked: + return + + rows = [] + for player in ranked: + pct = probabilities.get(player.name, 0.0) * 100 + rank = player.rank_for_source(draft.ranking_source) + rows.append( + { + "Player": player.name, + "Pos": player.position.value, + "Team": player.team, + "Rank": int(rank) if rank < 999 else "—", + "FPG": round(player.fantasy_points(draft.settings.scoring), 1), + "Avail next pick": f"{pct:.0f}%", + "_pct": pct, + } + ) + + rows.sort(key=lambda r: r["_pct"], reverse=True) + for row in rows: + del row["_pct"] + + st.subheader("Likely available next pick") + st.dataframe(rows, use_container_width=True, hide_index=True) diff --git a/app/research_tab.py b/app/research_tab.py new file mode 100644 index 0000000..1e97eaa --- /dev/null +++ b/app/research_tab.py @@ -0,0 +1,467 @@ +"""Streamlit UI for NFL player research.""" + +from __future__ import annotations + +import html + +import matplotlib.pyplot as plt +import numpy as np +import streamlit as st + +from sdks.distribution_fit import fit_best_distribution, format_params_string +from sdks.espn_player_loader import ( + CollegeSeasonStats, + CombineMetrics, + GameLogEntry, + InjuryInfo, + NewsArticle, + PlayerProfile, + SeasonStats, + calculate_fantasy_points, + get_available_seasons, + get_player_college_stats, + get_player_combine, + get_player_gamelog, + get_player_injuries, + get_player_news, + get_player_profile, + get_player_seasons, + get_player_stats, + is_upcoming_rookie, + search_players, +) + + +def _init_research_state() -> None: + if "research_selected_player" not in st.session_state: + st.session_state["research_selected_player"] = None + if "research_search_results" not in st.session_state: + st.session_state["research_search_results"] = [] + + +def _render_search() -> None: + """Render the player search section.""" + st.subheader("Search Players") + + if not st.session_state.get("research_selected_player"): + st.info("Search for a player below to view their profile, stats, and news.") + + with st.form("research_player_search", clear_on_submit=False): + col1, col2 = st.columns([4, 1]) + with col1: + query = st.text_input( + "Enter player name", + placeholder="e.g. Patrick Mahomes", + label_visibility="collapsed", + ) + with col2: + search_clicked = st.form_submit_button( + "Search", type="primary", use_container_width=True + ) + + if search_clicked and query.strip(): + with st.spinner("Searching..."): + results = search_players(query.strip()) + st.session_state["research_search_results"] = results + if not results: + st.warning("No players found. Try a different search term.") + + results = st.session_state.get("research_search_results", []) + if results and not st.session_state.get("research_selected_player"): + st.caption(f"Found {len(results)} active player(s)") + + for player in results: + img_col, info_col, btn_col = st.columns([1, 5, 1]) + with img_col: + headshot = player.get("headshot", "") + if headshot: + st.image(headshot, width=72) + else: + st.markdown("—") + with info_col: + display = f"**{player['name']}** — {player['position']}, {player['team']}" + st.markdown(display) + with btn_col: + if st.button("Select", key=f"select_{player['id']}", use_container_width=True): + st.session_state["research_selected_player"] = player["id"] + st.session_state["research_search_results"] = [] + st.rerun() + + +def _college_stats_to_rows(seasons: list[CollegeSeasonStats]) -> list[dict]: + """Convert college season stats to dataframe rows, omitting empty columns.""" + columns = [ + ("Season", "season"), + ("Team", "team"), + ("GP", "games_played"), + ("Cmp", "completions"), + ("Pass Yds", "passing_yards"), + ("Pass TD", "passing_tds"), + ("INT", "interceptions"), + ("Rush Att", "rush_attempts"), + ("Rush Yds", "rushing_yards"), + ("Rush TD", "rushing_tds"), + ("Rec", "receptions"), + ("Rec Yds", "receiving_yards"), + ("Rec TD", "receiving_tds"), + ] + raw_rows = [] + for season in seasons: + raw_rows.append({label: getattr(season, attr) for label, attr in columns}) + + active_labels = ["Season"] + for label, _ in columns[1:]: + if any(row.get(label, "—") not in ("—", "0", "0.0", "") for row in raw_rows): + active_labels.append(label) + + return [{label: row[label] for label in active_labels} for row in raw_rows] + + +def _render_college_stats(profile: PlayerProfile) -> None: + """Show college stats table for pre-season rookies.""" + with st.spinner("Loading college stats..."): + seasons = get_player_college_stats(profile.id) + + if not seasons: + st.info("No college statistics available for this player.") + return + + st.subheader("College stats") + rows = _college_stats_to_rows(seasons) + st.dataframe(rows, use_container_width=True, hide_index=True) + + +def _render_profile_card(profile: PlayerProfile) -> None: + """Render the player profile card.""" + st.divider() + + col1, col2 = st.columns([1, 3]) + + with col1: + if profile.headshot_url: + st.image(profile.headshot_url, width=150) + else: + st.markdown("*No photo available*") + + with col2: + rookie = is_upcoming_rookie(profile.draft_year) + name_line = f"## {profile.name}" + if rookie: + name_line += ' Rookie' + st.markdown(name_line, unsafe_allow_html=True) + st.markdown(f"**{profile.position}** | #{profile.jersey} | {profile.team}") + + info_col1, info_col2, info_col3 = st.columns(3) + with info_col1: + st.metric("Height", profile.height) + st.metric("College", profile.college) + with info_col2: + st.metric("Weight", profile.weight) + st.metric("Draft", profile.draft_info) + with info_col3: + if profile.age: + st.metric("Age", profile.age) + st.metric("Experience", f"{profile.experience} yrs") + + status_color = "green" if profile.status == "Active" else "orange" + st.markdown(f"**Status:** :{status_color}[{profile.status}]") + + if is_upcoming_rookie(profile.draft_year): + _render_college_stats(profile) + + +def _create_timeseries_chart(gamelog: list[GameLogEntry], scoring_key: str, scoring_label: str) -> plt.Figure: + """Create time series chart of fantasy points by week.""" + fig, ax = plt.subplots(figsize=(6, 3.5)) + + if not gamelog: + ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes) + ax.set_xlabel("Week") + ax.set_ylabel("Fantasy Points") + return fig + + weeks = [] + points = [] + + for game in gamelog: + fp = calculate_fantasy_points(game, scoring_key) + weeks.append(game.week) + points.append(fp) + + if weeks: + ax.plot(weeks, points, marker='o', linestyle='-', linewidth=2, markersize=6, + color='#2E86AB', markerfacecolor='#2E86AB', markeredgecolor='white', markeredgewidth=1.5) + ax.fill_between(weeks, points, alpha=0.2, color='#2E86AB') + + max_week = max(weeks) + ax.set_xlim(0.5, max_week + 0.5) + ax.set_xticks(range(1, max_week + 1)) + + if points: + y_max = max(points) * 1.1 if max(points) > 0 else 10 + ax.set_ylim(0, y_max) + + ax.set_xlabel("Week", fontsize=10) + ax.set_ylabel(f"Fantasy Points ({scoring_label})", fontsize=10) + ax.set_title("Weekly Performance", fontsize=11, fontweight='bold') + ax.grid(alpha=0.3, linestyle='--', linewidth=0.5) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + plt.tight_layout() + + return fig + + +def _create_histogram_with_fit(fp_data: list[float], scoring_label: str) -> plt.Figure: + """Create histogram with fitted distribution overlay.""" + nonzero = [x for x in fp_data if x > 0] + + fig, ax = plt.subplots(figsize=(6, 3.5)) + + if not nonzero: + ax.text(0.5, 0.5, "No data available", ha="center", va="center", transform=ax.transAxes) + ax.set_xlabel("Fantasy Points") + ax.set_ylabel("Density") + return fig + + ax.hist(nonzero, bins="auto", density=True, alpha=0.7, color="steelblue", edgecolor="white", label="Data") + + if len(nonzero) >= 5: + dist_name, params, pdf_fn = fit_best_distribution(nonzero) + x = np.linspace(min(nonzero), max(nonzero), 100) + ax.plot(x, pdf_fn(x), "r-", lw=2, label=dist_name) + + param_str = format_params_string(params) + ax.set_title(f"{dist_name} Distribution\n({param_str})", fontsize=10, fontweight='bold') + else: + ax.set_title(f"Fantasy Points Distribution\n(Need 5+ games for model fit, have {len(nonzero)})", fontsize=9, fontweight='bold') + + ax.set_xlabel(f"Fantasy Points ({scoring_label})", fontsize=10) + ax.set_ylabel("Density", fontsize=10) + ax.legend(loc="upper right") + ax.grid(alpha=0.3, linestyle='--', linewidth=0.5) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + plt.tight_layout() + + return fig + + +def _build_stats_table(gamelog: list[GameLogEntry], scoring_key: str) -> list[dict]: + """Build stats table rows from game log.""" + rows = [] + for game in gamelog: + fp = calculate_fantasy_points(game, scoring_key) + rows.append({ + "Week": game.week, + "Pass Yds": game.passing_yards, + "Pass TD": game.passing_tds, + "INT": game.interceptions, + "Rush Yds": game.rushing_yards, + "Rush TD": game.rushing_tds, + "Rec": game.receptions, + "Rec Yds": game.receiving_yards, + "Rec TD": game.receiving_tds, + "Fum": game.fumbles_lost, + "FP": fp, + }) + return rows + + +def _render_stats_tab(player_id: str) -> None: + """Render the statistics sub-tab with season checkboxes and histogram.""" + scoring_options = ["PPR", "Half-PPR", "Standard"] + scoring_label = st.selectbox("Scoring Format", options=scoring_options, index=0) + scoring_key = scoring_label.lower().replace("-", "_") + + with st.spinner("Loading available seasons..."): + seasons = get_player_seasons(player_id) + + if not seasons: + st.info("No season data available for this player.") + return + + selected_seasons = st.multiselect( + "Select Seasons", + options=seasons, + default=[seasons[0]] if seasons else [], + ) + + if not selected_seasons: + st.info("Select one or more seasons above to view stats.") + return + + for season in selected_seasons: + st.subheader(f"{season} Season") + + with st.spinner(f"Loading {season} game log..."): + gamelog = get_player_gamelog(player_id, season) + + if not gamelog: + st.warning(f"No game data available for {season}.") + continue + + fp_list = [calculate_fantasy_points(g, scoring_key) for g in gamelog] + + col_table, col_charts = st.columns(2) + + with col_table: + st.caption(f"Games: {len(gamelog)}") + rows = _build_stats_table(gamelog, scoring_key) + st.dataframe(rows, use_container_width=True, hide_index=True) + + with col_charts: + nonzero_count = len([x for x in fp_list if x > 0]) + st.caption(f"Games with stats: {nonzero_count}") + + fig_ts = _create_timeseries_chart(gamelog, scoring_key, scoring_label) + st.pyplot(fig_ts) + plt.close(fig_ts) + + fig_hist = _create_histogram_with_fit(fp_list, scoring_label) + st.pyplot(fig_hist) + plt.close(fig_hist) + + st.divider() + + +def _render_combine_tab(player_id: str) -> None: + """Render the combine metrics sub-tab.""" + with st.spinner("Loading combine data..."): + combine = get_player_combine(player_id) + + if not combine: + st.info("No combine data available for this player.") + return + + if combine.year: + st.caption(f"NFL Combine {combine.year}") + + col1, col2 = st.columns(2) + + with col1: + st.metric("40-Yard Dash", f"{combine.forty_yard}s" if combine.forty_yard else "—") + st.metric("Bench Press", f"{combine.bench_press} reps" if combine.bench_press else "—") + st.metric("3-Cone Drill", f"{combine.three_cone}s" if combine.three_cone else "—") + + with col2: + st.metric("Vertical Jump", f'{combine.vertical_jump}"' if combine.vertical_jump else "—") + st.metric("Broad Jump", f'{combine.broad_jump}"' if combine.broad_jump else "—") + st.metric("20-Yard Shuttle", f"{combine.shuttle}s" if combine.shuttle else "—") + + all_none = all([ + combine.forty_yard is None, + combine.vertical_jump is None, + combine.bench_press is None, + combine.broad_jump is None, + combine.three_cone is None, + combine.shuttle is None, + ]) + + if all_none: + st.info("Combine metrics not available. Player may not have participated in the NFL Combine or data is not publicly available.") + + +def _render_injuries_tab(player_id: str) -> None: + """Render the injuries sub-tab.""" + with st.spinner("Loading injury data..."): + injuries = get_player_injuries(player_id) + + if not injuries: + st.success("No injuries reported.") + return + + for injury in injuries: + status_color = "green" if injury.status == "Active" else "red" + + st.markdown(f"### :{status_color}[{injury.status}]") + + if injury.injury_type and injury.injury_type != "—": + st.markdown(f"**Type:** {injury.injury_type}") + + if injury.details and injury.details != "—": + st.markdown(f"**Details:** {injury.details}") + + if injury.date and injury.date != "—": + st.caption(f"Date: {injury.date}") + + st.divider() + + +def _render_news_tab(player_id: str) -> None: + """Render the news sub-tab.""" + with st.spinner("Loading news..."): + articles = get_player_news(player_id) + + if not articles: + st.info("No recent news articles found.") + return + + for article in articles: + col1, col2 = st.columns([1, 4]) + + with col1: + if article.image_url: + st.image(article.image_url, width=120) + + with col2: + if article.link: + st.markdown(f"### [{article.headline}]({article.link})") + else: + st.markdown(f"### {article.headline}") + + if article.description: + st.markdown( + f'

' + f"{html.escape(article.description)}

", + unsafe_allow_html=True, + ) + + st.caption(article.published) + + st.divider() + + +def render_research_tab() -> None: + """Main entry point for the Research tab.""" + _init_research_state() + + _render_search() + + player_id = st.session_state.get("research_selected_player") + + if not player_id: + return + + with st.spinner("Loading player profile..."): + profile = get_player_profile(player_id) + + if not profile: + st.error("Could not load player profile. Please try again.") + st.session_state["research_selected_player"] = None + return + + if st.button("← Back to Search"): + st.session_state["research_selected_player"] = None + st.session_state["research_search_results"] = [] + st.rerun() + + _render_profile_card(profile) + + st.divider() + + stats_tab, combine_tab, injuries_tab, news_tab = st.tabs([ + "Stats", "Combine", "Injuries", "News" + ]) + + with stats_tab: + _render_stats_tab(player_id) + + with combine_tab: + _render_combine_tab(player_id) + + with injuries_tab: + _render_injuries_tab(player_id) + + with news_tab: + _render_news_tab(player_id) diff --git a/app/streamlit_app.py b/app/streamlit_app.py index c2892f0..475a5a9 100644 --- a/app/streamlit_app.py +++ b/app/streamlit_app.py @@ -11,9 +11,8 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) -from config import settings -from rag.engine import RAGEngine -from rag.ollama_client import OllamaError +from app.ui_styles import inject_global_styles +from draft.rpy2_setup import init_rpy2_on_main_thread st.set_page_config( page_title="StatShift", @@ -21,73 +20,35 @@ layout="wide", ) -st.title("StatShift") -st.caption("Local prototype — Streamlit → RAG → FastAPI (read-only) → SQLite → Ollama/Gemma") -engine = RAGEngine() +@st.cache_resource +def _rpy2_ready() -> bool: + """Initialize rpy2 once on Streamlit's main thread (required for Mock Draft).""" + init_rpy2_on_main_thread() + return True -with st.sidebar: - st.header("System status") - if st.button("Check health", use_container_width=True): - st.session_state["health"] = engine.health() - health = st.session_state.get("health") - if health: - st.write("API", "✅" if health["api_ok"] else "❌", health.get("api_detail", "")) - st.write("Ollama", "✅" if health["ollama_ok"] else "❌") - if health.get("ollama_models"): - st.caption("Models: " + ", ".join(health["ollama_models"][:5])) - st.caption(f"Configured model: {health['configured_model']}") - st.divider() - st.markdown( - """ - **Docker** - `docker compose up --build` → open http://localhost:8501 +_rpy2_ready() +inject_global_styles() - **Local** - 1. `python scripts/init_db.py` - 2. `uvicorn api.main:app --reload` - 3. `ollama pull gemma2:2b` - 4. `streamlit run app/streamlit_app.py` - """ - ) - st.caption(f"API: {settings.api_base_url}") - st.caption(f"Ollama: {settings.ollama_base_url}") +VIEWS_DIR = Path(__file__).parent / "views" -query = st.text_input( - "Ask about players, teams, or league stats", - placeholder="e.g. How did Victor Wembanyama perform in 2024-25?", +home_page = st.Page( + str(VIEWS_DIR / "home_page.py"), + title="Home", + icon="🏠", + default=True, +) +ask_page = st.Page( + str(VIEWS_DIR / "ask_page.py"), + title="Ask", + icon="💬", +) +research_page = st.Page( + str(VIEWS_DIR / "research_page.py"), + title="Research", + icon="🔍", ) -col1, col2 = st.columns([1, 4]) -with col1: - run = st.button("Run RAG", type="primary", use_container_width=True) -with col2: - show_prompt = st.checkbox("Show prompt sent to Gemma") - -if run and query.strip(): - with st.spinner("Retrieving context and generating answer..."): - try: - result = engine.ask(query.strip()) - st.session_state["last_result"] = result - except OllamaError as exc: - st.error(str(exc)) - except Exception as exc: - st.error(f"RAG failed: {exc}") - -result = st.session_state.get("last_result") -if result: - st.subheader("Answer") - st.markdown(result.answer) - - st.subheader("Retrieved sources") - for doc in result.sources: - with st.expander(f"[{doc['id']}] {doc['title']} · {doc['category']}"): - st.write(doc["content"]) - - if show_prompt: - st.subheader("Prompt") - st.code(result.prompt) - -elif run and not query.strip(): - st.warning("Enter a question first.") +pg = st.navigation([home_page, ask_page, research_page]) +pg.run() diff --git a/app/ui_styles.py b/app/ui_styles.py new file mode 100644 index 0000000..83b79b1 --- /dev/null +++ b/app/ui_styles.py @@ -0,0 +1,68 @@ +"""Shared Streamlit UI styles for readable, non-truncated text.""" + +from __future__ import annotations + +import streamlit as st + + +def inject_global_styles() -> None: + """Apply CSS so long labels and values wrap instead of showing ellipsis.""" + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) diff --git a/app/views/ask_page.py b/app/views/ask_page.py new file mode 100644 index 0000000..4a073a3 --- /dev/null +++ b/app/views/ask_page.py @@ -0,0 +1,66 @@ +"""Ask page - RAG-powered Q&A.""" + +from __future__ import annotations + +import streamlit as st + +from rag.engine import RAGEngine +from rag.ollama_client import OllamaClient, OllamaError + +st.title("StatShift") + +# Model selection +model_options = { + "Gemma 2": "gemma2:latest", + "Llama 3": "llama3:latest", +} + +model_col, _ = st.columns([1, 3]) +with model_col: + selected_model_name = st.selectbox( + "AI Model", + options=list(model_options.keys()), + index=0, + ) +selected_model = model_options[selected_model_name] + +query = st.text_input( + "Ask about players, matchups, injuries, or fantasy scoring", + placeholder="e.g. How many targets did Ja'Marr Chase have in 2024?", +) + +col1, col2 = st.columns([1, 4]) +with col1: + run = st.button("Ask", type="primary", use_container_width=True) +with col2: + show_prompt = st.checkbox("Show AI prompt") + +if run and query.strip(): + with st.spinner("Routing your question..."): + try: + ollama_client = OllamaClient(model=selected_model) + engine = RAGEngine(ollama=ollama_client) + result = engine.ask(query.strip()) + st.session_state["last_result"] = result + except OllamaError as exc: + st.error(str(exc)) + except Exception as exc: + st.error(f"Something went wrong: {exc}") + +result = st.session_state.get("last_result") +if result: + st.subheader("Answer") + st.markdown(result.answer) + + if result.sources: + st.subheader("Sources") + for doc in result.sources: + with st.expander(f"[{doc['id']}] {doc['title']} · {doc['category']}"): + st.write(doc["content"]) + + if show_prompt and result.prompt: + st.subheader("Prompt") + st.code(result.prompt) + +elif run and not query.strip(): + st.warning("Enter a question first.") diff --git a/app/views/home_page.py b/app/views/home_page.py new file mode 100644 index 0000000..098fd2a --- /dev/null +++ b/app/views/home_page.py @@ -0,0 +1,11 @@ +"""Home page - Mock Draft.""" + +from __future__ import annotations + +import streamlit as st + +from app.mock_draft_tab import render_mock_draft_tab + +st.title("StatShift") + +render_mock_draft_tab() diff --git a/app/views/research_page.py b/app/views/research_page.py new file mode 100644 index 0000000..8ee3821 --- /dev/null +++ b/app/views/research_page.py @@ -0,0 +1,11 @@ +"""Research page - NFL player research.""" + +from __future__ import annotations + +import streamlit as st + +from app.research_tab import render_research_tab + +st.title("StatShift") + +render_research_tab() diff --git a/docker-compose.yml b/docker-compose.yml index 7a0fe90..8739a6f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,7 +23,8 @@ services: ports: - "11434:11434" healthcheck: - test: ["CMD", "curl", "-f", "http://127.0.0.1:11434/"] + # ollama/ollama has no curl; use the bundled CLI against the local server + test: ["CMD", "ollama", "list"] interval: 10s timeout: 5s retries: 10 @@ -37,7 +38,9 @@ services: condition: service_healthy environment: OLLAMA_HOST: http://ollama:11434 - entrypoint: ["ollama", "pull", "gemma2:2b"] + volumes: + - ./scripts:/scripts:ro + entrypoint: ["/bin/sh", "/scripts/pull-ollama-models.sh"] restart: "no" ui: @@ -46,7 +49,9 @@ services: environment: API_BASE_URL: http://api:8000 OLLAMA_BASE_URL: http://ollama:11434 - OLLAMA_MODEL: gemma2:2b + OLLAMA_MODEL: gemma2:latest + volumes: + - statshift_data:/app/data ports: - "8501:8501" depends_on: @@ -58,6 +63,15 @@ services: condition: service_completed_successfully restart: unless-stopped + # Optional: pre-warm mock-draft player cache (ffanalytics scrape; can take several minutes) + ffanalytics-sync: + profiles: [sync] + build: . + volumes: + - statshift_data:/app/data + entrypoint: ["python", "scripts/sync_ffanalytics_players.py"] + restart: "no" + volumes: statshift_data: ollama_data: diff --git a/draft/__init__.py b/draft/__init__.py new file mode 100644 index 0000000..303528a --- /dev/null +++ b/draft/__init__.py @@ -0,0 +1,4 @@ +from draft.engine import MockDraftEngine +from draft.models import LeagueSettings, Player, ScoringFormat + +__all__ = ["MockDraftEngine", "LeagueSettings", "Player", "ScoringFormat"] diff --git a/draft/engine.py b/draft/engine.py new file mode 100644 index 0000000..832719b --- /dev/null +++ b/draft/engine.py @@ -0,0 +1,346 @@ +"""Snake-draft engine for mock fantasy leagues.""" + +from __future__ import annotations + +import math +import random +import time +from dataclasses import dataclass, field + +from draft.models import ( + ROSTER_SLOTS, + TEAM_NAME_POOL, + DraftPick, + LeagueSettings, + Player, + Position, + RankingSource, + ScoringFormat, + TeamRoster, +) +from draft.player_pool import get_players, max_rounds_for_league + + +def _generate_team_names( + league_size: int, + user_team_name: str = "My Team", + user_slot: int = 1, + rng: random.Random | None = None, +) -> list[str]: + """Generate team names with user's name at their draft slot.""" + rng = rng or random.Random() + cpu_names = rng.sample(TEAM_NAME_POOL, min(league_size - 1, len(TEAM_NAME_POOL))) + + names: list[str] = [] + cpu_idx = 0 + for i in range(league_size): + if i == user_slot - 1: + names.append(user_team_name) + else: + if cpu_idx < len(cpu_names): + names.append(cpu_names[cpu_idx]) + cpu_idx += 1 + else: + names.append(f"Team {i + 1}") + return names + + +def _snake_team_index(round_num: int, pick_in_round: int, league_size: int) -> int: + if round_num % 2 == 1: + return pick_in_round - 1 + return league_size - pick_in_round + + +@dataclass +class MockDraftEngine: + settings: LeagueSettings + pool: list[Player] = field(default_factory=list) + picks: list[DraftPick] = field(default_factory=list) + rosters: list[TeamRoster] = field(default_factory=list) + available: list[Player] = field(default_factory=list) + requested_rounds: int = 15 + max_supported_rounds: int = 15 + cpu_randomness: float = 0.35 + cpu_top_k: int = 12 + seed: int | None = None + ranking_source: RankingSource = RankingSource.YAHOO + team_names: list[str] = field(default_factory=list) + user_team_name: str = "My Team" + pick_time_limit: int = 30 + pick_start_time: float | None = None + _rng: random.Random = field( + init=False, repr=False, compare=False, default_factory=random.Random + ) + + def __post_init__(self) -> None: + if not self.pool: + self.pool = get_players() + self.rosters = [TeamRoster(team_index=i) for i in range(self.settings.league_size)] + self.available = sorted( + self.pool, + key=lambda p: p.fantasy_points(self.settings.scoring), + reverse=True, + ) + self._rng = random.Random(self.seed) + if not self.team_names: + self.team_names = _generate_team_names( + self.settings.league_size, + self.user_team_name, + self.settings.draft_slot, + self._rng, + ) + + @property + def total_picks(self) -> int: + return self.settings.league_size * self.settings.rounds + + @property + def pool_exhausted(self) -> bool: + return not self.available + + @property + def is_complete(self) -> bool: + return len(self.picks) >= self.total_picks or self.pool_exhausted + + @property + def current_overall(self) -> int: + return len(self.picks) + 1 + + @property + def current_round(self) -> int: + return (len(self.picks) // self.settings.league_size) + 1 + + @property + def current_pick_in_round(self) -> int: + return (len(self.picks) % self.settings.league_size) + 1 + + @property + def current_team_index(self) -> int: + return _snake_team_index( + self.current_round, + self.current_pick_in_round, + self.settings.league_size, + ) + + @property + def is_user_turn(self) -> bool: + return ( + not self.is_complete + and self.current_team_index == self.settings.draft_slot - 1 + ) + + def start_pick_timer(self) -> None: + """Start or reset the pick timer for the current pick.""" + self.pick_start_time = time.time() + + def time_remaining(self) -> int: + """Return seconds remaining on the pick clock.""" + if self.pick_start_time is None: + return self.pick_time_limit + elapsed = time.time() - self.pick_start_time + return max(0, self.pick_time_limit - int(elapsed)) + + def is_timer_expired(self) -> bool: + """Check if the pick timer has expired.""" + return self.pick_start_time is not None and self.time_remaining() <= 0 + + def get_team_name(self, team_index: int) -> str: + """Get the display name for a team.""" + if 0 <= team_index < len(self.team_names): + return self.team_names[team_index] + return f"Team {team_index + 1}" + + def drafted_names(self) -> set[str]: + return {pick.player.name for pick in self.picks} + + def _roster_need_score(self, roster: TeamRoster, player: Player) -> float: + pos = player.position + counts = {p: roster.count(p) for p in Position} + need = 0.0 + + if pos == Position.QB and counts[Position.QB] < ROSTER_SLOTS["QB"]: + need += 3.0 + elif pos == Position.RB and counts[Position.RB] < ROSTER_SLOTS["RB"]: + need += 2.5 + elif pos == Position.WR and counts[Position.WR] < ROSTER_SLOTS["WR"]: + need += 2.5 + elif pos == Position.TE and counts[Position.TE] < ROSTER_SLOTS["TE"]: + need += 2.0 + elif pos in (Position.RB, Position.WR, Position.TE): + starters = ( + min(counts[Position.RB], ROSTER_SLOTS["RB"]) + + min(counts[Position.WR], ROSTER_SLOTS["WR"]) + + min(counts[Position.TE], ROSTER_SLOTS["TE"]) + ) + flex_filled = roster.flex_eligible_count() - starters + if flex_filled < ROSTER_SLOTS["FLEX"]: + need += 1.5 + elif pos == Position.K and counts[Position.K] < ROSTER_SLOTS["K"]: + need += 0.8 + elif pos == Position.DEF and counts[Position.DEF] < ROSTER_SLOTS["DEF"]: + need += 0.8 + else: + need -= 1.0 + + return need + + def _candidate_score(self, player: Player, team_index: int) -> float: + roster = self.rosters[team_index] + scoring = self.settings.scoring + value = player.fantasy_points(scoring) + need = self._roster_need_score(roster, player) + rank = player.rank_for_source(self.ranking_source) + rank_bonus = max(0.0, (80.0 - rank) / 80.0) * 0.5 + return value + need * 2.0 + rank_bonus + + def rank_available(self, team_index: int | None = None) -> list[Player]: + idx = team_index if team_index is not None else self.current_team_index + return sorted( + self.available, + key=lambda p: self._candidate_score(p, idx), + reverse=True, + ) + + def _sample_cpu_pick(self, candidates: list[Player]) -> Player: + if not candidates: + raise RuntimeError("No candidates available") + if self.cpu_randomness <= 0.0 or len(candidates) == 1: + return candidates[0] + top_k = candidates[: max(1, self.cpu_top_k)] + team_idx = self.current_team_index + scores = [self._candidate_score(p, team_idx) for p in top_k] + top = scores[0] + # Softmax with temperature: 0.5 (near-deterministic) → 5.0 (nearly uniform) + temperature = 0.5 + self.cpu_randomness * 4.5 + weights = [math.exp((s - top) / temperature) for s in scores] + return self._rng.choices(top_k, weights=weights, k=1)[0] + + def make_pick(self, player: Player) -> DraftPick: + if self.is_complete: + raise RuntimeError("Draft is already complete") + if player.name not in {p.name for p in self.available}: + raise ValueError(f"{player.name} is not available") + + team_index = self.current_team_index + pick = DraftPick( + round=self.current_round, + pick_in_round=self.current_pick_in_round, + overall=self.current_overall, + team_index=team_index, + player=player, + ) + self.picks.append(pick) + self.rosters[team_index].picks.append(player) + self.available = [p for p in self.available if p.name != player.name] + return pick + + def auto_pick(self) -> DraftPick: + ranked = self.rank_available() + if not ranked: + raise RuntimeError("No players available") + return self.make_pick(self._sample_cpu_pick(ranked)) + + def run_cpu_picks_until_user(self) -> list[DraftPick]: + made: list[DraftPick] = [] + while not self.is_complete and not self.is_user_turn: + if self.pool_exhausted: + break + made.append(self.auto_pick()) + return made + + def user_roster(self) -> TeamRoster: + return self.rosters[self.settings.draft_slot - 1] + + def _clone_for_simulation(self) -> MockDraftEngine: + """Lightweight copy that shares the immutable Player/pool references.""" + clone = MockDraftEngine( + settings=self.settings, + pool=self.pool, + cpu_randomness=self.cpu_randomness, + cpu_top_k=self.cpu_top_k, + ranking_source=self.ranking_source, + team_names=self.team_names, + user_team_name=self.user_team_name, + pick_time_limit=self.pick_time_limit, + ) + clone.picks = list(self.picks) + clone.rosters = [ + TeamRoster(team_index=r.team_index, picks=list(r.picks)) + for r in self.rosters + ] + clone.available = list(self.available) + clone.requested_rounds = self.requested_rounds + clone.max_supported_rounds = self.max_supported_rounds + clone._rng = random.Random() + return clone + + def simulate_availability_at_next_user_pick( + self, *, num_sims: int = 40 + ) -> dict[str, float]: + """ + Monte Carlo: probability each currently-available player is still on the + board when the user is next on the clock. + + Each sim clones the engine and runs CPU picks (with cpu_randomness) until + the user's next turn. If the user is currently on the clock, the sim + treats that as an auto-pick (so the result describes the *following* pick). + """ + if self.is_complete or self.pool_exhausted: + return {p.name: 1.0 for p in self.available} + + counts: dict[str, int] = {p.name: 0 for p in self.available} + + for _ in range(num_sims): + sim = self._clone_for_simulation() + if sim.is_user_turn and not sim.pool_exhausted: + sim.auto_pick() + while not sim.is_complete and not sim.is_user_turn: + sim.auto_pick() + for player in sim.available: + if player.name in counts: + counts[player.name] += 1 + + return {name: count / num_sims for name, count in counts.items()} + + @staticmethod + def create( + scoring: ScoringFormat, + league_size: int, + draft_slot: int, + rounds: int = 15, + *, + pool: list[Player] | None = None, + cpu_randomness: float = 0.35, + cpu_top_k: int = 12, + seed: int | None = None, + ranking_source: RankingSource = RankingSource.YAHOO, + user_team_name: str = "My Team", + pick_time_limit: int = 30, + ) -> MockDraftEngine: + player_pool = pool if pool is not None else get_players() + if not player_pool: + raise ValueError( + "Player pool is empty. Refresh the player pool from ffanalytics first." + ) + supported = len(player_pool) // max(league_size, 1) + effective_rounds = max(1, min(rounds, supported)) + settings = LeagueSettings( + scoring=scoring, + league_size=league_size, + draft_slot=draft_slot, + rounds=effective_rounds, + ) + engine = MockDraftEngine( + settings=settings, + pool=player_pool, + cpu_randomness=cpu_randomness, + cpu_top_k=cpu_top_k, + seed=seed, + ranking_source=ranking_source, + user_team_name=user_team_name, + pick_time_limit=pick_time_limit, + ) + engine.requested_rounds = rounds + engine.max_supported_rounds = supported + engine.start_pick_timer() + return engine diff --git a/draft/espn_loader.py b/draft/espn_loader.py new file mode 100644 index 0000000..46b94a4 --- /dev/null +++ b/draft/espn_loader.py @@ -0,0 +1,240 @@ +"""Load active NFL players from ESPN into draft Player models.""" + +from __future__ import annotations + +import json +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import httpx + +from config import settings +from draft.models import Player, Position +from sdks.espnSDK import ESPNEndpoint + +FANTASY_ABBREV: dict[str, Position] = { + "QB": Position.QB, + "RB": Position.RB, + "WR": Position.WR, + "TE": Position.TE, + "K": Position.K, + "PK": Position.K, + "FB": Position.RB, +} + +POSITION_BASE_FP: dict[Position, float] = { + Position.QB: 17.0, + Position.RB: 11.0, + Position.WR: 10.0, + Position.TE: 8.5, + Position.K: 7.5, +} + +CACHE_PATH = settings.project_root / "data" / "espn_active_players.json" +CACHE_MAX_AGE_HOURS = 24 +_LIST_PAGE_SIZE = 1000 +_RESOLVE_WORKERS = 32 + + +def _cache_path(path: Path | None = None) -> Path: + return path or CACHE_PATH + + +def _team_id_from_ref(ref: str) -> str | None: + match = re.search(r"/teams/(\d+)", ref) + return match.group(1) if match else None + + +def _estimate_fantasy_points(position: Position, experience_years: int) -> tuple[float, float, float]: + base = POSITION_BASE_FP[position] + min(experience_years, 12) * 0.35 + if position in (Position.WR, Position.RB, Position.TE): + return ( + round(base + 1.2, 1), + round(base + 0.6, 1), + round(base, 1), + ) + return (round(base, 1), round(base, 1), round(base, 1)) + + +def parse_athlete_payload( + payload: dict[str, Any], + team_abbreviations: dict[str, str], +) -> Player | None: + pos_data = payload.get("position") or {} + abbrev = (pos_data.get("abbreviation") or "").upper() + position = FANTASY_ABBREV.get(abbrev) + if position is None: + return None + + name = payload.get("displayName") or payload.get("fullName") or "" + if not name: + return None + + experience = payload.get("experience") or {} + years = int(experience.get("years") or 0) + + team = "" + team_ref = payload.get("team") or {} + if isinstance(team_ref, dict): + ref = team_ref.get("$ref", "") + team_id = team_ref.get("id") or _team_id_from_ref(ref) + if team_id: + team = team_abbreviations.get(str(team_id), "") + + fp_ppr, fp_half, fp_std = _estimate_fantasy_points(position, years) + espn_id = str(payload.get("id") or "") + + return Player( + name=name, + position=position, + fp_ppr=fp_ppr, + fp_half=fp_half, + fp_std=fp_std, + team=team, + adp=999.0, + ) + + +def _load_team_abbreviations(http: httpx.Client, timeout: float) -> dict[str, str]: + teams_api = ESPNEndpoint("teams", client=http, timeout=timeout) + mapping: dict[str, str] = {} + for item in teams_api.iter_items(limit=50): + payload = teams_api.resolve_ref(item) + team_id = str(payload.get("id", "")) + abbrev = payload.get("abbreviation") or payload.get("shortDisplayName") or "" + if team_id and abbrev: + mapping[team_id] = abbrev + return mapping + + +def _collect_athlete_refs(api: ESPNEndpoint) -> list[str]: + refs: list[str] = [] + for page in api.iter_pages(limit=_LIST_PAGE_SIZE, params={"active": "true"}): + for item in page.get("items", []): + ref = item.get("$ref") + if ref: + refs.append(ref) + return refs + + +def _resolve_athlete( + ref: str, + team_map: dict[str, str], + http: httpx.Client, +) -> Player | None: + response = http.get(ref) + response.raise_for_status() + return parse_athlete_payload(response.json(), team_map) + + +def fetch_active_players( + *, + timeout: float = 30.0, + max_workers: int = _RESOLVE_WORKERS, +) -> list[Player]: + """Fetch all active NFL players draftable in fantasy (QB/RB/WR/TE/K).""" + with ESPNEndpoint("athletes", timeout=timeout) as api: + team_map = _load_team_abbreviations(api.client, timeout) + refs = _collect_athlete_refs(api) + + players: list[Player] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_resolve_athlete, ref, team_map, api.client): ref + for ref in refs + } + for future in as_completed(futures): + try: + player = future.result() + if player is not None: + players.append(player) + except httpx.HTTPError: + continue + + players.sort(key=lambda p: p.fp_ppr, reverse=True) + for rank, player in enumerate(players, start=1): + player.adp = float(rank) + return players + + +def save_player_cache(players: list[Player], path: Path | None = None) -> Path: + target = _cache_path(path) + target.parent.mkdir(parents=True, exist_ok=True) + payload = { + "fetched_at": datetime.now(timezone.utc).isoformat(), + "count": len(players), + "players": [ + { + "name": p.name, + "position": p.position.value, + "fp_ppr": p.fp_ppr, + "fp_half": p.fp_half, + "fp_std": p.fp_std, + "team": p.team, + "adp": p.adp, + } + for p in players + ], + } + target.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return target + + +def load_player_cache(path: Path | None = None) -> list[Player] | None: + target = _cache_path(path) + if not target.exists(): + return None + try: + payload = json.loads(target.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + return None + + players: list[Player] = [] + for row in payload.get("players", []): + try: + players.append( + Player( + name=row["name"], + position=Position(row["position"]), + fp_ppr=float(row["fp_ppr"]), + fp_half=float(row["fp_half"]), + fp_std=float(row["fp_std"]), + team=row.get("team", ""), + adp=float(row.get("adp", 999)), + ) + ) + except (KeyError, ValueError): + continue + return players if players else None + + +def cache_is_fresh(path: Path | None = None, max_age_hours: int = CACHE_MAX_AGE_HOURS) -> bool: + target = _cache_path(path) + if not target.exists(): + return False + try: + payload = json.loads(target.read_text(encoding="utf-8")) + fetched_at = datetime.fromisoformat(payload["fetched_at"]) + age_hours = (datetime.now(timezone.utc) - fetched_at).total_seconds() / 3600 + return age_hours < max_age_hours + except (json.JSONDecodeError, OSError, KeyError, ValueError): + return False + + +def load_active_players( + *, + force_refresh: bool = False, + cache_path: Path | None = None, +) -> list[Player]: + path = _cache_path(cache_path) + if not force_refresh and cache_is_fresh(path): + cached = load_player_cache(path) + if cached: + return cached + + players = fetch_active_players() + save_player_cache(players, path) + return players diff --git a/draft/ffanalytics_loader.py b/draft/ffanalytics_loader.py new file mode 100644 index 0000000..6e008ad --- /dev/null +++ b/draft/ffanalytics_loader.py @@ -0,0 +1,500 @@ +"""Load fantasy player projections and ADP via R ffanalytics (rpy2).""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + +from config import settings +from draft.models import Player, Position +from draft.sleeper_loader import get_sleeper_rankings, lookup_sleeper_rank + +CACHE_PATH = settings.project_root / "data" / "ffanalytics_players.json" +CACHE_MAX_AGE_HOURS = 24 + +DRAFT_POSITIONS = ("QB", "RB", "WR", "TE", "K") +DEFAULT_SCRAPE_SOURCES = ("FantasyPros", "ESPN", "CBS", "Yahoo") +DEFAULT_ADP_SOURCES = ("ESPN", "Yahoo", "CBS", "NFL") + +POSITION_MAP: dict[str, Position] = { + "QB": Position.QB, + "RB": Position.RB, + "WR": Position.WR, + "TE": Position.TE, + "K": Position.K, + "PK": Position.K, +} + + +def _cache_path(path: Any = None) -> Any: + return path or CACHE_PATH + + +def current_nfl_season() -> int: + """Season year used by ffanalytics scrape_data (Mar+ = current calendar year).""" + now = datetime.now() + return now.year if now.month >= 3 else now.year - 1 + + +def _require_rpy2(): + try: + import rpy2.robjects as ro # noqa: F401 + from rpy2.robjects.packages import importr # noqa: F401 + except ImportError as exc: + raise RuntimeError( + "rpy2 is not installed. Run: pip install rpy2" + ) from exc + return ro, importr + + +def _to_python_scalar(value: Any) -> Any: + """Coerce pandas/numpy scalars to plain Python types for merge logic.""" + if value is None: + return None + type_name = type(value).__module__ + if type_name == "numpy": + return value.item() if hasattr(value, "item") else value + try: + if value != value: # NaN + return None + except (TypeError, ValueError): + pass + return value + + +def _r_dataframe_to_records(r_df: Any) -> list[dict[str, Any]]: + import rpy2.robjects as ro + + pdf = ro.conversion.rpy2py(r_df) + if pdf is None or len(pdf) == 0: + return [] + records = pdf.to_dict(orient="records") + return [ + {key: _to_python_scalar(val) for key, val in row.items()} + for row in records + ] + + +def _pick_column(row: dict[str, Any], *candidates: str) -> Any: + for key in candidates: + if key in row and row[key] is not None: + val = row[key] + try: + if val != val: # NaN + continue + except TypeError: + pass + return val + return None + + +def _float_or(value: Any, default: float = 0.0) -> float: + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _player_name_from_row(row: dict[str, Any]) -> str | None: + name = _pick_column(row, "player_name", "player", "name", "Player") + if name: + return str(name).strip() + first = _pick_column(row, "first_name") + last = _pick_column(row, "last_name") + if first or last: + return f"{first or ''} {last or ''}".strip() + return None + + +def _adp_from_row(row: dict[str, Any]) -> float | None: + """ADP from ffanalytics get_adp (avg column or adp_ columns).""" + avg = _pick_column(row, "avg", "average", "mean") + if avg is not None: + val = _float_or(avg, 0.0) + return val if val > 0 else None + values: list[float] = [] + for key, raw in row.items(): + if not str(key).startswith("adp_"): + continue + val = _float_or(raw, 0.0) + if val > 0: + values.append(val) + if not values: + return None + return sum(values) / len(values) + + +def _per_source_adp_from_row(row: dict[str, Any]) -> dict[str, float]: + """Extract per-source ADP values from ffanalytics get_adp row.""" + result: dict[str, float] = {} + + for key, raw in row.items(): + key_lower = str(key).lower() + if key_lower.startswith("adp_"): + source = key_lower[4:] + val = _float_or(raw, 0.0) + if val > 0: + result[source] = val + + return result + + +def parse_projection_row( + row: dict[str, Any], + *, + points: float, + adp_by_id: dict[str, float], + adp_by_name: dict[str, float], + espn_adp_by_id: dict[str, float] | None = None, + espn_adp_by_name: dict[str, float] | None = None, + yahoo_adp_by_id: dict[str, float] | None = None, + yahoo_adp_by_name: dict[str, float] | None = None, +) -> Player | None: + pos_raw = _pick_column(row, "pos", "position") + if pos_raw is None: + return None + position = POSITION_MAP.get(str(pos_raw).upper()) + if position is None: + return None + + name = _player_name_from_row(row) + if not name: + return None + + team = str(_pick_column(row, "team", "nfl_team") or "").strip() + player_id = _pick_column(row, "id", "player_id") + pid_str = str(player_id) if player_id is not None else None + name_lower = name.lower() + + adp = 999.0 + if pid_str is not None: + adp = adp_by_id.get(pid_str, adp) + if adp >= 999.0: + adp = adp_by_name.get(name_lower, adp) + + rank_espn = 999.0 + if espn_adp_by_id and pid_str: + rank_espn = espn_adp_by_id.get(pid_str, rank_espn) + if rank_espn >= 999.0 and espn_adp_by_name: + rank_espn = espn_adp_by_name.get(name_lower, rank_espn) + + rank_yahoo = 999.0 + if yahoo_adp_by_id and pid_str: + rank_yahoo = yahoo_adp_by_id.get(pid_str, rank_yahoo) + if rank_yahoo >= 999.0 and yahoo_adp_by_name: + rank_yahoo = yahoo_adp_by_name.get(name_lower, rank_yahoo) + + pts = round(points, 1) + return Player( + name=name, + position=position, + fp_ppr=pts, + fp_half=pts, + fp_std=pts, + team=team, + adp=adp, + rank_espn=rank_espn, + rank_yahoo=rank_yahoo, + rank_sleeper=999.0, + ) + + +def merge_projection_tables( + std_rows: list[dict[str, Any]], + half_rows: list[dict[str, Any]], + ppr_rows: list[dict[str, Any]], + adp_rows: list[dict[str, Any]], +) -> list[Player]: + """Combine standard, half-PPR, and PPR projection rows into Player models.""" + adp_by_id: dict[str, float] = {} + adp_by_name: dict[str, float] = {} + espn_adp_by_id: dict[str, float] = {} + espn_adp_by_name: dict[str, float] = {} + yahoo_adp_by_id: dict[str, float] = {} + yahoo_adp_by_name: dict[str, float] = {} + + for row in adp_rows: + pid = _pick_column(row, "id", "player_id") + adp_val = _adp_from_row(row) + per_source = _per_source_adp_from_row(row) + + pid_str = str(pid) if pid is not None else None + name = _player_name_from_row(row) + name_lower = name.lower() if name else None + + if adp_val is not None: + if pid_str: + adp_by_id[pid_str] = adp_val + if name_lower: + adp_by_name[name_lower] = adp_val + + espn_val = per_source.get("espn") + if espn_val is not None: + if pid_str: + espn_adp_by_id[pid_str] = espn_val + if name_lower: + espn_adp_by_name[name_lower] = espn_val + + yahoo_val = per_source.get("yahoo") + if yahoo_val is not None: + if pid_str: + yahoo_adp_by_id[pid_str] = yahoo_val + if name_lower: + yahoo_adp_by_name[name_lower] = yahoo_val + + def index_by_id(rows: list[dict[str, Any]]) -> dict[str, tuple[dict[str, Any], float]]: + out: dict[str, tuple[dict[str, Any], float]] = {} + for row in rows: + pid = _pick_column(row, "id", "player_id") + if pid is None: + continue + pts = _float_or(_pick_column(row, "points", "avg_points", "fpts")) + if pts <= 0: + continue + out[str(pid)] = (row, pts) + return out + + std_idx = index_by_id(std_rows) + half_idx = index_by_id(half_rows) + ppr_idx = index_by_id(ppr_rows) + + all_ids = set(std_idx) | set(half_idx) | set(ppr_idx) + players: list[Player] = [] + + for pid in all_ids: + base_row, _ = std_idx.get(pid) or half_idx.get(pid) or ppr_idx.get(pid) or ({}, 0.0) + if not base_row: + continue + fp_std = std_idx.get(pid, (base_row, 0.0))[1] + fp_half = half_idx.get(pid, (base_row, fp_std))[1] + fp_ppr = ppr_idx.get(pid, (base_row, fp_std))[1] + if fp_std <= 0 and fp_half <= 0 and fp_ppr <= 0: + continue + + player = parse_projection_row( + base_row, + points=fp_ppr or fp_half or fp_std, + adp_by_id=adp_by_id, + adp_by_name=adp_by_name, + espn_adp_by_id=espn_adp_by_id, + espn_adp_by_name=espn_adp_by_name, + yahoo_adp_by_id=yahoo_adp_by_id, + yahoo_adp_by_name=yahoo_adp_by_name, + ) + if player is None: + continue + player.fp_std = round(fp_std or fp_half or fp_ppr, 1) + player.fp_half = round(fp_half or fp_std or fp_ppr, 1) + player.fp_ppr = round(fp_ppr or fp_half or fp_std, 1) + players.append(player) + + players.sort(key=lambda p: p.adp if p.adp < 999 else 9999) + for rank, player in enumerate(players, start=1): + if player.adp >= 999: + player.adp = float(rank) + return players + + +def fetch_ffanalytics_players( + *, + season: int | None = None, + scrape_sources: tuple[str, ...] = DEFAULT_SCRAPE_SOURCES, + adp_sources: tuple[str, ...] = DEFAULT_ADP_SOURCES, +) -> list[Player]: + """Scrape projections (3 scoring formats) and ADP from ffanalytics.""" + from draft.rpy2_setup import rpy2_session + + with rpy2_session(): + return _fetch_ffanalytics_players_impl( + season=season, + scrape_sources=scrape_sources, + adp_sources=adp_sources, + ) + + +def _run_ffanalytics_pipeline_r( + ro: Any, + *, + season: int, + scrape_sources: tuple[str, ...], + adp_sources: tuple[str, ...], +) -> tuple[Any, Any, Any, Any]: + """ + Run scrape → projections (3 formats) → ADP entirely in R. + + Intermediate objects must not round-trip through rpy2; that drops attributes + (e.g. season/year) and breaks projections_table(). + """ + from rpy2.robjects.vectors import IntVector, StrVector + + ro.globalenv[".statshift_src"] = StrVector(list(scrape_sources)) + ro.globalenv[".statshift_adp_src"] = StrVector(list(adp_sources)) + ro.globalenv[".statshift_season"] = IntVector([season]) + ro.globalenv[".statshift_pos"] = StrVector(list(DRAFT_POSITIONS)) + + ro.r( + """ + { + library(ffanalytics) + build_scoring <- function(rec_pts) { + s <- unserialize(serialize(ffanalytics::scoring, NULL)) + s$rec$rec <- rec_pts + s + } + scraped <- scrape_data( + src = .statshift_src, + pos = .statshift_pos, + season = .statshift_season, + week = 0L + ) + .statshift_proj_std <<- add_player_info( + projections_table(scraped, scoring_rules = build_scoring(0)) + ) + .statshift_proj_half <<- add_player_info( + projections_table(scraped, scoring_rules = build_scoring(0.5)) + ) + .statshift_proj_ppr <<- add_player_info( + projections_table(scraped, scoring_rules = build_scoring(1)) + ) + .statshift_adp <<- get_adp(sources = .statshift_adp_src) + invisible(NULL) + } + """ + ) + return ( + ro.globalenv[".statshift_proj_std"], + ro.globalenv[".statshift_proj_half"], + ro.globalenv[".statshift_proj_ppr"], + ro.globalenv[".statshift_adp"], + ) + + +def _fetch_ffanalytics_players_impl( + *, + season: int | None, + scrape_sources: tuple[str, ...], + adp_sources: tuple[str, ...], +) -> list[Player]: + ro, _importr = _require_rpy2() + season = season or current_nfl_season() + + proj_std, proj_half, proj_ppr, adp_df = _run_ffanalytics_pipeline_r( + ro, + season=season, + scrape_sources=scrape_sources, + adp_sources=adp_sources, + ) + + players = merge_projection_tables( + _r_dataframe_to_records(proj_std), + _r_dataframe_to_records(proj_half), + _r_dataframe_to_records(proj_ppr), + _r_dataframe_to_records(adp_df), + ) + + _populate_sleeper_rankings(players) + + players.sort(key=lambda p: p.fp_ppr, reverse=True) + return players + + +def _populate_sleeper_rankings(players: list[Player]) -> None: + """Fetch Sleeper rankings and populate rank_sleeper for each player.""" + try: + sleeper_rankings = get_sleeper_rankings() + except Exception: + return + + for player in players: + player.rank_sleeper = lookup_sleeper_rank(player.name, sleeper_rankings) + + +def save_player_cache(players: list[Player], path: Any = None) -> Any: + target = _cache_path(path) + target.parent.mkdir(parents=True, exist_ok=True) + payload = { + "fetched_at": datetime.now(timezone.utc).isoformat(), + "season": current_nfl_season(), + "source": "ffanalytics", + "count": len(players), + "players": [ + { + "name": p.name, + "position": p.position.value, + "fp_ppr": p.fp_ppr, + "fp_half": p.fp_half, + "fp_std": p.fp_std, + "team": p.team, + "adp": p.adp, + "rank_espn": p.rank_espn, + "rank_yahoo": p.rank_yahoo, + "rank_sleeper": p.rank_sleeper, + } + for p in players + ], + } + target.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return target + + +def load_player_cache(path: Any = None) -> list[Player] | None: + target = _cache_path(path) + if not target.exists(): + return None + try: + payload = json.loads(target.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + return None + + players: list[Player] = [] + for row in payload.get("players", []): + try: + players.append( + Player( + name=row["name"], + position=Position(row["position"]), + fp_ppr=float(row["fp_ppr"]), + fp_half=float(row["fp_half"]), + fp_std=float(row["fp_std"]), + team=row.get("team", ""), + adp=float(row.get("adp", 999)), + rank_espn=float(row.get("rank_espn", 999)), + rank_yahoo=float(row.get("rank_yahoo", 999)), + rank_sleeper=float(row.get("rank_sleeper", 999)), + ) + ) + except (KeyError, ValueError): + continue + return players if players else None + + +def cache_is_fresh(path: Any = None, max_age_hours: int = CACHE_MAX_AGE_HOURS) -> bool: + target = _cache_path(path) + if not target.exists(): + return False + try: + payload = json.loads(target.read_text(encoding="utf-8")) + fetched_at = datetime.fromisoformat(payload["fetched_at"]) + age_hours = (datetime.now(timezone.utc) - fetched_at).total_seconds() / 3600 + return age_hours < max_age_hours + except (json.JSONDecodeError, OSError, KeyError, ValueError): + return False + + +def load_active_players( + *, + force_refresh: bool = False, + cache_path: Any = None, +) -> list[Player]: + path = _cache_path(cache_path) + if not force_refresh and cache_is_fresh(path): + cached = load_player_cache(path) + if cached: + return cached + + players = fetch_ffanalytics_players() + save_player_cache(players, path) + return players diff --git a/draft/models.py b/draft/models.py new file mode 100644 index 0000000..bedcf04 --- /dev/null +++ b/draft/models.py @@ -0,0 +1,139 @@ +"""Data models for mock fantasy drafts.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + + +class ScoringFormat(str, Enum): + PPR = "PPR" + HALF_PPR = "Half-PPR" + STANDARD = "Standard" + + +class RankingSource(str, Enum): + YAHOO = "Yahoo" + ESPN = "ESPN" + SLEEPER = "Sleeper" + + +class Position(str, Enum): + QB = "QB" + RB = "RB" + WR = "WR" + TE = "TE" + K = "K" + DEF = "DEF" + + +# Sleeper-style position colors: (background, text) +POSITION_COLORS: dict[str, tuple[str, str]] = { + "QB": ("#523C87", "#FFFFFF"), + "RB": ("#2E7D5E", "#FFFFFF"), + "WR": ("#2D7A4F", "#FFFFFF"), + "TE": ("#C75724", "#FFFFFF"), + "K": ("#6B6B6B", "#FFFFFF"), + "DEF": ("#C72424", "#FFFFFF"), +} + + +ROSTER_SLOTS: dict[str, int] = { + "QB": 1, + "RB": 2, + "WR": 2, + "TE": 1, + "FLEX": 1, + "K": 1, + "DEF": 1, +} + +TEAM_NAME_POOL: list[str] = [ + "Gridiron Gladiators", + "Sunday Scaries", + "The Algorithm", + "Waiver Wire Wizards", + "Trade Block Party", + "Bench Warmers", + "Fantasy Fiends", + "TD Chasers", + "Point Projectors", + "Sleeper Agents", + "Red Zone Rockets", + "Fourth Quarter Comeback", + "Bye Week Blues", + "Commissioner's Picks", + "Draft Day Disasters", + "Playoff Bound", + "Injury Reserve", + "Monday Night Mayhem", + "Touchdown Titans", + "Roster Roulette", +] + + +@dataclass(frozen=True) +class LeagueSettings: + scoring: ScoringFormat + league_size: int + draft_slot: int + rounds: int = 15 + + def __post_init__(self) -> None: + if self.league_size < 4 or self.league_size > 16: + raise ValueError("league_size must be between 4 and 16") + if self.draft_slot < 1 or self.draft_slot > self.league_size: + raise ValueError("draft_slot must be between 1 and league_size") + if self.rounds < 1: + raise ValueError("rounds must be at least 1") + + +@dataclass +class Player: + name: str + position: Position + fp_ppr: float + fp_half: float + fp_std: float + team: str = "" + adp: float = 999.0 + rank_espn: float = 999.0 + rank_yahoo: float = 999.0 + rank_sleeper: float = 999.0 + + def fantasy_points(self, scoring: ScoringFormat) -> float: + if scoring == ScoringFormat.PPR: + return self.fp_ppr + if scoring == ScoringFormat.HALF_PPR: + return self.fp_half + return self.fp_std + + def rank_for_source(self, source: RankingSource) -> float: + if source == RankingSource.ESPN: + return self.rank_espn + if source == RankingSource.YAHOO: + return self.rank_yahoo + return self.rank_sleeper + + +@dataclass +class DraftPick: + round: int + pick_in_round: int + overall: int + team_index: int + player: Player + + +@dataclass +class TeamRoster: + team_index: int + picks: list[Player] = field(default_factory=list) + + def count(self, position: Position) -> int: + return sum(1 for p in self.picks if p.position == position) + + def flex_eligible_count(self) -> int: + return sum( + 1 for p in self.picks if p.position in (Position.RB, Position.WR, Position.TE) + ) diff --git a/draft/player_pool.py b/draft/player_pool.py new file mode 100644 index 0000000..b94c41c --- /dev/null +++ b/draft/player_pool.py @@ -0,0 +1,60 @@ +"""Player pool for mock drafts — loaded from ffanalytics via rpy2.""" + +from __future__ import annotations + +from draft.ffanalytics_loader import ( + CACHE_PATH, + cache_is_fresh, + load_active_players, + load_player_cache, +) +from draft.models import Player + +_players: list[Player] | None = None + + +def get_players(*, force_refresh: bool = False) -> list[Player]: + """Return the draftable player pool (cached ffanalytics projections + ADP).""" + global _players + if force_refresh or _players is None: + _players = load_active_players(force_refresh=force_refresh) + return _players + + +def player_pool_size() -> int: + if _players is not None: + return len(_players) + cached = load_player_cache() + if cached: + return len(cached) + return 0 + + +def max_rounds_for_league(league_size: int) -> int: + """Maximum draft rounds supported without running out of players.""" + size = player_pool_size() + if size == 0: + return 15 + return size // max(league_size, 1) + + +# Lazy default for imports; call refresh_players() before draft in UI +PLAYERS: list[Player] = [] + + +def refresh_players(*, force_refresh: bool = False) -> list[Player]: + """Load or refresh pool and update module-level PLAYERS.""" + global PLAYERS, _players + _players = load_active_players(force_refresh=force_refresh) + PLAYERS = _players + return _players + + +def cache_status() -> str: + if cache_is_fresh(): + cached = load_player_cache() + count = len(cached) if cached else 0 + return f"ffanalytics cache ({count} players) at {CACHE_PATH.name}" + if CACHE_PATH.exists(): + return "ffanalytics cache stale — will refresh on next load" + return "No ffanalytics cache — will scrape on first Mock Draft load" diff --git a/draft/rpy2_setup.py b/draft/rpy2_setup.py new file mode 100644 index 0000000..94f1c14 --- /dev/null +++ b/draft/rpy2_setup.py @@ -0,0 +1,51 @@ +"""Initialize rpy2 conversion rules for Streamlit (main thread + context manager).""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Iterator + +_converter: Any = None + + +def get_rpy2_converter() -> Any: + """Build (once) the default + pandas conversion bundle.""" + global _converter + if _converter is not None: + return _converter + + from rpy2.robjects import default_converter, numpy2ri, pandas2ri + + _converter = default_converter + numpy2ri.converter + pandas2ri.converter + return _converter + + +def init_rpy2_on_main_thread() -> Any: + """ + Call once from Streamlit's main script thread before any R work. + + Sets global conversion rules and marks this thread as R's init thread. + """ + import rpy2.rinterface_lib.embedded as embedded + from rpy2.robjects import conversion + + if not embedded.isinitialized(): + embedded.set_init_thread() + + converter = get_rpy2_converter() + conversion.set_conversion(converter) + return converter + + +@contextmanager +def rpy2_session() -> Iterator[Any]: + """ + Run rpy2/R code with conversion rules applied. + + Use around every block that calls into R from Streamlit reruns or worker threads. + """ + from rpy2.robjects.conversion import localconverter + + converter = get_rpy2_converter() + with localconverter(converter): + yield converter diff --git a/draft/sleeper_loader.py b/draft/sleeper_loader.py new file mode 100644 index 0000000..960044a --- /dev/null +++ b/draft/sleeper_loader.py @@ -0,0 +1,126 @@ +"""Load player rankings from the Sleeper API.""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + +import requests + +from config import settings + +SLEEPER_PLAYERS_URL = "https://api.sleeper.app/v1/players/nfl" +CACHE_PATH = settings.project_root / "data" / "sleeper_players.json" +CACHE_MAX_AGE_HOURS = 24 + + +def _normalize_name(name: str) -> str: + """Normalize player name for matching: lowercase, strip suffixes.""" + name = name.lower().strip() + for suffix in (" jr", " sr", " ii", " iii", " iv", " v"): + if name.endswith(suffix): + name = name[: -len(suffix)].strip() + return name + + +def fetch_sleeper_players() -> dict[str, Any]: + """Fetch all NFL players from Sleeper API.""" + resp = requests.get(SLEEPER_PLAYERS_URL, timeout=30) + resp.raise_for_status() + return resp.json() + + +def parse_sleeper_rankings(raw: dict[str, Any]) -> dict[str, int]: + """ + Extract player name -> search_rank mapping from Sleeper data. + + Returns a dict mapping normalized player names to their search_rank. + Lower search_rank = higher ranked player. + """ + rankings: dict[str, int] = {} + + for player_id, data in raw.items(): + if not isinstance(data, dict): + continue + + position = data.get("position", "") + if position not in ("QB", "RB", "WR", "TE", "K", "DEF"): + continue + + search_rank = data.get("search_rank") + if search_rank is None or not isinstance(search_rank, (int, float)): + continue + + first_name = data.get("first_name", "") or "" + last_name = data.get("last_name", "") or "" + full_name = f"{first_name} {last_name}".strip() + + if not full_name: + continue + + normalized = _normalize_name(full_name) + current = rankings.get(normalized) + if current is None or search_rank < current: + rankings[normalized] = int(search_rank) + + return rankings + + +def save_sleeper_cache(rankings: dict[str, int], path: Any = None) -> Any: + """Save Sleeper rankings to cache file.""" + target = path or CACHE_PATH + target.parent.mkdir(parents=True, exist_ok=True) + + payload = { + "fetched_at": datetime.now(timezone.utc).isoformat(), + "source": "sleeper", + "count": len(rankings), + "rankings": rankings, + } + target.write_text(json.dumps(payload, indent=2), encoding="utf-8") + return target + + +def load_sleeper_cache(path: Any = None) -> dict[str, int] | None: + """Load cached Sleeper rankings if available and fresh.""" + target = path or CACHE_PATH + if not target.exists(): + return None + + try: + payload = json.loads(target.read_text(encoding="utf-8")) + fetched_at = datetime.fromisoformat(payload["fetched_at"]) + age_hours = (datetime.now(timezone.utc) - fetched_at).total_seconds() / 3600 + + if age_hours >= CACHE_MAX_AGE_HOURS: + return None + + return payload.get("rankings", {}) + except (json.JSONDecodeError, OSError, KeyError, ValueError): + return None + + +def get_sleeper_rankings(*, force_refresh: bool = False) -> dict[str, int]: + """ + Get Sleeper search_rank for all NFL players. + + Returns a dict mapping normalized player names to their search_rank. + Uses cached data if available and fresh (< 24 hours old). + """ + if not force_refresh: + cached = load_sleeper_cache() + if cached is not None: + return cached + + raw = fetch_sleeper_players() + rankings = parse_sleeper_rankings(raw) + save_sleeper_cache(rankings) + return rankings + + +def lookup_sleeper_rank(name: str, rankings: dict[str, int]) -> float: + """Look up a player's Sleeper rank by name, returning 999.0 if not found.""" + normalized = _normalize_name(name) + rank = rankings.get(normalized) + return float(rank) if rank is not None else 999.0 diff --git a/rag/__init__.py b/rag/__init__.py index c317619..0212f36 100644 --- a/rag/__init__.py +++ b/rag/__init__.py @@ -1,3 +1,10 @@ from rag.engine import RAGEngine, RAGResult +from rag.intent import IntentResult, PromptIntent, detect_prompt_intent -__all__ = ["RAGEngine", "RAGResult"] +__all__ = [ + "RAGEngine", + "RAGResult", + "IntentResult", + "PromptIntent", + "detect_prompt_intent", +] diff --git a/rag/engine.py b/rag/engine.py index 9aacd6f..05ab737 100644 --- a/rag/engine.py +++ b/rag/engine.py @@ -1,13 +1,14 @@ -"""RAG orchestration: retrieve from API, generate with Ollama/Gemma.""" +"""RAG orchestration: route by intent to API retrieval or Gemma chat.""" from __future__ import annotations from dataclasses import dataclass -from typing import Any +from typing import Any, Literal import httpx from config import settings +from rag.intent import IntentResult, PromptIntent, detect_prompt_intent from rag.ollama_client import OllamaClient, OllamaError @@ -15,11 +16,17 @@ class APIError(RuntimeError): pass +Route = Literal["api", "gemma"] + + @dataclass class RAGResult: answer: str sources: list[dict[str, Any]] prompt: str + intent: PromptIntent + route: Route + intent_reason: str class RAGEngine: @@ -57,29 +64,61 @@ def _format_context(sources: list[dict[str, Any]]) -> str: return "\n\n".join(blocks) @staticmethod - def _build_prompt(query: str, context: str) -> str: - return f"""You are StatShift, a local basketball analytics assistant. -Answer using ONLY the context below. If the context is insufficient, say what is missing. -Keep answers concise and cite source numbers like [1] when you use a fact. - -Context: -{context} - -Question: {query} + def _answer_from_sources(sources: list[dict[str, Any]]) -> str: + if not sources: + return ( + "No matching records were found in the StatShift database for that question. " + "Try rephrasing with a player, team, week, or season." + ) + lines = ["Based on stored records:\n"] + for idx, doc in enumerate(sources, start=1): + lines.append( + f"[{idx}] **{doc['title']}** ({doc['category']})\n{doc['content']}" + ) + return "\n\n".join(lines) -Answer:""" + @staticmethod + def _build_conversational_prompt(query: str) -> str: + return f"""You are StatShift, a friendly fantasy football analytics assistant. +Answer the user's message helpfully. You may use general NFL and fantasy football knowledge. +If they ask for specific stats from the StatShift database, suggest they ask a direct factual question. + +User: {query} + +Assistant:""" + + def ask(self, query: str, *, intent: IntentResult | None = None) -> RAGResult: + resolved = intent or detect_prompt_intent(query) + + if resolved.intent == PromptIntent.DEFINITIVE: + sources = self._search(query) + return RAGResult( + answer=self._answer_from_sources(sources), + sources=sources, + prompt="", + intent=resolved.intent, + route="api", + intent_reason=resolved.reason, + ) - def ask(self, query: str) -> RAGResult: - sources = self._search(query) - context = self._format_context(sources) - prompt = self._build_prompt(query, context) + prompt = self._build_conversational_prompt(query) try: answer = self.ollama.generate(prompt) + except OllamaError: + # Re-raise OllamaError as-is (already has specific message) + raise except httpx.HTTPError as exc: raise OllamaError( - "Could not reach Ollama. Start it locally and ensure Gemma is pulled." + f"Network error: {type(exc).__name__} - {str(exc)}" ) from exc - return RAGResult(answer=answer, sources=sources, prompt=prompt) + return RAGResult( + answer=answer, + sources=[], + prompt=prompt, + intent=resolved.intent, + route="gemma", + intent_reason=resolved.reason, + ) def health(self) -> dict[str, Any]: api_ok = False diff --git a/rag/intent.py b/rag/intent.py new file mode 100644 index 0000000..97287b8 --- /dev/null +++ b/rag/intent.py @@ -0,0 +1,121 @@ +"""Heuristic intent detection for user prompts.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum + + +class PromptIntent(str, Enum): + """Whether a prompt should be answered from the API or via Gemma.""" + + DEFINITIVE = "definitive" + CONVERSATIONAL = "conversational" + + +@dataclass(frozen=True) +class IntentResult: + intent: PromptIntent + reason: str + + +_QUESTION_STARTERS = re.compile( + r"^\s*(" + r"what|who|whom|which|when|where|why|how|" + r"is|are|was|were|do|does|did|can|could|will|would|should|has|have|had" + r")\b", + re.IGNORECASE, +) + +_DEFINITIVE_PATTERNS: tuple[tuple[re.Pattern[str], str], ...] = ( + (re.compile(r"\b(how many|how much|how often)\b", re.I), "quantity question"), + (re.compile(r"\b(what was|what is|what were|what are)\b", re.I), "fact lookup"), + (re.compile(r"\b(who (led|won|scored|averaged|had|is|was|threw))\b", re.I), "who question"), + (re.compile(r"\b(how did|how does|how do)\b", re.I), "performance question"), + (re.compile(r"\b(average[ds]?|stats?|statline|fantasy points?|fppg|ppg)\b", re.I), "stats terms"), + (re.compile(r"\b(yards?|touchdowns?|tds?|receptions?|targets?|carries|attempts?)\b", re.I), "football counting stats"), + (re.compile(r"\b(ppr|half[- ]?ppr|standard scoring|snap share|target share|red[- ]?zone)\b", re.I), "fantasy terms"), + (re.compile(r"\b(week \d{1,2}|matchup|projection|start[- ]?sit)\b", re.I), "weekly fantasy context"), + (re.compile(r"\b(injury|injured|questionable|doubtful|out|ir|trade[d]?|waiver)\b", re.I), "injury/roster terms"), + (re.compile(r"\b\d{4}\s*(season|stats?)?\b"), "season reference"), + (re.compile(r"\b\d+(\.\d+)?\s*%"), "percentage"), + (re.compile(r"\b\d+(\.\d+)?\s*(yds|yards|tds|rec|tgt|car)\b", re.I), "rate stat"), +) + +_CONVERSATIONAL_PATTERNS: tuple[tuple[re.Pattern[str], str], ...] = ( + (re.compile(r"^\s*(hi|hello|hey|thanks|thank you)\b", re.I), "greeting"), + (re.compile(r"\b(what do you think|in your opinion|do you believe)\b", re.I), "opinion"), + (re.compile(r"\b(should i draft|would you draft|dynasty value)\b", re.I), "draft advice"), + (re.compile(r"\b(explain like|eli5|tell me a story|write a|poem|joke)\b", re.I), "creative"), + (re.compile(r"\b(brainstorm|ideas? for|help me (think|decide))\b", re.I), "open-ended help"), + (re.compile(r"\b(compare .+ (vs\.?|versus|or) .+ (opinion|better overall))\b", re.I), "subjective compare"), + (re.compile(r"\b(best (ever|of all time)|goat|all[- ]time greatest)\b", re.I), "subjective ranking"), + (re.compile(r"\b(why do you think|why is fantasy football)\b", re.I), "general discussion"), +) + + +def detect_prompt_intent(prompt: str) -> IntentResult: + """Classify whether a prompt needs a definitive DB-backed answer or open chat. + + Definitive prompts are routed to the read-only API (search/retrieval). + Conversational prompts are sent directly to Gemma without API retrieval. + """ + text = prompt.strip() + if not text: + return IntentResult( + PromptIntent.CONVERSATIONAL, + reason="empty prompt", + ) + + definitive_score = 0 + conversational_score = 0 + definitive_reasons: list[str] = [] + conversational_reasons: list[str] = [] + + for pattern, label in _CONVERSATIONAL_PATTERNS: + if pattern.search(text): + conversational_score += 4 + conversational_reasons.append(label) + + for pattern, label in _DEFINITIVE_PATTERNS: + if pattern.search(text): + definitive_score += 2 + definitive_reasons.append(label) + + if conversational_score == 0: + if text.endswith("?"): + definitive_score += 2 + definitive_reasons.append("question mark") + + if _QUESTION_STARTERS.search(text): + definitive_score += 2 + definitive_reasons.append("question starter") + + word_count = len(text.split()) + if word_count <= 12 and re.search( + r"\b(stats?|statline|overview|splits|profile|report|summary|projection)\b", + text, + re.I, + ): + definitive_score += 2 + definitive_reasons.append("stats lookup phrase") + + if definitive_score > conversational_score: + reason = ", ".join(definitive_reasons) or "factual question signals" + return IntentResult(PromptIntent.DEFINITIVE, reason=reason) + + if conversational_score > 0: + reason = ", ".join(conversational_reasons) + return IntentResult(PromptIntent.CONVERSATIONAL, reason=reason) + + if text.endswith("?") or _QUESTION_STARTERS.search(text): + return IntentResult( + PromptIntent.DEFINITIVE, + reason="unclassified question defaulting to API", + ) + + return IntentResult( + PromptIntent.CONVERSATIONAL, + reason="no definitive signals; defaulting to Gemma", + ) diff --git a/rag/ollama_client.py b/rag/ollama_client.py index 594894b..a139ac4 100644 --- a/rag/ollama_client.py +++ b/rag/ollama_client.py @@ -1,4 +1,4 @@ -"""Thin client for local Ollama (Gemma) inference.""" +"""Thin client for local Ollama inference (Gemma by default).""" from __future__ import annotations @@ -16,7 +16,7 @@ def __init__( self, base_url: str | None = None, model: str | None = None, - timeout: float = 120.0, + timeout: float = 300.0, # Increased to 5 minutes for initial model load ) -> None: self.base_url = (base_url or settings.ollama_base_url).rstrip("/") self.model = model or settings.ollama_model @@ -37,20 +37,37 @@ def list_models(self) -> list[str]: return [m["name"] for m in payload.get("models", [])] def generate(self, prompt: str, *, temperature: float = 0.2) -> str: - response = httpx.post( - f"{self.base_url}/api/generate", - json={ - "model": self.model, - "prompt": prompt, - "stream": False, - "options": {"temperature": temperature}, - }, - timeout=self.timeout, - ) - if response.status_code == 404: - raise OllamaError( - f"Model '{self.model}' not found. Pull it with: ollama pull {self.model}" + try: + response = httpx.post( + f"{self.base_url}/api/generate", + json={ + "model": self.model, + "prompt": prompt, + "stream": False, + "options": {"temperature": temperature}, + }, + timeout=self.timeout, ) - response.raise_for_status() - data = response.json() - return (data.get("response") or "").strip() + if response.status_code == 404: + raise OllamaError( + f"Model '{self.model}' not found. Pull it with: ollama pull {self.model}" + ) + response.raise_for_status() + data = response.json() + return (data.get("response") or "").strip() + except httpx.ConnectError as e: + raise OllamaError( + f"Cannot connect to Ollama at {self.base_url}. Error: {str(e)}" + ) from e + except httpx.TimeoutException as e: + raise OllamaError( + f"Timeout connecting to Ollama at {self.base_url}. Error: {str(e)}" + ) from e + except httpx.HTTPStatusError as e: + raise OllamaError( + f"HTTP error from Ollama (status {e.response.status_code}): {str(e)}" + ) from e + except httpx.HTTPError as e: + raise OllamaError( + f"Network error connecting to Ollama at {self.base_url}: {type(e).__name__} - {str(e)}" + ) from e diff --git a/requirements.txt b/requirements.txt index 92b5e20..44d7d33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,10 @@ fastapi>=0.115.0 uvicorn[standard]>=0.32.0 httpx>=0.27.0 streamlit>=1.40.0 +streamlit-autorefresh>=1.0.0 pydantic>=2.9.0 pydantic-settings>=2.6.0 +rpy2>=3.5.0 +scipy>=1.11.0 +matplotlib>=3.8.0 +numpy>=1.26.0 diff --git a/scripts/init_db.py b/scripts/init_db.py index 317e847..a92ff1f 100644 --- a/scripts/init_db.py +++ b/scripts/init_db.py @@ -14,75 +14,76 @@ DOCUMENTS = [ { - "title": "2024-25 Lakers team overview", - "category": "team", + "title": "Christian McCaffrey 2024 rushing workload", + "category": "player", "content": ( - "The Los Angeles Lakers finished the 2024-25 regular season 50-32. " - "They ranked 8th in offensive rating (115.4) and 12th in defensive rating (112.1). " - "Anthony Davis averaged 24.7 PPG, 12.6 RPG, and 2.4 BPG. " - "LeBron James averaged 25.3 PPG, 7.4 RPG, and 8.1 APG at age 40." + "Christian McCaffrey played 16 games in 2024 with 272 carries for 1,459 rushing yards " + "and 14 rushing TDs. He added 67 receptions for 564 yards and 7 receiving TDs in PPR leagues. " + "His 21.8 fantasy PPG ranked RB1 in season-long formats." ), }, { - "title": "Stephen Curry shooting splits", + "title": "Ja'Marr Chase 2024 target share", "category": "player", "content": ( - "Stephen Curry shot 45.2% from the field, 40.8% from three, and 92.1% from the line " - "in 2024-25. He attempted 11.8 threes per game and led the Warriors with 26.4 PPG. " - "His true shooting percentage was 62.8%." + "Ja'Marr Chase led the NFL with 175 targets in 2024, a 28.4% team target share. " + "He posted 127 receptions, 1,708 yards, and 12 TDs. " + "In PPR scoring he averaged 22.1 fantasy points per game." ), }, { - "title": "Celtics defensive profile", - "category": "team", + "title": "Week 12 RB matchup notes", + "category": "matchup", "content": ( - "Boston held opponents to 108.9 points per 100 possessions, 3rd in the league. " - "Jrue Holiday and Derrick White anchored perimeter defense. " - "Boston forced the 5th-most turnovers per game (15.2)." + "Week 12 featured several run-heavy game scripts. " + "Baltimore allowed 5.1 yards per carry to RBs (4th-most). " + "Derrick Henry projects as an RB1 with 18-22 expected touches against that front. " + "Green Bay limited RB receiving work to 3.8 targets per game over the prior four weeks." ), }, { - "title": "Victor Wembanyama sophomore season", - "category": "player", + "title": "2024 NFL passing yards leaders", + "category": "league", "content": ( - "Victor Wembanyama averaged 22.1 PPG, 10.8 RPG, 3.6 BPG, and 1.2 SPG in 2024-25. " - "San Antonio improved defensively when he was on the floor (+8.2 net rating swing). " - "He shot 48.9% inside the arc and 34.2% from three on low volume." + "Joe Burrow led the NFL with 4,918 passing yards in 2024. " + "Jared Goff (4,629) and Baker Mayfield (4,500) rounded out the top three. " + "League average team passing yards per game was 224.6." ), }, { - "title": "League pace and efficiency leaders", - "category": "league", + "title": "Travis Kelce red-zone usage 2024", + "category": "player", "content": ( - "League average pace was 99.2 possessions per 48 minutes in 2024-25. " - "Indiana played the fastest pace (102.4); New York played the slowest (96.1). " - "Denver led offensive rating (119.8); Charlotte ranked last defensively (118.4 DRtg)." + "Travis Kelce drew 22 red-zone targets in 2024, 8th among tight ends. " + "He scored 5 red-zone TDs on a 22.7% red-zone target share inside the 20. " + "His 12.4 fantasy PPG in half-PPR was TE3 in season-long ranks." ), }, { - "title": "Nikola Jokic playmaking", - "category": "player", + "title": "Buffalo Bills offensive pace 2024", + "category": "team", "content": ( - "Nikola Jokic recorded 9.2 APG with a 42.1% assist rate, highest among centers. " - "He posted a 31.8 PER and 14.2 win shares. Denver's offense was +12.4 per 100 with him on court." + "Buffalo averaged 64.2 plays per game (7th) and 3.04 drives leading to scores per game. " + "Josh Allen accounted for 42 total TDs (32 pass, 10 rush). " + "Bills pass rate over expected was +4.2% in neutral script situations." ), }, { - "title": "Injury report: March 2025", + "title": "Injury report: Week 15 2024", "category": "injury", "content": ( - "Kawhi Leonard missed 18 games with knee management. " - "Joel Embiid was limited to 35 games due to knee soreness. " - "Ja Morant returned in late February after a shoulder strain." + "Tyreek Hill was listed questionable with a wrist sprain but practiced fully Friday. " + "Joe Mixon missed practice Wednesday with an ankle issue, then returned limited Thursday. " + "Mark Andrews was ruled out with an ankle injury before Sunday kickoff." ), }, { - "title": "Trade deadline summary", - "category": "transaction", + "title": "PPR scoring and roster construction", + "category": "league", "content": ( - "At the 2025 deadline, Milwaukee acquired a two-way wing for second-round picks. " - "Phoenix moved a veteran point guard for cap flexibility. " - "No blockbuster deals moved the title odds more than 3% in consensus models." + "Standard PPR awards 1 point per reception. " + "In 2024, WRs with 8+ targets per game outscored WRs below 6 targets by 6.1 PPG on average. " + "Zero-RB builds that waited until rounds 4-5 for a first RB gained +0.8 PPG vs early-RB builds in best-ball data." ), }, ] diff --git a/scripts/pull-ollama-models.sh b/scripts/pull-ollama-models.sh new file mode 100755 index 0000000..d119222 --- /dev/null +++ b/scripts/pull-ollama-models.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# Pull Ollama models for StatShift (skips models that are already cached) + +set -e + +ensure_model() { + model_name="$1" + if ollama list | awk '{print $1}' | grep -Fxq "$model_name"; then + echo "✓ $model_name is already cached, skipping download." + else + echo "↓ Pulling $model_name..." + ollama pull "$model_name" + echo "✓ $model_name pulled successfully." + fi +} + +ensure_model "gemma2:latest" +ensure_model "llama3:latest" + +echo "All models ready!" diff --git a/scripts/sync_espn_players.py b/scripts/sync_espn_players.py new file mode 100644 index 0000000..fddec1a --- /dev/null +++ b/scripts/sync_espn_players.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +"""Download active NFL players from ESPN and cache for mock drafts.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from draft.espn_loader import CACHE_PATH, fetch_active_players, save_player_cache + + +def main() -> None: + print("Fetching active NFL players from ESPN (this may take 1–2 minutes)...") + players = fetch_active_players() + path = save_player_cache(players) + print(f"Saved {len(players)} draftable players to {path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sync_ffanalytics_players.py b/scripts/sync_ffanalytics_players.py new file mode 100644 index 0000000..a7a1fff --- /dev/null +++ b/scripts/sync_ffanalytics_players.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +"""Scrape ffanalytics projections/ADP and cache for mock drafts.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from draft.ffanalytics_loader import CACHE_PATH, fetch_ffanalytics_players, save_player_cache + + +def main() -> None: + print("Scraping projections and ADP via ffanalytics (may take several minutes)...") + players = fetch_ffanalytics_players() + path = save_player_cache(players) + print(f"Saved {len(players)} draftable players to {path}") + + +if __name__ == "__main__": + main() diff --git a/sdks/distribution_fit.py b/sdks/distribution_fit.py new file mode 100644 index 0000000..9d89967 --- /dev/null +++ b/sdks/distribution_fit.py @@ -0,0 +1,92 @@ +"""Distribution fitting for fantasy points analysis.""" + +from __future__ import annotations + +from typing import Any, Callable + +import numpy as np +from scipy import stats + + +CANDIDATE_DISTRIBUTIONS: list[tuple[str, Any]] = [ + ("Normal", stats.norm), + ("Gamma", stats.gamma), + ("Log-Normal", stats.lognorm), + ("Weibull", stats.weibull_min), +] + + +def fit_best_distribution( + data: list[float], +) -> tuple[str, dict[str, float], Callable[[np.ndarray], np.ndarray]]: + """Fit candidate distributions and return best fit by AIC. + + Args: + data: List of fantasy point values (non-zero values recommended) + + Returns: + Tuple of (distribution_name, parameters_dict, pdf_function) + """ + arr = np.array([x for x in data if x > 0]) + + if len(arr) < 5: + mean = float(np.mean(arr)) if len(arr) > 0 else 0.0 + std = float(np.std(arr)) if len(arr) > 1 else 1.0 + return ( + "Normal", + {"mean": mean, "std": std}, + lambda x, m=mean, s=max(std, 0.1): stats.norm.pdf(x, loc=m, scale=s), + ) + + best_name = "Normal" + best_params: dict[str, float] = {} + best_pdf: Callable[[np.ndarray], np.ndarray] = lambda x: stats.norm.pdf(x) + best_aic = float("inf") + + for name, dist in CANDIDATE_DISTRIBUTIONS: + try: + if name == "Normal": + params = dist.fit(arr) + loc, scale = params + pdf_fn = lambda x, d=dist, p=params: d.pdf(x, *p) + param_dict = {"mean": loc, "std": scale} + + elif name == "Gamma": + params = dist.fit(arr, floc=0) + a, loc, scale = params + pdf_fn = lambda x, d=dist, p=params: d.pdf(x, *p) + param_dict = {"shape": a, "scale": scale} + + elif name == "Log-Normal": + params = dist.fit(arr, floc=0) + s, loc, scale = params + pdf_fn = lambda x, d=dist, p=params: d.pdf(x, *p) + param_dict = {"s": s, "scale": scale} + + elif name == "Weibull": + params = dist.fit(arr, floc=0) + c, loc, scale = params + pdf_fn = lambda x, d=dist, p=params: d.pdf(x, *p) + param_dict = {"c": c, "scale": scale} + else: + continue + + log_likelihood = np.sum(np.log(dist.pdf(arr, *params) + 1e-10)) + k = len(params) + aic = 2 * k - 2 * log_likelihood + + if aic < best_aic: + best_aic = aic + best_name = name + best_params = param_dict + best_pdf = pdf_fn + + except Exception: + continue + + return best_name, best_params, best_pdf + + +def format_params_string(params: dict[str, float]) -> str: + """Format parameters dictionary as a readable string.""" + return ", ".join(f"{k}={v:.2f}" for k, v in params.items()) diff --git a/sdks/espnSDK.py b/sdks/espnSDK.py new file mode 100644 index 0000000..b952ce6 --- /dev/null +++ b/sdks/espnSDK.py @@ -0,0 +1,142 @@ +"""Thin client for ESPN's public NFL core API.""" + +from __future__ import annotations + +import platform +import socket +from types import TracebackType +from typing import Any, Iterator + +import httpx + + +class ESPNClient: + """Shared ESPN API settings (no HTTP session).""" + + base_url: str = "https://sports.core.api.espn.com/v2/sports/football/leagues/nfl/" + + def __init__( + self, + timeout: float = 30.0, + host_name: str | None = None, + host_hardware: str | None = None, + ) -> None: + self.host_name = host_name or socket.gethostname() + self.host_hardware = host_hardware or platform.machine() + self.name = f"{self.host_name}_{self.host_hardware}" + self.timeout = timeout + + +class ESPNEndpoint(ESPNClient): + """Call a single ESPN resource with optional query params and pagination.""" + + def __init__( + self, + endpoint: str, + timeout: float = 30.0, + client: httpx.Client | None = None, + host_name: str | None = None, + host_hardware: str | None = None, + ) -> None: + ESPNClient.__init__( + self, + timeout=timeout, + host_name=host_name, + host_hardware=host_hardware, + ) + self.endpoint = endpoint + self._owns_client = client is None + self.client = client or httpx.Client( + timeout=timeout, + follow_redirects=True, + headers={"User-Agent": f"StatShift/{self.name}"}, + ) + + def close(self) -> None: + if self._owns_client: + self.client.close() + + def __enter__(self) -> ESPNEndpoint: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def _endpoint_url(self) -> str: + path = self.endpoint.lstrip("/") + if path and not path.endswith("/"): + path = f"{path}/" + return f"{self.base_url}{path}" + + def get( + self, + url: str | None = None, + params: dict[str, Any] | None = None, + ) -> dict[str, Any]: + response = self.client.get(url or self._endpoint_url(), params=params) + response.raise_for_status() + return response.json() + + def get_page( + self, + *, + page: int = 1, + limit: int | None = None, + params: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Fetch one page. Pass filters via `params` (merged with page/limit).""" + query = dict(params or {}) + query["page"] = page + if limit is not None: + query["limit"] = limit + + return self.get(params=query) + + def iter_pages( + self, + *, + limit: int = 100, + params: dict[str, Any] | None = None, + max_pages: int | None = None, + ) -> Iterator[dict[str, Any]]: + """Yield each paginated JSON response until pageCount is exhausted.""" + page = 1 + pages_fetched = 0 + + while True: + payload = self.get_page(page=page, limit=limit, params=params) + yield payload + + pages_fetched += 1 + if max_pages is not None and pages_fetched >= max_pages: + break + + page_count = payload.get("pageCount", 1) + if page >= page_count: + break + page += 1 + + def iter_items( + self, + *, + limit: int = 100, + params: dict[str, Any] | None = None, + max_pages: int | None = None, + ) -> Iterator[dict[str, Any]]: + """Yield each item from `items` across pages (usually `{$ref: ...}` links).""" + for payload in self.iter_pages( + limit=limit, params=params, max_pages=max_pages + ): + yield from payload.get("items", []) + + def resolve_ref(self, item: dict[str, Any]) -> dict[str, Any]: + """Follow a hypermedia `$ref` link to load full resource JSON.""" + ref = item.get("$ref") + if not ref: + raise ValueError("Item has no $ref field") + return self.get(ref) diff --git a/sdks/espn_player_loader.py b/sdks/espn_player_loader.py new file mode 100644 index 0000000..e6f5a10 --- /dev/null +++ b/sdks/espn_player_loader.py @@ -0,0 +1,690 @@ +"""ESPN API client for player research data.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import httpx + +CORE_API_BASE = "https://sports.core.api.espn.com/v2/sports/football/leagues/nfl" +CFB_API_BASE = "https://sports.core.api.espn.com/v2/sports/football/leagues/college-football" +SITE_API_BASE = "https://site.api.espn.com/apis/site/v2/sports/football/nfl" +COMMON_SEARCH_URL = "https://site.api.espn.com/apis/common/v3/search" +TIMEOUT = 30.0 + +_COLLEGE_STAT_FIELDS: dict[str, str] = { + "gamesPlayed": "games_played", + "completions": "completions", + "netPassingYards": "passing_yards", + "passingTouchdowns": "passing_tds", + "interceptions": "interceptions", + "rushingAttempts": "rush_attempts", + "rushingYards": "rushing_yards", + "rushingTouchdowns": "rushing_tds", + "receptions": "receptions", + "receivingYards": "receiving_yards", + "receivingTouchdowns": "receiving_tds", +} + + +@dataclass +class PlayerProfile: + id: str + name: str + position: str + team: str + jersey: str + height: str + weight: str + age: int | None + birth_date: str + college: str + draft_info: str + draft_year: int | None + experience: int + headshot_url: str + status: str + + +@dataclass +class CollegeSeasonStats: + season: int + team: str + games_played: str + completions: str + passing_yards: str + passing_tds: str + interceptions: str + rush_attempts: str + rushing_yards: str + rushing_tds: str + receptions: str + receiving_yards: str + receiving_tds: str + + +@dataclass +class SeasonStats: + season: int + games_played: int + stats: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CombineMetrics: + year: int | None + forty_yard: float | None + vertical_jump: float | None + bench_press: int | None + broad_jump: float | None + three_cone: float | None + shuttle: float | None + + +@dataclass +class InjuryInfo: + status: str + injury_type: str + details: str + date: str + + +@dataclass +class NewsArticle: + headline: str + description: str + published: str + link: str + image_url: str | None + + +@dataclass +class GameLogEntry: + """Single game stats for fantasy point calculation.""" + week: int + opponent: str + result: str + passing_yards: int + passing_tds: int + interceptions: int + rushing_yards: int + rushing_tds: int + receptions: int + receiving_yards: int + receiving_tds: int + fumbles_lost: int + + +def _get_json(url: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + """Make a GET request and return JSON response.""" + with httpx.Client(timeout=TIMEOUT, follow_redirects=True) as client: + response = client.get(url, params=params) + response.raise_for_status() + return response.json() + + +def _safe_get(data: dict, *keys: str, default: Any = None) -> Any: + """Safely navigate nested dict keys.""" + for key in keys: + if not isinstance(data, dict): + return default + data = data.get(key, default) + if data is None: + return default + return data + + +def _format_height(inches: int | None) -> str: + """Convert height in inches to feet'inches format.""" + if not inches: + return "—" + feet = inches // 12 + remaining = inches % 12 + return f"{feet}'{remaining}\"" + + +def _format_weight(pounds: int | None) -> str: + """Format weight with lbs suffix.""" + if not pounds: + return "—" + return f"{pounds} lbs" + + +def nfl_season_has_started(season_year: int, *, now: datetime | None = None) -> bool: + """Return True if the NFL regular season for season_year has begun.""" + now = now or datetime.now() + if now.year > season_year: + return True + if now.year < season_year: + return False + return now.month >= 9 + + +def is_upcoming_rookie(draft_year: int | None, *, now: datetime | None = None) -> bool: + """True when drafted this calendar year and that NFL season has not started.""" + now = now or datetime.now() + if draft_year is None: + return False + return draft_year == now.year and not nfl_season_has_started(draft_year, now=now) + + +def _resolve_ref_field(ref_obj: Any, *, name_keys: tuple[str, ...] = ("name", "displayName", "abbreviation")) -> str: + """Resolve a display name from an embedded object or ESPN $ref.""" + if not ref_obj: + return "—" + if isinstance(ref_obj, dict): + for key in name_keys: + value = ref_obj.get(key) + if value: + return str(value) + ref_url = ref_obj.get("$ref") + if ref_url: + try: + data = _get_json(ref_url) + except httpx.HTTPError: + return "—" + for key in name_keys: + value = data.get(key) + if value: + return str(value) + return "—" + + +def _stat_display_value(categories: list[dict], stat_name: str) -> str: + for cat in categories: + for stat in cat.get("stats", []): + if stat.get("name") == stat_name: + return str(stat.get("displayValue", stat.get("value", "—"))) + return "—" + + +def _empty_college_season(season: int, team: str = "") -> CollegeSeasonStats: + return CollegeSeasonStats( + season=season, + team=team, + games_played="—", + completions="—", + passing_yards="—", + passing_tds="—", + interceptions="—", + rush_attempts="—", + rushing_yards="—", + rushing_tds="—", + receptions="—", + receiving_yards="—", + receiving_tds="—", + ) + + +def _college_season_from_categories(season: int, team: str, categories: list[dict]) -> CollegeSeasonStats: + row = _empty_college_season(season, team) + for stat_name, attr in _COLLEGE_STAT_FIELDS.items(): + value = _stat_display_value(categories, stat_name) + if value != "—": + setattr(row, attr, value) + return row + + +def _team_from_search_item(item: dict[str, Any]) -> str: + for rel in item.get("teamRelationships", []) or []: + if rel.get("type") == "team": + return ( + rel.get("displayName") + or _safe_get(rel, "core", "displayName", default="Free Agent") + ) + return "Free Agent" + + +def _headshot_from_search_item(item: dict[str, Any]) -> str: + headshot = item.get("headshot") + if isinstance(headshot, dict): + return headshot.get("href", "") or "" + player_id = item.get("id") + if player_id: + return f"https://a.espncdn.com/i/headshots/nfl/players/full/{player_id}.png" + return "" + + +def _position_for_player_id(player_id: str) -> str: + try: + data = _get_json(f"{CORE_API_BASE}/athletes/{player_id}") + return _safe_get(data, "position", "abbreviation", default="—") + except httpx.HTTPError: + return "—" + + +def search_players(query: str, limit: int = 10) -> list[dict[str, Any]]: + """Search active NFL players by name via ESPN common search API.""" + if not query or len(query.strip()) < 2: + return [] + + try: + data = _get_json( + COMMON_SEARCH_URL, + params={ + "query": query.strip(), + "limit": max(limit * 3, 25), + "type": "player", + "sport": "football", + "league": "nfl", + }, + ) + except httpx.HTTPError: + return [] + + results: list[dict[str, Any]] = [] + + for item in data.get("items", []): + if len(results) >= limit: + break + if not item.get("isActive") or item.get("isRetired"): + continue + + player_id = str(item.get("id", "")) + if not player_id: + continue + + results.append({ + "id": player_id, + "name": item.get("displayName", "Unknown"), + "position": _position_for_player_id(player_id), + "team": _team_from_search_item(item), + "headshot": _headshot_from_search_item(item), + "active": True, + "status": "Active", + }) + + return results + + +def get_player_profile(player_id: str) -> PlayerProfile | None: + """Get detailed player profile information.""" + url = f"{CORE_API_BASE}/athletes/{player_id}" + + try: + data = _get_json(url) + except httpx.HTTPError: + return None + + height_inches = data.get("height") + weight_pounds = data.get("weight") + + birth_date = data.get("dateOfBirth", "") + age = None + if birth_date: + try: + dob = datetime.fromisoformat(birth_date.replace("Z", "+00:00")) + age = (datetime.now(dob.tzinfo) - dob).days // 365 + except (ValueError, TypeError): + pass + + draft = data.get("draft", {}) + draft_info = "—" + draft_year: int | None = None + if draft: + draft_year = draft.get("year") + year = draft.get("year", "") + round_num = draft.get("round", "") + pick = draft.get("selection", "") + if year and round_num and pick: + draft_info = f"{year} Round {round_num}, Pick {pick}" + + return PlayerProfile( + id=str(data.get("id", "")), + name=data.get("displayName", "Unknown"), + position=_safe_get(data, "position", "displayName", default="—"), + team=_safe_get(data, "team", "displayName", default="Free Agent"), + jersey=str(data.get("jersey", "—")), + height=_format_height(height_inches), + weight=_format_weight(weight_pounds), + age=age, + birth_date=birth_date[:10] if birth_date else "—", + college=_resolve_ref_field(data.get("college")), + draft_info=draft_info, + draft_year=int(draft_year) if draft_year else None, + experience=data.get("experience", {}).get("years", 0), + headshot_url=_safe_get(data, "headshot", "href", default=""), + status=_safe_get(data, "status", "name", default="Active"), + ) + + +def get_player_college_stats(player_id: str) -> list[CollegeSeasonStats]: + """Fetch season-by-season college statistics for an NFL player.""" + try: + athlete = _get_json(f"{CORE_API_BASE}/athletes/{player_id}") + except httpx.HTTPError: + return [] + + college_athlete_ref = _safe_get(athlete, "collegeAthlete", "$ref", default="") + if not college_athlete_ref: + return [] + + athlete_id = college_athlete_ref.rstrip("/").split("/athletes/")[-1].split("?")[0] + try: + log_data = _get_json(f"{CFB_API_BASE}/athletes/{athlete_id}/statisticslog") + except httpx.HTTPError: + return [] + + seasons: list[CollegeSeasonStats] = [] + for entry in log_data.get("entries", []): + season_ref = _safe_get(entry, "season", "$ref", default="") + if not season_ref or "/seasons/" not in season_ref: + continue + try: + season_year = int(season_ref.split("/seasons/")[1].split("?")[0]) + except (IndexError, ValueError): + continue + + stats_ref = "" + team_name = "" + for stat_entry in entry.get("statistics", []): + if stat_entry.get("type") == "total": + stats_ref = _safe_get(stat_entry, "statistics", "$ref", default="") + elif stat_entry.get("type") == "team" and not team_name: + team_ref = _safe_get(stat_entry, "team", "$ref", default="") + if team_ref: + team_name = _resolve_ref_field({"$ref": team_ref}, name_keys=("abbreviation", "displayName", "name")) + + if not stats_ref: + continue + + try: + stats_data = _get_json(stats_ref) + except httpx.HTTPError: + continue + + categories = _safe_get(stats_data, "splits", "categories", default=[]) + if not categories: + continue + + seasons.append(_college_season_from_categories(season_year, team_name, categories)) + + seasons.sort(key=lambda s: s.season, reverse=True) + return seasons + + +def get_player_stats(player_id: str, season: int | None = None) -> list[SeasonStats]: + """Get player statistics for a season or career.""" + url = f"{CORE_API_BASE}/athletes/{player_id}/statistics" + + try: + data = _get_json(url) + except httpx.HTTPError: + return [] + + results = [] + splits = data.get("splits", {}) + categories = splits.get("categories", []) + + if not categories: + return [] + + stats_dict: dict[str, Any] = {} + games_played = 0 + + for category in categories: + cat_name = category.get("displayName", "") + for stat in category.get("stats", []): + stat_name = stat.get("displayName", stat.get("name", "")) + stat_value = stat.get("displayValue", stat.get("value", "—")) + if stat_name: + stats_dict[f"{cat_name} - {stat_name}"] = stat_value + if stat_name.lower() in ("games played", "gp"): + try: + games_played = int(stat_value) + except (ValueError, TypeError): + pass + + current_year = datetime.now().year + results.append(SeasonStats( + season=season or current_year, + games_played=games_played, + stats=stats_dict, + )) + + return results + + +def get_player_combine(player_id: str) -> CombineMetrics | None: + """Get NFL Combine metrics for a player.""" + profile = get_player_profile(player_id) + if not profile: + return None + + url = f"{CORE_API_BASE}/athletes/{player_id}" + + try: + data = _get_json(url) + except httpx.HTTPError: + return CombineMetrics( + year=None, + forty_yard=None, + vertical_jump=None, + bench_press=None, + broad_jump=None, + three_cone=None, + shuttle=None, + ) + + draft = data.get("draft", {}) + combine_year = draft.get("year") if draft else None + + return CombineMetrics( + year=combine_year, + forty_yard=None, + vertical_jump=None, + bench_press=None, + broad_jump=None, + three_cone=None, + shuttle=None, + ) + + +def get_player_injuries(player_id: str) -> list[InjuryInfo]: + """Get injury information for a player.""" + url = f"{CORE_API_BASE}/athletes/{player_id}" + + try: + data = _get_json(url) + except httpx.HTTPError: + return [] + + injuries = data.get("injuries", []) + results = [] + + for injury in injuries: + results.append(InjuryInfo( + status=_safe_get(injury, "status", "type", default="Unknown"), + injury_type=injury.get("type", {}).get("text", "—"), + details=injury.get("details", {}).get("detail", "—"), + date=injury.get("date", "—")[:10] if injury.get("date") else "—", + )) + + status_info = data.get("status", {}) + if status_info and not injuries: + results.append(InjuryInfo( + status=status_info.get("name", "Active"), + injury_type="—", + details="No current injuries reported", + date="—", + )) + + return results if results else [InjuryInfo( + status="Active", + injury_type="—", + details="No injury information available", + date="—", + )] + + +def get_player_news(player_id: str, limit: int = 10) -> list[NewsArticle]: + """Get recent news articles about a player.""" + url = f"{SITE_API_BASE}/news" + + try: + data = _get_json(url, params={"player": player_id, "limit": limit}) + except httpx.HTTPError: + return [] + + results = [] + articles = data.get("articles", []) + + for article in articles[:limit]: + published = article.get("published", "") + if published: + try: + dt = datetime.fromisoformat(published.replace("Z", "+00:00")) + published = dt.strftime("%b %d, %Y") + except (ValueError, TypeError): + published = published[:10] + + images = article.get("images", []) + image_url = images[0].get("url") if images else None + + links = article.get("links", {}) + web_link = links.get("web", {}).get("href", "") + + results.append(NewsArticle( + headline=article.get("headline", "No headline"), + description=article.get("description", ""), + published=published, + link=web_link, + image_url=image_url, + )) + + return results + + +def get_available_seasons(player_id: str) -> list[int]: + """Get list of seasons with available statistics for a player.""" + current_year = datetime.now().year + return list(range(current_year, current_year - 5, -1)) + + +def get_player_seasons(player_id: str) -> list[int]: + """Fetch actual seasons with game data from ESPN statisticslog.""" + url = f"{CORE_API_BASE}/athletes/{player_id}/statisticslog" + + try: + data = _get_json(url) + except httpx.HTTPError: + return [] + + seasons: list[int] = [] + entries = data.get("entries", []) + + for entry in entries: + season_ref = _safe_get(entry, "season", "$ref", default="") + if season_ref and "/seasons/" in season_ref: + try: + year_str = season_ref.split("/seasons/")[1].split("?")[0] + year = int(year_str) + if year not in seasons: + seasons.append(year) + except (IndexError, ValueError): + continue + + seasons.sort(reverse=True) + return seasons + + +def _extract_stat_value(categories: list[dict], category_name: str, stat_name: str) -> float: + """Extract a specific stat value from ESPN categories structure.""" + for cat in categories: + if cat.get("name") == category_name: + for stat in cat.get("stats", []): + if stat.get("name") == stat_name: + try: + return float(stat.get("value", 0)) + except (TypeError, ValueError): + return 0.0 + return 0.0 + + +def get_player_gamelog(player_id: str, season: int) -> list[GameLogEntry]: + """Fetch week-by-week game stats for a season.""" + url = f"{CORE_API_BASE}/seasons/{season}/athletes/{player_id}/eventlog" + + try: + data = _get_json(url) + except httpx.HTTPError: + return [] + + events = data.get("events", {}) + items = events.get("items", []) + + results: list[GameLogEntry] = [] + + for idx, item in enumerate(items): + if not item.get("played"): + continue + + stats_ref = _safe_get(item, "statistics", "$ref", default="") + if not stats_ref: + continue + + try: + stats_data = _get_json(stats_ref) + except httpx.HTTPError: + continue + + categories = _safe_get(stats_data, "splits", "categories", default=[]) + if not categories: + continue + + entry = GameLogEntry( + week=idx + 1, + opponent="", + result="", + passing_yards=int(_extract_stat_value(categories, "passing", "passingYards")), + passing_tds=int(_extract_stat_value(categories, "passing", "passingTouchdowns")), + interceptions=int(_extract_stat_value(categories, "passing", "interceptions")), + rushing_yards=int(_extract_stat_value(categories, "rushing", "rushingYards")), + rushing_tds=int(_extract_stat_value(categories, "rushing", "rushingTouchdowns")), + receptions=int(_extract_stat_value(categories, "receiving", "receptions")), + receiving_yards=int(_extract_stat_value(categories, "receiving", "receivingYards")), + receiving_tds=int(_extract_stat_value(categories, "receiving", "receivingTouchdowns")), + fumbles_lost=int(_extract_stat_value(categories, "general", "fumblesLost")), + ) + results.append(entry) + + return results + + +def calculate_fantasy_points(game: GameLogEntry, scoring: str) -> float: + """Calculate fantasy points for a game. + + Args: + game: GameLogEntry with stats + scoring: 'ppr', 'half_ppr', or 'standard' + + Returns: + Fantasy points as float + """ + points = 0.0 + + # Passing: 0.04 per yard, 4 per TD, -2 per INT + points += game.passing_yards * 0.04 + points += game.passing_tds * 4 + points -= game.interceptions * 2 + + # Rushing: 0.1 per yard, 6 per TD + points += game.rushing_yards * 0.1 + points += game.rushing_tds * 6 + + # Receiving: 0.1 per yard, 6 per TD + points += game.receiving_yards * 0.1 + points += game.receiving_tds * 6 + + # Reception bonus based on scoring format + if scoring == "ppr": + points += game.receptions * 1.0 + elif scoring == "half_ppr": + points += game.receptions * 0.5 + + # Fumbles lost: -2 + points -= game.fumbles_lost * 2 + + return round(points, 2) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..1efeaa2 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""StatShift unit tests.""" diff --git a/tests/support.py b/tests/support.py new file mode 100644 index 0000000..830663c --- /dev/null +++ b/tests/support.py @@ -0,0 +1,28 @@ +"""Shared helpers for StatShift unit tests.""" + +from __future__ import annotations + +import sys +import tempfile +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from api.db import ReadOnlyDatabase # noqa: E402 +from scripts.init_db import init_database # noqa: E402 + + +class DatabaseTestCase(unittest.TestCase): + """Provides a seeded temporary SQLite database.""" + + def setUp(self) -> None: + self._tmp = tempfile.TemporaryDirectory() + self.db_path = Path(self._tmp.name) / "statshift_test.db" + init_database(self.db_path) + self.db = ReadOnlyDatabase(self.db_path) + + def tearDown(self) -> None: + self._tmp.cleanup() diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..ff4ee04 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,80 @@ +"""Unit tests for the FastAPI read-only service.""" + +from __future__ import annotations + +import unittest +from unittest.mock import patch + +from fastapi.testclient import TestClient + +import api.main as main_module +from api.main import app +from tests.support import DatabaseTestCase + + +class APITests(DatabaseTestCase): + def setUp(self) -> None: + super().setUp() + self.settings_patcher = patch.object(main_module.settings, "db_path", self.db_path) + # main.py imports `db` by value; patch the name bound in api.main + self.db_patcher = patch.object(main_module, "db", self.db) + self.settings_patcher.start() + self.db_patcher.start() + self.client = TestClient(app) + + def tearDown(self) -> None: + self.db_patcher.stop() + self.settings_patcher.stop() + super().tearDown() + + def test_health_ok(self) -> None: + response = self.client.get("/health") + self.assertEqual(response.status_code, 200) + payload = response.json() + self.assertEqual(payload["status"], "ok") + self.assertTrue(payload["read_only"]) + self.assertEqual(payload["db_path"], str(self.db_path)) + + def test_list_documents(self) -> None: + response = self.client.get("/documents", params={"limit": 3}) + self.assertEqual(response.status_code, 200) + docs = response.json() + self.assertEqual(len(docs), 3) + + def test_get_document(self) -> None: + response = self.client.get("/documents/1") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["id"], 1) + + def test_get_document_not_found(self) -> None: + response = self.client.get("/documents/99999") + self.assertEqual(response.status_code, 404) + + def test_search(self) -> None: + response = self.client.get("/search", params={"q": "Chase", "limit": 5}) + self.assertEqual(response.status_code, 200) + results = response.json() + self.assertTrue(results) + self.assertIn("Chase", results[0]["title"]) + + def test_categories(self) -> None: + response = self.client.get("/categories") + self.assertEqual(response.status_code, 200) + categories = response.json()["categories"] + self.assertIn("matchup", categories) + + def test_write_methods_blocked(self) -> None: + for method in ("post", "put", "patch", "delete"): + response = getattr(self.client, method)("/documents") + self.assertEqual(response.status_code, 405) + self.assertIn("Write operations are disabled", response.json()["detail"]) + + def test_health_missing_database(self) -> None: + missing = self.db_path.parent / "missing.db" + with patch.object(main_module.settings, "db_path", missing): + response = self.client.get("/health") + self.assertEqual(response.status_code, 503) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..82fee58 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,51 @@ +"""Unit tests for read-only SQLite access.""" + +from __future__ import annotations + +from tests.support import DatabaseTestCase + + +class ReadOnlyDatabaseTests(DatabaseTestCase): + def test_list_documents_returns_seeded_rows(self) -> None: + docs = self.db.list_documents(limit=50) + self.assertGreaterEqual(len(docs), 8) + self.assertIn("title", docs[0]) + self.assertIn("category", docs[0]) + self.assertIn("content", docs[0]) + + def test_list_documents_filters_by_category(self) -> None: + players = self.db.list_documents(category="player", limit=50) + self.assertTrue(players) + self.assertTrue(all(doc["category"] == "player" for doc in players)) + + def test_get_document_returns_row(self) -> None: + doc = self.db.get_document(1) + self.assertIsNotNone(doc) + assert doc is not None + self.assertEqual(doc["id"], 1) + self.assertIn("McCaffrey", doc["title"]) + + def test_get_document_missing_returns_none(self) -> None: + self.assertIsNone(self.db.get_document(99999)) + + def test_search_finds_relevant_document(self) -> None: + results = self.db.search_documents("Chase", limit=5) + self.assertTrue(results) + self.assertTrue( + any("Chase" in doc["title"] or "Chase" in doc["content"] for doc in results) + ) + + def test_search_team_content(self) -> None: + results = self.db.search_documents("Bills", limit=3) + self.assertLessEqual(len(results), 3) + + def test_list_categories(self) -> None: + categories = self.db.list_categories() + self.assertIn("player", categories) + self.assertIn("matchup", categories) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/tests/test_draft_engine.py b/tests/test_draft_engine.py new file mode 100644 index 0000000..92f9906 --- /dev/null +++ b/tests/test_draft_engine.py @@ -0,0 +1,258 @@ +"""Unit tests for mock draft engine.""" + +from __future__ import annotations + +import unittest + +from draft.engine import MockDraftEngine, _snake_team_index +from draft.models import Player, Position, RankingSource, ScoringFormat + + +def _test_pool() -> list[Player]: + return [ + Player( + f"QB {i}", Position.QB, 20.0, 20.0, 20.0, + adp=float(i), + rank_espn=float(i), + rank_yahoo=float(5 - i) if i <= 4 else 999.0, + rank_sleeper=float(i * 2), + ) + for i in range(1, 5) + ] + [ + Player( + f"RB {i}", Position.RB, 15.0, 14.0, 13.0, + adp=float(10 + i), + rank_espn=float(10 + i), + rank_yahoo=float(20 - i), + rank_sleeper=float(5 + i), + ) + for i in range(1, 10) + ] + [ + Player( + f"WR {i}", Position.WR, 14.0, 13.0, 12.0, + adp=float(20 + i), + rank_espn=float(20 + i), + rank_yahoo=float(10 + i), + rank_sleeper=float(15 + i), + ) + for i in range(1, 10) + ] + + +class SnakeDraftTests(unittest.TestCase): + def test_snake_order_round_two_reverses(self) -> None: + self.assertEqual(_snake_team_index(1, 1, 10), 0) + self.assertEqual(_snake_team_index(1, 10, 10), 9) + self.assertEqual(_snake_team_index(2, 1, 10), 9) + self.assertEqual(_snake_team_index(2, 10, 10), 0) + + def test_ppr_ranks_pass_catchers_higher(self) -> None: + pool = _test_pool() + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=1, + rounds=1, + pool=pool, + ) + wr = next(p for p in draft.rank_available() if p.position == Position.WR) + draft.make_pick(wr) + self.assertEqual(draft.picks[0].player.position, Position.WR) + + def test_auto_pick_reduces_available_pool(self) -> None: + draft = MockDraftEngine.create( + scoring=ScoringFormat.STANDARD, + league_size=4, + draft_slot=4, + rounds=2, + pool=_test_pool(), + ) + draft.run_cpu_picks_until_user() + self.assertEqual(len(draft.picks), 3) + self.assertTrue(draft.is_user_turn) + + def test_create_rejects_empty_pool(self) -> None: + with self.assertRaises(ValueError) as ctx: + MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=12, + draft_slot=1, + rounds=15, + pool=[], + ) + self.assertIn("empty", str(ctx.exception).lower()) + + def test_draft_completes_after_all_rounds(self) -> None: + pool = _test_pool() + draft = MockDraftEngine.create( + scoring=ScoringFormat.HALF_PPR, + league_size=4, + draft_slot=1, + rounds=2, + pool=pool, + ) + while not draft.is_complete: + draft.auto_pick() + self.assertEqual(len(draft.picks), 8) + self.assertEqual(len(draft.available), len(pool) - 8) + + +class MonteCarloDraftTests(unittest.TestCase): + def _run_full(self, *, cpu_randomness: float, seed: int | None) -> list[str]: + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=2, + rounds=3, + pool=_test_pool(), + cpu_randomness=cpu_randomness, + seed=seed, + ) + while not draft.is_complete: + draft.auto_pick() + return [p.player.name for p in draft.picks] + + def test_randomness_zero_is_deterministic(self) -> None: + # seed should be irrelevant when randomness is off + a = self._run_full(cpu_randomness=0.0, seed=None) + b = self._run_full(cpu_randomness=0.0, seed=999) + self.assertEqual(a, b) + + def test_seed_reproduces_random_draft(self) -> None: + a = self._run_full(cpu_randomness=0.6, seed=42) + b = self._run_full(cpu_randomness=0.6, seed=42) + self.assertEqual(a, b) + + def test_different_seeds_diverge(self) -> None: + a = self._run_full(cpu_randomness=0.8, seed=1) + b = self._run_full(cpu_randomness=0.8, seed=2) + self.assertNotEqual(a, b) + + def test_simulate_availability_returns_probabilities(self) -> None: + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=2, + rounds=2, + pool=_test_pool(), + cpu_randomness=0.5, + seed=7, + ) + draft.run_cpu_picks_until_user() + self.assertTrue(draft.is_user_turn) + + probs = draft.simulate_availability_at_next_user_pick(num_sims=10) + + self.assertEqual(set(probs.keys()), {p.name for p in draft.available}) + self.assertTrue(all(0.0 <= v <= 1.0 for v in probs.values())) + + def test_simulate_does_not_mutate_live_draft(self) -> None: + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=2, + rounds=2, + pool=_test_pool(), + cpu_randomness=0.4, + seed=11, + ) + draft.run_cpu_picks_until_user() + picks_before = len(draft.picks) + available_before = [p.name for p in draft.available] + + draft.simulate_availability_at_next_user_pick(num_sims=5) + + self.assertEqual(len(draft.picks), picks_before) + self.assertEqual([p.name for p in draft.available], available_before) + + +class RankingSourceTests(unittest.TestCase): + def test_player_rank_for_source_espn(self) -> None: + player = Player( + "Test Player", Position.QB, 20.0, 20.0, 20.0, + rank_espn=5.0, rank_yahoo=10.0, rank_sleeper=15.0, + ) + self.assertEqual(player.rank_for_source(RankingSource.ESPN), 5.0) + + def test_player_rank_for_source_yahoo(self) -> None: + player = Player( + "Test Player", Position.QB, 20.0, 20.0, 20.0, + rank_espn=5.0, rank_yahoo=10.0, rank_sleeper=15.0, + ) + self.assertEqual(player.rank_for_source(RankingSource.YAHOO), 10.0) + + def test_player_rank_for_source_sleeper(self) -> None: + player = Player( + "Test Player", Position.QB, 20.0, 20.0, 20.0, + rank_espn=5.0, rank_yahoo=10.0, rank_sleeper=15.0, + ) + self.assertEqual(player.rank_for_source(RankingSource.SLEEPER), 15.0) + + def test_different_ranking_sources_produce_different_orders(self) -> None: + pool = _test_pool() + + draft_yahoo = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=2, + rounds=2, + pool=pool, + cpu_randomness=0.0, + ranking_source=RankingSource.YAHOO, + ) + draft_espn = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=2, + rounds=2, + pool=pool, + cpu_randomness=0.0, + ranking_source=RankingSource.ESPN, + ) + + while not draft_yahoo.is_complete: + draft_yahoo.auto_pick() + while not draft_espn.is_complete: + draft_espn.auto_pick() + + yahoo_picks = [p.player.name for p in draft_yahoo.picks] + espn_picks = [p.player.name for p in draft_espn.picks] + + self.assertNotEqual(yahoo_picks, espn_picks) + + def test_default_ranking_source_is_yahoo(self) -> None: + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=1, + rounds=1, + pool=_test_pool(), + ) + self.assertEqual(draft.ranking_source, RankingSource.YAHOO) + + def test_ranking_source_passed_to_create(self) -> None: + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=1, + rounds=1, + pool=_test_pool(), + ranking_source=RankingSource.SLEEPER, + ) + self.assertEqual(draft.ranking_source, RankingSource.SLEEPER) + + def test_clone_preserves_ranking_source(self) -> None: + draft = MockDraftEngine.create( + scoring=ScoringFormat.PPR, + league_size=4, + draft_slot=2, + rounds=2, + pool=_test_pool(), + ranking_source=RankingSource.ESPN, + ) + clone = draft._clone_for_simulation() + self.assertEqual(clone.ranking_source, RankingSource.ESPN) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_espn_loader.py b/tests/test_espn_loader.py new file mode 100644 index 0000000..c127b0e --- /dev/null +++ b/tests/test_espn_loader.py @@ -0,0 +1,46 @@ +"""Unit tests for ESPN player parsing.""" + +from __future__ import annotations + +import unittest + +from draft.espn_loader import parse_athlete_payload +from draft.models import Position + + +class ParseAthletePayloadTests(unittest.TestCase): + def test_parses_skill_position(self) -> None: + payload = { + "id": "1", + "displayName": "Ja'Marr Chase", + "position": {"abbreviation": "WR"}, + "experience": {"years": 4}, + "team": {"id": "4", "$ref": "http://example/teams/4"}, + } + player = parse_athlete_payload(payload, {"4": "CIN"}) + assert player is not None + self.assertEqual(player.name, "Ja'Marr Chase") + self.assertEqual(player.position, Position.WR) + self.assertEqual(player.team, "CIN") + + def test_skips_offensive_line(self) -> None: + payload = { + "displayName": "Some Lineman", + "position": {"abbreviation": "OT"}, + "experience": {"years": 3}, + } + self.assertIsNone(parse_athlete_payload(payload, {})) + + def test_maps_pk_to_kicker(self) -> None: + payload = { + "displayName": "Justin Tucker", + "position": {"abbreviation": "PK"}, + "experience": {"years": 10}, + } + player = parse_athlete_payload(payload, {}) + assert player is not None + self.assertEqual(player.position, Position.K) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_espn_player_research.py b/tests/test_espn_player_research.py new file mode 100644 index 0000000..a9d99da --- /dev/null +++ b/tests/test_espn_player_research.py @@ -0,0 +1,62 @@ +"""Tests for ESPN player research helpers.""" + +from __future__ import annotations + +import unittest +from datetime import datetime + +from sdks.espn_player_loader import ( + _college_season_from_categories, + is_upcoming_rookie, + nfl_season_has_started, +) + + +class NflSeasonTests(unittest.TestCase): + def test_season_not_started_before_september(self) -> None: + now = datetime(2026, 5, 20) + self.assertFalse(nfl_season_has_started(2026, now=now)) + + def test_season_started_in_september(self) -> None: + now = datetime(2026, 9, 10) + self.assertTrue(nfl_season_has_started(2026, now=now)) + + def test_rookie_when_drafted_this_year_pre_season(self) -> None: + now = datetime(2026, 5, 20) + self.assertTrue(is_upcoming_rookie(2026, now=now)) + + def test_not_rookie_when_drafted_prior_year(self) -> None: + now = datetime(2026, 5, 20) + self.assertFalse(is_upcoming_rookie(2025, now=now)) + + def test_not_rookie_after_season_starts(self) -> None: + now = datetime(2026, 9, 10) + self.assertFalse(is_upcoming_rookie(2026, now=now)) + + +class CollegeStatsParsingTests(unittest.TestCase): + def test_parses_categories_into_season_row(self) -> None: + categories = [ + { + "name": "general", + "stats": [{"name": "gamesPlayed", "displayValue": "12"}], + }, + { + "name": "passing", + "stats": [ + {"name": "completions", "displayValue": "250"}, + {"name": "netPassingYards", "displayValue": "3,000"}, + {"name": "passingTouchdowns", "displayValue": "25"}, + {"name": "interceptions", "displayValue": "5"}, + ], + }, + ] + row = _college_season_from_categories(2024, "MIA", categories) + self.assertEqual(row.season, 2024) + self.assertEqual(row.team, "MIA") + self.assertEqual(row.games_played, "12") + self.assertEqual(row.passing_yards, "3,000") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ffanalytics_loader.py b/tests/test_ffanalytics_loader.py new file mode 100644 index 0000000..64a7a37 --- /dev/null +++ b/tests/test_ffanalytics_loader.py @@ -0,0 +1,157 @@ +"""Unit tests for ffanalytics player parsing (no R required).""" + +from __future__ import annotations + +import unittest + +from draft.ffanalytics_loader import ( + merge_projection_tables, + parse_projection_row, + _player_name_from_row, + _adp_from_row, + _per_source_adp_from_row, +) +from draft.models import Position + + +class ParseHelpersTests(unittest.TestCase): + def test_player_name_from_first_last(self) -> None: + row = {"first_name": "Josh", "last_name": "Allen", "pos": "QB"} + self.assertEqual(_player_name_from_row(row), "Josh Allen") + + def test_adp_from_source_columns(self) -> None: + row = {"id": "1", "adp_espn": 2.2, "adp_yahoo": 3.0} + self.assertAlmostEqual(_adp_from_row(row), 2.6) + + def test_per_source_adp_from_row(self) -> None: + row = {"id": "1", "adp_espn": 5.0, "adp_yahoo": 8.0, "adp_cbs": 6.0} + result = _per_source_adp_from_row(row) + self.assertEqual(result["espn"], 5.0) + self.assertEqual(result["yahoo"], 8.0) + self.assertEqual(result["cbs"], 6.0) + + def test_per_source_adp_ignores_zero_values(self) -> None: + row = {"id": "1", "adp_espn": 5.0, "adp_yahoo": 0.0} + result = _per_source_adp_from_row(row) + self.assertEqual(result["espn"], 5.0) + self.assertNotIn("yahoo", result) + + +class MergeProjectionTablesTests(unittest.TestCase): + def test_merges_scoring_formats_and_adp(self) -> None: + std_rows = [ + { + "id": "1", + "first_name": "Ja'Marr", + "last_name": "Chase", + "pos": "WR", + "team": "CIN", + "points": 18.0, + }, + { + "id": "2", + "first_name": "Christian", + "last_name": "McCaffrey", + "pos": "RB", + "team": "SF", + "points": 20.0, + }, + ] + half_rows = [ + { + "id": "1", + "first_name": "Ja'Marr", + "last_name": "Chase", + "pos": "WR", + "team": "CIN", + "points": 19.0, + }, + { + "id": "2", + "first_name": "Christian", + "last_name": "McCaffrey", + "pos": "RB", + "team": "SF", + "points": 20.5, + }, + ] + ppr_rows = [ + { + "id": "1", + "first_name": "Ja'Marr", + "last_name": "Chase", + "pos": "WR", + "team": "CIN", + "points": 20.0, + }, + { + "id": "2", + "first_name": "Christian", + "last_name": "McCaffrey", + "pos": "RB", + "team": "SF", + "points": 21.0, + }, + ] + adp_rows = [ + {"id": "1", "adp_espn": 1.5}, + {"id": "2", "adp_espn": 2.0}, + ] + players = merge_projection_tables(std_rows, half_rows, ppr_rows, adp_rows) + self.assertEqual(len(players), 2) + chase = next(p for p in players if p.name == "Ja'Marr Chase") + self.assertEqual(chase.position, Position.WR) + self.assertEqual(chase.fp_std, 18.0) + self.assertEqual(chase.fp_half, 19.0) + self.assertEqual(chase.fp_ppr, 20.0) + self.assertEqual(chase.adp, 1.5) + + def test_skips_unknown_positions(self) -> None: + row = {"id": "9", "player_name": "Some Lineman", "pos": "OT", "points": 5.0} + player = parse_projection_row(row, points=5.0, adp_by_id={}, adp_by_name={}) + self.assertIsNone(player) + + def test_populates_per_source_rankings(self) -> None: + std_rows = [ + { + "id": "1", + "first_name": "Patrick", + "last_name": "Mahomes", + "pos": "QB", + "team": "KC", + "points": 25.0, + }, + ] + half_rows = [ + { + "id": "1", + "first_name": "Patrick", + "last_name": "Mahomes", + "pos": "QB", + "team": "KC", + "points": 25.0, + }, + ] + ppr_rows = [ + { + "id": "1", + "first_name": "Patrick", + "last_name": "Mahomes", + "pos": "QB", + "team": "KC", + "points": 25.0, + }, + ] + adp_rows = [ + {"id": "1", "adp_espn": 10.0, "adp_yahoo": 12.0}, + ] + players = merge_projection_tables(std_rows, half_rows, ppr_rows, adp_rows) + self.assertEqual(len(players), 1) + mahomes = players[0] + self.assertEqual(mahomes.rank_espn, 10.0) + self.assertEqual(mahomes.rank_yahoo, 12.0) + self.assertEqual(mahomes.rank_sleeper, 999.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_intent.py b/tests/test_intent.py new file mode 100644 index 0000000..a7e2e3b --- /dev/null +++ b/tests/test_intent.py @@ -0,0 +1,45 @@ +"""Unit tests for prompt intent detection.""" + +from __future__ import annotations + +import unittest + +from rag.intent import PromptIntent, detect_prompt_intent + + +class DetectPromptIntentTests(unittest.TestCase): + def test_factual_player_question_is_definitive(self) -> None: + result = detect_prompt_intent( + "How many targets did Ja'Marr Chase have in 2024?" + ) + self.assertEqual(result.intent, PromptIntent.DEFINITIVE) + + def test_stats_lookup_is_definitive(self) -> None: + result = detect_prompt_intent("Christian McCaffrey 2024 stats") + self.assertEqual(result.intent, PromptIntent.DEFINITIVE) + + def test_passing_leader_question_is_definitive(self) -> None: + result = detect_prompt_intent("Who led the NFL in passing yards in 2024?") + self.assertEqual(result.intent, PromptIntent.DEFINITIVE) + + def test_opinion_question_is_conversational(self) -> None: + result = detect_prompt_intent( + "What do you think about the Bills playoff chances?" + ) + self.assertEqual(result.intent, PromptIntent.CONVERSATIONAL) + + def test_greeting_is_conversational(self) -> None: + result = detect_prompt_intent("Hello, can you help me?") + self.assertEqual(result.intent, PromptIntent.CONVERSATIONAL) + + def test_creative_request_is_conversational(self) -> None: + result = detect_prompt_intent("Write a poem about fantasy football") + self.assertEqual(result.intent, PromptIntent.CONVERSATIONAL) + + def test_empty_prompt_is_conversational(self) -> None: + result = detect_prompt_intent(" ") + self.assertEqual(result.intent, PromptIntent.CONVERSATIONAL) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ollama_client.py b/tests/test_ollama_client.py new file mode 100644 index 0000000..b9ea2f6 --- /dev/null +++ b/tests/test_ollama_client.py @@ -0,0 +1,60 @@ +"""Unit tests for the Ollama HTTP client.""" + +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +import httpx + +from rag.ollama_client import OllamaClient, OllamaError + + +class OllamaClientTests(unittest.TestCase): + def setUp(self) -> None: + self.client = OllamaClient(base_url="http://ollama.test", model="gemma2:2b") + + @patch("rag.ollama_client.httpx.get") + def test_is_available_true(self, mock_get: MagicMock) -> None: + mock_get.return_value = MagicMock(status_code=200, raise_for_status=MagicMock()) + self.assertTrue(self.client.is_available()) + mock_get.assert_called_once_with("http://ollama.test/api/tags", timeout=5.0) + + @patch("rag.ollama_client.httpx.get") + def test_is_available_false_on_error(self, mock_get: MagicMock) -> None: + mock_get.side_effect = httpx.ConnectError("down") + self.assertFalse(self.client.is_available()) + + @patch("rag.ollama_client.httpx.get") + def test_list_models(self, mock_get: MagicMock) -> None: + mock_get.return_value = MagicMock( + raise_for_status=MagicMock(), + json=MagicMock(return_value={"models": [{"name": "gemma2:2b"}, {"name": "llama3"}]}), + ) + self.assertEqual(self.client.list_models(), ["gemma2:2b", "llama3"]) + + @patch("rag.ollama_client.httpx.post") + def test_generate_returns_response_text(self, mock_post: MagicMock) -> None: + mock_post.return_value = MagicMock( + status_code=200, + raise_for_status=MagicMock(), + json=MagicMock(return_value={"response": " Answer text "}), + ) + result = self.client.generate("prompt") + self.assertEqual(result, "Answer text") + mock_post.assert_called_once() + payload = mock_post.call_args.kwargs["json"] + self.assertEqual(payload["model"], "gemma2:2b") + self.assertEqual(payload["prompt"], "prompt") + self.assertFalse(payload["stream"]) + + @patch("rag.ollama_client.httpx.post") + def test_generate_missing_model_raises(self, mock_post: MagicMock) -> None: + mock_post.return_value = MagicMock(status_code=404) + with self.assertRaises(OllamaError) as ctx: + self.client.generate("prompt") + self.assertIn("gemma2:2b", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rag_engine.py b/tests/test_rag_engine.py new file mode 100644 index 0000000..9f4a6ef --- /dev/null +++ b/tests/test_rag_engine.py @@ -0,0 +1,136 @@ +"""Unit tests for RAG orchestration.""" + +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +import httpx + +from rag.engine import APIError, RAGEngine +from rag.intent import IntentResult, PromptIntent +from rag.ollama_client import OllamaClient, OllamaError + + +class RAGEngineTests(unittest.TestCase): + def setUp(self) -> None: + self.mock_ollama = MagicMock(spec=OllamaClient) + self.mock_ollama.model = "gemma2:2b" + self.engine = RAGEngine( + api_base_url="http://api.test", + ollama=self.mock_ollama, + top_k=2, + ) + + def test_format_context_empty(self) -> None: + self.assertEqual( + RAGEngine._format_context([]), + "No relevant documents were retrieved.", + ) + + def test_format_context_numbered_blocks(self) -> None: + sources = [ + {"title": "Doc A", "category": "team", "content": "Line one."}, + {"title": "Doc B", "category": "player", "content": "Line two."}, + ] + context = RAGEngine._format_context(sources) + self.assertIn("[1] Doc A (team)", context) + self.assertIn("[2] Doc B (player)", context) + self.assertIn("Line one.", context) + + def test_answer_from_sources_empty(self) -> None: + answer = RAGEngine._answer_from_sources([]) + self.assertIn("No matching records", answer) + + def test_answer_from_sources_includes_content(self) -> None: + sources = [ + { + "title": "Ja'Marr Chase 2024 target share", + "category": "player", + "content": "175 targets", + }, + ] + answer = RAGEngine._answer_from_sources(sources) + self.assertIn("175 targets", answer) + self.assertIn("Chase", answer) + + def test_build_conversational_prompt_includes_query(self) -> None: + prompt = RAGEngine._build_conversational_prompt("Tell me a fun fact") + self.assertIn("Tell me a fun fact", prompt) + self.assertIn("StatShift", prompt) + + @patch("rag.engine.httpx.get") + def test_ask_definitive_uses_api_not_gemma(self, mock_get: MagicMock) -> None: + mock_get.return_value = MagicMock( + raise_for_status=MagicMock(), + json=MagicMock( + return_value=[ + { + "id": 1, + "title": "Ja'Marr Chase 2024 target share", + "category": "player", + "content": "175 targets", + } + ] + ), + ) + intent = IntentResult(PromptIntent.DEFINITIVE, reason="test") + + result = self.engine.ask("How many targets did Chase have?", intent=intent) + + self.assertEqual(result.route, "api") + self.assertIn("175 targets", result.answer) + self.assertEqual(len(result.sources), 1) + self.assertEqual(result.prompt, "") + mock_get.assert_called_once() + self.mock_ollama.generate.assert_not_called() + + @patch("rag.engine.httpx.get") + def test_ask_conversational_uses_gemma_not_api(self, mock_get: MagicMock) -> None: + self.mock_ollama.generate.return_value = "Fantasy football is a weekly game." + intent = IntentResult(PromptIntent.CONVERSATIONAL, reason="test") + + result = self.engine.ask("Hello!", intent=intent) + + self.assertEqual(result.route, "gemma") + self.assertEqual(result.answer, "Fantasy football is a weekly game.") + self.assertEqual(result.sources, []) + self.assertIn("Hello!", result.prompt) + mock_get.assert_not_called() + self.mock_ollama.generate.assert_called_once() + + @patch("rag.engine.httpx.get") + def test_search_api_error_raises(self, mock_get: MagicMock) -> None: + mock_get.side_effect = httpx.ConnectError("refused") + intent = IntentResult(PromptIntent.DEFINITIVE, reason="test") + with self.assertRaises(APIError): + self.engine.ask("How many yards did Burrow throw for?", intent=intent) + + @patch("rag.engine.httpx.get") + def test_ask_ollama_error_wrapped(self, mock_get: MagicMock) -> None: + self.mock_ollama.generate.side_effect = httpx.ConnectError("down") + intent = IntentResult(PromptIntent.CONVERSATIONAL, reason="test") + with self.assertRaises(OllamaError): + self.engine.ask("Hi there", intent=intent) + mock_get.assert_not_called() + + @patch("rag.engine.httpx.get") + def test_health_reports_status(self, mock_get: MagicMock) -> None: + mock_get.return_value = MagicMock( + status_code=200, + json=MagicMock(return_value={"status": "ok"}), + ) + self.mock_ollama.is_available.return_value = True + self.mock_ollama.list_models.return_value = ["gemma2:2b"] + + health = self.engine.health() + + self.assertTrue(health["api_ok"]) + self.assertEqual(health["api_detail"], "ok") + self.assertTrue(health["ollama_ok"]) + self.assertEqual(health["ollama_models"], ["gemma2:2b"]) + self.assertEqual(health["configured_model"], "gemma2:2b") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sleeper_loader.py b/tests/test_sleeper_loader.py new file mode 100644 index 0000000..a896d2d --- /dev/null +++ b/tests/test_sleeper_loader.py @@ -0,0 +1,85 @@ +"""Unit tests for Sleeper loader.""" + +from __future__ import annotations + +import unittest + +from draft.sleeper_loader import _normalize_name, lookup_sleeper_rank, parse_sleeper_rankings + + +class SleeperLoaderTests(unittest.TestCase): + def test_normalize_name_lowercase(self) -> None: + self.assertEqual(_normalize_name("Patrick Mahomes"), "patrick mahomes") + + def test_normalize_name_strips_jr(self) -> None: + self.assertEqual(_normalize_name("Marvin Harrison Jr"), "marvin harrison") + + def test_normalize_name_strips_sr(self) -> None: + self.assertEqual(_normalize_name("John Smith Sr"), "john smith") + + def test_normalize_name_strips_ii(self) -> None: + self.assertEqual(_normalize_name("Robert Griffin II"), "robert griffin") + + def test_normalize_name_strips_iii(self) -> None: + self.assertEqual(_normalize_name("Robert Griffin III"), "robert griffin") + + def test_parse_sleeper_rankings_extracts_search_rank(self) -> None: + raw = { + "1234": { + "first_name": "Patrick", + "last_name": "Mahomes", + "position": "QB", + "search_rank": 5, + }, + "5678": { + "first_name": "Travis", + "last_name": "Kelce", + "position": "TE", + "search_rank": 12, + }, + } + rankings = parse_sleeper_rankings(raw) + self.assertEqual(rankings["patrick mahomes"], 5) + self.assertEqual(rankings["travis kelce"], 12) + + def test_parse_sleeper_rankings_ignores_non_skill_positions(self) -> None: + raw = { + "1234": { + "first_name": "Andy", + "last_name": "Reid", + "position": "HC", + "search_rank": 100, + }, + } + rankings = parse_sleeper_rankings(raw) + self.assertNotIn("andy reid", rankings) + + def test_parse_sleeper_rankings_handles_missing_search_rank(self) -> None: + raw = { + "1234": { + "first_name": "Patrick", + "last_name": "Mahomes", + "position": "QB", + }, + } + rankings = parse_sleeper_rankings(raw) + self.assertNotIn("patrick mahomes", rankings) + + def test_lookup_sleeper_rank_found(self) -> None: + rankings = {"patrick mahomes": 5} + rank = lookup_sleeper_rank("Patrick Mahomes", rankings) + self.assertEqual(rank, 5.0) + + def test_lookup_sleeper_rank_not_found(self) -> None: + rankings = {"patrick mahomes": 5} + rank = lookup_sleeper_rank("Unknown Player", rankings) + self.assertEqual(rank, 999.0) + + def test_lookup_sleeper_rank_with_suffix(self) -> None: + rankings = {"marvin harrison": 3} + rank = lookup_sleeper_rank("Marvin Harrison Jr", rankings) + self.assertEqual(rank, 3.0) + + +if __name__ == "__main__": + unittest.main()