Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions src/basic_memory/repository/search_repository_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
OVERSIZED_ENTITY_VECTOR_SHARD_SIZE = 256
_SQLITE_MAX_PREPARE_WINDOW = 8

# Entity, observation, and relation rows in search_index carry ids from independent
# auto-increment sequences, so a bare id is ambiguous across row types. Every map in
# the vector/hybrid retrieval path must key rows by (type, id) to avoid collisions.
type SearchIndexKey = tuple[str, int]


@dataclass
class VectorSyncBatchResult:
Expand Down Expand Up @@ -1857,7 +1862,7 @@ async def _dispatch_retrieval_mode(
# ------------------------------------------------------------------

@staticmethod
def _parse_chunk_key(chunk_key: str) -> tuple[str, int]:
def _parse_chunk_key(chunk_key: str) -> SearchIndexKey:
"""Parse a chunk_key like 'observation:5:0' into (type, search_index_id)."""
parts = chunk_key.split(":")
return parts[0], int(parts[1])
Expand Down Expand Up @@ -1932,26 +1937,27 @@ def _log_vector_summary() -> None:

hydrate_start = time.perf_counter()
# Build per-search_index_row similarity scores from chunk-level results.
# Each chunk_key encodes the search_index row type and id.
# Each chunk_key encodes the search_index row type and id; keep both as the
# key because different row types can share the same numeric id (#982).
# Track the best similarity per row (for ranking) and all chunks (for context).
similarity_by_si_id: dict[int, float] = {}
chunks_by_si_id: dict[int, list[tuple[float, str]]] = {}
similarity_by_si_key: dict[SearchIndexKey, float] = {}
chunks_by_si_key: dict[SearchIndexKey, list[tuple[float, str]]] = {}
for row in vector_rows:
chunk_key = row.get("chunk_key", "")
distance = float(row["best_distance"])
similarity = self._distance_to_similarity(distance)
chunk_text = row.get("chunk_text", "")
try:
_, si_id = self._parse_chunk_key(chunk_key)
si_key = self._parse_chunk_key(chunk_key)
except (ValueError, IndexError):
# Fallback: group by entity_id for chunks without parseable keys
continue
current = similarity_by_si_id.get(si_id)
current = similarity_by_si_key.get(si_key)
if current is None or similarity > current:
similarity_by_si_id[si_id] = similarity
chunks_by_si_id.setdefault(si_id, []).append((similarity, chunk_text))
similarity_by_si_key[si_key] = similarity
chunks_by_si_key.setdefault(si_key, []).append((similarity, chunk_text))

if not similarity_by_si_id:
if not similarity_by_si_key:
hydrate_ms = (time.perf_counter() - hydrate_start) * 1000
_log_vector_summary()
return []
Expand All @@ -1962,16 +1968,17 @@ def _log_vector_summary() -> None:
min_similarity if min_similarity is not None else self._semantic_min_similarity
)
if effective_min_similarity > 0.0:
similarity_by_si_id = {
k: v for k, v in similarity_by_si_id.items() if v >= effective_min_similarity
similarity_by_si_key = {
k: v for k, v in similarity_by_si_key.items() if v >= effective_min_similarity
}
if not similarity_by_si_id:
if not similarity_by_si_key:
hydrate_ms = (time.perf_counter() - hydrate_start) * 1000
_log_vector_summary()
return []

# Fetch the actual search_index rows
si_ids = list(similarity_by_si_id.keys())
# Fetch the actual search_index rows. Colliding (type, id) keys share one
# bare id, so deduplicate while preserving first-seen order.
si_ids = list(dict.fromkeys(si_id for _, si_id in similarity_by_si_key))
search_index_rows = await self._fetch_search_index_rows_by_ids(si_ids)

# Apply optional filters if requested
Expand Down Expand Up @@ -2003,16 +2010,14 @@ def _log_vector_summary() -> None:
limit=VECTOR_FILTER_SCAN_LIMIT,
offset=0,
)
# Use (id, type) tuples to avoid collisions between different
# Use (type, id) tuples to avoid collisions between different
# search_index row types that share the same auto-increment id.
allowed_keys = {(row.id, row.type) for row in filtered_rows if row.id is not None}
search_index_rows = {
k: v for k, v in search_index_rows.items() if (v.id, v.type) in allowed_keys
}
allowed_keys = {(row.type, row.id) for row in filtered_rows if row.id is not None}
search_index_rows = {k: v for k, v in search_index_rows.items() if k in allowed_keys}

ranked_rows: list[SearchIndexRow] = []
for si_id, similarity in similarity_by_si_id.items():
row = search_index_rows.get(si_id)
for si_key, similarity in similarity_by_si_key.items():
row = search_index_rows.get(si_key)
if row is None:
continue

Expand All @@ -2022,7 +2027,7 @@ def _log_vector_summary() -> None:
if content_snippet and len(content_snippet) <= SMALL_NOTE_CONTENT_LIMIT:
matched_chunk_text = content_snippet
else:
si_chunks = chunks_by_si_id.get(si_id, [])
si_chunks = chunks_by_si_key.get(si_key, [])
si_chunks.sort(key=lambda c: c[0], reverse=True)
top_texts = [text for _, text in si_chunks[:TOP_CHUNKS_PER_RESULT]]
matched_chunk_text = "\n---\n".join(top_texts) if top_texts else None
Expand Down Expand Up @@ -2088,8 +2093,12 @@ async def _fetch_entity_rows_by_ids(self, entity_ids: list[int]) -> dict[int, Se

async def _fetch_search_index_rows_by_ids(
self, row_ids: list[int]
) -> dict[int, SearchIndexRow]:
"""Fetch search_index rows by their primary key (id), any type."""
) -> dict[SearchIndexKey, SearchIndexRow]:
"""Fetch search_index rows by id, keyed by (type, id) to disambiguate types.

A bare id can match one row per type (independent id sequences), so the
result must carry every matching row rather than letting one clobber another.
"""
if not row_ids:
return {}
placeholders = ",".join(f":id_{idx}" for idx in range(len(row_ids)))
Expand All @@ -2106,11 +2115,11 @@ async def _fetch_search_index_rows_by_ids(
WHERE project_id = :project_id
AND id IN ({placeholders})
"""
result: dict[int, SearchIndexRow] = {}
result: dict[SearchIndexKey, SearchIndexRow] = {}
async with db.scoped_session(self.session_maker) as session:
row_result = await session.execute(text(sql), params)
for row in row_result.fetchall():
result[row.id] = SearchIndexRow(
result[(row.type, row.id)] = SearchIndexRow(
project_id=self.project_id,
id=row.id,
title=row.title,
Expand Down Expand Up @@ -2156,7 +2165,7 @@ async def _search_hybrid(
) -> List[SearchIndexRow]:
"""Fuse FTS and vector results using score-based fusion.

Uses search_index row id as the fusion key. The formula
Uses the search_index (type, id) pair as the fusion key. The formula
``max(vec, fts) + FUSION_BONUS * min(vec, fts)`` preserves
the dominant signal and rewards dual-source agreement.
"""
Expand Down Expand Up @@ -2199,50 +2208,52 @@ async def _search_hybrid(
vector_ms = (time.perf_counter() - vector_start) * 1000
fusion_start = time.perf_counter()

# --- Score-based fusion keyed on search_index row id ---
# --- Score-based fusion keyed on (type, id) ---
# A bare row id collides across row types (independent id sequences), so
# fusion must key on (type, id) or distinct rows would merge (#982).
# FTS scores are normalized to [0, 1] (BM25 is unbounded).
# Vector scores are used raw — already calibrated [0, 1] by _distance_to_similarity().
rows_by_id: dict[int, SearchIndexRow] = {}
rows_by_key: dict[SearchIndexKey, SearchIndexRow] = {}

# Normalize FTS scores to [0, 1] — handles both SQLite (negative bm25)
# and Postgres (positive ts_rank) by using absolute values
fts_abs = [abs(row.score or 0.0) for row in fts_results]
fts_max = max(fts_abs) if fts_abs else 1.0

fts_scores: dict[int, float] = {}
fts_scores: dict[SearchIndexKey, float] = {}
for row in fts_results:
if row.id is None:
continue
norm = abs(row.score or 0.0) / fts_max if fts_max > 0 else 0.0
# Gate: FTS scores below threshold contribute zero
if norm < FTS_GATE_THRESHOLD:
norm = 0.0
fts_scores[row.id] = norm
rows_by_id[row.id] = row
fts_scores[(row.type, row.id)] = norm
rows_by_key[(row.type, row.id)] = row

vec_scores: dict[int, float] = {}
vec_scores: dict[SearchIndexKey, float] = {}
for row in vector_results:
if row.id is None:
continue
# Trigger: no re-normalization by vec_max
# Why: vector similarity is already calibrated [0, 1]; re-normalizing
# inflates weak matches when the entire result set is mediocre
vec_scores[row.id] = row.score or 0.0
rows_by_id[row.id] = row
vec_scores[(row.type, row.id)] = row.score or 0.0
rows_by_key[(row.type, row.id)] = row

# Fuse: max(v, f) + FUSION_BONUS * min(v, f)
# Preserves the dominant signal; bonus rewards dual-source agreement.
# Output range: [0, 1.3] for dual-source, [0, 1.0] for single-source.
fused_scores: dict[int, float] = {}
for row_id in fts_scores.keys() | vec_scores.keys():
v = vec_scores.get(row_id, 0.0)
f = fts_scores.get(row_id, 0.0)
fused_scores[row_id] = max(v, f) + FUSION_BONUS * min(v, f)
fused_scores: dict[SearchIndexKey, float] = {}
for row_key in fts_scores.keys() | vec_scores.keys():
v = vec_scores.get(row_key, 0.0)
f = fts_scores.get(row_key, 0.0)
fused_scores[row_key] = max(v, f) + FUSION_BONUS * min(v, f)

ranked = sorted(fused_scores.items(), key=lambda item: item[1], reverse=True)
output: list[SearchIndexRow] = []
for row_id, fused_score in ranked[offset : offset + limit]:
row = rows_by_id[row_id]
for row_key, fused_score in ranked[offset : offset + limit]:
row = rows_by_key[row_key]
# Trigger: FTS-only results have no matched_chunk_text from vector search.
# Why: without chunk text, API falls back to truncated content, losing answer text.
# Outcome: FTS-only results get full content_snippet as matched_chunk.
Expand Down
28 changes: 28 additions & 0 deletions tests/repository/test_hybrid_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,34 @@ async def test_zero_score_produces_zero_fused():
assert results[0].score == pytest.approx(0.0, rel=1e-6)


@pytest.mark.asyncio
async def test_cross_type_id_collision_keeps_both_results():
"""An entity and a relation sharing the same numeric id stay distinct (#982).

search_index row types have independent id sequences, so fusing on a bare
row id merged unrelated rows into one result and dropped the other.
"""
repo = ConcreteSearchRepo()

fts_results = [FakeRow(id=1, type="entity", score=5.0, title="entity-row")]
vector_results = [FakeRow(id=1, type="relation", score=0.8, title="relation-row")]

with (
patch.object(repo, "search", new_callable=AsyncMock, return_value=fts_results),
patch.object(
repo, "_search_vector_only", new_callable=AsyncMock, return_value=vector_results
),
):
results = await repo._search_hybrid(**HYBRID_KWARGS)

assert {(r.type, r.id) for r in results} == {("entity", 1), ("relation", 1)}
# Single-source scores must not earn the dual-source fusion bonus across types.
entity_result = next(r for r in results if r.type == "entity")
relation_result = next(r for r in results if r.type == "relation")
assert entity_result.score == pytest.approx(1.0, rel=1e-6)
assert relation_result.score == pytest.approx(0.8, rel=1e-6)


@pytest.mark.asyncio
async def test_fts_only_result_gets_matched_chunk_from_content_snippet():
"""FTS-only results should have matched_chunk_text populated from content_snippet."""
Expand Down
91 changes: 91 additions & 0 deletions tests/repository/test_sqlite_vector_search_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ def _entity_row(
)


def _relation_row(
*,
project_id: int,
row_id: int,
entity_id: int,
title: str,
permalink: str,
relation_type: str,
) -> SearchIndexRow:
now = datetime.now(timezone.utc)
return SearchIndexRow(
project_id=project_id,
id=row_id,
type=SearchItemType.RELATION.value,
title=title,
permalink=permalink,
file_path=f"{permalink}.md",
metadata=None,
entity_id=entity_id,
from_id=entity_id,
relation_type=relation_type,
created_at=now,
updated_at=now,
)


def _enable_semantic(
search_repository: SQLiteSearchRepository,
embedding_provider: StubEmbeddingProvider | None = None,
Expand Down Expand Up @@ -498,6 +524,71 @@ async def test_sqlite_vector_search_returns_ranked_entities(search_repository):
assert all(result.type == SearchItemType.ENTITY.value for result in results)


@pytest.mark.asyncio
async def test_sqlite_vector_search_survives_cross_type_id_collision(search_repository):
"""Entity and relation rows sharing one numeric id must both hydrate (#982).

Entity, observation, and relation rows carry ids from independent
auto-increment sequences, so search_index rows of different types routinely
share the same numeric id. Keying vector hydration by bare id collapsed
colliding hits into one dict slot and silently dropped the other result.
"""
if not isinstance(search_repository, SQLiteSearchRepository):
pytest.skip("sqlite-vec repository behavior is local SQLite-only.")

_enable_semantic(search_repository)
await search_repository.init_search_index()
await search_repository.bulk_index_items(
[
_entity_row(
project_id=search_repository.project_id,
row_id=7,
entity_id=701,
title="Auth Token Design",
permalink="specs/auth-token-design",
content_stems="auth token session login design",
),
# Same numeric id as the entity row above, different row type.
_relation_row(
project_id=search_repository.project_id,
row_id=7,
entity_id=702,
title="login flow relates to auth token design",
permalink="specs/login-flow/relates-to/auth-token-design",
relation_type="relates_to",
),
]
)
await search_repository.sync_entity_vectors(701)
await search_repository.sync_entity_vectors(702)

results = await search_repository.search(
search_text="session token auth",
retrieval_mode=SearchRetrievalMode.VECTOR,
limit=5,
offset=0,
)

# Both rows match the query; both share id=7 and must survive hydration.
assert len(results) == 2
assert {result.type for result in results} == {
SearchItemType.ENTITY.value,
SearchItemType.RELATION.value,
}
entity_result = next(r for r in results if r.type == SearchItemType.ENTITY.value)
assert entity_result.permalink == "specs/auth-token-design"

# The type filter must keep the entity even though a relation shares its id.
filtered = await search_repository.search(
search_text="session token auth",
search_item_types=[SearchItemType.ENTITY],
retrieval_mode=SearchRetrievalMode.VECTOR,
limit=5,
offset=0,
)
assert [r.permalink for r in filtered] == ["specs/auth-token-design"]


@pytest.mark.asyncio
async def test_sqlite_hybrid_search_combines_fts_and_vector(search_repository):
"""Hybrid mode fuses FTS and vector results with score-based fusion."""
Expand Down
2 changes: 1 addition & 1 deletion tests/repository/test_vector_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def test_page1_scores_gte_page2_scores():

repo._embedding_provider = _EmbeddingProvider()

fake_index_rows = {i: FakeRow(id=i) for i in range(20)}
fake_index_rows = {("entity", i): FakeRow(id=i) for i in range(20)}

async def run_page(offset, limit):
with (
Expand Down
Loading
Loading