Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions src/basic_memory/repository/search_repository_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,24 +1934,25 @@ def _log_vector_summary() -> None:
# Build per-search_index_row similarity scores from chunk-level results.
# Each chunk_key encodes the search_index row type and id.
# Track the best similarity per row (for ranking) and all chunks (for context).
similarity_by_si_id: dict[int, float] = {}
chunks_by_si_id: dict[int, list[tuple[float, str]]] = {}
similarity_by_si_key: dict[tuple[str, int], float] = {}
chunks_by_si_key: dict[tuple[str, int], list[tuple[float, str]]] = {}
for row in vector_rows:
chunk_key = row.get("chunk_key", "")
distance = float(row["best_distance"])
similarity = self._distance_to_similarity(distance)
chunk_text = row.get("chunk_text", "")
try:
_, si_id = self._parse_chunk_key(chunk_key)
si_type, si_id = self._parse_chunk_key(chunk_key)
except (ValueError, IndexError):
# Fallback: group by entity_id for chunks without parseable keys
continue
current = similarity_by_si_id.get(si_id)
si_key = (si_type, si_id)
current = similarity_by_si_key.get(si_key)
if current is None or similarity > current:
similarity_by_si_id[si_id] = similarity
chunks_by_si_id.setdefault(si_id, []).append((similarity, chunk_text))
similarity_by_si_key[si_key] = similarity
chunks_by_si_key.setdefault(si_key, []).append((similarity, chunk_text))

if not similarity_by_si_id:
if not similarity_by_si_key:
hydrate_ms = (time.perf_counter() - hydrate_start) * 1000
_log_vector_summary()
return []
Expand All @@ -1962,17 +1963,17 @@ def _log_vector_summary() -> None:
min_similarity if min_similarity is not None else self._semantic_min_similarity
)
if effective_min_similarity > 0.0:
similarity_by_si_id = {
k: v for k, v in similarity_by_si_id.items() if v >= effective_min_similarity
similarity_by_si_key = {
k: v for k, v in similarity_by_si_key.items() if v >= effective_min_similarity
}
if not similarity_by_si_id:
if not similarity_by_si_key:
hydrate_ms = (time.perf_counter() - hydrate_start) * 1000
_log_vector_summary()
return []

# Fetch the actual search_index rows
si_ids = list(similarity_by_si_id.keys())
search_index_rows = await self._fetch_search_index_rows_by_ids(si_ids)
si_keys = list(similarity_by_si_key.keys())
search_index_rows = await self._fetch_search_index_rows_by_ids(si_keys)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Carry typed keys through hybrid fusion

When this vector path is invoked from _search_hybrid (retrieval_mode=hybrid), the typed duplicate ids returned here are immediately collapsed because hybrid fusion still stores rows_by_id, fts_scores, vec_scores, and fused_scores by bare row.id. A query whose vector hits include both entity:4:0 and relation:4:0 will now return both in vector mode, but hybrid mode overwrites one of them, so the collision fix does not apply to MCP/API hybrid searches. Please carry (row.type, row.id) through the hybrid fusion maps as well.

Useful? React with 👍 / 👎.


# Apply optional filters if requested
filter_requested = any(
Expand Down Expand Up @@ -2003,16 +2004,16 @@ def _log_vector_summary() -> None:
limit=VECTOR_FILTER_SCAN_LIMIT,
offset=0,
)
# Use (id, type) tuples to avoid collisions between different
# Use (type, id) tuples to avoid collisions between different
# search_index row types that share the same auto-increment id.
allowed_keys = {(row.id, row.type) for row in filtered_rows if row.id is not None}
allowed_keys = {(row.type, row.id) for row in filtered_rows if row.id is not None}
search_index_rows = {
k: v for k, v in search_index_rows.items() if (v.id, v.type) in allowed_keys
k: v for k, v in search_index_rows.items() if (v.type, v.id) in allowed_keys
}

ranked_rows: list[SearchIndexRow] = []
for si_id, similarity in similarity_by_si_id.items():
row = search_index_rows.get(si_id)
for si_key, similarity in similarity_by_si_key.items():
row = search_index_rows.get(si_key)
if row is None:
continue

Expand All @@ -2022,7 +2023,7 @@ def _log_vector_summary() -> None:
if content_snippet and len(content_snippet) <= SMALL_NOTE_CONTENT_LIMIT:
matched_chunk_text = content_snippet
else:
si_chunks = chunks_by_si_id.get(si_id, [])
si_chunks = chunks_by_si_key.get(si_key, [])
si_chunks.sort(key=lambda c: c[0], reverse=True)
top_texts = [text for _, text in si_chunks[:TOP_CHUNKS_PER_RESULT]]
matched_chunk_text = "\n---\n".join(top_texts) if top_texts else None
Expand Down Expand Up @@ -2087,11 +2088,12 @@ async def _fetch_entity_rows_by_ids(self, entity_ids: list[int]) -> dict[int, Se
return result

async def _fetch_search_index_rows_by_ids(
self, row_ids: list[int]
) -> dict[int, SearchIndexRow]:
"""Fetch search_index rows by their primary key (id), any type."""
if not row_ids:
self, row_keys: list[tuple[str, int]]
) -> dict[tuple[str, int], SearchIndexRow]:
"""Fetch search_index rows by primary key and return them keyed by (type, id)."""
if not row_keys:
return {}
row_ids = sorted({row_id for _, row_id in row_keys})
placeholders = ",".join(f":id_{idx}" for idx in range(len(row_ids)))
params: dict[str, Any] = {
**{f"id_{idx}": rid for idx, rid in enumerate(row_ids)},
Expand All @@ -2106,11 +2108,11 @@ async def _fetch_search_index_rows_by_ids(
WHERE project_id = :project_id
AND id IN ({placeholders})
"""
result: dict[int, SearchIndexRow] = {}
result: dict[tuple[str, int], SearchIndexRow] = {}
async with db.scoped_session(self.session_maker) as session:
row_result = await session.execute(text(sql), params)
for row in row_result.fetchall():
result[row.id] = SearchIndexRow(
result[(row.type, row.id)] = SearchIndexRow(
project_id=self.project_id,
id=row.id,
title=row.title,
Expand Down
2 changes: 1 addition & 1 deletion tests/repository/test_vector_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def test_page1_scores_gte_page2_scores():

repo._embedding_provider = _EmbeddingProvider()

fake_index_rows = {i: FakeRow(id=i) for i in range(20)}
fake_index_rows = {("entity", i): FakeRow(id=i) for i in range(20)}

async def run_page(offset, limit):
with (
Expand Down
62 changes: 54 additions & 8 deletions tests/repository/test_vector_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def test_threshold_zero_returns_all():
repo,
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
return_value={i: FakeRow(id=i) for i in range(3)},
return_value={("entity", i): FakeRow(id=i) for i in range(3)},
),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)
Expand Down Expand Up @@ -192,7 +192,7 @@ async def test_threshold_filters_low_scores():
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
# Only entity_0 (score=0.9) passes the threshold; the fetch only gets id 0
return_value={0: FakeRow(id=0)},
return_value={("entity", 0): FakeRow(id=0)},
),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)
Expand Down Expand Up @@ -255,7 +255,7 @@ async def test_per_query_min_similarity_overrides_instance_default():
repo,
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
return_value={i: FakeRow(id=i) for i in range(3)},
return_value={("entity", i): FakeRow(id=i) for i in range(3)},
),
):
# Override to 0.0 → all results pass through despite instance default of 0.6
Expand Down Expand Up @@ -289,7 +289,7 @@ async def test_per_query_min_similarity_tightens_threshold():
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
# Only id=0 (score=0.9) will be fetched after filtering
return_value={0: FakeRow(id=0)},
return_value={("entity", 0): FakeRow(id=0)},
),
):
# Override to 0.8 → only score=0.9 passes
Expand Down Expand Up @@ -321,7 +321,7 @@ async def test_matched_chunk_text_populated_on_vector_results():
repo,
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
return_value={i: FakeRow(id=i) for i in range(2)},
return_value={("entity", i): FakeRow(id=i) for i in range(2)},
),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)
Expand All @@ -333,6 +333,52 @@ async def test_matched_chunk_text_populated_on_vector_results():
assert results[1].matched_chunk_text == "chunk text for entity:1:0"


@pytest.mark.asyncio
async def test_entity_and_relation_with_same_id_both_returned():
"""Vector hydration keeps row type with id so entity/relation ids cannot collide."""
repo = ConcreteSearchRepo()
repo._semantic_min_similarity = 0.0

fake_rows = [
{
"chunk_key": "entity:4:0",
"best_distance": (1.0 / 0.9) - 1.0,
"chunk_text": "entity chunk",
},
{
"chunk_key": "relation:4:0",
"best_distance": (1.0 / 0.8) - 1.0,
"chunk_text": "relation chunk",
},
]

mock_embed = AsyncMock(return_value=[0.0] * 384)
repo._embedding_provider = _fake_embedding_provider(mock_embed)

mock_fetch = AsyncMock(
return_value={
("entity", 4): FakeRow(id=4, type="entity"),
("relation", 4): FakeRow(id=4, type="relation"),
}
)

with (
patch(
"basic_memory.repository.search_repository_base.db.scoped_session",
fake_scoped_session,
),
patch.object(repo, "_ensure_vector_tables", new_callable=AsyncMock),
patch.object(repo, "_prepare_vector_session", new_callable=AsyncMock),
patch.object(repo, "_run_vector_query", new_callable=AsyncMock, return_value=fake_rows),
patch.object(repo, "_fetch_search_index_rows_by_ids", mock_fetch),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)

mock_fetch.assert_awaited_once_with([("entity", 4), ("relation", 4)])
assert [(result.type, result.id) for result in results] == [("entity", 4), ("relation", 4)]
assert [result.matched_chunk_text for result in results] == ["entity chunk", "relation chunk"]


def _make_multi_chunk_vector_rows(si_id: int, scores: list[float]) -> list[dict]:
"""Build multiple fake vector chunks for a single search_index row.

Expand Down Expand Up @@ -379,7 +425,7 @@ async def test_top_n_chunks_joined_in_matched_chunk_text():
repo,
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
return_value={0: FakeRow(id=0, content_snippet=large_content)},
return_value={("entity", 0): FakeRow(id=0, content_snippet=large_content)},
),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)
Expand Down Expand Up @@ -423,7 +469,7 @@ async def test_small_note_returns_full_content_as_matched_chunk():
repo,
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
return_value={0: FakeRow(id=0, content_snippet=small_content)},
return_value={("entity", 0): FakeRow(id=0, content_snippet=small_content)},
),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)
Expand Down Expand Up @@ -457,7 +503,7 @@ async def test_large_note_returns_chunks_not_full_content():
repo,
"_fetch_search_index_rows_by_ids",
new_callable=AsyncMock,
return_value={0: FakeRow(id=0, content_snippet=large_content)},
return_value={("entity", 0): FakeRow(id=0, content_snippet=large_content)},
),
):
results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)
Expand Down
Loading