diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..1d094f9 --- /dev/null +++ b/app/main.py @@ -0,0 +1,269 @@ +""" +TranscriptIQ – FastAPI backend +============================== + +Endpoints +--------- +POST /api/v1/ingest/youtube Enqueue a YouTube processing job +GET /api/v1/job/{job_id} Poll job status +GET /api/v1/summary/{video_id} Retrieve a stored summary +GET /api/v1/compare/{video_id} Compare both model summaries +GET /health Liveness check +""" + +import sys +import os +import logging + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from dotenv import load_dotenv + +load_dotenv() + +from typing import Literal, Optional + +from fastapi import BackgroundTasks, FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, field_validator + +from src.jobs import create_job, get_job, update_job +from src.pipeline import process_youtube_pipeline +from src.processing.summarize import summarize_text +from src.retrieval.rag import build_vector_store +from src.storage import ( + load_summary, + save_index, + save_summary, + save_transcript, + summary_exists, +) +from src.utils import video_id_from_url + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# App setup +# --------------------------------------------------------------------------- + +app = FastAPI( + title="TranscriptIQ API", + description="Backend API for TranscriptIQ – YouTube transcript summarization and RAG", + version="1.0.0", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + + +# --------------------------------------------------------------------------- +# Request / Response models +# --------------------------------------------------------------------------- + + +class IngestYouTubeRequest(BaseModel): + url: str + detail_level: Literal["brief", "medium", "detailed"] = "medium" + model: Literal["bart", "t5", "both"] = "bart" + + @field_validator("url") + @classmethod + def url_must_be_youtube(cls, v: str) -> str: + v = v.strip() + if not v: + raise ValueError("url must not be empty") + # Delegate deep validation to video_id_from_url; raises ValueError on bad input. + try: + video_id_from_url(v) + except ValueError as exc: + raise ValueError(str(exc)) from exc + return v + + +class JobResponse(BaseModel): + job_id: str + + +# --------------------------------------------------------------------------- +# Background job +# --------------------------------------------------------------------------- + +_MODEL_KEY_MAP = { + "bart": "bart-large-cnn", + "t5": "t5-base", +} + + +def run_youtube_job(job_id: str, url: str, detail_level: str, model: str) -> None: + """ + Heavy lifting executed inside FastAPI BackgroundTasks. + + Steps + ----- + 1. Extract / transcribe the YouTube video. + 2. Save the transcript to disk. + 3. Summarise with BART (always) and optionally with T5. + 4. Build and persist the FAISS index. + 5. Update the job record to "done" (or "error" on failure). + """ + try: + video_id = video_id_from_url(url) + except ValueError as exc: + update_job(job_id, status="error", error=str(exc)) + return + + try: + # --- Step 1: transcript ------------------------------------------------ + update_job(job_id, status="running", progress="Extracting transcript…") + text, source, bart_summary, bart_metrics = process_youtube_pipeline( + url, detail_level + ) + + # --- Step 2: persist transcript ---------------------------------------- + update_job(job_id, progress="Saving transcript…") + save_transcript(video_id, text, meta={"source": source, "url": url}) + + # --- Step 3: BART summary (returned by pipeline) ---------------------- + update_job(job_id, progress="Saving BART summary…") + save_summary(video_id, "bart-large-cnn", bart_summary, bart_metrics) + + # --- Step 3b: T5 summary (on request) --------------------------------- + requested_model_key = _MODEL_KEY_MAP.get(model, "bart-large-cnn") + run_t5 = model in ("t5", "both") + if run_t5 and not summary_exists(video_id, "t5-base"): + update_job(job_id, progress="Generating T5 summary…") + t5_summary, t5_metrics = summarize_text( + text, detail_level=detail_level, model_name="t5-base", return_metrics=True + ) + save_summary(video_id, "t5-base", t5_summary, t5_metrics) + + # --- Step 4: FAISS index ----------------------------------------------- + update_job(job_id, progress="Building vector index…") + index, chunks = build_vector_store(text) + if index is not None: + save_index(video_id, index, chunks) + + # --- Step 5: done ------------------------------------------------------- + result = { + "video_id": video_id, + "source": source, + "models_run": ["bart-large-cnn"] + (["t5-base"] if run_t5 else []), + } + update_job(job_id, status="done", progress="Completed", result=result) + + except Exception as exc: + logger.exception("Job %s failed", job_id) + update_job( + job_id, + status="error", + error=f"{type(exc).__name__}: {exc}", + ) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@app.get("/health", tags=["system"]) +def health(): + return {"status": "ok"} + + +@app.post("/api/v1/ingest/youtube", status_code=202, response_model=JobResponse, tags=["ingest"]) +def ingest_youtube( + body: IngestYouTubeRequest, + background_tasks: BackgroundTasks, +): + """Enqueue a YouTube video for transcription and summarisation.""" + job = create_job() + background_tasks.add_task( + run_youtube_job, + job["job_id"], + body.url, + body.detail_level, + body.model, + ) + return {"job_id": job["job_id"]} + + +@app.get("/api/v1/job/{job_id}", tags=["jobs"]) +def get_job_status(job_id: str): + """Return the current status of a background job.""" + record = get_job(job_id) + if record is None: + raise HTTPException(status_code=404, detail=f"Job {job_id!r} not found") + return record + + +@app.get("/api/v1/summary/{video_id}", tags=["summaries"]) +def get_summary(video_id: str, model: str = "bart"): + """ + Return the stored summary for *video_id*. + + Query params + ------------ + model : "bart" (default) | "t5" + """ + model_key = _MODEL_KEY_MAP.get(model, "bart-large-cnn") + data = load_summary(video_id, model_key) + if data is None: + raise HTTPException(status_code=404, detail="Summary not found") + + other_key = "t5-base" if model_key == "bart-large-cnn" else "bart-large-cnn" + comparison_available = summary_exists(video_id, other_key) + + return { + **data, + "video_id": video_id, + "comparison_available": comparison_available, + } + + +@app.get("/api/v1/compare/{video_id}", tags=["summaries"]) +def compare_summaries(video_id: str): + """ + Return both BART and T5 summaries together with lightweight comparison metrics. + + Raises 404 if both summaries are not yet available. + """ + bart_data = load_summary(video_id, "bart-large-cnn") + t5_data = load_summary(video_id, "t5-base") + + if bart_data is None or t5_data is None: + missing = [] + if bart_data is None: + missing.append("bart-large-cnn") + if t5_data is None: + missing.append("t5-base") + raise HTTPException( + status_code=404, + detail=f"Summaries not yet available for model(s): {missing}", + ) + + # Lightweight comparison metrics (word counts, compression ratios) + def _wc(text: str) -> int: + return len(text.split()) + + bart_words = _wc(bart_data["summary"]) + t5_words = _wc(t5_data["summary"]) + + comparison = { + "bart_word_count": bart_words, + "t5_word_count": t5_words, + "bart_compression_ratio": bart_data.get("metrics", {}).get("compression_ratio"), + "t5_compression_ratio": t5_data.get("metrics", {}).get("compression_ratio"), + "bart_processing_time": bart_data.get("metrics", {}).get("processing_time"), + "t5_processing_time": t5_data.get("metrics", {}).get("processing_time"), + } + + return { + "video_id": video_id, + "bart": bart_data, + "t5": t5_data, + "comparison": comparison, + } diff --git a/requirements.txt b/requirements.txt index f397a81..ae15055 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,15 +5,16 @@ torch yt-dlp sentencepiece accelerate -youtube-transcript-api sentence-transformers faiss-cpu gTTS python-dotenv openai +fastapi +uvicorn[standard] +python-multipart pytest pytest-cov pytest-mock -flake8 -python-dotenv -openai \ No newline at end of file +httpx +flake8 \ No newline at end of file diff --git a/src/ingestion/transcribe.py b/src/ingestion/transcribe.py index 652e6d7..418d21f 100644 --- a/src/ingestion/transcribe.py +++ b/src/ingestion/transcribe.py @@ -1,8 +1,21 @@ -import whisper +import threading from config import WHISPER_MODEL -model = whisper.load_model(WHISPER_MODEL) +_model = None +_model_lock = threading.Lock() + + +def get_whisper_model(): + """Thread-safe lazy loader for the Whisper model.""" + global _model + if _model is None: + with _model_lock: + if _model is None: + import whisper # noqa: PLC0415 – intentional lazy import + _model = whisper.load_model(WHISPER_MODEL) + return _model + def transcribe_audio(audio_path): - result = model.transcribe(audio_path) + result = get_whisper_model().transcribe(audio_path) return result["text"] diff --git a/src/jobs.py b/src/jobs.py new file mode 100644 index 0000000..8949c9c --- /dev/null +++ b/src/jobs.py @@ -0,0 +1,110 @@ +""" +File-backed job store for FastAPI BackgroundTasks. + +Jobs are persisted to data/jobs/.json so their status survives +a worker restart (within the same run). Each record is a plain dict: + +{ + "job_id": str, + "status": "pending" | "running" | "done" | "error", + "progress": str, # human-readable status message + "result": dict | None, # set when status == "done" + "error": str | None, # set when status == "error" + "created_at": ISO-8601 str, + "updated_at": ISO-8601 str, +} +""" + +import json +import os +import re +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Optional + +from config import BASE_DIR + +_JOBS_DIR = Path(BASE_DIR) / "data" / "jobs" + +# Job IDs are always UUID4 strings – allow only that pattern. +_JOB_ID_RE = re.compile(r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$") + + +def _validate_job_id(job_id: str) -> None: + """Raise ValueError if *job_id* is not a valid UUID4 string.""" + if not _JOB_ID_RE.match(job_id): + raise ValueError(f"Invalid job_id format: {job_id!r}") + + +def _jobs_dir() -> Path: + _JOBS_DIR.mkdir(parents=True, exist_ok=True) + return _JOBS_DIR + + +def _job_path(job_id: str) -> Path: + _validate_job_id(job_id) + return _jobs_dir() / f"{job_id}.json" + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def create_job() -> Dict[str, Any]: + """Create a new job record, persist it, and return the record.""" + job_id = str(uuid.uuid4()) + record: Dict[str, Any] = { + "job_id": job_id, + "status": "pending", + "progress": "Queued", + "result": None, + "error": None, + "created_at": _now_iso(), + "updated_at": _now_iso(), + } + _save(record) + return record + + +def update_job( + job_id: str, + *, + status: Optional[str] = None, + progress: Optional[str] = None, + result: Optional[Dict[str, Any]] = None, + error: Optional[str] = None, +) -> Dict[str, Any]: + """Update fields on an existing job record and persist the change.""" + record = get_job(job_id) + if record is None: + raise ValueError(f"Job {job_id!r} not found") + if status is not None: + record["status"] = status + if progress is not None: + record["progress"] = progress + if result is not None: + record["result"] = result + if error is not None: + record["error"] = error + record["updated_at"] = _now_iso() + _save(record) + return record + + +def get_job(job_id: str) -> Optional[Dict[str, Any]]: + """Return the job record or None if it does not exist.""" + try: + path = _job_path(job_id) + except ValueError: + return None + if not path.exists(): + return None + with open(path, "r", encoding="utf-8") as fh: + return json.load(fh) + + +def _save(record: Dict[str, Any]) -> None: + path = _job_path(record["job_id"]) + with open(path, "w", encoding="utf-8") as fh: + json.dump(record, fh, indent=2) diff --git a/src/storage.py b/src/storage.py new file mode 100644 index 0000000..82e7ed8 --- /dev/null +++ b/src/storage.py @@ -0,0 +1,170 @@ +""" +Persistent storage helpers for transcripts, summaries, and FAISS indexes. + +Directory layout under data/ (BASE_DIR/data): + + data/ + ├── jobs/ ← job status files (managed by src/jobs.py) + ├── indexes/ + │ ├── .index ← FAISS binary index + │ └── _chunks.json ← serialised chunk list + └── / + ├── transcript.json.gz + └── summary_.json.gz + +All transcript and summary files are gzip-compressed JSON to reduce disk footprint. +""" + +import gzip +import json +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import faiss + +from config import BASE_DIR + +_DATA_DIR = Path(BASE_DIR) / "data" +_INDEX_DIR = _DATA_DIR / "indexes" + +# YouTube video IDs are 11 characters of [A-Za-z0-9_-]. +# We allow the same character set but do not enforce length strictly so that +# future ID formats keep working. Path-traversal characters are forbidden. +_SAFE_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + + +def _validate_id(value: str, label: str = "id") -> None: + """Raise ValueError if *value* contains unsafe path characters.""" + if not _SAFE_ID_RE.match(value): + raise ValueError(f"Unsafe {label}: {value!r}") + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _video_dir(video_id: str) -> Path: + _validate_id(video_id, "video_id") + d = _DATA_DIR / video_id + d.mkdir(parents=True, exist_ok=True) + return d + + +def _index_dir() -> Path: + _INDEX_DIR.mkdir(parents=True, exist_ok=True) + return _INDEX_DIR + + +def _gz_write(path: Path, obj: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with gzip.open(str(path), "wt", encoding="utf-8") as fh: + json.dump(obj, fh) + + +def _gz_read(path: Path) -> Any: + with gzip.open(str(path), "rt", encoding="utf-8") as fh: + return json.load(fh) + + +# --------------------------------------------------------------------------- +# Transcript +# --------------------------------------------------------------------------- + +def save_transcript(video_id: str, text: str, meta: Optional[Dict[str, Any]] = None) -> None: + """Persist transcript text (and optional metadata) for *video_id*.""" + payload = {"text": text, "meta": meta or {}} + path = _video_dir(video_id) / "transcript.json.gz" + _gz_write(path, payload) + + +def load_transcript(video_id: str) -> Optional[Dict[str, Any]]: + """Return ``{"text": ..., "meta": ...}`` or *None* if not found.""" + _validate_id(video_id, "video_id") + path = _DATA_DIR / video_id / "transcript.json.gz" + if not path.exists(): + return None + return _gz_read(path) + + +# --------------------------------------------------------------------------- +# Summaries +# --------------------------------------------------------------------------- + +# Allow only known model keys to prevent path traversal via model_key parameter. +_SAFE_MODEL_RE = re.compile(r"^[A-Za-z0-9_.-]{1,64}$") + + +def _validate_model_key(model_key: str) -> None: + if not _SAFE_MODEL_RE.match(model_key): + raise ValueError(f"Unsafe model_key: {model_key!r}") + + +def _summary_path(video_id: str, model_key: str) -> Path: + _validate_id(video_id, "video_id") + _validate_model_key(model_key) + safe_key = model_key.replace("/", "_").replace("\\", "_") + return _DATA_DIR / video_id / f"summary_{safe_key}.json.gz" + + +def save_summary( + video_id: str, + model_key: str, + summary_text: str, + metrics: Optional[Dict[str, Any]] = None, +) -> None: + """Persist a summary produced by *model_key* for *video_id*.""" + payload = {"summary": summary_text, "metrics": metrics or {}, "model_key": model_key} + path = _summary_path(video_id, model_key) + path.parent.mkdir(parents=True, exist_ok=True) + _gz_write(path, payload) + + +def load_summary(video_id: str, model_key: str) -> Optional[Dict[str, Any]]: + """Return ``{"summary": ..., "metrics": ..., "model_key": ...}`` or *None*.""" + path = _summary_path(video_id, model_key) + if not path.exists(): + return None + return _gz_read(path) + + +def summary_exists(video_id: str, model_key: str) -> bool: + """Return *True* if a summary for this (video, model) pair has been persisted.""" + try: + return _summary_path(video_id, model_key).exists() + except ValueError: + return False + + +# --------------------------------------------------------------------------- +# FAISS index + chunks +# --------------------------------------------------------------------------- + +def save_index(video_id: str, index: Any, chunks: List[str]) -> None: + """Write the FAISS index and corresponding chunk list to disk.""" + _validate_id(video_id, "video_id") + index_path = str(_index_dir() / f"{video_id}.index") + faiss.write_index(index, index_path) + + chunks_path = _index_dir() / f"{video_id}_chunks.json" + with open(str(chunks_path), "w", encoding="utf-8") as fh: + json.dump(chunks, fh) + + +def load_index_and_chunks(video_id: str) -> Tuple[Optional[Any], List[str]]: + """ + Load the FAISS index and chunk list for *video_id*. + + Returns ``(index, chunks)`` where either may be *None* / empty list if not found. + """ + _validate_id(video_id, "video_id") + index_path = _index_dir() / f"{video_id}.index" + chunks_path = _index_dir() / f"{video_id}_chunks.json" + + if not index_path.exists() or not chunks_path.exists(): + return None, [] + + index = faiss.read_index(str(index_path)) + with open(str(chunks_path), "r", encoding="utf-8") as fh: + chunks = json.load(fh) + return index, chunks diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..d143fd3 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,36 @@ +import re +from urllib.parse import urlparse, parse_qs + + +def video_id_from_url(url: str) -> str: + """ + Extract the YouTube video ID from a URL and return it as a canonical storage key. + + Supports the common URL forms: + - https://www.youtube.com/watch?v=VIDEO_ID + - https://youtu.be/VIDEO_ID + - https://www.youtube.com/shorts/VIDEO_ID + - https://www.youtube.com/embed/VIDEO_ID + + Raises ValueError if the video ID cannot be determined. + """ + parsed = urlparse(url) + + # youtu.be short links + if parsed.netloc in ("youtu.be", "www.youtu.be"): + vid = parsed.path.lstrip("/").split("/")[0] + if vid: + return vid + + # Standard watch URL (?v=...) + if parsed.path == "/watch": + qs = parse_qs(parsed.query) + if "v" in qs: + return qs["v"][0] + + # /shorts/, /embed/, /v/ + m = re.match(r"^/(?:shorts|embed|v)/([A-Za-z0-9_-]+)", parsed.path) + if m: + return m.group(1) + + raise ValueError(f"Cannot extract YouTube video ID from URL: {url!r}") diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..4b09882 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,228 @@ +""" +Minimal FastAPI endpoint tests. + +Heavy ML dependencies (transformers, torch, faiss, whisper, sentence_transformers) +are stubbed via sys.modules before any project code is imported, so the suite +runs without installing or loading any ML model. +""" + +import os +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest + +# Ensure the project root is importable +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +# --------------------------------------------------------------------------- +# Stub heavy dependencies before importing project modules +# --------------------------------------------------------------------------- + +def _make_stub(name: str) -> ModuleType: + mod = ModuleType(name) + return mod + + +_HEAVY = [ + "whisper", + "torch", + "faiss", + "transformers", + "sentence_transformers", + "openai", + "gtts", +] + +for _stub in _HEAVY: + if _stub not in sys.modules: + sys.modules[_stub] = _make_stub(_stub) + +# faiss needs write_index / read_index / IndexFlatL2 +_faiss = sys.modules["faiss"] +_faiss.write_index = MagicMock() +_faiss.read_index = MagicMock(return_value=MagicMock()) +_faiss.IndexFlatL2 = MagicMock() + +# transformers needs pipeline +sys.modules["transformers"].pipeline = MagicMock() + +# sentence_transformers needs SentenceTransformer +sys.modules["sentence_transformers"].SentenceTransformer = MagicMock() + +# openai needs OpenAI class +sys.modules["openai"].OpenAI = MagicMock() + +from fastapi.testclient import TestClient # noqa: E402 +from app.main import app # noqa: E402 + +client = TestClient(app) + + +# --------------------------------------------------------------------------- +# /health +# --------------------------------------------------------------------------- + + +def test_health(): + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +# --------------------------------------------------------------------------- +# POST /api/v1/ingest/youtube +# --------------------------------------------------------------------------- + + +def test_ingest_youtube_returns_job_id(tmp_path, monkeypatch): + """Posting a valid YouTube URL should enqueue a job and return a job_id.""" + # Redirect data/jobs to a temp directory so tests don't pollute the repo + monkeypatch.setattr("src.jobs._JOBS_DIR", tmp_path / "jobs") + + # Mock the heavy pipeline functions so the background task completes fast + with patch("app.main.process_youtube_pipeline") as mock_pipeline, \ + patch("app.main.build_vector_store") as mock_vs, \ + patch("app.main.save_transcript"), \ + patch("app.main.save_summary"), \ + patch("app.main.save_index"), \ + patch("app.main.summary_exists", return_value=False): + + mock_pipeline.return_value = ( + "Transcript text", + "YouTube Captions", + "BART summary", + {"compression_ratio": 50.0, "processing_time": 1.0}, + ) + mock_vs.return_value = (MagicMock(), ["chunk1", "chunk2"]) + + response = client.post( + "/api/v1/ingest/youtube", + json={ + "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + "detail_level": "medium", + "model": "bart", + }, + ) + + assert response.status_code == 202 + data = response.json() + assert "job_id" in data + assert isinstance(data["job_id"], str) + assert len(data["job_id"]) > 0 + + +# --------------------------------------------------------------------------- +# GET /api/v1/job/{job_id} +# --------------------------------------------------------------------------- + + +def test_get_job_status_not_found(): + response = client.get("/api/v1/job/nonexistent-job-id") + assert response.status_code == 404 + + +def test_get_job_status_found(tmp_path, monkeypatch): + monkeypatch.setattr("src.jobs._JOBS_DIR", tmp_path / "jobs") + + from src.jobs import create_job, get_job + + job = create_job() + job_id = job["job_id"] + + response = client.get(f"/api/v1/job/{job_id}") + assert response.status_code == 200 + data = response.json() + assert data["job_id"] == job_id + assert data["status"] == "pending" + + +# --------------------------------------------------------------------------- +# GET /api/v1/summary/{video_id} +# --------------------------------------------------------------------------- + + +def test_get_summary_not_found(): + with patch("app.main.load_summary", return_value=None): + response = client.get("/api/v1/summary/some_video_id?model=bart") + assert response.status_code == 404 + + +def test_get_summary_found(): + fake_summary = { + "summary": "This is a test summary.", + "metrics": {"compression_ratio": 60.0}, + "model_key": "bart-large-cnn", + } + # BART summary exists; T5 summary does not yet exist → comparison_available=False + with patch("app.main.load_summary", return_value=fake_summary), \ + patch("app.main.summary_exists", return_value=False): + response = client.get("/api/v1/summary/test_video?model=bart") + + assert response.status_code == 200 + data = response.json() + assert data["summary"] == "This is a test summary." + # T5 summary not present so comparison is not available + assert data["comparison_available"] is False + assert data["video_id"] == "test_video" + + +def test_get_summary_comparison_available(): + """When the other model's summary also exists, comparison_available should be True.""" + fake_summary = { + "summary": "This is a test summary.", + "metrics": {"compression_ratio": 60.0}, + "model_key": "bart-large-cnn", + } + # Both BART (returned by load_summary) and T5 (summary_exists=True) are present + with patch("app.main.load_summary", return_value=fake_summary), \ + patch("app.main.summary_exists", return_value=True): + response = client.get("/api/v1/summary/test_video?model=bart") + + assert response.status_code == 200 + data = response.json() + assert data["comparison_available"] is True + + +# --------------------------------------------------------------------------- +# GET /api/v1/compare/{video_id} +# --------------------------------------------------------------------------- + + +def test_compare_missing_one_model(): + with patch("app.main.load_summary", side_effect=[None, None]): + response = client.get("/api/v1/compare/test_video") + assert response.status_code == 404 + + +def test_compare_both_available(): + bart = { + "summary": "BART summary text.", + "metrics": {"compression_ratio": 55.0, "processing_time": 2.0}, + "model_key": "bart-large-cnn", + } + t5 = { + "summary": "T5 summary text.", + "metrics": {"compression_ratio": 60.0, "processing_time": 1.5}, + "model_key": "t5-base", + } + + def _load_side_effect(video_id, model_key): + if model_key == "bart-large-cnn": + return bart + if model_key == "t5-base": + return t5 + return None + + with patch("app.main.load_summary", side_effect=_load_side_effect): + response = client.get("/api/v1/compare/test_video") + + assert response.status_code == 200 + data = response.json() + assert data["video_id"] == "test_video" + assert "bart" in data + assert "t5" in data + assert "comparison" in data + assert data["comparison"]["bart_word_count"] == 3 + assert data["comparison"]["t5_word_count"] == 3