From 1863cb54f34ce9ef126ca858690db34e3eba3187 Mon Sep 17 00:00:00 2001 From: Vaishnavi Desai Date: Mon, 22 Jun 2026 16:05:06 +0530 Subject: [PATCH] feat: add async and batch retrieval to VortexRAGRetriever Signed-off-by: Vaishnavi Desai --- integrations/langchain_retriever.py | 24 ++++++- tests/test_async_batch_retriever.py | 98 +++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 tests/test_async_batch_retriever.py diff --git a/integrations/langchain_retriever.py b/integrations/langchain_retriever.py index 9bd8032..fb968e2 100644 --- a/integrations/langchain_retriever.py +++ b/integrations/langchain_retriever.py @@ -19,7 +19,8 @@ """ from __future__ import annotations - +import asyncio +from concurrent.futures import ThreadPoolExecutor import sys import os from typing import List, Optional, Any @@ -68,6 +69,7 @@ def __init__( top_k: int = 5, config: Optional[VortexRAGConfig] = None, verbose: bool = False, + max_workers: int = 4, ): if config is not None: self._config = config @@ -78,6 +80,7 @@ def __init__( self._top_k = top_k self._rag: Optional[VortexRAG] = None self._texts: list[str] = [] + self._executor = ThreadPoolExecutor(max_workers=max_workers) # ------------------------------------------------------------------ # Document management @@ -187,6 +190,8 @@ def model_post_init(self, __context: Any) -> None: self._rag = None if not hasattr(self, "_texts"): self._texts = [] + if not hasattr(self, "_executor"): + self._executor = ThreadPoolExecutor(max_workers=4) def _get_relevant_documents( self, @@ -195,7 +200,24 @@ def _get_relevant_documents( run_manager: Optional["CallbackManagerForRetrieverRun"] = None, ) -> "List[Document]": return self.get_relevant_documents(query) # type: ignore[return-value] + + async def aget_relevant_documents(self, query): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self.get_relevant_documents, query) + + async def ainvoke(self, query, **kwargs): + return await self.aget_relevant_documents(query) + def batch(self, queries): + futures = [self._executor.submit(self.get_relevant_documents, q) for q in queries] + return [f.result() for f in futures] + + async def abatch(self, queries): + return await asyncio.gather(*[self.aget_relevant_documents(q) for q in queries]) else: # Fallback: plain class, no LangChain dependency VortexRAGRetriever = _VortexRAGRetrieverBase # type: ignore[misc,assignment] + + + + diff --git a/tests/test_async_batch_retriever.py b/tests/test_async_batch_retriever.py new file mode 100644 index 0000000..275dd6d --- /dev/null +++ b/tests/test_async_batch_retriever.py @@ -0,0 +1,98 @@ +""" +Tests for async and batch extensions of VortexRAGRetriever. +""" + +import asyncio +import pytest +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from integrations.langchain_retriever import VortexRAGRetriever + +SAMPLE_DOCS = [ + "Sepsis is a life-threatening condition caused by the body's extreme response to infection. " + "It occurs when chemicals released into the bloodstream to fight infection trigger widespread inflammation.", + + "Treatment of sepsis requires immediate administration of antibiotics and intravenous fluids. " + "Vasopressors may be needed to maintain blood pressure in cases of septic shock.", + + "The sequential organ failure assessment (SOFA) score is used to track a patient's status " + "and predict outcomes in intensive care units.", + + "Machine learning models have been applied to predict sepsis onset from electronic health records, " + "achieving AUROC scores above 0.85 on held-out test sets.", +] + + +@pytest.fixture +def retriever(): + r = VortexRAGRetriever(domain="medical", top_k=3) + r.add_documents(SAMPLE_DOCS) + return r + + +class TestAsyncRetrieval: + def test_aget_relevant_documents_returns_list(self, retriever): + results = asyncio.run(retriever.aget_relevant_documents("What causes sepsis?")) + assert isinstance(results, list) + assert len(results) > 0 + + def test_ainvoke_alias(self, retriever): + results = asyncio.run(retriever.ainvoke("sepsis treatment")) + assert isinstance(results, list) + + def test_async_respects_top_k(self, retriever): + results = asyncio.run(retriever.aget_relevant_documents("organ failure")) + assert len(results) <= 3 + + def test_async_result_has_page_content(self, retriever): + results = asyncio.run(retriever.aget_relevant_documents("infection")) + for doc in results: + if isinstance(doc, dict): + assert "page_content" in doc + else: + assert hasattr(doc, "page_content") + + def test_async_result_has_metadata(self, retriever): + results = asyncio.run(retriever.aget_relevant_documents("antibiotics")) + for doc in results: + meta = doc.metadata if hasattr(doc, "metadata") else doc["metadata"] + assert "phi_score" in meta + assert "rank" in meta + + def test_async_raises_without_documents(self): + r = VortexRAGRetriever(domain="medical", top_k=3) + with pytest.raises(ValueError, match="No documents indexed"): + asyncio.run(r.aget_relevant_documents("query")) + + def test_async_matches_sync_length(self, retriever): + query = "What causes sepsis?" + sync_results = retriever.get_relevant_documents(query) + async_results = asyncio.run(retriever.aget_relevant_documents(query)) + assert len(sync_results) == len(async_results) + + def test_abatch_concurrent(self, retriever): + queries = ["What causes sepsis?", "How is it treated?", "SOFA score"] + results = asyncio.run(retriever.abatch(queries)) + assert len(results) == 3 + for doc_list in results: + assert isinstance(doc_list, list) + + def test_abatch_empty_list(self, retriever): + results = asyncio.run(retriever.abatch([])) + assert results == [] + + +class TestBatchRetrieval: + def test_batch_returns_list_of_lists(self, retriever): + results = retriever.batch(["What causes sepsis?", "sepsis treatment"]) + assert len(results) == 2 + for doc_list in results: + assert isinstance(doc_list, list) + + def test_batch_order_preserved(self, retriever): + queries = ["infection", "antibiotics", "organ failure"] + results = retriever.batch(queries) + assert len(results) == len(queries) \ No newline at end of file