diff --git a/README.md b/README.md index e1c3121..d1e6e0c 100644 --- a/README.md +++ b/README.md @@ -166,6 +166,15 @@ In the above script modify, - `results_save_path` to specify where the index and metadata file will be saved - `embedded_chunks_json_file` to specify where the `embedded_chunks.json` is present +### Context Retriever Module + +Run `python .\atlas\core\retriever\context.py` + +In the above script modify, +- `results_load_path` to specify where the index and metadata file are present and will be loaded from +- `user_query` to specify the user prompt/query +- `k` to specify the number of most relevant chunks as the context for the user query + ### Tests Run unit tests via VS Code diff --git a/atlas/core/indexer/run_indexer.py b/atlas/core/indexer/run_indexer.py index f76811a..62ce03b 100644 --- a/atlas/core/indexer/run_indexer.py +++ b/atlas/core/indexer/run_indexer.py @@ -8,18 +8,21 @@ LOGGER = LoggerConfig().logger -if __name__ == "__main__": - LOGGER.info("Running indexer to save the chunk embeddings to a vector index") - # this is the root folder which saves the following 2 files: - # 1. index file - # 2. metadata json - results_save_path = r"D:\\Deep learning\\Atlas\\Resources" - store = FaissVectorStore( - dim=384 - ) # the encoder model we used generated embeddings of size 384 - embedded_chunks_json_file = ( - r"D:\\Deep learning\\Atlas\\Resources\\embedded_chunks.json" - ) + +def build_and_save_index( + store: FaissVectorStore, results_save_path: str, embedded_chunks_json_file: str +) -> None: + """ + Build the vector index using all the chunk embeddings and save the results. + Save the following two files: + 1. index file -> index.faiss + 2. chunk metadata -> metadata.json + + Args: + store (FaissVectorStore): Instance of FAISS Vector Store from Facebook AI Semantic Search. + results_save_path (str): Directory to save the above mentioned two result files. + embedded_chunks_json_file (str): The path to the embedded chunks json file. + """ embedded_chunks = load_embedded_chunks(embedded_chunks_json_file) store.add( @@ -29,15 +32,19 @@ store.save(results_save_path) - # Sanity checks - # query_text = "Role of luck in life" # exact phrase query - # query_text = "Folks who inspire me" # paraphrasing - query_text = ( - "Journey is more important that the final result in life" # paraphrasing - ) - encoder_config_path = os.path.join( - os.getcwd(), "atlas", "core", "configs", "sentence_transformer_config.yaml" - ) + +def sanity_check( + store: FaissVectorStore, query_text: str, encoder_config_path: str +) -> None: + """ + Retrieve 5 most relevant chunks for given user query. Used for sanity testing retrieval + process. + + Args: + store (FaissVectorStore): Instance of FAISS Vector Store from Facebook AI Semantic Search. + query_text (str): User query to retrieve context for. + encoder_config_path (str): Path to the encoder configuration file. + """ query_vector = generate_embedding(query_text, encoder_config_path) results = store.search(query_vector, k=5) LOGGER.info(len(results)) @@ -45,3 +52,37 @@ LOGGER.info(f"score: {res['score']}") LOGGER.info(f"Note title: {res['chunk_id']}") LOGGER.info("===\n") + + +if __name__ == "__main__": + LOGGER.info("Running indexer to save the chunk embeddings to a vector index") + # this is the root folder which saves the following 2 files: + # 1. index file + # 2. metadata json + results_save_path = r"D:\\Deep learning\\Atlas\\Resources" + + store = FaissVectorStore( + dim=384 + ) # the encoder model we used generated embeddings of size 384 + + embedded_chunks_json_file = ( + r"D:\\Deep learning\\Atlas\\Resources\\embedded_chunks.json" + ) + + build_and_save_index(store, results_save_path, embedded_chunks_json_file) + + # Optional sanity checks for testing + do_sanity_test = False + + if do_sanity_test: + # query_text = "Role of luck in life" # exact phrase query + # query_text = "Folks who inspire me" # paraphrasing + query_text = ( + "Journey is more important that the final result in life" # paraphrasing + ) + + encoder_config_path = os.path.join( + os.getcwd(), "atlas", "core", "configs", "sentence_transformer_config.yaml" + ) + + sanity_check(store, query_text, encoder_config_path) diff --git a/atlas/core/retriever/context.py b/atlas/core/retriever/context.py new file mode 100644 index 0000000..00ad269 --- /dev/null +++ b/atlas/core/retriever/context.py @@ -0,0 +1,81 @@ +import os + +from atlas.utils.embedder_utils import generate_embedding +from atlas.core.indexer.faiss_vector_store import FaissVectorStore + +from atlas.utils.logger import LoggerConfig + +LOGGER = LoggerConfig().logger + + +def retrieve_context(results_load_path: str, user_query: str, k: int = 5) -> str | None: + """ + Retrieve the context for the user query. The context is the concatenated text of the most + relevant chunks associated with the user query. + + Args: + results_load_path (str): Directory to load the above mentioned two result files from. + user_query (str): User query to retrieve context for. + k (int): Number of most similar embeddings (aka neighbors) to the query vector. + Default is 5. + + Returns: + str | None: The context associated with the user query. + """ + + # 1. load the vector store + store = FaissVectorStore( + dim=384 + ) # the encoder model we used generated embeddings of size 384 + try: + store.load(results_load_path) + except Exception as e: + LOGGER.error(f"Error while retrieving context : {repr(e)}") + return None + + # 2. embded user query + encoder_config_path = os.path.join( + os.getcwd(), "atlas", "core", "configs", "sentence_transformer_config.yaml" + ) + query_vector = generate_embedding(user_query, encoder_config_path) + + # 3. search for k top neighbors + try: + results = store.search(query_vector, k) + except Exception as e: + LOGGER.error(f"Error while retrieving context : {repr(e)}") + return None + + # 4. build and return context + context_parts = [] + if results == []: + return "" + + for rank, result in enumerate(results): + context_parts.append(f"[Context {rank + 1}]\n{result['text'].strip()}") + + return "\n\n".join(context_parts) + + +if __name__ == "__main__": + LOGGER.info("-" * 20) + LOGGER.info("Retrieve context for user prompt") + + # this is the root folder which loads the following 2 files: + # 1. index file + # 2. metadata json + results_load_path = r"D:\\Deep learning\\Atlas\\Resources" + user_query = ( + "Journey is more important that the final result in life" # paraphrasing + ) + k = 3 # retrieve 3 most relevant chunks as the context for the user query + + context = retrieve_context(results_load_path, user_query, k) + + if context: + LOGGER.info("context found: \n\n") + LOGGER.info(context) + elif context == "": + LOGGER.warning("No context found!!!") + else: + LOGGER.error("Error while trying to find context!!!") diff --git a/tests/unittests/scripts/test_context.py b/tests/unittests/scripts/test_context.py new file mode 100644 index 0000000..bcd8726 --- /dev/null +++ b/tests/unittests/scripts/test_context.py @@ -0,0 +1,35 @@ +import json +import pytest +import numpy as np +from pathlib import Path +import faiss + +from atlas.core.indexer.faiss_vector_store import FaissVectorStore +from atlas.core.retriever.context import retrieve_context +from atlas.utils.embedder_utils import load_embedded_chunks + + +@pytest.mark.unittest +@pytest.mark.runonci +def test_retrieve_context(tmp_path: Path, dummy_embedded_chunk_data_path: Path) -> None: + """ + Test context retrieval functionality given a user query. + + Args: + tmp_path (Path): Temporary path provided by pytest. + dummy_embedded_chunk_data_path (Path): The path to the dummy embedded chunks json file. + """ + vectors = np.array([[i for i in range(384)]]) + embedded_chunks = load_embedded_chunks(str(dummy_embedded_chunk_data_path)) + store = FaissVectorStore(dim=384) + store.add(vectors, embedded_chunks) + results_save_path = tmp_path / "Results" + store.save(str(results_save_path)) + + user_query = "test note is used for testing" + k = 1 + context = retrieve_context( + results_load_path=str(results_save_path), user_query=user_query, k=k + ) + + assert context != None and len(context) != 0