Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ In the above script modify,
- `results_save_path` to specify where the index and metadata file will be saved
- `embedded_chunks_json_file` to specify where the `embedded_chunks.json` is present

### Context Retriever Module

Run `python .\atlas\core\retriever\context.py`

In the above script modify,
- `results_load_path` to specify where the index and metadata file are present and will be loaded from
- `user_query` to specify the user prompt/query
- `k` to specify the number of most relevant chunks as the context for the user query

### Tests

Run unit tests via VS Code
Expand Down
83 changes: 62 additions & 21 deletions atlas/core/indexer/run_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@

LOGGER = LoggerConfig().logger

if __name__ == "__main__":
LOGGER.info("Running indexer to save the chunk embeddings to a vector index")
# this is the root folder which saves the following 2 files:
# 1. index file
# 2. metadata json
results_save_path = r"D:\\Deep learning\\Atlas\\Resources"
store = FaissVectorStore(
dim=384
) # the encoder model we used generated embeddings of size 384
embedded_chunks_json_file = (
r"D:\\Deep learning\\Atlas\\Resources\\embedded_chunks.json"
)

def build_and_save_index(
store: FaissVectorStore, results_save_path: str, embedded_chunks_json_file: str
) -> None:
"""
Build the vector index using all the chunk embeddings and save the results.
Save the following two files:
1. index file -> index.faiss
2. chunk metadata -> metadata.json

Args:
store (FaissVectorStore): Instance of FAISS Vector Store from Facebook AI Semantic Search.
results_save_path (str): Directory to save the above mentioned two result files.
embedded_chunks_json_file (str): The path to the embedded chunks json file.
"""
embedded_chunks = load_embedded_chunks(embedded_chunks_json_file)

store.add(
Expand All @@ -29,19 +32,57 @@

store.save(results_save_path)

# Sanity checks
# query_text = "Role of luck in life" # exact phrase query
# query_text = "Folks who inspire me" # paraphrasing
query_text = (
"Journey is more important that the final result in life" # paraphrasing
)
encoder_config_path = os.path.join(
os.getcwd(), "atlas", "core", "configs", "sentence_transformer_config.yaml"
)

def sanity_check(
store: FaissVectorStore, query_text: str, encoder_config_path: str
) -> None:
"""
Retrieve 5 most relevant chunks for given user query. Used for sanity testing retrieval
process.

Args:
store (FaissVectorStore): Instance of FAISS Vector Store from Facebook AI Semantic Search.
query_text (str): User query to retrieve context for.
encoder_config_path (str): Path to the encoder configuration file.
"""
query_vector = generate_embedding(query_text, encoder_config_path)
results = store.search(query_vector, k=5)
LOGGER.info(len(results))
for res in results:
LOGGER.info(f"score: {res['score']}")
LOGGER.info(f"Note title: {res['chunk_id']}")
LOGGER.info("===\n")


if __name__ == "__main__":
LOGGER.info("Running indexer to save the chunk embeddings to a vector index")
# this is the root folder which saves the following 2 files:
# 1. index file
# 2. metadata json
results_save_path = r"D:\\Deep learning\\Atlas\\Resources"

store = FaissVectorStore(
dim=384
) # the encoder model we used generated embeddings of size 384

embedded_chunks_json_file = (
r"D:\\Deep learning\\Atlas\\Resources\\embedded_chunks.json"
)

build_and_save_index(store, results_save_path, embedded_chunks_json_file)

# Optional sanity checks for testing
do_sanity_test = False

if do_sanity_test:
# query_text = "Role of luck in life" # exact phrase query
# query_text = "Folks who inspire me" # paraphrasing
query_text = (
"Journey is more important that the final result in life" # paraphrasing
)

encoder_config_path = os.path.join(
os.getcwd(), "atlas", "core", "configs", "sentence_transformer_config.yaml"
)

sanity_check(store, query_text, encoder_config_path)
81 changes: 81 additions & 0 deletions atlas/core/retriever/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

from atlas.utils.embedder_utils import generate_embedding
from atlas.core.indexer.faiss_vector_store import FaissVectorStore

from atlas.utils.logger import LoggerConfig

LOGGER = LoggerConfig().logger


def retrieve_context(results_load_path: str, user_query: str, k: int = 5) -> str | None:
"""
Retrieve the context for the user query. The context is the concatenated text of the most
relevant chunks associated with the user query.

Args:
results_load_path (str): Directory to load the above mentioned two result files from.
user_query (str): User query to retrieve context for.
k (int): Number of most similar embeddings (aka neighbors) to the query vector.
Default is 5.

Returns:
str | None: The context associated with the user query.
"""

# 1. load the vector store
store = FaissVectorStore(
dim=384
) # the encoder model we used generated embeddings of size 384
try:
store.load(results_load_path)
except Exception as e:
LOGGER.error(f"Error while retrieving context : {repr(e)}")
return None

# 2. embded user query
encoder_config_path = os.path.join(
os.getcwd(), "atlas", "core", "configs", "sentence_transformer_config.yaml"
)
query_vector = generate_embedding(user_query, encoder_config_path)

# 3. search for k top neighbors
try:
results = store.search(query_vector, k)
except Exception as e:
LOGGER.error(f"Error while retrieving context : {repr(e)}")
return None

# 4. build and return context
context_parts = []
if results == []:
return ""

for rank, result in enumerate(results):
context_parts.append(f"[Context {rank + 1}]\n{result['text'].strip()}")

return "\n\n".join(context_parts)


if __name__ == "__main__":
LOGGER.info("-" * 20)
LOGGER.info("Retrieve context for user prompt")

# this is the root folder which loads the following 2 files:
# 1. index file
# 2. metadata json
results_load_path = r"D:\\Deep learning\\Atlas\\Resources"
user_query = (
"Journey is more important that the final result in life" # paraphrasing
)
k = 3 # retrieve 3 most relevant chunks as the context for the user query

context = retrieve_context(results_load_path, user_query, k)

if context:
LOGGER.info("context found: \n\n")
LOGGER.info(context)
elif context == "":
LOGGER.warning("No context found!!!")
else:
LOGGER.error("Error while trying to find context!!!")
35 changes: 35 additions & 0 deletions tests/unittests/scripts/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
import pytest
import numpy as np
from pathlib import Path
import faiss

from atlas.core.indexer.faiss_vector_store import FaissVectorStore
from atlas.core.retriever.context import retrieve_context
from atlas.utils.embedder_utils import load_embedded_chunks


@pytest.mark.unittest
@pytest.mark.runonci
def test_retrieve_context(tmp_path: Path, dummy_embedded_chunk_data_path: Path) -> None:
"""
Test context retrieval functionality given a user query.

Args:
tmp_path (Path): Temporary path provided by pytest.
dummy_embedded_chunk_data_path (Path): The path to the dummy embedded chunks json file.
"""
vectors = np.array([[i for i in range(384)]])
embedded_chunks = load_embedded_chunks(str(dummy_embedded_chunk_data_path))
store = FaissVectorStore(dim=384)
store.add(vectors, embedded_chunks)
results_save_path = tmp_path / "Results"
store.save(str(results_save_path))

user_query = "test note is used for testing"
k = 1
context = retrieve_context(
results_load_path=str(results_save_path), user_query=user_query, k=k
)

assert context != None and len(context) != 0