Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion integrations/langchain_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"""

from __future__ import annotations

import asyncio
from concurrent.futures import ThreadPoolExecutor
import sys
import os
from typing import List, Optional, Any
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
top_k: int = 5,
config: Optional[VortexRAGConfig] = None,
verbose: bool = False,
max_workers: int = 4,
):
if config is not None:
self._config = config
Expand All @@ -78,6 +80,7 @@ def __init__(
self._top_k = top_k
self._rag: Optional[VortexRAG] = None
self._texts: list[str] = []
self._executor = ThreadPoolExecutor(max_workers=max_workers)

# ------------------------------------------------------------------
# Document management
Expand Down Expand Up @@ -187,6 +190,8 @@ def model_post_init(self, __context: Any) -> None:
self._rag = None
if not hasattr(self, "_texts"):
self._texts = []
if not hasattr(self, "_executor"):
self._executor = ThreadPoolExecutor(max_workers=4)

def _get_relevant_documents(
self,
Expand All @@ -195,7 +200,24 @@ def _get_relevant_documents(
run_manager: Optional["CallbackManagerForRetrieverRun"] = None,
) -> "List[Document]":
return self.get_relevant_documents(query) # type: ignore[return-value]

async def aget_relevant_documents(self, query):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._executor, self.get_relevant_documents, query)

async def ainvoke(self, query, **kwargs):
return await self.aget_relevant_documents(query)

def batch(self, queries):
futures = [self._executor.submit(self.get_relevant_documents, q) for q in queries]
return [f.result() for f in futures]

async def abatch(self, queries):
return await asyncio.gather(*[self.aget_relevant_documents(q) for q in queries])
else:
# Fallback: plain class, no LangChain dependency
VortexRAGRetriever = _VortexRAGRetrieverBase # type: ignore[misc,assignment]




98 changes: 98 additions & 0 deletions tests/test_async_batch_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Tests for async and batch extensions of VortexRAGRetriever.
"""

import asyncio
import pytest
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from integrations.langchain_retriever import VortexRAGRetriever

SAMPLE_DOCS = [
"Sepsis is a life-threatening condition caused by the body's extreme response to infection. "
"It occurs when chemicals released into the bloodstream to fight infection trigger widespread inflammation.",

"Treatment of sepsis requires immediate administration of antibiotics and intravenous fluids. "
"Vasopressors may be needed to maintain blood pressure in cases of septic shock.",

"The sequential organ failure assessment (SOFA) score is used to track a patient's status "
"and predict outcomes in intensive care units.",

"Machine learning models have been applied to predict sepsis onset from electronic health records, "
"achieving AUROC scores above 0.85 on held-out test sets.",
]


@pytest.fixture
def retriever():
r = VortexRAGRetriever(domain="medical", top_k=3)
r.add_documents(SAMPLE_DOCS)
return r


class TestAsyncRetrieval:
def test_aget_relevant_documents_returns_list(self, retriever):
results = asyncio.run(retriever.aget_relevant_documents("What causes sepsis?"))
assert isinstance(results, list)
assert len(results) > 0

def test_ainvoke_alias(self, retriever):
results = asyncio.run(retriever.ainvoke("sepsis treatment"))
assert isinstance(results, list)

def test_async_respects_top_k(self, retriever):
results = asyncio.run(retriever.aget_relevant_documents("organ failure"))
assert len(results) <= 3

def test_async_result_has_page_content(self, retriever):
results = asyncio.run(retriever.aget_relevant_documents("infection"))
for doc in results:
if isinstance(doc, dict):
assert "page_content" in doc
else:
assert hasattr(doc, "page_content")

def test_async_result_has_metadata(self, retriever):
results = asyncio.run(retriever.aget_relevant_documents("antibiotics"))
for doc in results:
meta = doc.metadata if hasattr(doc, "metadata") else doc["metadata"]
assert "phi_score" in meta
assert "rank" in meta

def test_async_raises_without_documents(self):
r = VortexRAGRetriever(domain="medical", top_k=3)
with pytest.raises(ValueError, match="No documents indexed"):
asyncio.run(r.aget_relevant_documents("query"))

def test_async_matches_sync_length(self, retriever):
query = "What causes sepsis?"
sync_results = retriever.get_relevant_documents(query)
async_results = asyncio.run(retriever.aget_relevant_documents(query))
assert len(sync_results) == len(async_results)

def test_abatch_concurrent(self, retriever):
queries = ["What causes sepsis?", "How is it treated?", "SOFA score"]
results = asyncio.run(retriever.abatch(queries))
assert len(results) == 3
for doc_list in results:
assert isinstance(doc_list, list)

def test_abatch_empty_list(self, retriever):
results = asyncio.run(retriever.abatch([]))
assert results == []


class TestBatchRetrieval:
def test_batch_returns_list_of_lists(self, retriever):
results = retriever.batch(["What causes sepsis?", "sepsis treatment"])
assert len(results) == 2
for doc_list in results:
assert isinstance(doc_list, list)

def test_batch_order_preserved(self, retriever):
queries = ["infection", "antibiotics", "organ failure"]
results = retriever.batch(queries)
assert len(results) == len(queries)