diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 39c03227..183ce8d2 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -1934,24 +1934,25 @@ def _log_vector_summary() -> None: # Build per-search_index_row similarity scores from chunk-level results. # Each chunk_key encodes the search_index row type and id. # Track the best similarity per row (for ranking) and all chunks (for context). - similarity_by_si_id: dict[int, float] = {} - chunks_by_si_id: dict[int, list[tuple[float, str]]] = {} + similarity_by_si_key: dict[tuple[str, int], float] = {} + chunks_by_si_key: dict[tuple[str, int], list[tuple[float, str]]] = {} for row in vector_rows: chunk_key = row.get("chunk_key", "") distance = float(row["best_distance"]) similarity = self._distance_to_similarity(distance) chunk_text = row.get("chunk_text", "") try: - _, si_id = self._parse_chunk_key(chunk_key) + si_type, si_id = self._parse_chunk_key(chunk_key) except (ValueError, IndexError): # Fallback: group by entity_id for chunks without parseable keys continue - current = similarity_by_si_id.get(si_id) + si_key = (si_type, si_id) + current = similarity_by_si_key.get(si_key) if current is None or similarity > current: - similarity_by_si_id[si_id] = similarity - chunks_by_si_id.setdefault(si_id, []).append((similarity, chunk_text)) + similarity_by_si_key[si_key] = similarity + chunks_by_si_key.setdefault(si_key, []).append((similarity, chunk_text)) - if not similarity_by_si_id: + if not similarity_by_si_key: hydrate_ms = (time.perf_counter() - hydrate_start) * 1000 _log_vector_summary() return [] @@ -1962,17 +1963,17 @@ def _log_vector_summary() -> None: min_similarity if min_similarity is not None else self._semantic_min_similarity ) if effective_min_similarity > 0.0: - similarity_by_si_id = { - k: v for k, v in similarity_by_si_id.items() if v >= effective_min_similarity + similarity_by_si_key = { + k: v for k, v in similarity_by_si_key.items() if v >= effective_min_similarity } - if not similarity_by_si_id: + if not similarity_by_si_key: hydrate_ms = (time.perf_counter() - hydrate_start) * 1000 _log_vector_summary() return [] # Fetch the actual search_index rows - si_ids = list(similarity_by_si_id.keys()) - search_index_rows = await self._fetch_search_index_rows_by_ids(si_ids) + si_keys = list(similarity_by_si_key.keys()) + search_index_rows = await self._fetch_search_index_rows_by_ids(si_keys) # Apply optional filters if requested filter_requested = any( @@ -2003,16 +2004,16 @@ def _log_vector_summary() -> None: limit=VECTOR_FILTER_SCAN_LIMIT, offset=0, ) - # Use (id, type) tuples to avoid collisions between different + # Use (type, id) tuples to avoid collisions between different # search_index row types that share the same auto-increment id. - allowed_keys = {(row.id, row.type) for row in filtered_rows if row.id is not None} + allowed_keys = {(row.type, row.id) for row in filtered_rows if row.id is not None} search_index_rows = { - k: v for k, v in search_index_rows.items() if (v.id, v.type) in allowed_keys + k: v for k, v in search_index_rows.items() if (v.type, v.id) in allowed_keys } ranked_rows: list[SearchIndexRow] = [] - for si_id, similarity in similarity_by_si_id.items(): - row = search_index_rows.get(si_id) + for si_key, similarity in similarity_by_si_key.items(): + row = search_index_rows.get(si_key) if row is None: continue @@ -2022,7 +2023,7 @@ def _log_vector_summary() -> None: if content_snippet and len(content_snippet) <= SMALL_NOTE_CONTENT_LIMIT: matched_chunk_text = content_snippet else: - si_chunks = chunks_by_si_id.get(si_id, []) + si_chunks = chunks_by_si_key.get(si_key, []) si_chunks.sort(key=lambda c: c[0], reverse=True) top_texts = [text for _, text in si_chunks[:TOP_CHUNKS_PER_RESULT]] matched_chunk_text = "\n---\n".join(top_texts) if top_texts else None @@ -2087,11 +2088,12 @@ async def _fetch_entity_rows_by_ids(self, entity_ids: list[int]) -> dict[int, Se return result async def _fetch_search_index_rows_by_ids( - self, row_ids: list[int] - ) -> dict[int, SearchIndexRow]: - """Fetch search_index rows by their primary key (id), any type.""" - if not row_ids: + self, row_keys: list[tuple[str, int]] + ) -> dict[tuple[str, int], SearchIndexRow]: + """Fetch search_index rows by primary key and return them keyed by (type, id).""" + if not row_keys: return {} + row_ids = sorted({row_id for _, row_id in row_keys}) placeholders = ",".join(f":id_{idx}" for idx in range(len(row_ids))) params: dict[str, Any] = { **{f"id_{idx}": rid for idx, rid in enumerate(row_ids)}, @@ -2106,11 +2108,11 @@ async def _fetch_search_index_rows_by_ids( WHERE project_id = :project_id AND id IN ({placeholders}) """ - result: dict[int, SearchIndexRow] = {} + result: dict[tuple[str, int], SearchIndexRow] = {} async with db.scoped_session(self.session_maker) as session: row_result = await session.execute(text(sql), params) for row in row_result.fetchall(): - result[row.id] = SearchIndexRow( + result[(row.type, row.id)] = SearchIndexRow( project_id=self.project_id, id=row.id, title=row.title, diff --git a/tests/repository/test_vector_pagination.py b/tests/repository/test_vector_pagination.py index 21bc98a8..09a001e3 100644 --- a/tests/repository/test_vector_pagination.py +++ b/tests/repository/test_vector_pagination.py @@ -133,7 +133,7 @@ async def test_page1_scores_gte_page2_scores(): repo._embedding_provider = _EmbeddingProvider() - fake_index_rows = {i: FakeRow(id=i) for i in range(20)} + fake_index_rows = {("entity", i): FakeRow(id=i) for i in range(20)} async def run_page(offset, limit): with ( diff --git a/tests/repository/test_vector_threshold.py b/tests/repository/test_vector_threshold.py index 266ee9ae..64ec4f81 100644 --- a/tests/repository/test_vector_threshold.py +++ b/tests/repository/test_vector_threshold.py @@ -160,7 +160,7 @@ async def test_threshold_zero_returns_all(): repo, "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, - return_value={i: FakeRow(id=i) for i in range(3)}, + return_value={("entity", i): FakeRow(id=i) for i in range(3)}, ), ): results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS) @@ -192,7 +192,7 @@ async def test_threshold_filters_low_scores(): "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, # Only entity_0 (score=0.9) passes the threshold; the fetch only gets id 0 - return_value={0: FakeRow(id=0)}, + return_value={("entity", 0): FakeRow(id=0)}, ), ): results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS) @@ -255,7 +255,7 @@ async def test_per_query_min_similarity_overrides_instance_default(): repo, "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, - return_value={i: FakeRow(id=i) for i in range(3)}, + return_value={("entity", i): FakeRow(id=i) for i in range(3)}, ), ): # Override to 0.0 → all results pass through despite instance default of 0.6 @@ -289,7 +289,7 @@ async def test_per_query_min_similarity_tightens_threshold(): "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, # Only id=0 (score=0.9) will be fetched after filtering - return_value={0: FakeRow(id=0)}, + return_value={("entity", 0): FakeRow(id=0)}, ), ): # Override to 0.8 → only score=0.9 passes @@ -321,7 +321,7 @@ async def test_matched_chunk_text_populated_on_vector_results(): repo, "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, - return_value={i: FakeRow(id=i) for i in range(2)}, + return_value={("entity", i): FakeRow(id=i) for i in range(2)}, ), ): results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS) @@ -333,6 +333,52 @@ async def test_matched_chunk_text_populated_on_vector_results(): assert results[1].matched_chunk_text == "chunk text for entity:1:0" +@pytest.mark.asyncio +async def test_entity_and_relation_with_same_id_both_returned(): + """Vector hydration keeps row type with id so entity/relation ids cannot collide.""" + repo = ConcreteSearchRepo() + repo._semantic_min_similarity = 0.0 + + fake_rows = [ + { + "chunk_key": "entity:4:0", + "best_distance": (1.0 / 0.9) - 1.0, + "chunk_text": "entity chunk", + }, + { + "chunk_key": "relation:4:0", + "best_distance": (1.0 / 0.8) - 1.0, + "chunk_text": "relation chunk", + }, + ] + + mock_embed = AsyncMock(return_value=[0.0] * 384) + repo._embedding_provider = _fake_embedding_provider(mock_embed) + + mock_fetch = AsyncMock( + return_value={ + ("entity", 4): FakeRow(id=4, type="entity"), + ("relation", 4): FakeRow(id=4, type="relation"), + } + ) + + with ( + patch( + "basic_memory.repository.search_repository_base.db.scoped_session", + fake_scoped_session, + ), + patch.object(repo, "_ensure_vector_tables", new_callable=AsyncMock), + patch.object(repo, "_prepare_vector_session", new_callable=AsyncMock), + patch.object(repo, "_run_vector_query", new_callable=AsyncMock, return_value=fake_rows), + patch.object(repo, "_fetch_search_index_rows_by_ids", mock_fetch), + ): + results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS) + + mock_fetch.assert_awaited_once_with([("entity", 4), ("relation", 4)]) + assert [(result.type, result.id) for result in results] == [("entity", 4), ("relation", 4)] + assert [result.matched_chunk_text for result in results] == ["entity chunk", "relation chunk"] + + def _make_multi_chunk_vector_rows(si_id: int, scores: list[float]) -> list[dict]: """Build multiple fake vector chunks for a single search_index row. @@ -379,7 +425,7 @@ async def test_top_n_chunks_joined_in_matched_chunk_text(): repo, "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, - return_value={0: FakeRow(id=0, content_snippet=large_content)}, + return_value={("entity", 0): FakeRow(id=0, content_snippet=large_content)}, ), ): results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS) @@ -423,7 +469,7 @@ async def test_small_note_returns_full_content_as_matched_chunk(): repo, "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, - return_value={0: FakeRow(id=0, content_snippet=small_content)}, + return_value={("entity", 0): FakeRow(id=0, content_snippet=small_content)}, ), ): results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS) @@ -457,7 +503,7 @@ async def test_large_note_returns_chunks_not_full_content(): repo, "_fetch_search_index_rows_by_ids", new_callable=AsyncMock, - return_value={0: FakeRow(id=0, content_snippet=large_content)}, + return_value={("entity", 0): FakeRow(id=0, content_snippet=large_content)}, ), ): results = await repo._search_vector_only(**COMMON_SEARCH_KWARGS)