From 3b5750c83467be01cd3edb8eda3e90dedf91e917 Mon Sep 17 00:00:00 2001 From: 2002yy <15135142681@163.com> Date: Fri, 5 Jun 2026 18:06:39 +0800 Subject: [PATCH] Add RAG knowledge UI and vector backend scaffold --- .env.example | 6 ++ README.md | 25 +++-- docs/INTERVIEW_NOTES.md | 2 +- docs/RAG.md | 45 ++++++-- docs/TECH_STACK.md | 6 +- docs/TESTING.md | 7 +- src/rag/__init__.py | 14 +++ src/rag/backends.py | 120 ++++++++++++++++++++++ src/rag/chroma_backend.py | 144 ++++++++++++++++++++++++++ src/rag/embeddings.py | 21 ++++ src/rag/eval.py | 8 ++ src/rag/service.py | 13 ++- src/ui/rag_panel.py | 205 +++++++++++++++++++++++++++++++++++-- tests/test_rag.py | 57 ++++++++++- tests/test_rag_backends.py | 135 ++++++++++++++++++++++++ 15 files changed, 774 insertions(+), 34 deletions(-) create mode 100644 src/rag/backends.py create mode 100644 src/rag/chroma_backend.py create mode 100644 src/rag/embeddings.py create mode 100644 tests/test_rag_backends.py diff --git a/.env.example b/.env.example index e39ac38..82fc438 100644 --- a/.env.example +++ b/.env.example @@ -64,3 +64,9 @@ DEEPSEEK_MODEL_PRO_NAME=deepseek-v4-pro # 本地 trafilatura/readability/raw 和 Firecrawl 都失败后,是否允许调用 hosted Jina Reader 兜底。 # 默认关闭;开启后会把公开 HTTP(S) URL 发给 https://r.jina.ai/ 读取 Markdown。 # NEWS_ENABLE_JINA_READER=false + +# === RAG 向量后端(默认 local,无需额外依赖)=== +# local 使用当前 deterministic hash-vector prototype;chroma 是可选持久化适配器,需要自行安装 chromadb。 +# RAG_VECTOR_BACKEND=local +# RAG_CHROMA_PATH=logs/chroma +# RAG_CHROMA_COLLECTION=study_agent diff --git a/README.md b/README.md index db30e08..0821573 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
A local AI learning assistant with long-term memory, role-based group chat, @@ -17,7 +17,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **长期记忆**:Markdown memory + safe writer - **上下文分层**:fast / light / deep / archive - **联网搜索**:RSS / News fetch → article extraction → LLM digest → source tracing -- **RAG MVP**:本地 Markdown / TXT / DOCX / PDF 索引、关键词 / 本地向量原型 / hybrid 检索、引用上下文、来源块、Streamlit 检索面板、聊天注入和 FastAPI RAG 接口 +- **RAG MVP**:本地 Markdown / TXT / DOCX / PDF 索引、关键词 / 本地向量原型 / hybrid / backend-vector 检索、引用上下文、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG 接口 - **工程安全**:SSRF protection、detect-secrets、配置模板 - **工程质量**:pytest 测试套件、Ruff、GitHub Actions CI、打包检查 @@ -27,11 +27,11 @@ Study Agent 是一个本地优先的 AI 学习助手,重点不是简单调用 - **Model routing** with fast / light / deep / archive context tiers - **Long-term memory** based on Markdown files and safe-writer persistence - **Web search pipeline**: feed registry → URL safety checks → article extraction → LLM digest → auditable source trace -- **RAG MVP**: local Markdown / TXT / DOCX / PDF indexing, lexical / local vector prototype / hybrid retrieval, citation-first context formatting, source blocks, a Streamlit retrieval panel, optional chat injection, and FastAPI RAG endpoints +- **RAG MVP**: local Markdown / TXT / DOCX / PDF indexing, lexical / local vector prototype / hybrid / backend-vector retrieval, citation-first context formatting, source blocks, a Streamlit retrieval/debug panel, optional chat injection, and FastAPI RAG endpoints - **SSRF protection** for article fetching, **detect-secrets** in CI - **Batched session logging** and multi-layer caching for performance - **Performance budget**: mode-based `max_tokens` bounds on the main chat, WeChat, and news LLM paths -- **265 pytest tests**, Ruff clean, GitHub Actions CI workflow +- **273 pytest tests**, Ruff clean, GitHub Actions CI workflow For a detailed breakdown of the stack and engineering highlights, see [Technical Stack & Engineering Highlights](docs/TECH_STACK.md). @@ -109,7 +109,7 @@ Study Agent 的定位很明确:**一个运行在你本地的、有长期记忆 | **角色群聊** | 四位角色(三月七、刻晴、纳西妲、流萤)群聊讨论,各有独立人设 | | **联网搜索** | Google News + Bing News + RSSHub 多源聚合,页面正文三层提取 | | **来源追溯** | 搜索结果写入群聊记录,可回溯依据 | -| **RAG MVP** | 本地 Markdown / TXT / DOCX / PDF 文档索引,前端面板返回带文件路径、行号、分数和命中词的引用片段,并可注入单人聊天和微信群互动回复;FastAPI 提供 `/health`、`/rag`、`/rag/index`、`/rag/query` | +| **RAG MVP** | 本地 Markdown / TXT / DOCX / PDF 文档索引,前端面板返回带文件路径、行号、分数、命中词和 score breakdown 的引用片段,并可注入单人聊天和微信群互动回复;FastAPI 提供 `/health`、`/rag`、`/rag/index`、`/rag/query` | | **课后总结** | 学习完成后自动总结进展,用户确认后写入记忆 | | **长期记忆** | 学习者画像、进度追踪、项目上下文、当前焦点,多级记忆档案 | | **多 Provider** | 支持 OpenAI / DeepSeek / OpenRouter / SiliconFlow / 本地模型 | @@ -207,6 +207,15 @@ pip-compile requirements-dev.in # 重新锁定开发依赖 参数优先级:代码显式参数 → 任务级环境变量 → 任务默认值 → 全局环境变量 → provider 级环境变量。完整配置见 [`.env.example`](.env.example) 和 [用户指南](USER_GUIDE.md)。 +RAG 向量后端默认使用 `local`,不需要额外服务;可选 `chroma` adapter 需要用户自行安装 `chromadb`: + +```bash +RAG_VECTOR_BACKEND=local +# RAG_VECTOR_BACKEND=chroma +# RAG_CHROMA_PATH=logs/chroma +# RAG_CHROMA_COLLECTION=study_agent +``` + --- ## 项目结构 @@ -234,7 +243,7 @@ pip-compile requirements-dev.in # 重新锁定开发依赖 │ ├── config.py # 全局配置 │ ├── router.py # 路由配置 │ ├── news/ # 新闻聚合链路 -│ ├── rag/ # 本地 RAG MVP:加载、分块、索引、关键词/向量原型检索 +│ ├── rag/ # 本地 RAG MVP:加载、分块、索引、关键词/向量原型/可选后端检索 │ └── ui/ # Streamlit UI 组件 ├── tests/ # pytest 测试套件 ├── docs/ # 设计文档与工程说明 @@ -255,7 +264,7 @@ pip-compile requirements-dev.in # 重新锁定开发依赖 ## 测试 ```bash -pytest tests/ -v # current local baseline: 265 passed +pytest tests/ -v # current local baseline: 273 passed pytest tests/ --cov=src # 覆盖率 ruff check src/ tests/ # linting mypy --explicit-package-bases src/ # CI soft check; may report type debt @@ -299,7 +308,7 @@ CI 通过 GitHub Actions 在 push / pull request 上运行,集成 `pytest`、` - [ ] FastAPI service layer (partial): `/health`, `/rag`, `/rag/index`, `/rag/query` implemented; `/chat` and `/memory` remain planned - [x] RAG MVP: Markdown / TXT / DOCX / PDF loading, chunking, local keyword retrieval, local vector prototype, hybrid retrieval, citation context, source blocks, Streamlit retrieval panel, optional single-chat and WeChat interactive injection -- [ ] RAG document QA (partial): PDF parsing has file-size, page-count, extracted-text and encrypted-file guards; embedding model retrieval remains planned +- [ ] RAG document QA (partial): PDF parsing has file-size, page-count, extracted-text and encrypted-file guards; Chroma adapter scaffold exists; production embedding model retrieval remains planned - [ ] Vector store: FAISS local prototype, pgvector engineering version - [ ] Web UI: TypeScript + Vue3 / React, streaming chat, source panel - [ ] Observability: trace_id, token usage, latency, provider fallback logs diff --git a/docs/INTERVIEW_NOTES.md b/docs/INTERVIEW_NOTES.md index 30b9253..8a4a3bb 100644 --- a/docs/INTERVIEW_NOTES.md +++ b/docs/INTERVIEW_NOTES.md @@ -10,7 +10,7 @@ Study Agent 是一个本地优先的 AI 学习助手,重点在多 Provider 模 2. **长期记忆写入安全** — safe writer + preview/confirm 机制,防止不可逆的记忆污染 3. **联网搜索来源追溯** — Feed registry / RSS 多源聚合 → URL safety matrix → 文章正文三层提取 → LLM digest → pipeline trace 全过程来源可回溯 4. **Streamlit 重渲染性能优化** — 多层缓存策略、按模式批量落盘、主链路 token 预算控制 -5. **CI / Ruff / detect-secrets 工程检查** — 265 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 +5. **CI / Ruff / detect-secrets 工程检查** — 273 pytest tests、Ruff clean、GitHub Actions workflow、detect-secrets 对未豁免发现硬阻断 ## 可讲亮点 diff --git a/docs/RAG.md b/docs/RAG.md index 2cbe085..9729ce4 100644 --- a/docs/RAG.md +++ b/docs/RAG.md @@ -20,11 +20,13 @@ Implemented: - Optional single-chat and WeChat interactive reply injection through the `用于聊天回答` toggle - UI source blocks for retrieved file paths, line ranges, scores and matched terms - FastAPI endpoints: `GET /health`, `POST /rag`, `POST /rag/index`, `POST /rag/query` +- Streamlit knowledge/debug panel with index summary, document rows, chunk preview and score breakdowns +- Optional vector backend interface with local fallback and Chroma adapter scaffold Not implemented yet: -- Embedding model integration -- FAISS, pgvector or other vector stores +- Production embedding model integration +- FAISS, pgvector or managed vector stores - Automatic injection into every generation path; current injection covers single chat and WeChat interactive replies, but not news discussion or after-session feedback ## Module Map @@ -34,6 +36,9 @@ Not implemented yet: | `src/rag/loader.py` | Load supported local files into normalized `RagDocument` objects | | `src/rag/chunker.py` | Split documents into line-traceable `RagChunk` objects | | `src/rag/index.py` | Build, save, load and search a local JSON RAG index | +| `src/rag/embeddings.py` | Embedding provider contract and local hash embedding provider | +| `src/rag/backends.py` | Vector backend contract, local backend and environment-driven backend selection | +| `src/rag/chroma_backend.py` | Optional Chroma persistent backend adapter scaffold | | `src/rag/vector.py` | Deterministic local vector prototype and hybrid retrieval | | `src/rag/eval.py` | LLM-free retrieval quality evaluation over gold query fixtures | | `src/rag/service.py` | Application-facing helpers for indexing, querying and context formatting | @@ -62,6 +67,7 @@ Supported retrieval modes: - `lexical`: TF-IDF-style term scoring - `vector`: deterministic local hash-vector cosine similarity - `hybrid`: normalized lexical score plus vector similarity +- `backend_vector`: configured vector backend; defaults to local and can use the optional Chroma adapter Each result keeps: @@ -123,6 +129,22 @@ P4-B adds API/query diagnostics: - Per-result rank, chunk id, source path, matched terms and score breakdown - Optional one-query evaluation when `/rag/query` receives `expected_sources` +P4-C / P6 adds Streamlit inspection controls: + +- Current index path, document count and chunk count +- Indexed document table with file type, size, mtime, hash prefix and chunk count +- Chunk preview table with line range, character count and source path +- Retrieval controls for mode, `top_k`, `min_score` and debug visibility +- Score-breakdown table for retrieved chunks + +P5 adds the first vector-backend abstraction: + +- `EmbeddingProvider` protocol plus `LocalHashEmbeddingProvider` +- `VectorBackend` protocol plus `LocalVectorBackend` +- `RAG_VECTOR_BACKEND=local|chroma` +- Optional `ChromaVectorBackend` using lazy `chromadb` import, `PersistentClient`, collection `upsert` and vector query +- `tests/test_rag_backends.py` verifies local backend behavior, environment config and Chroma fake-client upsert/query behavior + ## Next Steps ### P4: Retrieval Quality Loop @@ -132,26 +154,27 @@ Goal: prove retrieval quality before expanding the stack. - [x] Add a small gold fixture set with queries, expected sources and expected terms. - [x] Track `recall@k`, mean reciprocal rank, source hit rate and empty-result rate. - [x] Surface retrieval debug data in tests and API responses before adding more UI polish. -- [ ] Add a Streamlit source/debug panel for inspecting score breakdowns. +- [x] Add a Streamlit source/debug panel for inspecting score breakdowns. - Keep the first evaluation layer LLM-free so CI can catch retrieval regressions deterministically. ### P5: Real Embedding Backend Goal: replace the local hash-vector prototype with optional real embeddings without breaking local-first defaults. -- Extract a retriever / vector-backend contract. -- Keep JSON + lexical / hybrid retrieval as the zero-infrastructure fallback. -- Add one optional backend first, likely Qdrant or Chroma; defer FAISS if Windows install friction is high. -- Make embedding provider selection explicit through config. +- [x] Extract an embedding-provider and vector-backend contract. +- [x] Keep JSON + lexical / hybrid retrieval as the zero-infrastructure fallback. +- [x] Add an optional Chroma adapter scaffold with lazy import and fake-client tests. +- [x] Make vector backend selection explicit through config. +- [ ] Add a production embedding provider; current Chroma adapter uses the local hash embedding provider by default. ### P6: Knowledge UI Goal: turn the Streamlit expander into a usable knowledge panel. -- List indexed documents with chunk count, mtime, hash and status. -- Add query debugging controls for mode, `top_k`, threshold and score preview. -- Add source preview with title, path, page or line range and matched terms. -- Add per-chat RAG scope selection instead of one global toggle only. +- [x] List indexed documents with chunk count, mtime, hash and status. +- [x] Add query debugging controls for mode, `top_k`, threshold and score preview. +- [x] Add source preview with title, path, page or line range and matched terms. +- [ ] Add per-chat RAG scope selection instead of one global toggle only. ### P7: Agentic RAG diff --git a/docs/TECH_STACK.md b/docs/TECH_STACK.md index 9354355..4c10217 100644 --- a/docs/TECH_STACK.md +++ b/docs/TECH_STACK.md @@ -35,7 +35,7 @@ Study Agent 是一个本地运行的 AI 学习助理系统,面向个人学习 | Long-term Memory | Markdown files | 用 `summary.md`、`current_focus.md`、`learner_profile.md` 等文件保存长期记忆 | | Context Control | fast / light / deep / archive tiers | 按性能模式选择不同记忆文件组,控制 token 成本 | | Routing | Rule-based router + optional LLM router | 根据任务类型、用户选择和性能模式决定角色、学习模式和模型档位 | -| RAG MVP | `src/rag/*`, `src/ui/rag_panel.py`, `src/api.py`, JSON index | 本地 Markdown / TXT / DOCX / PDF 加载、分块、关键词 / 本地向量原型 / hybrid 检索、引用上下文拼装、来源块、Streamlit 检索面板、聊天注入和 FastAPI RAG endpoints | +| RAG MVP | `src/rag/*`, `src/ui/rag_panel.py`, `src/api.py`, JSON index | 本地 Markdown / TXT / DOCX / PDF 加载、分块、关键词 / 本地向量原型 / hybrid / backend-vector 检索、引用上下文拼装、来源块、Streamlit 检索/调试面板、聊天注入和 FastAPI RAG endpoints | | News Search | Feed registry / RSS / Google News / Bing News / RSSHub-style sources | 多源新闻聚合、源健康记录、去重、排序、来源追溯 | | Article Extraction | `trafilatura`, `readability-lxml`, `lxml` | 新闻网页正文读取与降级解析 | | Security | URL safety matrix, SSRF validation, redirect checks, secret scanning | 防止读取本地/内网资源,降低密钥误提交风险 | @@ -273,16 +273,18 @@ User query - 带 `source_path`、标题、chunk 序号和行号范围的分块 - 本地关键词 / TF-IDF-style 检索 - deterministic hash-vector 本地向量原型与 hybrid 检索模式 +- `EmbeddingProvider` / `VectorBackend` 抽象,默认 local backend,可选 Chroma adapter scaffold - 简单中文 CJK bigram 匹配 - JSON index 保存与加载,默认路径为 `logs/rag_index.json` - `build_rag_context()` 将检索结果拼装为带引用的 LLM 上下文块 - Streamlit `本地资料检索` 面板支持上传资料、输入本地路径、建立索引、检索和查看引用上下文 +- Streamlit 面板显示当前索引、文档列表、chunk preview、检索参数和 score breakdown - 单人聊天和微信群互动回复可通过 `用于聊天回答` 开关把检索结果注入 system prompt,并显示 RAG 引用来源块 - FastAPI `GET /health`、`POST /rag`、`POST /rag/index`、`POST /rag/query` 未实现边界: -- 尚未接入 embedding model、FAISS、pgvector 或其他生产向量库 +- 尚未接入生产 embedding model、FAISS、pgvector 或其他生产向量库;Chroma 目前是 optional adapter scaffold - FastAPI 目前覆盖 health 和 RAG;`/chat`、`/memory` 仍是后续服务化任务 - 尚未自动注入所有生成路径;当前覆盖单人聊天和微信群互动回复,不覆盖新闻讨论或课后反馈 diff --git a/docs/TESTING.md b/docs/TESTING.md index ad0e41d..23f670d 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -6,7 +6,7 @@ Current verified baseline: | Check | Status | Evidence | |---|---|---| -| pytest | Passed | `265 passed` locally on 2026-06-05 | +| pytest | Passed | `273 passed` locally on 2026-06-05 | | Ruff | Passed | `python -m ruff check .` clean locally on 2026-06-04 | | Package helper | Passed | `python tools/package_project_helper.py . NUL 0` locally on 2026-06-04 | | mypy | Soft check, not clean | `python -m mypy --explicit-package-bases src/` reported 18 errors locally on 2026-06-04 | @@ -24,8 +24,9 @@ Current verified baseline: | **News URL safety** | `test_url_normalizer.py`, `test_link_resolver.py` | 28 | | **News pipeline trace / audit** | `test_news_pipeline_trace.py`, `test_news_audit.py` | 5 | | **Feed registry / health** | `test_feed_registry.py`, `test_feed_diagnostics.py` | 9 | -| **RAG MVP** | `test_rag.py` | 22 | +| **RAG MVP** | `test_rag.py` | 24 | | **RAG evaluation** | `test_rag_eval.py` | 5 | +| **RAG vector backends** | `test_rag_backends.py` | 6 | | **FastAPI RAG endpoints** | `test_api.py` | 6 | | **Architecture flows** | `test_architecture_flows.py` | 12 | | **WeChat decoupling** | `test_wechat_decoupling.py` | 4 | @@ -76,7 +77,7 @@ def test_flush_uses_safe_writer(): ## Running Tests ```bash -python -m pytest # current baseline: 265 passed +python -m pytest # current baseline: 273 passed pytest tests/ -v # Verbose pytest tests/ --cov=src # Coverage python -m ruff check . # Linting diff --git a/src/rag/__init__.py b/src/rag/__init__.py index aa4f200..1b3cecf 100644 --- a/src/rag/__init__.py +++ b/src/rag/__init__.py @@ -6,6 +6,14 @@ save_rag_index, search_rag_index, ) +from src.rag.backends import ( + LocalVectorBackend, + VectorBackendStatus, + get_vector_backend, + get_vector_backend_from_env, + vector_backend_config_from_env, +) +from src.rag.embeddings import LocalHashEmbeddingProvider from src.rag.eval import ( RagEvalCase, RagEvalResult, @@ -39,6 +47,10 @@ "evaluate_rag_index", "format_rag_sources", "index_documents", + "get_vector_backend", + "get_vector_backend_from_env", + "LocalHashEmbeddingProvider", + "LocalVectorBackend", "load_eval_cases", "load_rag_index", "query_documents", @@ -50,4 +62,6 @@ "search_rag_index", "search_rag_index_vector", "search_documents", + "VectorBackendStatus", + "vector_backend_config_from_env", ] diff --git a/src/rag/backends.py b/src/rag/backends.py new file mode 100644 index 0000000..d7abc8e --- /dev/null +++ b/src/rag/backends.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol + +from src.rag.embeddings import EmbeddingProvider, LocalHashEmbeddingProvider +from src.rag.schema import RagIndex, RagSearchResult +from src.rag.vector import search_rag_index_vector + + +@dataclass(frozen=True) +class VectorBackendStatus: + name: str + available: bool + detail: str + path: str = "" + collection: str = "" + embedding_provider: str = "" + + def to_dict(self) -> dict[str, str | bool]: + return { + "name": self.name, + "available": self.available, + "detail": self.detail, + "path": self.path, + "collection": self.collection, + "embedding_provider": self.embedding_provider, + } + + +class VectorBackend(Protocol): + name: str + + def status(self) -> VectorBackendStatus: + """Return backend availability and configuration details.""" + + def upsert_index(self, index: RagIndex) -> None: + """Persist or refresh index chunks in the backend.""" + + def query( + self, + index: RagIndex, + query: str, + *, + top_k: int = 5, + min_score: float = 0.05, + ) -> list[RagSearchResult]: + """Return vector search results.""" + + +class LocalVectorBackend: + name = "local" + + def __init__(self, embedding_provider: EmbeddingProvider | None = None) -> None: + self.embedding_provider = embedding_provider or LocalHashEmbeddingProvider() + + def status(self) -> VectorBackendStatus: + return VectorBackendStatus( + name=self.name, + available=True, + detail="In-memory deterministic local vector prototype", + embedding_provider=self.embedding_provider.name, + ) + + def upsert_index(self, index: RagIndex) -> None: + _ = index + + def query( + self, + index: RagIndex, + query: str, + *, + top_k: int = 5, + min_score: float = 0.05, + ) -> list[RagSearchResult]: + return search_rag_index_vector(index, query, top_k=top_k, min_score=min_score) + + +def vector_backend_config_from_env() -> dict[str, str]: + return { + "name": os.getenv("RAG_VECTOR_BACKEND", "local").strip() or "local", + "path": os.getenv("RAG_CHROMA_PATH", "logs/chroma").strip() or "logs/chroma", + "collection": os.getenv("RAG_CHROMA_COLLECTION", "study_agent").strip() or "study_agent", + } + + +def get_vector_backend( + name: str = "local", + *, + path: str | Path = "logs/chroma", + collection: str = "study_agent", + embedding_provider: EmbeddingProvider | None = None, +) -> VectorBackend: + normalized = (name or "local").strip().lower() + if normalized == "local": + return LocalVectorBackend(embedding_provider=embedding_provider) + if normalized == "chroma": + from src.rag.chroma_backend import ChromaVectorBackend + + return ChromaVectorBackend( + path=path, + collection_name=collection, + embedding_provider=embedding_provider, + ) + raise ValueError(f"Unsupported vector backend: {name}") + + +def get_vector_backend_from_env( + *, + embedding_provider: EmbeddingProvider | None = None, +) -> VectorBackend: + config = vector_backend_config_from_env() + return get_vector_backend( + config["name"], + path=config["path"], + collection=config["collection"], + embedding_provider=embedding_provider, + ) diff --git a/src/rag/chroma_backend.py b/src/rag/chroma_backend.py new file mode 100644 index 0000000..4323039 --- /dev/null +++ b/src/rag/chroma_backend.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from src.rag.backends import VectorBackendStatus +from src.rag.embeddings import EmbeddingProvider, LocalHashEmbeddingProvider +from src.rag.schema import RagChunk, RagIndex, RagSearchResult + + +class ChromaVectorBackend: + name = "chroma" + + def __init__( + self, + *, + path: str | Path = "logs/chroma", + collection_name: str = "study_agent", + embedding_provider: EmbeddingProvider | None = None, + client: Any | None = None, + ) -> None: + self.path = Path(path) + self.collection_name = collection_name + self.embedding_provider = embedding_provider or LocalHashEmbeddingProvider() + self._client = client + + def _get_client(self) -> Any: + if self._client is not None: + return self._client + try: + import chromadb # type: ignore[import-not-found] + except Exception as exc: + raise RuntimeError("chromadb is required for the Chroma vector backend") from exc + self._client = chromadb.PersistentClient(path=str(self.path)) + return self._client + + def _collection(self) -> Any: + return self._get_client().get_or_create_collection(name=self.collection_name) + + def status(self) -> VectorBackendStatus: + try: + self._get_client() + except RuntimeError as exc: + return VectorBackendStatus( + name=self.name, + available=False, + detail=str(exc), + path=str(self.path), + collection=self.collection_name, + embedding_provider=self.embedding_provider.name, + ) + return VectorBackendStatus( + name=self.name, + available=True, + detail="Chroma persistent vector backend", + path=str(self.path), + collection=self.collection_name, + embedding_provider=self.embedding_provider.name, + ) + + def upsert_index(self, index: RagIndex) -> None: + collection = self._collection() + ids = [chunk.chunk_id for chunk in index.chunks] + if not ids: + return + + collection.upsert( + ids=ids, + embeddings=[list(self.embedding_provider.embed(chunk.text)) for chunk in index.chunks], + documents=[chunk.text for chunk in index.chunks], + metadatas=[_chunk_metadata(chunk) for chunk in index.chunks], + ) + + def query( + self, + index: RagIndex, + query: str, + *, + top_k: int = 5, + min_score: float = 0.05, + ) -> list[RagSearchResult]: + _ = index + if top_k <= 0: + return [] + response = self._collection().query( + query_embeddings=[list(self.embedding_provider.embed(query))], + n_results=top_k, + ) + ids = _first(response.get("ids", [])) + distances = _first(response.get("distances", [])) + documents = _first(response.get("documents", [])) + metadatas = _first(response.get("metadatas", [])) + + results: list[RagSearchResult] = [] + for item_id, distance, text, metadata in zip( + ids, + distances, + documents, + metadatas, + strict=False, + ): + score = max(0.0, 1.0 - float(distance)) + if score < min_score: + continue + chunk = _chunk_from_metadata(str(item_id), str(text), dict(metadata or {})) + results.append(RagSearchResult(chunk=chunk, score=round(score, 6), matched_terms=())) + return results + + +def _first(value: list) -> list: + if not value: + return [] + first = value[0] + return first if isinstance(first, list) else value + + +def _chunk_metadata(chunk: RagChunk) -> dict[str, str | int]: + return { + "document_hash": chunk.document_hash, + "source_path": chunk.source_path, + "title": chunk.title, + "chunk_index": chunk.chunk_index, + "start_line": chunk.start_line, + "end_line": chunk.end_line, + "file_type": str(chunk.metadata.get("file_type", "")), + "char_count": int(chunk.metadata.get("char_count", len(chunk.text))), + } + + +def _chunk_from_metadata(chunk_id: str, text: str, metadata: dict) -> RagChunk: + return RagChunk( + chunk_id=chunk_id, + document_hash=str(metadata.get("document_hash", "")), + source_path=str(metadata.get("source_path", "")), + title=str(metadata.get("title", "")), + text=text, + chunk_index=int(metadata.get("chunk_index", 0)), + start_line=int(metadata.get("start_line", 1)), + end_line=int(metadata.get("end_line", 1)), + metadata={ + "file_type": metadata.get("file_type", ""), + "char_count": int(metadata.get("char_count", len(text))), + }, + ) diff --git a/src/rag/embeddings.py b/src/rag/embeddings.py new file mode 100644 index 0000000..eab2c18 --- /dev/null +++ b/src/rag/embeddings.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Protocol + +from src.rag.vector import VECTOR_DIMENSIONS, embed_text + + +class EmbeddingProvider(Protocol): + name: str + dimensions: int + + def embed(self, text: str) -> tuple[float, ...]: + """Embed one text string for vector retrieval.""" + + +class LocalHashEmbeddingProvider: + name = "local_hash" + dimensions = VECTOR_DIMENSIONS + + def embed(self, text: str) -> tuple[float, ...]: + return embed_text(text, dimensions=self.dimensions) diff --git a/src/rag/eval.py b/src/rag/eval.py index 3f53916..8733f9e 100644 --- a/src/rag/eval.py +++ b/src/rag/eval.py @@ -6,6 +6,7 @@ from typing import Any from src.rag.index import search_rag_index +from src.rag.backends import get_vector_backend_from_env from src.rag.schema import RagIndex, RagSearchResult from src.rag.vector import search_rag_index_hybrid, search_rag_index_vector @@ -81,6 +82,13 @@ def _search( return search_rag_index_vector(index, case.query, top_k=case.top_k, min_score=min_score) if case.retrieval_mode == "hybrid": return search_rag_index_hybrid(index, case.query, top_k=case.top_k, min_score=min_score) + if case.retrieval_mode == "backend_vector": + return get_vector_backend_from_env().query( + index, + case.query, + top_k=case.top_k, + min_score=min_score, + ) raise ValueError(f"Unsupported RAG retrieval mode: {case.retrieval_mode}") diff --git a/src/rag/service.py b/src/rag/service.py index 741e23d..37ec676 100644 --- a/src/rag/service.py +++ b/src/rag/service.py @@ -14,6 +14,7 @@ save_rag_index, search_rag_index, ) +from src.rag.backends import get_vector_backend_from_env from src.rag.schema import RagIndex, RagSearchResult from src.rag.vector import ( cosine_similarity, @@ -22,7 +23,7 @@ search_rag_index_vector, ) -RETRIEVAL_MODES = {"lexical", "vector", "hybrid"} +RETRIEVAL_MODES = {"lexical", "vector", "hybrid", "backend_vector"} HYBRID_LEXICAL_WEIGHT = 0.7 @@ -40,6 +41,7 @@ def index_documents( overlap_chars=overlap_chars, ) save_rag_index(index, index_path) + get_vector_backend_from_env().upsert_index(index) return index @@ -83,6 +85,13 @@ def search_documents( min_score=min_score, lexical_weight=HYBRID_LEXICAL_WEIGHT, ) + if retrieval_mode == "backend_vector": + return get_vector_backend_from_env().query( + index, + query, + top_k=top_k, + min_score=min_score, + ) raise ValueError(f"Unsupported RAG retrieval mode: {retrieval_mode}") @@ -123,6 +132,8 @@ def _score_breakdown( return {"lexical_score": round(lexical_score, 6)} if retrieval_mode == "vector": return {"vector_score": round(vector_score, 6)} + if retrieval_mode == "backend_vector": + return {"backend_score": round(result.score, 6)} if retrieval_mode == "hybrid": return { "lexical_weight": HYBRID_LEXICAL_WEIGHT, diff --git a/src/ui/rag_panel.py b/src/ui/rag_panel.py index f3d1aae..2adb74e 100644 --- a/src/ui/rag_panel.py +++ b/src/ui/rag_panel.py @@ -5,11 +5,22 @@ import streamlit as st -from src.rag import build_rag_context, format_rag_sources, index_documents, query_documents +from src.rag import ( + build_rag_context, + build_rag_debug, + format_rag_sources, + get_vector_backend_from_env, + index_documents, + load_rag_index, + search_documents, +) +from src.rag.index import DEFAULT_RAG_INDEX_PATH +from src.rag.schema import RagIndex ROOT = Path(__file__).resolve().parent.parent.parent RAG_UPLOAD_DIR = ROOT / "logs" / "rag_uploads" UPLOAD_TYPES = ["md", "markdown", "txt", "docx", "pdf"] +RETRIEVAL_MODES = ["lexical", "hybrid", "vector", "backend_vector"] def sanitize_upload_name(name: str) -> str: @@ -44,6 +55,81 @@ def _collect_document_paths(uploaded_files: list, path_text: str) -> list[Path]: return paths +def summarize_rag_index(index: RagIndex) -> dict: + chunk_counts: dict[str, int] = {} + for chunk in index.chunks: + chunk_counts[chunk.document_hash] = chunk_counts.get(chunk.document_hash, 0) + 1 + + documents = [] + for document in index.documents: + documents.append( + { + "title": document.title, + "file_type": document.file_type, + "source_path": document.source_path, + "size_bytes": int(document.metadata.get("size_bytes", 0)), + "mtime_ns": int(document.metadata.get("mtime_ns", 0)), + "content_hash": document.content_hash[:8], + "chunk_count": chunk_counts.get(document.content_hash, 0), + } + ) + + return { + "documents": len(index.documents), + "chunks": len(index.chunks), + "document_rows": documents, + } + + +def chunk_preview_rows(index: RagIndex, *, max_rows: int = 12) -> list[dict]: + rows: list[dict] = [] + for chunk in index.chunks[:max_rows]: + preview = " ".join(chunk.text.split()) + rows.append( + { + "title": chunk.title, + "lines": f"L{chunk.start_line}-L{chunk.end_line}", + "chars": int(chunk.metadata.get("char_count", len(chunk.text))), + "source_path": chunk.source_path, + "preview": preview[:180] + ("..." if len(preview) > 180 else ""), + } + ) + return rows + + +def format_rag_debug_summary(debug: dict) -> str: + if not debug: + return "" + terms = ", ".join(sorted(debug.get("query_terms", []))) or "-" + return ( + f"mode={debug.get('retrieval_mode', '-')}; " + f"top_k={debug.get('top_k', '-')}; " + f"min_score={debug.get('min_score', '-')}; " + f"candidates={debug.get('candidate_count', 0)}; " + f"returned={debug.get('returned_count', 0)}; " + f"terms={terms}" + ) + + +def format_score_breakdown(result_debug: dict) -> str: + breakdown = result_debug.get("score_breakdown", {}) + if not breakdown: + return "" + parts = [] + for key in [ + "lexical_weight", + "lexical_score", + "lexical_normalized", + "vector_score", + "combined_score", + "backend_score", + ]: + if key in breakdown: + value = breakdown[key] + parts.append(f"{key}={value:.3f}" if isinstance(value, float) else f"{key}={value}") + return "; ".join(parts) + + def _render_result_cards(results) -> None: for index, result in enumerate(results, start=1): chunk = result.chunk @@ -57,8 +143,81 @@ def _render_result_cards(results) -> None: st.markdown(chunk.text) +def _load_current_index() -> RagIndex | None: + try: + return load_rag_index() + except FileNotFoundError: + return None + + +def _render_index_overview(index: RagIndex | None) -> None: + if index is None: + st.caption(f"当前索引:{DEFAULT_RAG_INDEX_PATH}(尚未建立)") + return + + summary = summarize_rag_index(index) + st.caption( + f"当前索引:{DEFAULT_RAG_INDEX_PATH} · " + f"{summary['documents']} documents / {summary['chunks']} chunks" + ) + + with st.expander("已索引资料", expanded=False): + document_rows = summary["document_rows"] + if document_rows: + st.dataframe(document_rows, use_container_width=True, hide_index=True) + else: + st.caption("索引中没有文档。") + + with st.expander("Chunk 预览", expanded=False): + rows = chunk_preview_rows(index) + if rows: + st.dataframe(rows, use_container_width=True, hide_index=True) + else: + st.caption("索引中没有 chunk。") + + +def _render_rag_debug(debug: dict) -> None: + if not debug: + return + with st.expander("检索调试", expanded=False): + st.caption(format_rag_debug_summary(debug)) + rows = [] + for item in debug.get("results", []): + rows.append( + { + "rank": item.get("rank"), + "title": item.get("title"), + "score": item.get("score"), + "matched": ", ".join(item.get("matched_terms", [])) or "-", + "breakdown": format_score_breakdown(item), + "source_path": item.get("source_path"), + } + ) + if rows: + st.dataframe(rows, use_container_width=True, hide_index=True) + + +def _render_vector_backend_status() -> None: + try: + status = get_vector_backend_from_env().status() + except ValueError as exc: + st.caption(f"Vector backend: unavailable ({exc})") + return + availability = "available" if status.available else "unavailable" + detail = f" · {status.detail}" if status.detail else "" + location = f" · {status.path}/{status.collection}" if status.path or status.collection else "" + st.caption( + f"Vector backend: {status.name} / {status.embedding_provider} " + f"({availability}){location}{detail}" + ) + + def render_rag_panel() -> None: with st.expander("本地资料检索", expanded=False): + current_index = _load_current_index() + _render_index_overview(current_index) + _render_vector_backend_status() + uploaded_files = st.file_uploader( "上传资料", type=UPLOAD_TYPES, @@ -83,6 +242,7 @@ def render_rag_panel() -> None: st.session_state.rag_context = "" st.session_state.rag_source_block = "" st.session_state.rag_index_summary = "" + st.session_state.rag_debug = {} if build_clicked: paths = _collect_document_paths(uploaded_files or [], path_text) @@ -95,6 +255,8 @@ def render_rag_panel() -> None: st.session_state.rag_index_summary = ( f"{len(index.documents)} documents / {len(index.chunks)} chunks" ) + st.session_state.rag_debug = {} + current_index = index st.success("索引已更新。") except Exception as exc: st.warning(f"索引失败:{exc}") @@ -121,17 +283,17 @@ def render_rag_panel() -> None: st.selectbox( "检索模式", - options=["lexical", "hybrid", "vector"], + options=RETRIEVAL_MODES, format_func=lambda value: { "lexical": "关键词", "hybrid": "混合", "vector": "本地向量", + "backend_vector": "向量后端", }.get(value, value), - index=["lexical", "hybrid", "vector"].index( + index=RETRIEVAL_MODES.index( st.session_state.get("rag_retrieval_mode", "hybrid") ) - if st.session_state.get("rag_retrieval_mode", "hybrid") - in {"lexical", "hybrid", "vector"} + if st.session_state.get("rag_retrieval_mode", "hybrid") in set(RETRIEVAL_MODES) else 1, key="rag_retrieval_mode", ) @@ -142,19 +304,46 @@ def render_rag_panel() -> None: with query_cols[1]: top_k = st.number_input("条数", min_value=1, max_value=8, value=3, step=1, key="rag_top_k") + min_score = st.number_input( + "最低分", + min_value=0.0, + max_value=10.0, + value=float(st.session_state.get("rag_min_score", 0.01)), + step=0.01, + key="rag_min_score", + ) + st.checkbox( + "显示调试", + value=st.session_state.get("rag_debug_enabled", True), + key="rag_debug_enabled", + ) + if st.button("检索", key="rag_search", use_container_width=True): if not query.strip(): st.warning("先输入检索问题。") else: try: - results = query_documents( + index = load_rag_index() + retrieval_mode = st.session_state.get("rag_retrieval_mode", "hybrid") + results = search_documents( + index, + query, + top_k=int(top_k), + min_score=float(min_score), + retrieval_mode=retrieval_mode, + ) + debug = build_rag_debug( + index, query, + results, + retrieval_mode=retrieval_mode, top_k=int(top_k), - retrieval_mode=st.session_state.get("rag_retrieval_mode", "hybrid"), + min_score=float(min_score), ) st.session_state.rag_results = results st.session_state.rag_context = build_rag_context(results) st.session_state.rag_source_block = format_rag_sources(results) + st.session_state.rag_debug = debug except FileNotFoundError: st.warning("先建立索引。") except Exception as exc: @@ -167,3 +356,5 @@ def render_rag_panel() -> None: st.code(st.session_state.get("rag_source_block", ""), language="text") with st.expander("引用上下文", expanded=False): st.code(st.session_state.get("rag_context", ""), language="text") + if st.session_state.get("rag_debug_enabled", True): + _render_rag_debug(st.session_state.get("rag_debug", {})) diff --git a/tests/test_rag.py b/tests/test_rag.py index 7d1a292..4193dd6 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -17,7 +17,14 @@ from src.rag.index import build_rag_index, load_rag_index, search_rag_index from src.rag.loader import load_document from src.rag.vector import cosine_similarity, embed_text, search_rag_index_hybrid, search_rag_index_vector -from src.ui.rag_panel import parse_path_lines, sanitize_upload_name +from src.ui.rag_panel import ( + chunk_preview_rows, + format_rag_debug_summary, + format_score_breakdown, + parse_path_lines, + sanitize_upload_name, + summarize_rag_index, +) def _install_fake_pypdf( @@ -197,10 +204,16 @@ def test_query_documents_supports_retrieval_modes(tmp_path): lexical_results = query_documents("cited chunks", index_path=index_path, retrieval_mode="lexical") hybrid_results = query_documents("cited chunks", index_path=index_path, retrieval_mode="hybrid") vector_results = query_documents("cited chunks", index_path=index_path, retrieval_mode="vector") + backend_results = query_documents( + "cited chunks", + index_path=index_path, + retrieval_mode="backend_vector", + ) assert lexical_results assert hybrid_results assert vector_results + assert backend_results def test_query_documents_rejects_unknown_retrieval_mode(tmp_path): @@ -240,6 +253,48 @@ def test_build_rag_debug_explains_hybrid_scores(tmp_path): assert breakdown["vector_score"] > 0 +def test_rag_panel_index_summary_and_chunk_preview(tmp_path): + path = tmp_path / "notes.md" + path.write_text("First retrieval paragraph.\n\nSecond retrieval paragraph.", encoding="utf-8") + index = build_rag_index([path], max_chars=200, overlap_chars=0) + + summary = summarize_rag_index(index) + preview_rows = chunk_preview_rows(index) + + assert summary["documents"] == 1 + assert summary["chunks"] == 1 + assert summary["document_rows"][0]["title"] == "notes" + assert summary["document_rows"][0]["chunk_count"] == 1 + assert len(summary["document_rows"][0]["content_hash"]) == 8 + assert preview_rows[0]["lines"] == "L1-L3" + assert "retrieval paragraph" in preview_rows[0]["preview"] + + +def test_rag_panel_formats_debug_summary_and_breakdown(): + debug = { + "retrieval_mode": "hybrid", + "top_k": 3, + "min_score": 0.01, + "candidate_count": 8, + "returned_count": 2, + "query_terms": ["rag", "debug"], + } + result_debug = { + "score_breakdown": { + "lexical_weight": 0.7, + "lexical_score": 3.5, + "lexical_normalized": 1.0, + "vector_score": 0.25, + "combined_score": 0.775, + } + } + + assert format_rag_debug_summary(debug) == ( + "mode=hybrid; top_k=3; min_score=0.01; candidates=8; returned=2; terms=debug, rag" + ) + assert "combined_score=0.775" in format_score_breakdown(result_debug) + + def test_build_rag_debug_marks_empty_queries(tmp_path): path = tmp_path / "notes.md" path.write_text("Local retrieval.", encoding="utf-8") diff --git a/tests/test_rag_backends.py b/tests/test_rag_backends.py new file mode 100644 index 0000000..9ee87c8 --- /dev/null +++ b/tests/test_rag_backends.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import pytest + +from src.rag.backends import ( + LocalVectorBackend, + get_vector_backend, + get_vector_backend_from_env, + vector_backend_config_from_env, +) +from src.rag.chroma_backend import ChromaVectorBackend +from src.rag.index import build_rag_index + + +class _FakeCollection: + def __init__(self) -> None: + self.upserts = [] + self.response = { + "ids": [["chunk-1"]], + "distances": [[0.2]], + "documents": [["Stored chunk text"]], + "metadatas": [ + [ + { + "document_hash": "doc-hash", + "source_path": "notes.md", + "title": "notes", + "chunk_index": 0, + "start_line": 1, + "end_line": 2, + "file_type": "md", + "char_count": 17, + } + ] + ], + } + + def upsert(self, **kwargs) -> None: + self.upserts.append(kwargs) + + def query(self, **kwargs): + self.last_query = kwargs + return self.response + + +class _FakeClient: + def __init__(self) -> None: + self.collection = _FakeCollection() + self.collection_names = [] + + def get_or_create_collection(self, name: str): + self.collection_names.append(name) + return self.collection + + +def test_local_vector_backend_queries_index(tmp_path): + path = tmp_path / "notes.md" + path.write_text("Vector backend retrieves local chunks.", encoding="utf-8") + index = build_rag_index([path], max_chars=200, overlap_chars=0) + backend = LocalVectorBackend() + + results = backend.query(index, "local chunks", top_k=1, min_score=0.0) + + assert backend.status().available is True + assert backend.status().embedding_provider == "local_hash" + assert results + assert results[0].chunk.source_path == str(path) + + +def test_get_vector_backend_rejects_unknown_backend(): + with pytest.raises(ValueError, match="Unsupported vector backend"): + get_vector_backend("pinecone") + + +def test_vector_backend_config_reads_environment(monkeypatch): + monkeypatch.setenv("RAG_VECTOR_BACKEND", "chroma") + monkeypatch.setenv("RAG_CHROMA_PATH", "logs/test_chroma") + monkeypatch.setenv("RAG_CHROMA_COLLECTION", "study_agent_test") + + config = vector_backend_config_from_env() + + assert config == { + "name": "chroma", + "path": "logs/test_chroma", + "collection": "study_agent_test", + } + + +def test_get_vector_backend_from_env_defaults_to_local(monkeypatch): + monkeypatch.delenv("RAG_VECTOR_BACKEND", raising=False) + + backend = get_vector_backend_from_env() + + assert isinstance(backend, LocalVectorBackend) + + +def test_chroma_backend_upserts_chunks_with_embeddings(tmp_path): + path = tmp_path / "notes.md" + path.write_text("Chroma adapter stores local chunks.", encoding="utf-8") + index = build_rag_index([path], max_chars=200, overlap_chars=0) + fake_client = _FakeClient() + backend = ChromaVectorBackend( + path=tmp_path / "chroma", + collection_name="study_agent_test", + client=fake_client, + ) + + backend.upsert_index(index) + + assert fake_client.collection_names == ["study_agent_test"] + upsert = fake_client.collection.upserts[0] + assert upsert["ids"] == [index.chunks[0].chunk_id] + assert len(upsert["embeddings"][0]) == 256 + assert upsert["documents"] == [index.chunks[0].text] + assert upsert["metadatas"][0]["source_path"] == str(path) + + +def test_chroma_backend_query_reconstructs_search_results(tmp_path): + fake_client = _FakeClient() + backend = ChromaVectorBackend( + path=tmp_path / "chroma", + collection_name="study_agent_test", + client=fake_client, + ) + path = tmp_path / "notes.md" + path.write_text("Local placeholder.", encoding="utf-8") + index = build_rag_index([path], max_chars=200, overlap_chars=0) + + results = backend.query(index, "stored chunk", top_k=1, min_score=0.1) + + assert backend.status().available is True + assert fake_client.collection.last_query["n_results"] == 1 + assert results[0].score == 0.8 + assert results[0].chunk.source_path == "notes.md" + assert results[0].chunk.start_line == 1