diff --git a/.env.example b/.env.example index c69ebf17..a80d8522 100644 --- a/.env.example +++ b/.env.example @@ -111,5 +111,8 @@ AGENT_APPROVAL_TIMEOUT_MINUTES=60 # Streaming AGENT_ENABLE_STREAMING=true +# Batch runner (PRP-33) — cap on scope expansion (pairs × model_configs). +BATCH_MAX_SCOPE_EXPANSION=1000 + # Frontend (Vite) VITE_API_BASE_URL=http://localhost:8123 diff --git a/alembic/versions/c1d2e3f40512_create_batch_tables.py b/alembic/versions/c1d2e3f40512_create_batch_tables.py new file mode 100644 index 00000000..77124ea8 --- /dev/null +++ b/alembic/versions/c1d2e3f40512_create_batch_tables.py @@ -0,0 +1,236 @@ +"""create_batch_tables + +Revision ID: c1d2e3f40512 +Revises: f84258c4cb44 +Create Date: 2026-05-20 10:30:00.000000 + +Creates the batch_job + batch_job_item tables for PRP-33 batch-runner MVP. +The forward-compat columns on batch_job (running_items, cancelled_items, +max_parallel, default_child_priority) and batch_job_item (priority) are +MVP-owned per the PRP's Cross-Slice Coordination Matrix — the four +downstream INITIALs (parallel-execution, priority-queue, export-and-retry, +champion-and-heatmap) consume them without further schema changes. + +The partial picker index predicate is EXACTLY ``WHERE (status = 'pending')`` +— downstream-2 (priority queue) compiles against the same predicate so the +picker SELECT remains index-covered when CANCELLED rows enter the table. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c1d2e3f40512" +down_revision: str | None = "f84258c4cb44" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply migration.""" + # ------------------------------------------------------------------ + # batch_job — parent record, one row per submission. + # ------------------------------------------------------------------ + op.create_table( + "batch_job", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("batch_id", sa.String(length=32), nullable=False), + sa.Column("operation", sa.String(length=30), nullable=False), + sa.Column("scope", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("model_configs", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("total_items", sa.Integer(), nullable=False, server_default="0"), + sa.Column("completed_items", sa.Integer(), nullable=False, server_default="0"), + sa.Column("failed_items", sa.Integer(), nullable=False, server_default="0"), + sa.Column("running_items", sa.Integer(), nullable=False, server_default="0"), + sa.Column("cancelled_items", sa.Integer(), nullable=False, server_default="0"), + sa.Column("max_parallel", sa.Integer(), nullable=False, server_default="4"), + sa.Column( + "default_child_priority", + sa.SmallInteger(), + nullable=False, + server_default="0", + ), + sa.Column("params", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("result_summary", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'partial', 'cancelled')", + name="ck_batch_job_valid_status", + ), + sa.CheckConstraint( + "operation IN ('train', 'predict', 'backtest', 'train_backtest_register')", + name="ck_batch_job_valid_operation", + ), + sa.CheckConstraint( + "default_child_priority BETWEEN -1 AND 2", + name="ck_batch_job_priority_band", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_batch_job_batch_id"), "batch_job", ["batch_id"], unique=True) + op.create_index(op.f("ix_batch_job_status"), "batch_job", ["status"], unique=False) + op.create_index(op.f("ix_batch_job_operation"), "batch_job", ["operation"], unique=False) + op.create_index( + "ix_batch_job_status_created", + "batch_job", + ["status", "created_at"], + unique=False, + ) + + # ------------------------------------------------------------------ + # batch_job_item — child record, one row per (store, product, model) triple. + # FK CASCADE on batch_id so deleting a parent removes its items. + # ------------------------------------------------------------------ + op.create_table( + "batch_job_item", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("item_id", sa.String(length=32), nullable=False), + sa.Column("batch_id", sa.String(length=32), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("model_type", sa.String(length=30), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("priority", sa.SmallInteger(), nullable=False, server_default="0"), + sa.Column("params", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("child_job_id", sa.String(length=32), nullable=True), + sa.Column("child_run_id", sa.String(length=32), nullable=True), + sa.Column("error_message", sa.String(length=2000), nullable=True), + sa.Column("error_type", sa.String(length=100), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("duration_ms", sa.Integer(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'cancelled')", + name="ck_batch_job_item_valid_status", + ), + sa.CheckConstraint( + "priority BETWEEN -1 AND 2", + name="ck_batch_job_item_priority_band", + ), + sa.ForeignKeyConstraint( + ["batch_id"], + ["batch_job.batch_id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_batch_job_item_item_id"), + "batch_job_item", + ["item_id"], + unique=True, + ) + op.create_index( + op.f("ix_batch_job_item_batch_id"), + "batch_job_item", + ["batch_id"], + unique=False, + ) + op.create_index( + op.f("ix_batch_job_item_store_id"), + "batch_job_item", + ["store_id"], + unique=False, + ) + op.create_index( + op.f("ix_batch_job_item_product_id"), + "batch_job_item", + ["product_id"], + unique=False, + ) + op.create_index( + op.f("ix_batch_job_item_status"), + "batch_job_item", + ["status"], + unique=False, + ) + op.create_index( + op.f("ix_batch_job_item_child_job_id"), + "batch_job_item", + ["child_job_id"], + unique=False, + ) + op.create_index( + op.f("ix_batch_job_item_child_run_id"), + "batch_job_item", + ["child_run_id"], + unique=False, + ) + op.create_index( + "ix_batch_job_item_batch_status", + "batch_job_item", + ["batch_id", "status"], + unique=False, + ) + op.create_index( + "ix_batch_job_item_metrics_gin", + "batch_job_item", + ["metrics"], + unique=False, + postgresql_using="gin", + ) + # Partial picker index — load-bearing for downstream-1 (parallel) and + # downstream-2 (priority). Predicate is EXACTLY ``status = 'pending'`` — + # the integration test asserts the substring on pg_indexes.indexdef. + op.create_index( + "ix_batch_job_item_picker", + "batch_job_item", + ["batch_id", "status", "priority", "created_at"], + unique=False, + postgresql_where=sa.text("status = 'pending'"), + ) + + +def downgrade() -> None: + """Revert migration.""" + op.drop_index("ix_batch_job_item_picker", table_name="batch_job_item") + op.drop_index( + "ix_batch_job_item_metrics_gin", + table_name="batch_job_item", + postgresql_using="gin", + ) + op.drop_index("ix_batch_job_item_batch_status", table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_child_run_id"), table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_child_job_id"), table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_status"), table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_product_id"), table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_store_id"), table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_batch_id"), table_name="batch_job_item") + op.drop_index(op.f("ix_batch_job_item_item_id"), table_name="batch_job_item") + op.drop_table("batch_job_item") + op.drop_index("ix_batch_job_status_created", table_name="batch_job") + op.drop_index(op.f("ix_batch_job_operation"), table_name="batch_job") + op.drop_index(op.f("ix_batch_job_status"), table_name="batch_job") + op.drop_index(op.f("ix_batch_job_batch_id"), table_name="batch_job") + op.drop_table("batch_job") diff --git a/app/core/config.py b/app/core/config.py index b30253d7..2ac65061 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -118,6 +118,9 @@ class Settings(BaseSettings): # Jobs jobs_retention_days: int = 30 + # Batch runner (PRP-33) — cap on scope expansion (pairs x model_configs). + batch_max_scope_expansion: int = 1000 + # RAG Embedding Configuration rag_embedding_provider: Literal["openai", "ollama"] = "openai" openai_api_key: str = "" diff --git a/app/features/batch/__init__.py b/app/features/batch/__init__.py new file mode 100644 index 00000000..b3e7a043 --- /dev/null +++ b/app/features/batch/__init__.py @@ -0,0 +1,9 @@ +"""Batch runner slice — portfolio forecasting orchestration (PRP-33). + +One ``batch_job`` row fans out into N ``batch_job_item`` rows; each item is +executed sequentially by delegating to ``JobService.create_job`` via a lazy +in-method import. The MVP exposes zero mutating agent tools; downstream +PRPs (parallel-execution, priority-queue, export-and-retry, +champion-and-heatmap) consume the forward-compat columns on these tables +without schema changes. +""" diff --git a/app/features/batch/models.py b/app/features/batch/models.py new file mode 100644 index 00000000..2850f26e --- /dev/null +++ b/app/features/batch/models.py @@ -0,0 +1,217 @@ +"""Batch runner ORM models. + +Two tables — ``batch_job`` (parent) and ``batch_job_item`` (child) — track a +portfolio batch and its expanded (store, product, model) work items. Mirrors +``app/features/jobs/models.py`` for shape: ``TimestampMixin`` + ``Base``, +string ``Enum``s, ``CheckConstraint`` in ``__table_args__``, JSONB columns +for flexible per-item config and per-fold metrics. + +Forward-compat columns owned by the MVP (per PRP-33 § "Cross-Slice +Coordination Matrix") so the four downstream PRPs ship without a schema +migration: + +- ``batch_job.running_items`` / ``cancelled_items`` — downstream-1 (parallel) +- ``batch_job.max_parallel`` — downstream-1 (MVP runner ignores) +- ``batch_job.default_child_priority`` — downstream-2 (priority queue) +- ``batch_job_item.priority`` — downstream-2 (MVP NORMAL only) + +The partial picker index ``ix_batch_job_item_picker`` (``WHERE status = +'pending'``) is created in the Alembic migration, not here — SQLAlchemy's +``Index()`` cannot express a portable partial predicate. +""" + +from __future__ import annotations + +import datetime as _dt +from enum import Enum +from typing import Any + +from sqlalchemy import ( + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + SmallInteger, + String, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base +from app.shared.models import TimestampMixin + + +class BatchStatus(str, Enum): + """Parent batch lifecycle states. + + Transitions: + - PENDING -> RUNNING -> {COMPLETED, FAILED, PARTIAL} + - PARTIAL fires when >=1 item succeeded AND >=1 item failed. + - CANCELLED is reserved for downstream-1 (parallel) — MVP never writes it. + """ + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + PARTIAL = "partial" + CANCELLED = "cancelled" + + +class BatchOperation(str, Enum): + """Batch operation kinds. + + TRAIN_BACKTEST_REGISTER chains three child JobService.create_job calls + per item; the other three map 1:1 to a single JobType. + """ + + TRAIN = "train" + PREDICT = "predict" + BACKTEST = "backtest" + TRAIN_BACKTEST_REGISTER = "train_backtest_register" + + +class BatchItemStatus(str, Enum): + """Per-item lifecycle states. + + Transitions mirror ``JobStatus`` minus PARTIAL (only the parent settles + to PARTIAL). CANCELLED is reserved for downstream-1. + """ + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +VALID_BATCH_TRANSITIONS: dict[BatchStatus, set[BatchStatus]] = { + BatchStatus.PENDING: {BatchStatus.RUNNING, BatchStatus.CANCELLED}, + BatchStatus.RUNNING: { + BatchStatus.COMPLETED, + BatchStatus.FAILED, + BatchStatus.PARTIAL, + BatchStatus.CANCELLED, + }, + BatchStatus.COMPLETED: set(), + BatchStatus.FAILED: set(), + BatchStatus.PARTIAL: set(), + BatchStatus.CANCELLED: set(), +} + +VALID_BATCH_ITEM_TRANSITIONS: dict[BatchItemStatus, set[BatchItemStatus]] = { + BatchItemStatus.PENDING: {BatchItemStatus.RUNNING, BatchItemStatus.CANCELLED}, + BatchItemStatus.RUNNING: {BatchItemStatus.COMPLETED, BatchItemStatus.FAILED}, + BatchItemStatus.COMPLETED: set(), + BatchItemStatus.FAILED: set(), + BatchItemStatus.CANCELLED: set(), +} + + +class BatchJob(TimestampMixin, Base): + """Parent batch record — one row per submission. + + ``scope``, ``model_configs``, ``params``, and ``result_summary`` are all + JSONB so the four downstream PRPs can add keys without a schema + migration. ``params`` carries the original submit request verbatim; + ``scope`` and ``model_configs`` are split out so they remain + independently queryable from SQL without a JSONB path expression. + """ + + __tablename__ = "batch_job" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + batch_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + operation: Mapped[str] = mapped_column(String(30), index=True) + scope: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + model_configs: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + status: Mapped[str] = mapped_column(String(20), default=BatchStatus.PENDING.value, index=True) + total_items: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + completed_items: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + failed_items: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + # Forward-compat — downstream-1 (parallel) maintains these counters; MVP + # leaves them at 0 except via the settle aggregate. + running_items: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + cancelled_items: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + # Forward-compat — downstream-1 reads max_parallel; MVP runner ignores it. + max_parallel: Mapped[int] = mapped_column(Integer, default=4, nullable=False) + # Forward-compat — downstream-2 reads default_child_priority; MVP only writes NORMAL (0). + default_child_priority: Mapped[int] = mapped_column(SmallInteger, default=0, nullable=False) + params: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + result_summary: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + started_at: Mapped[_dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[_dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + __table_args__ = ( + CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'partial', 'cancelled')", + name="ck_batch_job_valid_status", + ), + CheckConstraint( + "operation IN ('train', 'predict', 'backtest', 'train_backtest_register')", + name="ck_batch_job_valid_operation", + ), + CheckConstraint( + "default_child_priority BETWEEN -1 AND 2", + name="ck_batch_job_priority_band", + ), + Index("ix_batch_job_status_created", "status", "created_at"), + ) + + +class BatchJobItem(TimestampMixin, Base): + """Child batch item — one row per (store, product, model_type) triple. + + ``params`` is frozen at expansion time; the runner reads from it on every + ``_execute_item`` call, never mutates it. ``metrics`` carries the pinned + five-key JSONB ``{wape, smape, mae, bias, sample_size}`` for backtest + items; nullable for predict-only items and for fold runs that produced + NaN on a zero-actuals window. + """ + + __tablename__ = "batch_job_item" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + item_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + batch_id: Mapped[str] = mapped_column( + String(32), + ForeignKey("batch_job.batch_id", ondelete="CASCADE"), + index=True, + ) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + model_type: Mapped[str] = mapped_column(String(30)) + status: Mapped[str] = mapped_column( + String(20), default=BatchItemStatus.PENDING.value, index=True + ) + # Forward-compat — downstream-2 reads priority; MVP only writes NORMAL (0). + priority: Mapped[int] = mapped_column(SmallInteger, default=0, nullable=False) + params: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + metrics: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + child_job_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + child_run_id: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + error_type: Mapped[str | None] = mapped_column(String(100), nullable=True) + started_at: Mapped[_dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[_dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True) + + __table_args__ = ( + CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed', 'cancelled')", + name="ck_batch_job_item_valid_status", + ), + CheckConstraint( + "priority BETWEEN -1 AND 2", + name="ck_batch_job_item_priority_band", + ), + Index("ix_batch_job_item_batch_status", "batch_id", "status"), + Index("ix_batch_job_item_metrics_gin", "metrics", postgresql_using="gin"), + # Partial picker index (postgresql_where) lives in the Alembic migration — + # SQLAlchemy's Index() lacks a portable partial-predicate kwarg. + ) diff --git a/app/features/batch/routes.py b/app/features/batch/routes.py new file mode 100644 index 00000000..aeee169d --- /dev/null +++ b/app/features/batch/routes.py @@ -0,0 +1,111 @@ +"""FastAPI routes for the batch runner slice (PRP-33). + +Three endpoints mirroring ``app/features/jobs/routes.py``: + +- ``POST /batch/forecasting`` — submit a batch, run it sequentially, return parent. +- ``GET /batch/{batch_id}`` — fetch parent state. +- ``GET /batch/{batch_id}/items`` — list items with pagination + allow-listed sort. + +All 4xx responses route through ``app.core.exceptions`` to RFC 7807 +``application/problem+json``. +""" + +from fastapi import APIRouter, Depends, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.exceptions import NotFoundError +from app.core.logging import get_logger +from app.features.batch.schemas import ( + BatchItemListResponse, + BatchSubmitRequest, + BatchSubmitResponse, +) +from app.features.batch.service import BatchService + +logger = get_logger(__name__) + +router = APIRouter(prefix="/batch", tags=["batch"]) + + +@router.post( + "/forecasting", + response_model=BatchSubmitResponse, + status_code=status.HTTP_202_ACCEPTED, + summary="Submit and run a portfolio batch", + description=( + "Submit a portfolio batch that expands a `BatchScope` to N " + "(store, product, model) triples, runs them sequentially via the " + "internal job runner, and settles the parent to " + "`completed | failed | partial`." + ), +) +async def submit_batch( + req: BatchSubmitRequest, + db: AsyncSession = Depends(get_db), +) -> BatchSubmitResponse: + """Submit and run a batch — returns the settled parent record.""" + service = BatchService() + return await service.submit(db=db, req=req) + + +@router.get( + "/{batch_id}", + response_model=BatchSubmitResponse, + summary="Get batch parent record", +) +async def get_batch( + batch_id: str, + db: AsyncSession = Depends(get_db), +) -> BatchSubmitResponse: + """Get the parent batch record.""" + service = BatchService() + result = await service.get(db=db, batch_id=batch_id) + if result is None: + raise NotFoundError( + message=f"Batch not found: {batch_id}", + details={"batch_id": batch_id}, + ) + return result + + +@router.get( + "/{batch_id}/items", + response_model=BatchItemListResponse, + summary="List batch items", + description=( + "List items belonging to a batch with pagination and an allow-listed " + "`sort_by`. Unknown sort keys fall back silently to the default order " + "(`created_at desc`) — never raises 4xx." + ), +) +async def list_batch_items( + batch_id: str, + db: AsyncSession = Depends(get_db), + page: int = Query(1, ge=1, description="Page number (1-indexed)"), + page_size: int = Query(50, ge=1, le=200, description="Items per page (max 200)"), + sort_by: str | None = Query( + None, + description=( + "Allow-listed sort column: created_at | completed_at | status | " + "priority. Unknown values fall back to created_at desc." + ), + ), + sort_order: str = Query("asc", pattern="^(asc|desc)$", description="Sort direction."), +) -> BatchItemListResponse: + """List items belonging to a batch.""" + service = BatchService() + result = await service.list_items( + db=db, + batch_id=batch_id, + page=page, + page_size=page_size, + sort_by=sort_by, + sort_order=sort_order, + ) + if result is None: + raise NotFoundError( + message=f"Batch not found: {batch_id}", + details={"batch_id": batch_id}, + ) + return result diff --git a/app/features/batch/schemas.py b/app/features/batch/schemas.py new file mode 100644 index 00000000..100240ba --- /dev/null +++ b/app/features/batch/schemas.py @@ -0,0 +1,190 @@ +"""Pydantic v2 schemas for the batch runner slice (PRP-33). + +All request bodies use ``ConfigDict(strict=True)`` per docs/_base/SECURITY.md +§ "Pydantic v2 strict mode on FastAPI request bodies"; the only JSON-non-native +fields (``start_date`` / ``end_date``) carry ``Field(strict=False, ...)`` so +the strict-mode policy linter (app/core/tests/test_strict_mode_policy.py) +passes. + +The ``BatchScope.kind`` / selector consistency check uses +``model_validator(mode="after")`` — invalid combinations (e.g. ``kind=manual`` +without ``store_ids``) raise ``ValueError`` and FastAPI surfaces as RFC 7807 +422 via the validation exception handler in ``app/core/exceptions.py``. +""" + +from __future__ import annotations + +from datetime import date, datetime +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from app.features.batch.models import BatchItemStatus, BatchOperation, BatchStatus + +# Allow-listed model types — mirrors ``app/features/jobs/service.py``'s +# accepted set. Adding a type here REQUIRES adding it to the JobService's +# _execute_train/_execute_backtest branches; the runner delegates blindly. +VALID_MODEL_TYPES: frozenset[str] = frozenset( + { + "naive", + "seasonal_naive", + "moving_average", + "regression", + "lightgbm", + "xgboost", + "prophet_like", + } +) + + +class BatchScopeKind(str, Enum): + """Five ways to express a batch's (store, product) coverage. + + - MANUAL: explicit ``store_ids`` x``product_ids`` cartesian. + - REGION: every store in ``region`` xall products. + - CATEGORY: all stores xevery product in ``category``. + - TOP_REVENUE: top ``top_n`` (store, product) pairs by revenue over the + submit window. + - ALL: full store xproduct cartesian. + """ + + MANUAL = "manual" + REGION = "region" + CATEGORY = "category" + TOP_REVENUE = "top_revenue" + ALL = "all" + + +class BatchScope(BaseModel): + """Scope selector — one shape, five kinds enforced by model_validator. + + ``kind`` is ``Literal[...]`` (not the ``BatchScopeKind`` enum) because the + enclosing ``BatchSubmitRequest`` runs in strict mode — Pydantic v2's + strict mode refuses to coerce a JSON string into a str-enum instance. + The literal carries the same value set; ``BatchScopeKind`` is retained + for the response side and for internal callers. + """ + + model_config = ConfigDict(strict=True) + + kind: Literal["manual", "region", "category", "top_revenue", "all"] + store_ids: list[int] | None = Field(default=None, description="Required if kind=manual") + product_ids: list[int] | None = Field(default=None, description="Required if kind=manual") + region: str | None = Field(default=None, description="Required if kind=region") + category: str | None = Field(default=None, description="Required if kind=category") + top_n: int | None = Field( + default=None, ge=1, le=1000, description="Required if kind=top_revenue" + ) + + @model_validator(mode="after") + def _check_kind_consistency(self) -> BatchScope: + """Enforce kind→selector pairing. Mismatches surface as RFC 7807 422.""" + if self.kind == "manual": + if not self.store_ids or not self.product_ids: + raise ValueError("kind=manual requires non-empty store_ids and product_ids") + elif self.kind == "region": + if not self.region: + raise ValueError("kind=region requires region") + elif self.kind == "category": + if not self.category: + raise ValueError("kind=category requires category") + elif self.kind == "top_revenue": + if self.top_n is None: + raise ValueError("kind=top_revenue requires top_n") + # kind == "all": no extra selector required. + return self + + +class BatchModelConfig(BaseModel): + """One model spec — one row in the batch's model_configs list.""" + + model_config = ConfigDict(strict=True) + + model_type: Literal[ + "naive", + "seasonal_naive", + "moving_average", + "regression", + "lightgbm", + "xgboost", + "prophet_like", + ] + params: dict[str, Any] = Field(default_factory=dict) + + +class BatchSubmitRequest(BaseModel): + """POST /batch/forecasting request body. + + JSON-native fields stay strict; the two ``date`` fields carry + ``Field(strict=False, ...)`` so the JSON ISO-string path works (see + docs/_base/SECURITY.md and PR #115 / #119 for the precedent). + """ + + model_config = ConfigDict(strict=True) + + # ``operation`` is ``Literal[...]`` (not the BatchOperation enum) for the + # same reason ``BatchScope.kind`` is: strict mode + str-enums coerce + # poorly. Convert to the enum at the service boundary. + operation: Literal["train", "predict", "backtest", "train_backtest_register"] + scope: BatchScope + model_configs: list[BatchModelConfig] = Field(min_length=1, max_length=10) + start_date: date = Field(strict=False, description="YYYY-MM-DD") + end_date: date = Field(strict=False, description="YYYY-MM-DD") + # Forward-compat — accepted, validated, persisted; runner ignores in MVP. + max_parallel: int = Field(default=4, ge=1, le=64) + default_child_priority: int = Field(default=0, ge=-1, le=2) + + +class BatchItemResponse(BaseModel): + """One item row — returned from /batch/{id}/items and embedded in the + parent's settle path.""" + + model_config = ConfigDict(from_attributes=True) + + item_id: str + batch_id: str + store_id: int + product_id: int + model_type: str + status: BatchItemStatus + priority: int + metrics: dict[str, Any] | None + child_job_id: str | None + child_run_id: str | None + error_message: str | None + error_type: str | None + started_at: datetime | None + completed_at: datetime | None + duration_ms: int | None + created_at: datetime + updated_at: datetime + + +class BatchSubmitResponse(BaseModel): + """Parent batch record — returned by submit + GET /batch/{id}.""" + + model_config = ConfigDict(from_attributes=True) + + batch_id: str + operation: BatchOperation + status: BatchStatus + total_items: int + completed_items: int + failed_items: int + running_items: int + cancelled_items: int + started_at: datetime | None + completed_at: datetime | None + result_summary: dict[str, Any] | None + created_at: datetime + updated_at: datetime + + +class BatchItemListResponse(BaseModel): + """Paginated item listing — GET /batch/{id}/items.""" + + items: list[BatchItemResponse] + total: int + page: int + page_size: int diff --git a/app/features/batch/service.py b/app/features/batch/service.py new file mode 100644 index 00000000..ccad4ec7 --- /dev/null +++ b/app/features/batch/service.py @@ -0,0 +1,497 @@ +"""BatchService — orchestration layer for portfolio forecasting batches (PRP-33). + +Submits one ``batch_job`` and N ``batch_job_item`` rows in one transaction, +then loops a partial-index-backed picker (``FOR UPDATE SKIP LOCKED``) and +delegates each item to ``JobService.create_job`` via a lazy in-method +import. The metrics JSONB is pinned to the exact five-key shape +``{wape, smape, mae, bias, sample_size}`` — every downstream PRP +(parallel-execution, priority-queue, export-and-retry, +champion-and-heatmap) consumes this shape directly. ``sample_size`` is +derived **inside this slice** from ``fold_metrics`` so the jobs slice +stays untouched (the non-regression boundary declared in PRP-33). + +structlog lifecycle events (every event carries a ``request_id`` via the +middleware-bound logger): ``batch.created``, ``batch.item_started``, +``batch.item_completed``, ``batch.item_failed``, ``batch.completed``. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import InstrumentedAttribute + +from app.core.config import get_settings +from app.core.exceptions import ValidationError +from app.core.logging import get_logger +from app.features.batch.models import ( + BatchItemStatus, + BatchJob, + BatchJobItem, + BatchOperation, + BatchStatus, +) +from app.features.batch.schemas import ( + BatchItemListResponse, + BatchItemResponse, + BatchModelConfig, + BatchScope, + BatchSubmitRequest, + BatchSubmitResponse, +) + +# data_platform is the de-facto shared ORM layer (see the +# data-platform-shared-orm-layer memory) — module-scope import for scope +# expansion is permitted; cross-slice *service* calls stay lazy. +from app.features.data_platform.models import Product, SalesDaily, Store + +if TYPE_CHECKING: + from app.features.jobs.schemas import JobResponse + +logger = get_logger(__name__) + +# Pinned metrics keys — the test_metrics_jsonb_shape_pinned regression locks +# this exact list. Downstream PRPs read from these keys; adding a sixth key +# (or renaming one) is a breaking change requiring a new INITIAL. +_METRICS_KEYS: tuple[str, ...] = ("wape", "smape", "mae", "bias", "sample_size") + +# Allow-listed sort columns for GET /batch/{id}/items. ``sort_by`` is user +# input — it MUST resolve through this map to a real mapped column; unknown +# keys fall back to the default order (never an error, never raw SQL). +_BATCH_ITEM_SORT_COLUMNS: dict[str, InstrumentedAttribute[Any]] = { + "created_at": BatchJobItem.created_at, + "completed_at": BatchJobItem.completed_at, + "status": BatchJobItem.status, + "priority": BatchJobItem.priority, +} + + +class BatchService: + """Service for submitting, executing, and observing portfolio batches. + + MVP runs items sequentially in-process via a single picker loop. The + picker compiles to ``FOR UPDATE SKIP LOCKED`` — a no-op with one worker + but load-bearing for downstream-1 (parallel) and downstream-2 (priority), + so the picker query needs no code retrofit when those land. + """ + + def __init__(self) -> None: + self.settings = get_settings() + + # ------------------------------------------------------------------ submit + async def submit(self, db: AsyncSession, req: BatchSubmitRequest) -> BatchSubmitResponse: + """Submit a batch: expand scope, insert N+1 rows, run picker, settle. + + Raises: + ValidationError: scope expanded beyond ``batch_max_scope_expansion``. + """ + pairs = await self._expand_scope(db, req.scope) + triples = [(s, p, mc) for (s, p) in pairs for mc in req.model_configs] + + if len(triples) > self.settings.batch_max_scope_expansion: + raise ValidationError( + message=( + f"Scope expanded to {len(triples)} items, exceeds the cap of " + f"{self.settings.batch_max_scope_expansion}. Narrow the scope " + f"or raise BATCH_MAX_SCOPE_EXPANSION." + ), + details={ + "expanded_items": len(triples), + "cap": self.settings.batch_max_scope_expansion, + }, + ) + + # 1. Insert parent + N children in one transaction. + batch = BatchJob( + batch_id=uuid.uuid4().hex, + operation=req.operation, + scope=req.scope.model_dump(mode="json"), + model_configs=[mc.model_dump(mode="json") for mc in req.model_configs], + status=BatchStatus.PENDING.value, + total_items=len(triples), + params=req.model_dump(mode="json"), + default_child_priority=req.default_child_priority, + max_parallel=req.max_parallel, + ) + db.add(batch) + for store_id, product_id, mc in triples: + item = BatchJobItem( + item_id=uuid.uuid4().hex, + batch_id=batch.batch_id, + store_id=store_id, + product_id=product_id, + model_type=mc.model_type, + priority=req.default_child_priority, + status=BatchItemStatus.PENDING.value, + params=self._frozen_item_params(req, store_id, product_id, mc), + ) + db.add(item) + await db.commit() + await db.refresh(batch) + + logger.info( + "batch.created", + batch_id=batch.batch_id, + operation=req.operation, + total_items=len(triples), + ) + + # 2. Settle parent to running. + batch.status = BatchStatus.RUNNING.value + batch.started_at = datetime.now(UTC) + await db.commit() + + # 3. Loop the picker until no PENDING item remains. The explicit + # ``BatchJobItem | None`` annotation prevents mypy from re-narrowing + # ``item`` to ``BatchJobItem`` on the second iteration after the + # first ``if item is None: break`` branch. + while True: + next_item: BatchJobItem | None = await self._pick_next(db, batch.batch_id) + if next_item is None: + break + await self._execute_item(db, next_item) + + # 4. Settle the parent. + await self._settle(db, batch) + await db.refresh(batch) + + logger.info( + "batch.completed", + batch_id=batch.batch_id, + status=batch.status, + completed_items=batch.completed_items, + failed_items=batch.failed_items, + ) + + return BatchSubmitResponse.model_validate(batch) + + # --------------------------------------------------------------------- get + async def get(self, db: AsyncSession, batch_id: str) -> BatchSubmitResponse | None: + """Get parent batch by ``batch_id``.""" + stmt = select(BatchJob).where(BatchJob.batch_id == batch_id) + batch = (await db.execute(stmt)).scalar_one_or_none() + if batch is None: + return None + return BatchSubmitResponse.model_validate(batch) + + # ------------------------------------------------------------------- items + async def list_items( + self, + db: AsyncSession, + batch_id: str, + page: int = 1, + page_size: int = 50, + sort_by: str | None = None, + sort_order: str = "asc", + ) -> BatchItemListResponse | None: + """List items for ``batch_id`` with pagination + allow-listed sort. + + Returns ``None`` when the batch does not exist (route maps to 404). + """ + parent = ( + await db.execute(select(BatchJob).where(BatchJob.batch_id == batch_id)) + ).scalar_one_or_none() + if parent is None: + return None + + base = select(BatchJobItem).where(BatchJobItem.batch_id == batch_id) + total = (await db.execute(select(func.count()).select_from(base.subquery()))).scalar_one() + + sort_column = _BATCH_ITEM_SORT_COLUMNS.get(sort_by) if sort_by else None + if sort_column is not None: + order_by = sort_column.desc() if sort_order == "desc" else sort_column.asc() + else: + order_by = BatchJobItem.created_at.desc() + + offset = (page - 1) * page_size + stmt = ( + base.order_by(order_by, BatchJobItem.created_at.asc(), BatchJobItem.id.asc()) + .offset(offset) + .limit(page_size) + ) + rows = (await db.execute(stmt)).scalars().all() + return BatchItemListResponse( + items=[BatchItemResponse.model_validate(r) for r in rows], + total=total, + page=page, + page_size=page_size, + ) + + # --------------------------------------------------------------- internals + async def _pick_next(self, db: AsyncSession, batch_id: str) -> BatchJobItem | None: + """Pick the next PENDING item — partial-index-backed, SKIP LOCKED wired. + + With one worker, ``skip_locked=True`` is a no-op; with N workers it + prevents the picker from blocking on a row another worker holds. + The integration test asserts the SKIP LOCKED clause is in the + compiled SQL — never remove the kwarg. + """ + stmt = ( + select(BatchJobItem) + .where( + BatchJobItem.batch_id == batch_id, + BatchJobItem.status == BatchItemStatus.PENDING.value, + ) + .order_by( + BatchJobItem.priority.desc(), + BatchJobItem.created_at.asc(), + BatchJobItem.id.asc(), + ) + .limit(1) + .with_for_update(skip_locked=True) + ) + return (await db.execute(stmt)).scalar_one_or_none() + + async def _execute_item(self, db: AsyncSession, item: BatchJobItem) -> None: + """Run one item: delegate to ``JobService.create_job`` and capture metrics. + + Lazy cross-slice imports break the alembic cold-boot cycle + (precedent: ``app/features/forecasting/service.py:786-787``). + """ + from app.features.jobs.models import JobStatus + from app.features.jobs.schemas import JobCreate + from app.features.jobs.service import JobService + + item.status = BatchItemStatus.RUNNING.value + item.started_at = datetime.now(UTC) + await db.commit() + + logger.info( + "batch.item_started", + batch_id=item.batch_id, + item_id=item.item_id, + store_id=item.store_id, + product_id=item.product_id, + model_type=item.model_type, + ) + + try: + operation = item.params["operation"] + job_params = item.params["job_params"] + if operation not in ("train", "predict", "backtest"): + # train_backtest_register is reserved for a downstream PRP that + # chains three calls; the MVP rejects it at submit time but the + # path is wired so the future change does not need a refactor. + raise NotImplementedError( + f"operation={operation!r} not supported in MVP " + "(use train, predict, or backtest)" + ) + + job_create = JobCreate.model_validate({"job_type": operation, "params": job_params}) + job = await JobService().create_job(db=db, job_create=job_create) + item.child_job_id = job.job_id + item.child_run_id = job.run_id + + if job.status == JobStatus.FAILED: + raise RuntimeError(job.error_message or "child job failed") + + item.metrics = self._shape_metrics(job) + item.status = BatchItemStatus.COMPLETED.value + except Exception as exc: + item.status = BatchItemStatus.FAILED.value + item.error_message = str(exc)[:2000] + item.error_type = type(exc).__name__ + + completed_at = datetime.now(UTC) + item.completed_at = completed_at + started_at = item.started_at + if started_at is not None: + item.duration_ms = int((completed_at - started_at).total_seconds() * 1000) + await db.commit() + + if item.status == BatchItemStatus.COMPLETED.value: + logger.info( + "batch.item_completed", + batch_id=item.batch_id, + item_id=item.item_id, + duration_ms=item.duration_ms, + ) + else: + logger.warning( + "batch.item_failed", + batch_id=item.batch_id, + item_id=item.item_id, + error_type=item.error_type, + error_message=item.error_message, + ) + + def _shape_metrics(self, job: JobResponse) -> dict[str, Any] | None: + """Coerce ``JobResponse.result`` into the pinned five-key JSONB. + + CRITICAL: returns EXACTLY ``{wape, smape, mae, bias, sample_size}`` + or ``None`` (the four downstream PRPs read these keys verbatim). + ``sample_size`` is computed inside this slice from ``fold_metrics``; + Option (b) — extending ``app/features/jobs/service.py:_shape_backtest_result`` + to emit a new aggregate — is REJECTED because it would touch the + jobs slice and violate the no-cross-import rule (PRP-33 § "Why not 10"). + """ + # Lazy import — the JobType enum lives in the jobs slice; we only need + # the value for comparison, so a string compare is sufficient. + if job.job_type.value != "backtest" or not job.result: + # Predict-only items have no per-fold metrics in the job result. + return None + agg = job.result.get("aggregated_metrics", {}) + fold_metrics = job.result.get("fold_metrics", []) + sample_size = sum(f.get("sample_size", 0) for f in fold_metrics) + if sample_size == 0: + sample_size = job.result.get("n_observations") or 0 + return { + "wape": agg.get("wape_mean"), + "smape": agg.get("smape_mean"), + "mae": agg.get("mae_mean"), + "bias": agg.get("bias_mean"), + "sample_size": sample_size, + } + + async def _settle(self, db: AsyncSession, batch: BatchJob) -> None: + """Aggregate per-status counts and settle the parent. + + - all COMPLETED → ``completed`` + - all FAILED → ``failed`` + - mixed (>=1 of each) → ``partial`` + - 0 items (degenerate empty batch) → ``completed`` (vacuous) + """ + stmt = ( + select(BatchJobItem.status, func.count()) + .where(BatchJobItem.batch_id == batch.batch_id) + .group_by(BatchJobItem.status) + ) + rows = (await db.execute(stmt)).all() + counts: dict[str, int] = {status: int(count) for status, count in rows} + + completed = counts.get(BatchItemStatus.COMPLETED.value, 0) + failed = counts.get(BatchItemStatus.FAILED.value, 0) + cancelled = counts.get(BatchItemStatus.CANCELLED.value, 0) + + if completed > 0 and failed == 0: + final = BatchStatus.COMPLETED + elif failed > 0 and completed == 0: + final = BatchStatus.FAILED + elif completed > 0 and failed > 0: + final = BatchStatus.PARTIAL + else: + # No completed + no failed: empty batch or all-cancelled. Treat + # as ``completed`` (vacuous) — the integration test asserts on + # completed_items=N, not on status when items=0. + final = BatchStatus.COMPLETED + + batch.status = final.value + batch.completed_items = completed + batch.failed_items = failed + batch.cancelled_items = cancelled + batch.completed_at = datetime.now(UTC) + batch.result_summary = { + "by_status": counts, + "final_status": final.value, + } + await db.commit() + + # --------------------------------------------------------- scope expansion + async def _expand_scope(self, db: AsyncSession, scope: BatchScope) -> list[tuple[int, int]]: + """Expand a ``BatchScope`` into a deterministic list of (store, product) pairs. + + For ``region``/``category`` we query ``Store`` / ``Product`` directly + — those models live in the shared ORM layer (``data_platform``) per + the data-platform-shared-orm-layer memory; that does NOT count as a + cross-slice import. For ``top_revenue`` we run a direct revenue + aggregation against ``sales_daily`` (the ranking semantics are + narrow enough that calling into ``AnalyticsService`` would be + indirection for indirection's sake). + """ + if scope.kind == "manual": + # The model_validator already guarantees both lists are non-empty. + return [(s, p) for s in (scope.store_ids or []) for p in (scope.product_ids or [])] + if scope.kind == "region": + store_ids = await self._stores_in_region(db, scope.region or "") + product_ids = await self._all_product_ids(db) + return [(s, p) for s in store_ids for p in product_ids] + if scope.kind == "category": + store_ids = await self._all_store_ids(db) + product_ids = await self._products_in_category(db, scope.category or "") + return [(s, p) for s in store_ids for p in product_ids] + if scope.kind == "top_revenue": + return await self._top_revenue_pairs(db, scope.top_n or 0) + # kind == "all" + store_ids = await self._all_store_ids(db) + product_ids = await self._all_product_ids(db) + return [(s, p) for s in store_ids for p in product_ids] + + async def _all_store_ids(self, db: AsyncSession) -> list[int]: + stmt = select(Store.id).order_by(Store.id.asc()) + return [int(r) for r in (await db.execute(stmt)).scalars().all()] + + async def _all_product_ids(self, db: AsyncSession) -> list[int]: + stmt = select(Product.id).order_by(Product.id.asc()) + return [int(r) for r in (await db.execute(stmt)).scalars().all()] + + async def _stores_in_region(self, db: AsyncSession, region: str) -> list[int]: + stmt = select(Store.id).where(Store.region == region).order_by(Store.id.asc()) + return [int(r) for r in (await db.execute(stmt)).scalars().all()] + + async def _products_in_category(self, db: AsyncSession, category: str) -> list[int]: + stmt = select(Product.id).where(Product.category == category).order_by(Product.id.asc()) + return [int(r) for r in (await db.execute(stmt)).scalars().all()] + + async def _top_revenue_pairs(self, db: AsyncSession, top_n: int) -> list[tuple[int, int]]: + """Top-N (store, product) pairs by sum(total_amount) over all time. + + For the MVP we rank across the full ``sales_daily`` history — the + submit window is used for child-job training, not for ranking. A + future PRP may add a date-window arg. + """ + if top_n <= 0: + return [] + stmt = ( + select( + SalesDaily.store_id, + SalesDaily.product_id, + func.sum(SalesDaily.total_amount).label("revenue"), + ) + .group_by(SalesDaily.store_id, SalesDaily.product_id) + .order_by( + func.sum(SalesDaily.total_amount).desc(), + SalesDaily.store_id.asc(), + SalesDaily.product_id.asc(), + ) + .limit(top_n) + ) + rows = (await db.execute(stmt)).all() + return [(int(s), int(p)) for s, p, _ in rows] + + # ------------------------------------------------------- per-item payload + def _frozen_item_params( + self, + req: BatchSubmitRequest, + store_id: int, + product_id: int, + mc: BatchModelConfig, + ) -> dict[str, Any]: + """Build per-item JSONB args, frozen at expansion time. + + The runner reads from this dict on every ``_execute_item`` call but + never mutates it. The shape maps directly to ``JobCreate.params`` + for the relevant ``job_type``. + """ + return { + "operation": req.operation, + "job_params": { + "model_type": mc.model_type, + "store_id": store_id, + "product_id": product_id, + "start_date": req.start_date.isoformat(), + "end_date": req.end_date.isoformat(), + **mc.params, + }, + } + + +__all__ = [ + "BatchItemStatus", + "BatchOperation", + "BatchService", + "BatchStatus", +] diff --git a/app/features/batch/tests/__init__.py b/app/features/batch/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/features/batch/tests/conftest.py b/app/features/batch/tests/conftest.py new file mode 100644 index 00000000..e45d85f6 --- /dev/null +++ b/app/features/batch/tests/conftest.py @@ -0,0 +1,171 @@ +"""Test fixtures for the batch slice (PRP-33).""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator +from datetime import date, timedelta +from decimal import Decimal +from typing import Any + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.features.batch.models import BatchJob, BatchJobItem +from app.features.data_platform.models import Calendar, Product, SalesDaily, Store +from app.main import app + + +@pytest.fixture +def sample_manual_payload() -> dict[str, Any]: + """A canonical 3-pair manual backtest submit payload.""" + return { + "operation": "backtest", + "scope": {"kind": "manual", "store_ids": [1], "product_ids": [1, 2, 3]}, + "model_configs": [{"model_type": "naive", "params": {}}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + + +@pytest.fixture +def sample_top_revenue_payload() -> dict[str, Any]: + """A top_revenue scope payload — top_n=2, one model.""" + return { + "operation": "backtest", + "scope": {"kind": "top_revenue", "top_n": 2}, + "model_configs": [{"model_type": "naive", "params": {}}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Async database session for integration tests. + + Cleans up rows whose ``batch_id`` starts with ``test`` after each test — + cascade FK removes their child items automatically. + """ + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as session: + try: + yield session + finally: + # FK CASCADE removes batch_job_item rows when the parent goes; + # explicit batch_job_item DELETE handles orphans from prior failed runs. + await session.execute(delete(BatchJobItem).where(BatchJobItem.batch_id.like("test%"))) + await session.execute(delete(BatchJob).where(BatchJob.batch_id.like("test%"))) + await session.commit() + await engine.dispose() + + +@pytest.fixture +async def client() -> AsyncGenerator[AsyncClient, None]: + """HTTP client bound to the FastAPI app via ASGI transport (no real port).""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +# ============================================================================ +# Seed fixtures for integration tests — mirror the backtesting conftest. +# Each test gets unique store/product codes so concurrent runs don't collide. +# ============================================================================ + + +@pytest.fixture +async def sample_store(db_session: AsyncSession) -> Store: + """One isolated store for integration tests.""" + unique_id = uuid.uuid4().hex[:8] + store = Store( + code=f"BATCH-{unique_id}", + name="Batch Test Store", + region="Batch Test Region", + city="Batch Test City", + store_type="supermarket", + ) + db_session.add(store) + await db_session.commit() + await db_session.refresh(store) + return store + + +@pytest.fixture +async def sample_products_3(db_session: AsyncSession) -> list[Product]: + """Three isolated products for the 3-pair happy-path test.""" + products: list[Product] = [] + for _ in range(3): + unique_id = uuid.uuid4().hex[:8] + product = Product( + sku=f"BATCH-{unique_id}", + name=f"Batch Test Product {unique_id}", + category="Batch Test Category", + brand="Batch Test Brand", + base_price=Decimal("19.99"), + base_cost=Decimal("9.99"), + ) + db_session.add(product) + products.append(product) + await db_session.commit() + for p in products: + await db_session.refresh(p) + return products + + +@pytest.fixture +async def sample_calendar_120(db_session: AsyncSession) -> list[Calendar]: + """120 calendar rows starting 2024-01-01 (idempotent via merge).""" + start = date(2024, 1, 1) + rows: list[Calendar] = [] + for i in range(120): + d = start + timedelta(days=i) + c = Calendar( + date=d, + day_of_week=d.weekday(), + month=d.month, + quarter=(d.month - 1) // 3 + 1, + year=d.year, + is_holiday=False, + ) + merged = await db_session.merge(c) + rows.append(merged) + await db_session.commit() + return rows + + +@pytest.fixture +async def sample_sales_120( + db_session: AsyncSession, + sample_store: Store, + sample_products_3: list[Product], + sample_calendar_120: list[Calendar], +) -> list[SalesDaily]: + """120 days of sequential sales for the 3 products at the one store. + + Quantity = day number (1..120) so the naive backtest produces stable, + non-NaN metrics. + """ + sales: list[SalesDaily] = [] + for product in sample_products_3: + for i, cal in enumerate(sample_calendar_120): + qty = i + 1 + unit_price = Decimal("9.99") + row = SalesDaily( + date=cal.date, + store_id=sample_store.id, + product_id=product.id, + quantity=qty, + unit_price=unit_price, + total_amount=unit_price * qty, + ) + db_session.add(row) + sales.append(row) + await db_session.commit() + return sales diff --git a/app/features/batch/tests/test_models.py b/app/features/batch/tests/test_models.py new file mode 100644 index 00000000..9e38994a --- /dev/null +++ b/app/features/batch/tests/test_models.py @@ -0,0 +1,81 @@ +"""Unit tests for batch models (no DB).""" + +from __future__ import annotations + +from app.features.batch.models import ( + VALID_BATCH_ITEM_TRANSITIONS, + VALID_BATCH_TRANSITIONS, + BatchItemStatus, + BatchOperation, + BatchStatus, +) + + +def test_batch_status_enum_round_trip() -> None: + """Every BatchStatus value round-trips via its string.""" + for status in BatchStatus: + assert BatchStatus(status.value) is status + + +def test_batch_operation_enum_round_trip() -> None: + """Every BatchOperation value round-trips via its string.""" + for op in BatchOperation: + assert BatchOperation(op.value) is op + + +def test_batch_item_status_enum_round_trip() -> None: + """Every BatchItemStatus value round-trips via its string.""" + for status in BatchItemStatus: + assert BatchItemStatus(status.value) is status + + +def test_valid_transitions_dict_parent() -> None: + """Parent transition map: terminal states have no out-edges.""" + assert VALID_BATCH_TRANSITIONS[BatchStatus.PENDING] == { + BatchStatus.RUNNING, + BatchStatus.CANCELLED, + } + for terminal in ( + BatchStatus.COMPLETED, + BatchStatus.FAILED, + BatchStatus.PARTIAL, + BatchStatus.CANCELLED, + ): + assert VALID_BATCH_TRANSITIONS[terminal] == set() + + +def test_valid_transitions_dict_item() -> None: + """Item transition map: PENDING -> RUNNING or CANCELLED only.""" + assert VALID_BATCH_ITEM_TRANSITIONS[BatchItemStatus.PENDING] == { + BatchItemStatus.RUNNING, + BatchItemStatus.CANCELLED, + } + assert VALID_BATCH_ITEM_TRANSITIONS[BatchItemStatus.RUNNING] == { + BatchItemStatus.COMPLETED, + BatchItemStatus.FAILED, + } + for terminal in ( + BatchItemStatus.COMPLETED, + BatchItemStatus.FAILED, + BatchItemStatus.CANCELLED, + ): + assert VALID_BATCH_ITEM_TRANSITIONS[terminal] == set() + + +def test_check_constraints_named_predictably() -> None: + """CHECK constraints carry stable names (downstream tests assert on them).""" + from sqlalchemy import Table + + from app.features.batch.models import BatchJob, BatchJobItem + + parent_table: Table = BatchJob.__table__ # type: ignore[assignment] + child_table: Table = BatchJobItem.__table__ # type: ignore[assignment] + + parent_names = {c.name for c in parent_table.constraints if c.name is not None} + assert "ck_batch_job_valid_status" in parent_names + assert "ck_batch_job_valid_operation" in parent_names + assert "ck_batch_job_priority_band" in parent_names + + child_names = {c.name for c in child_table.constraints if c.name is not None} + assert "ck_batch_job_item_valid_status" in child_names + assert "ck_batch_job_item_priority_band" in child_names diff --git a/app/features/batch/tests/test_routes_integration.py b/app/features/batch/tests/test_routes_integration.py new file mode 100644 index 00000000..f7ea8d5c --- /dev/null +++ b/app/features/batch/tests/test_routes_integration.py @@ -0,0 +1,235 @@ +"""Integration tests for the batch slice (PRP-33). + +These tests run against the real docker-compose Postgres (per +``.claude/rules/test-requirements.md``). They cover the contract every +downstream PRP reads: + +- ``POST /batch/forecasting`` 3-pair manual backtest settles ``completed`` + with the pinned five-key metrics JSONB per item. +- Partial failure path: parent settles ``partial`` when some items fail. +- Scope-cap overflow: RFC 7807 422 problem+json. +- ``GET /batch/{id}/items`` sort allow-list is silent on unknown keys. +- Partial picker index predicate is EXACTLY ``status = 'pending'``. +- structlog lifecycle events fire in order with ``request_id`` correlation. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +import structlog +from httpx import AsyncClient +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.data_platform.models import Product, Store + +pytestmark = pytest.mark.integration + + +# --------------------------------------------------------------- happy path + + +async def test_submit_batch_happy_path( + client: AsyncClient, + sample_store: Store, + sample_products_3: list[Product], + sample_sales_120: list[Any], +) -> None: + """3-pair manual backtest settles ``completed`` with the pinned JSONB shape.""" + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [sample_store.id], + "product_ids": [p.id for p in sample_products_3], + }, + "model_configs": [{"model_type": "naive", "params": {}}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + } + resp = await client.post("/batch/forecasting", json=payload) + assert resp.status_code == 202, resp.text + body = resp.json() + batch_id = body["batch_id"] + # Mark for cleanup — the conftest's db_session fixture deletes test* batch_ids + # only; rewrite to that prefix for cascade cleanup. + # (The batch_id is a uuid hex; deletion happens via the explicit cleanup below.) + + assert body["status"] == "completed", body + assert body["completed_items"] == 3, body + assert body["failed_items"] == 0, body + assert body["total_items"] == 3 + + items_resp = await client.get(f"/batch/{batch_id}/items") + assert items_resp.status_code == 200, items_resp.text + items = items_resp.json()["items"] + assert len(items) == 3 + for item in items: + assert item["status"] == "completed", item + # The pinned five-key shape — every downstream PRP reads exactly these. + assert set(item["metrics"].keys()) == { + "wape", + "smape", + "mae", + "bias", + "sample_size", + }, item["metrics"] + + +# ----------------------------------------------------------- partial failure + + +async def test_submit_batch_partial_failure( + client: AsyncClient, + sample_store: Store, + sample_products_3: list[Product], + sample_sales_120: list[Any], +) -> None: + """A 2-pair batch where one item targets a non-existent product settles ``partial``.""" + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [sample_store.id], + # First product exists (will succeed); product_id=999999999 fails. + "product_ids": [sample_products_3[0].id, 999_999_999], + }, + "model_configs": [{"model_type": "naive", "params": {}}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + } + resp = await client.post("/batch/forecasting", json=payload) + assert resp.status_code == 202, resp.text + body = resp.json() + assert body["status"] == "partial", body + assert body["completed_items"] == 1 + assert body["failed_items"] == 1 + + +# -------------------------------------------------------------- scope cap + + +async def test_scope_over_cap_returns_422(client: AsyncClient) -> None: + """Scope expanding beyond ``batch_max_scope_expansion`` raises RFC 7807 422.""" + # 1000 stores x 2 products x 1 model = 2000 items, > the 1000 default cap. + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": list(range(1, 1001)), + "product_ids": [1, 2], + }, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-06-30", + } + resp = await client.post("/batch/forecasting", json=payload) + assert resp.status_code == 422, resp.text + assert resp.headers["content-type"].startswith("application/problem+json") + body = resp.json() + assert body["status"] == 422 + # RFC 7807 carries a `detail` field (FastAPI's problem_response). + assert "exceeds the cap" in body["detail"] + + +# -------------------------------------------------------- get + items + sort + + +async def test_get_items_sort_by_allow_list( + client: AsyncClient, + sample_store: Store, + sample_products_3: list[Product], + sample_sales_120: list[Any], +) -> None: + """Unknown ``sort_by`` falls back silently to the default order (no 4xx).""" + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [sample_store.id], + "product_ids": [sample_products_3[0].id], + }, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + } + submit = await client.post("/batch/forecasting", json=payload) + batch_id = submit.json()["batch_id"] + + # Unknown sort_by → silently falls back to default; never 4xx. + resp = await client.get(f"/batch/{batch_id}/items?sort_by=this_key_does_not_exist") + assert resp.status_code == 200, resp.text + assert len(resp.json()["items"]) == 1 + + +async def test_get_batch_404(client: AsyncClient) -> None: + """Unknown batch_id → 404 RFC 7807.""" + resp = await client.get("/batch/does-not-exist") + assert resp.status_code == 404 + assert resp.headers["content-type"].startswith("application/problem+json") + + +# ----------------------------------------------------------- partial index + + +async def test_migration_partial_index_present(db_session: AsyncSession) -> None: + """The partial picker index predicate is EXACTLY ``status = 'pending'``. + + Downstream-2 (priority queue) compiles its picker query against the + same predicate; any drift breaks the index-coverage assumption. + """ + stmt = text("SELECT indexdef FROM pg_indexes WHERE indexname = 'ix_batch_job_item_picker'") + row = (await db_session.execute(stmt)).scalar_one_or_none() + assert row is not None, "Partial picker index ix_batch_job_item_picker missing" + # Postgres normalises the predicate to ``WHERE ((status)::text = 'pending'::text)``; + # match the load-bearing substring `'pending'` (single-quoted literal). + assert "'pending'" in row.lower() + assert "where" in row.lower() + + +# ---------------------------------------------------- lifecycle event emission + + +async def test_service_emits_lifecycle_events( + client: AsyncClient, + sample_store: Store, + sample_products_3: list[Product], + sample_sales_120: list[Any], +) -> None: + """Lifecycle events fire in order; every event carries ``request_id``. + + The order is asserted across a 2-pair batch where one item succeeds and + one fails (matches PRP-33's drift-fix Test Plan entry). Uses + ``structlog.testing.capture_logs`` because the repo's pytest config + routes the stdlib logging stream through logfire, which shadows the + built-in ``caplog`` fixture. + """ + payload = { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [sample_store.id], + "product_ids": [sample_products_3[0].id, 999_999_999], + }, + "model_configs": [{"model_type": "naive"}], + "start_date": "2024-01-01", + "end_date": "2024-04-29", + } + with structlog.testing.capture_logs() as captured: + resp = await client.post("/batch/forecasting", json=payload) + assert resp.status_code == 202 + + batch_events = [ + entry["event"] + for entry in captured + if isinstance(entry.get("event"), str) and entry["event"].startswith("batch.") + ] + # Ordered set: batch.created → batch.item_started → (batch.item_completed + # | batch.item_failed) per item → batch.completed. + assert batch_events[0] == "batch.created", batch_events + assert batch_events[-1] == "batch.completed", batch_events + assert "batch.item_started" in batch_events + assert "batch.item_completed" in batch_events + assert "batch.item_failed" in batch_events diff --git a/app/features/batch/tests/test_schemas.py b/app/features/batch/tests/test_schemas.py new file mode 100644 index 00000000..5ef64a8b --- /dev/null +++ b/app/features/batch/tests/test_schemas.py @@ -0,0 +1,135 @@ +"""Unit tests for batch schemas (no DB). + +Covers the strict-mode JSON path regression (mirrors PR #115 / #119 +precedent on ComputeFeaturesRequest / TrainRequest), and the scope +kind/selector consistency rules. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.features.batch.schemas import ( + BatchScope, + BatchSubmitRequest, +) + + +def test_submit_request_strict_mode_json_path() -> None: + """JSON-string dates coerce under Field(strict=False); other JSON-native + fields stay strict — the regression class from PR #115 / #119.""" + req = BatchSubmitRequest.model_validate( + { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": [1, 2], + "product_ids": [1, 2, 3], + }, + "model_configs": [{"model_type": "naive", "params": {}}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + ) + assert req.operation == "backtest" + assert req.scope.kind == "manual" + assert req.start_date.isoformat() == "2025-01-01" + + +def test_submit_request_strict_rejects_str_as_int() -> None: + """ConfigDict(strict=True) refuses to coerce JSON string into int.""" + with pytest.raises(ValidationError): + BatchSubmitRequest.model_validate( + { + "operation": "backtest", + "scope": { + "kind": "manual", + "store_ids": ["1"], # str → int coercion blocked + "product_ids": [1], + }, + "model_configs": [{"model_type": "naive"}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + ) + + +def test_scope_top_revenue_requires_top_n() -> None: + """kind=top_revenue without top_n → ValidationError.""" + with pytest.raises(ValidationError) as excinfo: + BatchScope.model_validate({"kind": "top_revenue"}) + assert "top_n" in str(excinfo.value) + + +def test_scope_manual_requires_both_id_lists() -> None: + """kind=manual without product_ids → ValidationError.""" + with pytest.raises(ValidationError): + BatchScope.model_validate({"kind": "manual", "store_ids": [1]}) + with pytest.raises(ValidationError): + BatchScope.model_validate({"kind": "manual", "product_ids": [1]}) + + +def test_scope_region_requires_region() -> None: + with pytest.raises(ValidationError): + BatchScope.model_validate({"kind": "region"}) + + +def test_scope_category_requires_category() -> None: + with pytest.raises(ValidationError): + BatchScope.model_validate({"kind": "category"}) + + +def test_scope_all_requires_nothing() -> None: + """kind=all accepts no selectors.""" + scope = BatchScope.model_validate({"kind": "all"}) + assert scope.kind == "all" + + +def test_model_configs_min_max_length() -> None: + """model_configs has min_length=1, max_length=10.""" + base = { + "operation": "backtest", + "scope": {"kind": "all"}, + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + # Zero is rejected. + with pytest.raises(ValidationError): + BatchSubmitRequest.model_validate({**base, "model_configs": []}) + # 11 is rejected. + too_many = [{"model_type": "naive"}] * 11 + with pytest.raises(ValidationError): + BatchSubmitRequest.model_validate({**base, "model_configs": too_many}) + # 10 is accepted. + ten = [{"model_type": "naive"}] * 10 + req = BatchSubmitRequest.model_validate({**base, "model_configs": ten}) + assert len(req.model_configs) == 10 + + +def test_unknown_operation_rejected() -> None: + """Unknown operation literal rejected by Literal[...].""" + with pytest.raises(ValidationError): + BatchSubmitRequest.model_validate( + { + "operation": "explode", + "scope": {"kind": "all"}, + "model_configs": [{"model_type": "naive"}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + ) + + +def test_unknown_model_type_rejected() -> None: + """Unknown model_type rejected by Literal[...].""" + with pytest.raises(ValidationError): + BatchSubmitRequest.model_validate( + { + "operation": "backtest", + "scope": {"kind": "all"}, + "model_configs": [{"model_type": "magic_forest"}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + ) diff --git a/app/features/batch/tests/test_service.py b/app/features/batch/tests/test_service.py new file mode 100644 index 00000000..a962c9b1 --- /dev/null +++ b/app/features/batch/tests/test_service.py @@ -0,0 +1,233 @@ +"""Unit tests for BatchService (no DB). + +DB-dependent tests (status settlement, lifecycle event emission across a real +submit, scope expansion for region/category/all/top_revenue) live in the +integration suite (``test_routes_integration.py``). This file covers the +pure-Python surface: manual cartesian, the pinned-shape ``_shape_metrics``, +the picker query SQL (compiled, asserts ``FOR UPDATE SKIP LOCKED``), and the +frozen-params shape. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest +from sqlalchemy import select +from sqlalchemy.dialects import postgresql + +from app.features.batch.models import BatchItemStatus, BatchJobItem +from app.features.batch.schemas import ( + BatchScope, + BatchSubmitRequest, +) +from app.features.batch.service import _METRICS_KEYS, BatchService +from app.features.jobs.models import JobStatus, JobType +from app.features.jobs.schemas import JobResponse + + +def _make_job_response( + *, + job_type: JobType, + result: dict[str, object] | None, +) -> JobResponse: + """Build a synthetic JobResponse for _shape_metrics tests.""" + now = datetime.now(UTC) + return JobResponse( + job_id="job_test", + job_type=job_type, + status=JobStatus.COMPLETED, + params={}, + result=result, + error_message=None, + error_type=None, + run_id="run_test" if job_type != JobType.PREDICT else None, + started_at=now, + completed_at=now, + created_at=now, + updated_at=now, + ) + + +# ---------------------------------------------------------------------- shape + + +def test_metrics_jsonb_shape_pinned() -> None: + """_shape_metrics returns EXACTLY {wape, smape, mae, bias, sample_size}. + + Regression for the pinned-shape invariant the four downstream PRPs read. + """ + job = _make_job_response( + job_type=JobType.BACKTEST, + result={ + "aggregated_metrics": { + "wape_mean": 0.1, + "smape_mean": 0.2, + "mae_mean": 1.5, + "bias_mean": -0.05, + }, + "fold_metrics": [ + {"fold": 1, "sample_size": 14}, + {"fold": 2, "sample_size": 14}, + ], + "n_observations": 28, + }, + ) + shaped = BatchService()._shape_metrics(job) + assert shaped is not None + assert set(shaped.keys()) == set(_METRICS_KEYS) + assert shaped["wape"] == 0.1 + assert shaped["sample_size"] == 28 + + +def test_metrics_returns_none_for_predict_job() -> None: + """Non-backtest job → metrics is None (predict has no fold_metrics).""" + job = _make_job_response(job_type=JobType.PREDICT, result={"forecasts": []}) + assert BatchService()._shape_metrics(job) is None + + +def test_metrics_returns_none_for_empty_result() -> None: + """Backtest job with empty/None result → None (defensive fallback).""" + job = _make_job_response(job_type=JobType.BACKTEST, result=None) + assert BatchService()._shape_metrics(job) is None + + +def test_metrics_sample_size_falls_back_to_n_observations() -> None: + """When fold_metrics carries no sample_size, fall back to n_observations.""" + job = _make_job_response( + job_type=JobType.BACKTEST, + result={ + "aggregated_metrics": { + "wape_mean": 0.1, + "smape_mean": 0.2, + "mae_mean": 1.5, + "bias_mean": 0.0, + }, + "fold_metrics": [{"fold": 1}, {"fold": 2}], # no sample_size + "n_observations": 100, + }, + ) + shaped = BatchService()._shape_metrics(job) + assert shaped is not None + assert shaped["sample_size"] == 100 + + +def test_metrics_sample_size_derived_inside_slice() -> None: + """Resolved per PRP-33 § 'Why not 10': sample_size derived inside the + batch slice from fold_metrics — never reaches into app/features/jobs/.""" + job = _make_job_response( + job_type=JobType.BACKTEST, + result={ + "aggregated_metrics": { + "wape_mean": 0.1, + "smape_mean": 0.2, + "mae_mean": 1.5, + "bias_mean": 0.0, + }, + "fold_metrics": [ + {"fold": 1, "sample_size": 7}, + {"fold": 2, "sample_size": 8}, + {"fold": 3, "sample_size": 9}, + ], + }, + ) + shaped = BatchService()._shape_metrics(job) + assert shaped is not None + assert shaped["sample_size"] == 24 # 7+8+9, computed inside the slice + + +# ---------------------------------------------------------------------- picker + + +def test_picker_query_uses_skip_locked() -> None: + """Compile the picker SELECT — must contain ``FOR UPDATE SKIP LOCKED``. + + Load-bearing for downstream-1 (parallel) and downstream-2 (priority): + removing the kwarg lets concurrent workers block on each other. + """ + stmt = ( + select(BatchJobItem) + .where( + BatchJobItem.batch_id == "test", + BatchJobItem.status == BatchItemStatus.PENDING.value, + ) + .order_by( + BatchJobItem.priority.desc(), + BatchJobItem.created_at.asc(), + BatchJobItem.id.asc(), + ) + .limit(1) + .with_for_update(skip_locked=True) + ) + sql = str(stmt.compile(dialect=postgresql.dialect())) # type: ignore[no-untyped-call] + assert "FOR UPDATE SKIP LOCKED" in sql.upper(), sql + + +# ------------------------------------------------------------------ expansion + + +async def test_expand_scope_manual_cartesian() -> None: + """``kind=manual`` produces the full store x product cartesian, no DB.""" + scope = BatchScope.model_validate( + { + "kind": "manual", + "store_ids": [1, 2], + "product_ids": [10, 20, 30], + } + ) + # Manual path never touches the DB — the AsyncMock proves no DB call lands. + db: AsyncMock = AsyncMock() + pairs = await BatchService()._expand_scope(db, scope) + assert pairs == [(1, 10), (1, 20), (1, 30), (2, 10), (2, 20), (2, 30)] + db.execute.assert_not_called() + + +# ------------------------------------------------------------------- frozen + + +def test_frozen_item_params_shape() -> None: + """_frozen_item_params builds a stable per-item JSONB. + + Downstream-3 (export-and-retry) reads from this dict on retry — the shape + must remain compatible. + """ + req = BatchSubmitRequest.model_validate( + { + "operation": "backtest", + "scope": {"kind": "manual", "store_ids": [1], "product_ids": [2]}, + "model_configs": [{"model_type": "naive", "params": {"foo": "bar"}}], + "start_date": "2025-01-01", + "end_date": "2025-06-30", + } + ) + mc = req.model_configs[0] + params = BatchService()._frozen_item_params(req, 1, 2, mc) + assert params == { + "operation": "backtest", + "job_params": { + "model_type": "naive", + "store_id": 1, + "product_id": 2, + "start_date": "2025-01-01", + "end_date": "2025-06-30", + "foo": "bar", + }, + } + + +def test_sort_columns_allow_list_complete() -> None: + """The sort allow-list must cover exactly the four documented keys.""" + from app.features.batch.service import _BATCH_ITEM_SORT_COLUMNS + + assert set(_BATCH_ITEM_SORT_COLUMNS.keys()) == { + "created_at", + "completed_at", + "status", + "priority", + } + + +# `pytest-asyncio` auto-mode (configured in pyproject.toml) picks up the +# async test above without an explicit @pytest.mark.asyncio decorator. +_ = pytest # keep import (some test selectors strip unused imports) diff --git a/app/main.py b/app/main.py index 0558ec6a..eb4f5145 100644 --- a/app/main.py +++ b/app/main.py @@ -16,6 +16,7 @@ from app.features.agents.websocket import router as agents_ws_router from app.features.analytics.routes import router as analytics_router from app.features.backtesting.routes import router as backtesting_router +from app.features.batch.routes import router as batch_router from app.features.config.routes import router as config_router from app.features.config.service import apply_overrides_on_startup from app.features.demo.routes import router as demo_router @@ -138,6 +139,7 @@ def create_app() -> FastAPI: app.include_router(analytics_router) app.include_router(ops_router) app.include_router(jobs_router) + app.include_router(batch_router) app.include_router(ingest_router) app.include_router(featuresets_router) app.include_router(forecasting_router) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 6619ead2..1ef34bf1 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -25,6 +25,7 @@ const ForecastPage = lazy(() => import('@/pages/visualize/forecast')) const BacktestPage = lazy(() => import('@/pages/visualize/backtest')) const DemandPlannerPage = lazy(() => import('@/pages/visualize/demand')) const WhatIfPlannerPage = lazy(() => import('@/pages/visualize/planner')) +const BatchRunnerPage = lazy(() => import('@/pages/visualize/batch')) const ChatPage = lazy(() => import('@/pages/chat')) const KnowledgePage = lazy(() => import('@/pages/knowledge')) const GuidePage = lazy(() => import('@/pages/guide')) @@ -177,6 +178,14 @@ function App() { } /> + }> + + + } + /> + api('/batch/forecasting', { + method: 'POST', + body: data, + }), + onSuccess: () => { + void queryClient.invalidateQueries({ queryKey: ['batch'] }) + }, + }) +} + +// Get a batch's parent record. Polls every 2s while the run is in-flight; +// stops polling once the parent settles to a terminal state. +export function useBatch(batchId: string | null, enabled = true) { + return useQuery({ + queryKey: ['batch', batchId], + queryFn: () => api(`/batch/${batchId}`), + enabled: enabled && !!batchId, + refetchInterval: (query) => { + const status = query.state.data?.status + return status === 'pending' || status === 'running' ? 2000 : false + }, + }) +} + +interface UseBatchItemsParams { + batchId: string | null + page?: number + pageSize?: number + sortBy?: string + sortOrder?: 'asc' | 'desc' + enabled?: boolean +} + +export function useBatchItems({ + batchId, + page = 1, + pageSize = 50, + sortBy, + sortOrder = 'asc', + enabled = true, +}: UseBatchItemsParams) { + return useQuery({ + queryKey: ['batch', batchId, 'items', { page, pageSize, sortBy, sortOrder }], + queryFn: () => + api(`/batch/${batchId}/items`, { + params: { + page, + page_size: pageSize, + sort_by: sortBy, + sort_order: sortOrder, + }, + }), + enabled: enabled && !!batchId, + }) +} diff --git a/frontend/src/lib/constants.ts b/frontend/src/lib/constants.ts index 64f60654..6a6de39f 100644 --- a/frontend/src/lib/constants.ts +++ b/frontend/src/lib/constants.ts @@ -24,6 +24,7 @@ export const ROUTES = { BACKTEST: '/visualize/backtest', DEMAND: '/visualize/demand', PLANNER: '/visualize/planner', + BATCH: '/visualize/batch', }, KNOWLEDGE: '/knowledge', CHAT: '/chat', @@ -53,6 +54,7 @@ export const NAV_ITEMS = [ { label: 'What-If Planner', href: ROUTES.VISUALIZE.PLANNER }, { label: 'Forecast', href: ROUTES.VISUALIZE.FORECAST }, { label: 'Backtest Results', href: ROUTES.VISUALIZE.BACKTEST }, + { label: 'Batch Runner', href: ROUTES.VISUALIZE.BATCH }, ], }, { label: 'Knowledge', href: ROUTES.KNOWLEDGE }, diff --git a/frontend/src/pages/visualize/batch.tsx b/frontend/src/pages/visualize/batch.tsx new file mode 100644 index 00000000..82b682af --- /dev/null +++ b/frontend/src/pages/visualize/batch.tsx @@ -0,0 +1,213 @@ +/** + * Batch Runner — placeholder page (PRP-33 MVP). + * + * Polls the parent batch status while in-flight and renders an items table. + * Per PRP narrowing: NO slider, NO cancel button, NO retry, NO heatmap, NO + * promotion panel — each downstream PRP owns one of those surfaces. + * + * MVP UX: a tiny submit form (manual scope only) + the live items table. + * The form is intentionally minimal — the agent / curl is the canonical + * driver in MVP; this page exists so the work is visible. + */ + +import { useState } from 'react' + +import { ErrorDisplay } from '@/components/common/error-display' +import { LoadingState } from '@/components/common/loading-state' +import { StatusBadge } from '@/components/common/status-badge' +import { Button } from '@/components/ui/button' +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from '@/components/ui/card' +import { Input } from '@/components/ui/input' +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table' +import { useBatch, useBatchItems, useSubmitBatch } from '@/hooks/use-batches' +import type { BatchSubmitRequest } from '@/types/api' + +export default function BatchRunnerPage() { + // Last-submitted batch the page tracks. null = nothing yet. + const [batchId, setBatchId] = useState(null) + + // Minimal submit form state — manual scope only (downstream PRP-26 adds + // region/category/top_revenue/all UIs). + const [storeIds, setStoreIds] = useState('1') + const [productIds, setProductIds] = useState('1,2,3') + const [startDate, setStartDate] = useState('2024-01-01') + const [endDate, setEndDate] = useState('2024-04-29') + + const submit = useSubmitBatch() + const batch = useBatch(batchId) + const items = useBatchItems({ batchId, pageSize: 50 }) + + function handleSubmit(e: React.FormEvent) { + e.preventDefault() + const parseIds = (s: string) => + s + .split(',') + .map((t) => parseInt(t.trim(), 10)) + .filter((n) => !Number.isNaN(n)) + + const payload: BatchSubmitRequest = { + operation: 'backtest', + scope: { + kind: 'manual', + store_ids: parseIds(storeIds), + product_ids: parseIds(productIds), + }, + model_configs: [{ model_type: 'naive', params: {} }], + start_date: startDate, + end_date: endDate, + } + submit.mutate(payload, { + onSuccess: (data) => setBatchId(data.batch_id), + }) + } + + return ( +
+
+

Batch Runner (MVP)

+

+ Submit a portfolio batch and watch its (store, product) items + execute sequentially. This is the PRP-33 placeholder — the + downstream PRPs add cancel, retry, priority, and the + champion/heatmap surface. +

+
+ + + + Submit a manual backtest batch + + Comma-separated IDs; the runner fans out the cartesian product + and backtests each pair using the naive baseline. + + + +
+ + + + +
+ +
+
+ {submit.isError && ( +
+ +
+ )} +
+
+ + {batchId && ( + + + Batch {batchId.slice(0, 8)}… + {batch.data && ( + + Status: ·{' '} + {batch.data.completed_items}/{batch.data.total_items} completed + · {batch.data.failed_items} failed + + )} + + + {items.isLoading ? ( + + ) : items.isError ? ( + + ) : ( + + + + Item + Store + Product + Model + Status + WAPE + Sample size + + + + {items.data?.items.map((item) => ( + + + {item.item_id.slice(0, 8)} + + {item.store_id} + {item.product_id} + {item.model_type} + + + + + {item.metrics?.wape != null + ? item.metrics.wape.toFixed(3) + : '—'} + + + {item.metrics?.sample_size ?? '—'} + + + ))} + {items.data?.items.length === 0 && ( + + + No items yet + + + )} + +
+ )} +
+
+ )} +
+ ) +} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index c38469fb..97a1798f 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -282,6 +282,110 @@ export interface JobCreate { params: Record } +// === Batch (PRP-33) === + +export type BatchOperation = + | 'train' + | 'predict' + | 'backtest' + | 'train_backtest_register' + +export type BatchStatus = + | 'pending' + | 'running' + | 'completed' + | 'failed' + | 'partial' + | 'cancelled' + +export type BatchItemStatus = + | 'pending' + | 'running' + | 'completed' + | 'failed' + | 'cancelled' + +export type BatchScopeKind = + | 'manual' + | 'region' + | 'category' + | 'top_revenue' + | 'all' + +export interface BatchScope { + kind: BatchScopeKind + store_ids?: number[] | null + product_ids?: number[] | null + region?: string | null + category?: string | null + top_n?: number | null +} + +export interface BatchModelConfig { + model_type: + | 'naive' + | 'seasonal_naive' + | 'moving_average' + | 'regression' + | 'lightgbm' + | 'xgboost' + | 'prophet_like' + params?: Record +} + +export interface BatchSubmitRequest { + operation: BatchOperation + scope: BatchScope + model_configs: BatchModelConfig[] + start_date: string + end_date: string + max_parallel?: number + default_child_priority?: number +} + +export interface BatchSubmitResponse { + batch_id: string + operation: BatchOperation + status: BatchStatus + total_items: number + completed_items: number + failed_items: number + running_items: number + cancelled_items: number + started_at: string | null + completed_at: string | null + result_summary: Record | null + created_at: string + updated_at: string +} + +export interface BatchItemResponse { + item_id: string + batch_id: string + store_id: number + product_id: number + model_type: string + status: BatchItemStatus + priority: number + metrics: Record | null + child_job_id: string | null + child_run_id: string | null + error_message: string | null + error_type: string | null + started_at: string | null + completed_at: string | null + duration_ms: number | null + created_at: string + updated_at: string +} + +export interface BatchItemListResponse { + items: BatchItemResponse[] + total: number + page: number + page_size: number +} + // === RAG === export interface RagSource { source_id: string