diff --git a/CLAUDE.md b/CLAUDE.md index b042b5ef..de569f6f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -95,7 +95,7 @@ poetry run ruff check slayer/ tests/ **Write side**: `save_memory(learning, linked_entities, id=None)` and `forget_memory(id)`, exposed via MCP, REST (`POST /memories`, `DELETE /memories/{id}`), CLI (`slayer memory {save,forget}`), and `SlayerClient`. `linked_entities` is either a list of entity strings (resolved strictly; `memory:` accepted) or an inline `SlayerQuery` / dict (entities auto-extracted; the query is persisted on the memory). Optional `id` (DEV-1428) pins a user-controlled canonical memory id; duplicate id → unconditional upsert, `created_at` preserved. - **Read side**: a single `search(entities, query, question, datasource=None, cypher_filter=None, max_memories=5, max_example_queries=2, max_entities=5)` tool. Surfaces: MCP, REST (`POST /search`), CLI (`slayer search …`), `SlayerClient.search()`. DEV-1428: search is **lenient** — unresolved `entities` / `query` tokens become warnings rather than raising; stale memory entity tags are filtered out at retrieval time (belt) before BM25 ranks AND before `matched_entities` is surfaced; for `example_queries` hits whose attached `Memory.query` no longer resolves, a `example_query memory:: attached query has stale references; re-save to clean.` warning is emitted (the query is not rewritten). + **Read side**: a single `search(entities, query, question, datasource=None, cypher_filter=None, max_results=10)` tool. Surfaces: MCP, REST (`POST /search`), CLI (`slayer search …`), `SlayerClient.search()`. Returns a `SearchResponse` with a single flat ranked `results: List[SearchHit]` — memories (kind="memory"), entities (kind one of datasource/model/column/measure/aggregation) all in one list. DEV-1428: search is **lenient** — unresolved `entities` / `query` tokens become warnings rather than raising; stale memory entity tags are filtered out at retrieval time before BM25 ranks AND before `matched_entities` is surfaced; for query-bearing memory hits whose attached `Memory.query` no longer resolves, a `example_query memory:: attached query has stale references; re-save to clean.` warning is emitted (the query is not rewritten). **Canonical entity form** is ``, `.`, `..`, or — DEV-1428 — `memory:` (cross-memory references). Aggregation suffixes are stripped (`revenue:sum` → `..revenue`); `*:count` collapses to the source model; multi-hop paths keep only the leaf. Resolver: `slayer/memories/resolver.py` (the `memory:` branch runs at the top of `resolve_entity`, before `_strip_agg_suffix`, so `memory:abc` parses as the memory branch). Memory ids are non-empty strings (DEV-1428) — pure-digit auto-allocated by the storage layer (`"1"`, `"2"`, ...), or user-supplied (`"kb.policy.42"`); forbidden charset: `:`, `/`, `?`, `#`, whitespace, ASCII control. Bare names never resolve to memories (the `memory:` prefix is mandatory). `delete_memory` cascades to the matching embedding row AND strips every `memory:` reference to it from every other memory's `entities` (DEV-1428 cascade layer 1). @@ -106,13 +106,13 @@ poetry run ruff check slayer/ tests/ - **Tantivy** in-memory full-text index, built fresh per call over memories ∪ non-hidden entities (datasources / models / columns / named measures / aggregations), using the `en_stem` analyzer. - **Embeddings** (optional `advanced_search` pip extra) — dense cosine over a persistent `embeddings` sidecar keyed by `(canonical_id, embedding_model_name)`. Model from `SLAYER_EMBEDDING_MODEL` (default `openai/text-embedding-3-small`), dispatched via litellm. When the extra is missing, no API key is set, or the corpus is empty, the channel contributes nothing and emits one warning into `SearchResponse.warnings`. - BM25 (channel 1) operates with implicit self-references (DEV-1513): every doc — memory or entity — is treated as carrying a single tag pointing at itself. So `entities=[""]` surfaces the named entity in the entities bucket, and `entities=["memory:"]` surfaces the named memory in the memories bucket, on top of the usual entity-overlap matches. Entity rankings from channels 1, 2 (tantivy), and 3 (embeddings) are RRF-fused. Memory hits are partitioned by `Memory.query is None` into `memories` (learning-only) and `example_queries` (query-bearing), each with its own cap. Each output bucket is ranked independently of the others — varying one `max_X` cap cannot reorder or move items in/out of any other bucket. The in-memory tantivy index is built with `writer(num_threads=1)` so doc-id tiebreaks on equal BM25 scores are deterministic. Empty-input fallback returns the newest memories per bucket with a warning. + BM25 (channel 1) operates with implicit self-references (DEV-1513): every doc — memory or entity — is treated as carrying a single tag pointing at itself. So `entities=[""]` surfaces the named entity, and `entities=["memory:"]` surfaces the named memory, on top of the usual entity-overlap matches. All rankings from channels 1, 2 (tantivy), and 3 (embeddings) are RRF-fused into a single flat list. The in-memory tantivy index is built with `writer(num_threads=1)` so doc-id tiebreaks on equal BM25 scores are deterministic. Empty-input fallback returns the newest memories capped at `max_results` with a warning. **Indexed text** is rendered by `slayer/search/render.py`. Hidden models / columns are excluded; `meta` is never indexed. Named children (columns, measures, aggregations, join targets) are referenced by name + kind only (each child has its own indexed doc). **`datasource` filter**: all surfaces accept optional `datasource: Optional[str] = None`. When set, every channel pre-filters its corpus to canonical ids rooted at that datasource (exact name or strict dotted-path descendant); memories surface iff at least one of their `entities` is rooted there. Unknown datasource → `ValueError` (HTTP 400 on REST). - **`cypher_filter` graph pre-filter** (DEV-1464): all surfaces accept optional `cypher_filter: Optional[str] = None`. When set, an openCypher `MATCH … RETURN … AS id` query runs against an ephemeral in-memory LadybugDB property graph built from current storage state. Returned IDs become a hard allowlist for all three channels. Requires `advanced_search` extra (LadybugDB). Query must be a single read-only statement returning one `id` column. Graph nodes: Memory (id, learning), Datasource (id, name), Model (id, name, description), Column (id, name, data_type, description), Measure (id, name, description), Aggregation (id, name). Relationships: MENTIONS (Memory→any), CONTAINS (Datasource→Model, Model→{Column,Measure,Aggregation}), JOINS (Model→Model). Hidden models/columns excluded. Graph is rebuilt when `storage.graph_fingerprint()` changes (file mtime). Cache: per-storage-path with asyncio double-checked locking. + **`cypher_filter` graph pre-filter** (DEV-1464): all surfaces accept optional `cypher_filter: Optional[str] = None`. When set, returned IDs become a hard allowlist for all three channels. When `advanced_search` is installed, a full openCypher `MATCH … RETURN … AS id` query runs against an ephemeral in-memory LadybugDB property graph. When not installed, only `MATCH (n:Label1:Label2) RETURN n.id AS id` patterns are supported as a kind filter (naive fallback, DEV-1532); more complex Cypher raises `SlayerError` with an install hint. Graph nodes (full path): Memory (id=`memory:`, learning), Datasource (id, name), Model (id=`.`, name, description), ModelColumn (id=`..`, name, data_type, description), Measure (id, name, description), Aggregation (id, name). Relationships: MENTIONS (Memory→any), CONTAINS (Datasource→Model, Model→{ModelColumn,Measure,Aggregation}), JOINS (Model→Model). Hidden models/columns excluded. Graph rebuilt when `storage.graph_fingerprint()` changes (file mtime). Cache: per-storage-path with asyncio double-checked locking. **Embedding refresh** runs inline on `slayer ingest`, `edit_model`, `save_memory`, and `--ingest-on-startup`. Each per-datasource ingest pass refreshes the datasource doc, every visible model + its visible children, and every memory whose canonical entities are rooted at the datasource. Content-hash skips the litellm call when nothing has changed; the hot path issues one batched read + one batched write per refresh, independent of subtree size. Per-entity failures are non-fatal; per-memory failures surface as `IngestionError(model_name="memory:", …)` in `IdempotentIngestResult.errors`. @@ -120,7 +120,7 @@ poetry run ruff check slayer/ tests/ **Sample-value snapshots** are cached on `Column.sampled` (text), `Column.sampled_values` (structured top-50 list for categorical columns, DEV-1480), and `Column.distinct_count` (true cardinality for categorical columns, DEV-1480). Refreshed on `slayer ingest` (table-backed models only), on `slayer search refresh-samples`, on `edit_model` (column edits → that column; model-level changes to `filters` / `sql` / `source_queries` → every column), and lazily on `inspect_model` cache miss (best-effort write-back). Categorical columns are ordered by count desc with alphabetical tie-break; the structured list is the consumer-facing way to compare predicate literals against actual stored values (text-split on `sampled` is ambiguous for values containing commas, e.g. `"R$ 1,000–3,000"`). Cache validity for categorical columns requires `sampled_values is not None` (v6 → v7 upgrades re-profile on next `inspect_model`). sql-mode and query-backed models do not yet have sample-value coverage. - `inspect_model` auto-renders a `Learnings` section showing only learning-only memories (`query is None`); query-bearing memories surface only via `search` in the `example_queries` bucket. + `inspect_model` auto-renders a `Learnings` section showing only learning-only memories (`query is None`); query-bearing memories surface only via `search` (as hits with `hit.query is not None`). See [docs/concepts/memories.md](docs/concepts/memories.md) and [docs/concepts/search.md](docs/concepts/search.md). @@ -174,7 +174,7 @@ poetry run ruff check slayer/ tests/ - `slayer serve --ingest-on-startup` and `slayer mcp --ingest-on-startup` (DEV-1392) — opt-in boot-time idempotent auto-ingestion across every configured datasource, sync-before-listen (uvicorn/mcp.run don't start until ingest finishes). Continue-on-failure: per-datasource errors are friendly-formatted to stderr and never abort startup; `storage.list_datasources()` raising is the only thing that prevents the server from starting. `to_delete` drift entries are printed but **never auto-applied** — destructive cleanup stays gated behind `slayer validate-models --force-clean [--yes]`. Composes freely with `--demo` (demo first, then the ingest pass over every datasource including the freshly-created demo). Also exposed via `SLAYER_INGEST_ON_STARTUP=1` env var (flag wins when both set) and the `ingest_on_startup=True` kwarg on `create_app` / `create_mcp_server`. All output goes to stderr — `slayer mcp` stdio JSON-RPC remains protocol-safe. Orchestrator: `slayer/engine/ingestion.py::ingest_all_datasources_idempotent`. **Memory embeddings** (DEV-1416): each per-datasource pass also re-embeds every memory whose canonical entities are rooted at the datasource, so a stale `embeddings.db` is repaired by the next `--ingest-on-startup` without extra steps. See [docs/concepts/ingestion.md](docs/concepts/ingestion.md#ingesting-at-startup). - `slayer validate-models [--datasource X] [--force-clean] [--yes]` (DEV-1356) — read-only diff against live schemas; with `--force-clean`, prompts to apply each delete via `engine.apply_drift_deletes`. See [docs/concepts/schema-drift.md](docs/concepts/schema-drift.md). - `slayer storage migrate-types [--data-source X] [--dry-run]` (DEV-1361) — refine `DOUBLE → INT` on base columns whose live SQL type is integer for every persisted model, then write the refined v5 dict back. Hard-fails if a datasource is unreachable. The same refinement runs transparently inside `storage.get_model` on first load; this CLI is a batch / inspectable alternative. -- `slayer search [--entity ENT ...] [--query JSON_OR_@FILE] [--question TEXT] [--datasource DS] [--cypher-filter CYPHER] [--max-memories N] [--max-example-queries N] [--max-entities N] [--format json|text]` (DEV-1375 / DEV-1386 / DEV-1409 / DEV-1464) — up to three-channel semantic search over memories + canonical entities (BM25 over memory entity tags + tantivy full-text + optional dense embedding similarity). `--datasource` scopes the corpus to one datasource. `--cypher-filter` runs an openCypher MATCH query against the LadybugDB property graph and pre-filters all channels to the returned IDs (requires `advanced_search` extra). See [docs/concepts/search.md](docs/concepts/search.md). +- `slayer search [--entity ENT ...] [--query JSON_OR_@FILE] [--question TEXT] [--datasource DS] [--cypher-filter CYPHER] [--max-results N] [--format json|text]` (DEV-1375 / DEV-1386 / DEV-1409 / DEV-1464 / DEV-1532) — up to three-channel semantic search over memories + canonical entities (BM25 over memory entity tags + tantivy full-text + optional dense embedding similarity). Returns a single flat ranked `results` list. `--datasource` scopes the corpus to one datasource. `--cypher-filter` pre-filters all channels: full openCypher when `advanced_search` is installed; simple `MATCH (n:Label) RETURN n.id AS id` kind-filter without it. See [docs/concepts/search.md](docs/concepts/search.md). - `slayer search refresh-samples [--data-source X] [--model M ...]` (DEV-1375) — re-profile and persist `Column.sampled` for table-backed models. Best-effort: per-column failures are reported but don't abort. - MCP `query()` tool has a `format` parameter: `"markdown"` (default), `"json"`, or `"csv"`. - **`query_nested` MCP tool**: companion to `query` for the multi-stage DAG shape that `engine.execute(query=list[...])` already supports. Takes `queries: List[Dict[str, Any]]` plus the usual `variables` / `show_sql` / `dry_run` / `explain` / `format` knobs. Earlier entries are named sub-queries that later entries reference via `source_model: ""` or `joins.target_model`; the engine auto-sorts the list (Kahn's algorithm), so order doesn't matter. The single-stage `query` tool is unchanged — keep using it whenever the typed per-field schema fits, since it surfaces a richer signature to agents. diff --git a/docs/concepts/search.md b/docs/concepts/search.md index 09c83c5a..f608c49d 100644 --- a/docs/concepts/search.md +++ b/docs/concepts/search.md @@ -39,28 +39,28 @@ Concretely: ```json { - "call": {"entities": ["mydb.orders.amount"], "max_memories": 0}, + "call": {"entities": ["mydb.orders.amount"], "max_results": 5}, "response": { - "entities": [{"id": "mydb.orders.amount", "kind": "column"}] + "results": [{"id": "mydb.orders.amount", "kind": "column", "score": 0.0}] } } ``` ```json { - "call": {"entities": ["memory:42"], "max_entities": 0}, + "call": {"entities": ["memory:42"], "max_results": 5}, "response": { - "memories": [{"id": "42", "matched_entities": ["memory:42"]}] + "results": [{"id": "42", "kind": "memory", "matched_entities": ["memory:42"]}] } } ``` Filter rules for the new entity surfacing: -- `memory:` refs participate in the memory ranking only — they never appear in the entities bucket. +- `memory:` refs contribute to the memory portion of the ranking only — they surface as `kind="memory"` hits. - Refs not rooted at `datasource` (when set) drop with a warning `entity '' is not rooted at datasource ''; dropped from entities bucket.` The memory side fires the symmetric `memory: is not rooted at datasource ''; dropped.` when the named memory has no entities rooted at the requested datasource. -- Refs on a hidden model or hidden column drop from the entities bucket with `entity '' is on a hidden model/column; dropped from entities bucket.` BM25 over original memory tags is unaffected — memories tagged with that canonical still surface. -- An explicitly-named `memory:` whose attached `Memory.query` has stale references emits the standard stale-query warning regardless of `max_example_queries` (the user explicitly asked for that memory; they deserve to know the query is broken). +- Refs on a hidden model or hidden column drop from the entities portion with `entity '' is on a hidden model/column; dropped from entities bucket.` BM25 over original memory tags is unaffected — memories tagged with that canonical still surface. +- An explicitly-named `memory:` whose attached `Memory.query` has stale references emits the standard stale-query warning (the user explicitly asked for that memory; they deserve to know the query is broken). Activated when `entities` and/or `query` is supplied to `search`. @@ -146,19 +146,18 @@ Entity rankings from channels 1, 2, and 3 are RRF-fused the same way. Channel 1's entity ranking is the user-supplied canonical refs in supplied order (DEV-1513); channels 2 and 3 contribute fuzzy hits. -### Per-bucket ranking invariance (DEV-1414) +### Ranking stability (DEV-1414) Each channel produces a **full per-kind ranking** — channel 2 runs as two kind-filtered tantivy queries (one over memory docs only, one over entity docs only), and channel 3 partitions the embedding corpus by -`entity_kind` and ranks each side independently. There is no shared -candidate-pool budget across kinds, so for a fixed -`(question, datasource, max_X)` the membership and order of the -returned `X` bucket (`memories` / `example_queries` / `entities`) is a -pure function of the corpus + question + that one cap. Varying the -other two caps cannot move an id in or out of the returned list nor -reorder it. The `max_*` caps are pure post-fusion slice operations on -the three independent ranked lists. +`entity_kind` and ranks each side independently. The per-kind rankings +are RRF-fused into a single flat list before the `max_results` cap is +applied. Because the fusion is deterministic, the relative order of any +subset of the flat list is stable with respect to the corpus and +question — changing only `max_results` never reorders existing entries +nor causes an entry to appear or disappear unless the cap boundary +moves past it. ## Tool surface @@ -168,9 +167,8 @@ search( query: Optional[Union[SlayerQuery, dict]] = None, question: Optional[str] = None, datasource: Optional[str] = None, - max_memories: int = 5, - max_example_queries: int = 2, - max_entities: int = 5, + max_results: int = 10, + cypher_filter: Optional[str] = None, ) -> SearchResponse ``` @@ -232,8 +230,26 @@ allowlist** that pre-filters all three channels before any ranking: - If the query returns an empty set, a warning is emitted and an empty response is returned immediately (no channels fire). -**Requires the `advanced_search` extra** (`pip install motley-slayer[advanced_search]`). -When the extra is not installed, `get_filtered_ids` raises `ValueError`. +**Naive fallback (no `advanced_search` required)**. When the +`advanced_search` extra is not installed, a simple subset of Cypher is +supported without LadybugDB. The naive parser accepts only: + +```cypher +MATCH (var:Label1:Label2:...) RETURN var.id AS id +``` + +Labels must be one or more of `Memory`, `Datasource`, `Model`, +`Column` (or its alias `ModelColumn`), `Measure`, `Aggregation` +(case-insensitive). The colon-separated +multi-label form is a union — it returns hits whose kind matches any of +the listed labels. Any other Cypher (WHERE clauses, relationships, +multiple MATCH clauses, etc.) raises `SlayerError` with a hint to +install the `advanced_search` extra. + +**Full `advanced_search` path**: When the `advanced_search` extra is +installed, the full openCypher query runs against the property graph +(see graph schema below). Complex filters, relationship traversals, and +property conditions are all supported. **Query safety**: the Cypher statement must be: - A single statement (no semicolons). @@ -247,14 +263,16 @@ When the extra is not installed, `get_filtered_ids` raises `ValueError`. | `Memory` | `id` (`memory:` form), `learning` | | `Datasource` | `id`, `name` | | `Model` | `id` (`.`), `name`, `description` | -| `Column` | `id` (`..`), `name`, `data_type`, `description` | +| `ModelColumn` | `id` (`..`), `name`, `data_type`, `description` | | `Measure` | `id` (`..`), `name`, `description` | | `Aggregation` | `id` (`..`), `name` | +Note: the column node table is named `ModelColumn` (not `Column`) because `Column` is a reserved keyword in LadybugDB ≥ 0.15. + | Relationship | From → To | |---|---| -| `MENTIONS` | Memory → {Datasource, Model, Column, Measure, Aggregation, Memory} | -| `CONTAINS` | Datasource → Model, Model → {Column, Measure, Aggregation} | +| `MENTIONS` | Memory → {Datasource, Model, ModelColumn, Measure, Aggregation, Memory} | +| `CONTAINS` | Datasource → Model, Model → {ModelColumn, Measure, Aggregation} | | `JOINS` | Model → Model | Hidden models and hidden columns are excluded from the graph. The graph is @@ -271,58 +289,42 @@ RETURN m.id AS id **Example** — surface columns in the `shop` datasource: ```cypher -MATCH (d:Datasource {id: 'shop'})-[:CONTAINS*1..3]->(c:Column) +MATCH (d:Datasource {id: 'shop'})-[:CONTAINS*1..3]->(c:ModelColumn) RETURN c.id AS id ``` -**Multi-label union**: `MATCH (n:Memory:Column)` returns nodes from both -`Memory` AND `Column` tables (LadybugDB union semantics). +**Multi-label union**: `MATCH (n:Memory:ModelColumn)` returns nodes from both +`Memory` AND `ModelColumn` tables (LadybugDB union semantics). ### Behaviour matrix | `entities`/`query` | `question` | Result | |---|---|---| -| set | set | All eligible channels run. Memories RRF-fused (channels 1 + 2 + 3); entities RRF-fused (channels 1 + 2 + 3, DEV-1513). Channel 3 is skipped with a warning when the `advanced_search` extra is missing. Query-bearing memories partitioned out to `example_queries`. | -| set | unset/empty | Channel 1 only. Memories partitioned by `query` presence; entity hits = the named refs themselves (DEV-1513). | -| unset/empty | set | Channels 2 and 3 (when eligible). Memories RRF-fused; entities RRF-fused. | -| unset/empty | unset/empty | Recency fallback: newest `max_memories` learning-only memories + newest `max_example_queries` query-bearing memories, with a warning. | +| set | set | All eligible channels run. Memories and entities are RRF-fused across all active channels. Channel 3 is skipped with a warning when the `advanced_search` extra is missing. | +| set | unset/empty | Channel 1 only. Memory hits ranked by entity-tag overlap; entity hits = the named refs themselves (DEV-1513). | +| unset/empty | set | Channels 2 and 3 (when eligible). Memories and entities RRF-fused. | +| unset/empty | unset/empty | Recency fallback: newest memories (any kind) capped at `max_results`, with a warning. | ### Response shape -Memories are partitioned by `Memory.query is None`: learning-only -memories land in `memories`, query-bearing memories in -`example_queries`. The two lists are capped independently so a few -bulky example queries cannot crowd out small learning-only notes. +All hits — memories (both learning-only and query-bearing) and entities +(datasources, models, columns, measures, aggregations) — are returned +as a single flat ranked `results` list capped at `max_results`. +Query-bearing memories have `query` set; learning-only memories have +`query=None`; entity hits have `kind` set to their entity type. ```python -class MemoryHit(BaseModel): - id: str # memory id (forget_memory(id=hit.id) works) - score: float # RRF-fused (or single-channel raw) - text: str # full indexed text (no truncation) - matched_entities: List[str] # canonical entities that channel-1 input - # overlapped with the memory's tags; - # stale tags are filtered before this is - # computed (DEV-1428 lazy GC). - -class ExampleQueryHit(BaseModel): - id: str # memory id - score: float # RRF-fused - text: str # full indexed text - matched_entities: List[str] - query: SlayerQuery # always set on this hit type - -class EntityHit(BaseModel): - id: str # canonical entity string - kind: str # "datasource"|"model"|"column"|"measure"|"aggregation" - score: float # RRF-fused across channels 1+2+3 - # (DEV-1513), or single-channel raw - # when only one channel contributed - text: str # full indexed text (no truncation) +class SearchHit(BaseModel): + id: str # memory id OR canonical entity string + kind: str # "memory"|"datasource"|"model"|"column"|"measure"|"aggregation" + score: float # RRF-fused score + text: str # full indexed text (no truncation) + matched_entities: List[str] # channel-1 overlap (memory hits only; + # stale tags filtered, DEV-1428) + query: Optional[SlayerQuery] # set on query-bearing memory hits class SearchResponse(BaseModel): - memories: List[MemoryHit] # learning-only (query is None) - example_queries: List[ExampleQueryHit] # query-bearing - entities: List[EntityHit] + results: List[SearchHit] resolved_input_entities: List[str] # echo of the resolver output warnings: List[str] ``` @@ -340,7 +342,7 @@ search proceeds against whatever did resolve. Examples: - A stale entity tag inside a saved memory does not contribute to channel-1 BM25 ranking, and is excluded from any hit's `matched_entities` list. -- An `example_queries` hit whose attached `Memory.query` references a +- A query-bearing memory hit whose attached `Memory.query` references a vanished column gets the warning `example_query memory:: attached query has stale references (...); re-save to clean.` but is still surfaced with its stored query intact. diff --git a/docs/examples/09_lightning_talk/lightning_talk_nb.ipynb b/docs/examples/09_lightning_talk/lightning_talk_nb.ipynb index 3db5dd36..f379503a 100644 --- a/docs/examples/09_lightning_talk/lightning_talk_nb.ipynb +++ b/docs/examples/09_lightning_talk/lightning_talk_nb.ipynb @@ -746,22 +746,7 @@ ] } ], - "source": [ - "# You can search using a text query\n", - "\n", - "resp = run_sync(client.search(\n", - " question='What should I know before comparing Brooklyn revenue to other stores?',\n", - " max_memories=3,\n", - " max_example_queries=2,\n", - " max_entities=0,\n", - "))\n", - "print('Memories:')\n", - "for hit in resp.memories:\n", - " print(f' [{hit.id}] score={hit.score:.3f}')\n", - " print(f' -> {hit.text[:100]}...')\n", - "assert any(m.id == 'lightning.brooklyn_pos' for m in resp.memories), \\\n", - " 'Brooklyn memory must surface in resp.memories for this question'" - ] + "source": "# You can search using a text query\n\nresp = run_sync(client.search(\n question='What should I know before comparing Brooklyn revenue to other stores?',\n max_results=10,\n))\nmemories = [h for h in resp.results if h.kind == 'memory' and h.query is None]\nprint('Memories:')\nfor hit in memories:\n print(f' [{hit.id}] score={hit.score:.3f}')\n print(f' -> {hit.text[:100]}...')\nassert any(m.id == 'lightning.brooklyn_pos' for m in memories), \\\n 'Brooklyn memory must surface in resp.results for this question'\n" }, { "cell_type": "code", @@ -779,21 +764,7 @@ ] } ], - "source": [ - "# And/or entity references\n", - "\n", - "resp = run_sync(client.search(\n", - " entities=['jaffle_shop.orders.order_total'],\n", - " max_memories=3,\n", - " max_entities=0,\n", - "))\n", - "print('Memories anchored to jaffle_shop.orders.order_total:')\n", - "for hit in resp.memories:\n", - " print(f' [{hit.id}] matched_entities={hit.matched_entities}')\n", - " print(f' -> {hit.text[:100]}...')\n", - "assert any(m.id == 'lightning.brooklyn_pos' for m in resp.memories), \\\n", - " 'Brooklyn memory must surface via the order_total entity tag'" - ] + "source": "# And/or entity references\n\nresp = run_sync(client.search(\n entities=[\"jaffle_shop.orders.order_total\"],\n max_results=10,\n))\nmemories = [h for h in resp.results if h.kind == \"memory\" and h.query is None]\nprint(\"Memories anchored to jaffle_shop.orders.order_total:\")\nfor hit in memories:\n print(f\" [{hit.id}] matched_entities={hit.matched_entities}\")\n print(f\" -> {hit.text[:100]}...\")\nassert any(m.id == \"lightning.brooklyn_pos\" for m in memories), \\\n \"Brooklyn memory must surface via the order_total entity tag\"\n" }, { "cell_type": "code", @@ -819,27 +790,7 @@ ] } ], - "source": [ - "# And specifically look for memories with queries\n", - "\n", - "resp = run_sync(client.search(\n", - " question='How have analysts queried customer lifetime spend before?',\n", - " max_memories=0,\n", - " max_example_queries=2,\n", - " max_entities=5,\n", - "))\n", - "print('Example queries:')\n", - "for eq in resp.example_queries:\n", - " print(f' [{eq.id}] score={eq.score:.3f}')\n", - " print(f' -> {eq.text[:100]}...')\n", - "print()\n", - "print('Canonical entities:')\n", - "for ent in resp.entities:\n", - " print(f' [{ent.kind}] {ent.id} (score={ent.score:.3f})')\n", - "assert any(eq.id == 'lightning.top_customers' for eq in resp.example_queries), \\\n", - " 'top-customers example query must surface in resp.example_queries'\n", - "assert len(resp.entities) > 0, 'expected at least one canonical entity hit'" - ] + "source": "# And specifically look for memories with queries\n\nresp = run_sync(client.search(\n question='How have analysts queried customer lifetime spend before?',\n max_results=20,\n))\nexample_queries = [h for h in resp.results if h.kind == 'memory' and h.query is not None]\nentities = [h for h in resp.results if h.kind != 'memory']\nprint('Example queries:')\nfor eq in example_queries:\n print(f' [{eq.id}] score={eq.score:.3f}')\n print(f' -> {eq.text[:100]}...')\nprint()\nprint('Canonical entities:')\nfor ent in entities:\n print(f' [{ent.kind}] {ent.id} (score={ent.score:.3f})')\nassert any(eq.id == 'lightning.top_customers' for eq in example_queries), \\\n 'top-customers example query must surface in resp.results'\nassert len(entities) > 0, 'expected at least one canonical entity hit'\n" }, { "cell_type": "markdown", diff --git a/slayer/api/server.py b/slayer/api/server.py index cc127725..1201a980 100644 --- a/slayer/api/server.py +++ b/slayer/api/server.py @@ -13,6 +13,7 @@ EntityResolutionError, MemoryNotFoundError, SchemaDriftError, + SlayerError, ) from slayer.core.format import NumberFormat from slayer.core.models import DatasourceConfig, SlayerModel @@ -137,19 +138,17 @@ class SearchRequest(BaseModel): """Body for ``POST /search`` (DEV-1375). Mirrors the MCP / CLI / SlayerClient surfaces. - All three retrieval inputs are optional. Empty input falls back to - a recency listing of the newest ``max_memories`` learning-only - memories plus the newest ``max_example_queries`` query-bearing - memories. + All retrieval inputs are optional. Empty input falls back to a + recency listing capped at ``max_results`` hits. """ + model_config = ConfigDict(extra="forbid") + entities: Optional[List[str]] = None query: Optional[Any] = None question: Optional[str] = None datasource: Optional[str] = None - max_memories: int = 5 - max_example_queries: int = 2 - max_entities: int = 5 + max_results: int = 10 cypher_filter: Optional[str] = None @@ -671,16 +670,10 @@ async def search(request: SearchRequest) -> Dict[str, Any]: query=request.query, question=request.question, datasource=request.datasource, - max_memories=request.max_memories, - max_example_queries=request.max_example_queries, - max_entities=request.max_entities, + max_results=request.max_results, cypher_filter=request.cypher_filter, ) - except ( - EntityResolutionError, - AmbiguousModelError, - ValueError, - ) as exc: + except (SlayerError, ValueError) as exc: raise HTTPException(status_code=400, detail=str(exc)) return response.model_dump(mode="json") diff --git a/slayer/cli.py b/slayer/cli.py index 4d7e56cc..e5f14fa0 100644 --- a/slayer/cli.py +++ b/slayer/cli.py @@ -14,6 +14,7 @@ AmbiguousModelError, EntityResolutionError, MemoryNotFoundError, + SlayerError, ) from slayer.core.models import SlayerModel from slayer.engine.ingestion import ( @@ -584,25 +585,11 @@ def main(): ), ) search_parser.add_argument( - "--max-memories", + "--max-results", type=int, - default=5, - dest="max_memories", - help="Cap on returned learning-only memory hits (default 5).", - ) - search_parser.add_argument( - "--max-example-queries", - type=int, - default=2, - dest="max_example_queries", - help="Cap on returned query-bearing memory hits (default 2 — bulky).", - ) - search_parser.add_argument( - "--max-entities", - type=int, - default=5, - dest="max_entities", - help="Cap on returned entity hits (default 5).", + default=10, + dest="max_results", + help="Maximum total number of hits to return (default 10).", ) search_parser.add_argument( "--cypher-filter", @@ -610,8 +597,9 @@ def main(): dest="cypher_filter", help=( "openCypher MATCH query returning '… AS id' to pre-filter all " - "channels to matching canonical IDs. Requires the advanced_search " - "extra (LadybugDB). Read-only — no CREATE/MERGE/DELETE." + "channels to matching canonical IDs. When advanced_search is not " + "installed, only simple MATCH (n:Label) RETURN n.id AS id patterns " + "are supported as a kind filter." ), ) search_parser.add_argument( @@ -783,17 +771,11 @@ def _print_search_response_text(response) -> None: "\nResolved input entities: " + ", ".join(response.resolved_input_entities) ) - print(f"\nMemories ({len(response.memories)}):") - for hit in response.memories: - print(f" M{hit.id} (score={hit.score:.4f})") + print(f"\nResults ({len(response.results)}):") + for hit in response.results: + prefix = "M" if hit.kind == "memory" else f"[{hit.kind}]" + print(f" {prefix} {hit.id} (score={hit.score:.4f})") print(f" {hit.text.splitlines()[0] if hit.text else ''}") - print(f"\nExample queries ({len(response.example_queries)}):") - for hit in response.example_queries: - print(f" M{hit.id} (score={hit.score:.4f})") - print(f" {hit.text.splitlines()[0] if hit.text else ''}") - print(f"\nEntities ({len(response.entities)}):") - for hit in response.entities: - print(f" [{hit.kind}] {hit.id} (score={hit.score:.4f})") def _run_search_query(args, storage) -> None: @@ -807,12 +789,10 @@ def _run_search_query(args, storage) -> None: query=query_input, question=args.question, datasource=args.datasource, - max_memories=args.max_memories, - max_example_queries=args.max_example_queries, - max_entities=args.max_entities, + max_results=args.max_results, cypher_filter=args.cypher_filter, )) - except (EntityResolutionError, AmbiguousModelError, ValueError) as exc: + except (SlayerError, ValueError) as exc: _exit_with_error(exc) return if args.format == "json": diff --git a/slayer/client/slayer_client.py b/slayer/client/slayer_client.py index 8c0adea9..da3ee5b1 100644 --- a/slayer/client/slayer_client.py +++ b/slayer/client/slayer_client.py @@ -389,9 +389,7 @@ async def search( query: Optional[Union[SlayerQuery, Dict[str, Any]]] = None, question: Optional[str] = None, datasource: Optional[str] = None, - max_memories: int = 5, - max_example_queries: int = 2, - max_entities: int = 5, + max_results: int = 10, cypher_filter: Optional[str] = None, ) -> "SearchResponse": """Up to three-channel semantic search over memories + canonical @@ -400,9 +398,8 @@ async def search( Channels: (1) entity-overlap BM25 over memories; (2) tantivy full-text over memories ∪ entities; (3) optional dense embedding similarity (gated by the ``advanced_search`` extra and a - configured provider API key). Memory rankings from all active - channels and entity rankings from channels 2 and 3 are fused via - Reciprocal Rank Fusion (``k=60``). + configured provider API key). All hits are fused via Reciprocal + Rank Fusion (``k=60``) into a single ranked ``results`` list. ``datasource`` (DEV-1409, optional): when set, scope memories and entities to that one datasource. Entity hits are limited to docs @@ -436,15 +433,11 @@ async def search( query=coerced_query, question=question, datasource=datasource, - max_memories=max_memories, - max_example_queries=max_example_queries, - max_entities=max_entities, + max_results=max_results, cypher_filter=cypher_filter, ) body: Dict[str, Any] = { - "max_memories": max_memories, - "max_example_queries": max_example_queries, - "max_entities": max_entities, + "max_results": max_results, } if entities is not None: body["entities"] = entities diff --git a/slayer/mcp/server.py b/slayer/mcp/server.py index 8921353b..b93d1b44 100644 --- a/slayer/mcp/server.py +++ b/slayer/mcp/server.py @@ -11,6 +11,7 @@ AmbiguousModelError, EntityResolutionError, MemoryNotFoundError, + SlayerError, ) from slayer.core.models import ( Aggregation, @@ -2833,9 +2834,7 @@ async def search( query: Any = None, question: Optional[str] = None, datasource: Optional[str] = None, - max_memories: int = 5, - max_example_queries: int = 2, - max_entities: int = 5, + max_results: int = 10, cypher_filter: Optional[str] = None, ) -> str: """Up to three-channel semantic search over memories + canonical entities. @@ -2861,16 +2860,12 @@ async def search( a single warning into ``SearchResponse.warnings`` when any precondition fails — tantivy + BM25 continue to work. - Memory rankings from every active channel and entity rankings - from channels 2 and 3 are fused via Reciprocal Rank Fusion - (k=60). Query-bearing memories (those saved with an attached - ``SlayerQuery``) are partitioned into ``example_queries`` and - capped independently from learning-only ``memories`` so bulky - example queries cannot crowd out small notes. + All hits (memories, example queries, entities) are fused via + Reciprocal Rank Fusion (k=60) into a single ranked + ``results`` list capped at ``max_results``. Empty input (no entities, no query, no question) returns the - newest ``max_memories`` learning-only memories and the newest - ``max_example_queries`` query-bearing memories, with a warning. + newest memories capped at ``max_results``, with a warning. Args: entities: Canonical entity reference strings. @@ -2885,15 +2880,12 @@ async def search( a memory spanning multiple datasources surfaces from each. BM25 / IDF stats reflect only the filtered subset. Unknown datasource raises ``ValueError``. - max_memories: Cap on returned learning-only memory hits - (default 5). - max_example_queries: Cap on returned query-bearing memory - hits (default 2 — they're bulky). - max_entities: Cap on returned entity hits (default 5). + max_results: Maximum total number of hits to return (default 10). cypher_filter: Optional openCypher MATCH query returning ``… AS id`` that pre-filters all three channels to the - returned canonical IDs. Requires the ``advanced_search`` - extra (LadybugDB). Read-only — no CREATE/MERGE/DELETE. + returned canonical IDs. When ``advanced_search`` is not + installed, only simple ``MATCH (n:Label) RETURN n.id AS id`` + patterns are supported as a kind filter. """ try: response = await search_service.search( @@ -2901,16 +2893,10 @@ async def search( query=query, question=question, datasource=datasource, - max_memories=max_memories, - max_example_queries=max_example_queries, - max_entities=max_entities, + max_results=max_results, cypher_filter=cypher_filter, ) - except ( - EntityResolutionError, - AmbiguousModelError, - ValueError, - ) as exc: + except (SlayerError, ValueError) as exc: return _format_resolution_error(exc) return response.model_dump_json(indent=2) diff --git a/slayer/search/__init__.py b/slayer/search/__init__.py index d8b574c9..36f9441a 100644 --- a/slayer/search/__init__.py +++ b/slayer/search/__init__.py @@ -4,20 +4,18 @@ usage. The ``SearchService`` orchestrator runs up to three retrieval channels — entity-overlap BM25 over memories, tantivy full-text over the unioned corpus, and optional dense embedding similarity gated by the -``advanced_search`` extra — and fuses the memory rankings (and entity -rankings, for channels 2 and 3) via Reciprocal Rank Fusion. +``advanced_search`` extra — and fuses all hits via Reciprocal Rank Fusion +into a single ranked ``results`` list. """ from slayer.search.service import ( - EntityHit, - MemoryHit, + SearchHit, SearchResponse, SearchService, ) __all__ = [ - "EntityHit", - "MemoryHit", + "SearchHit", "SearchResponse", "SearchService", ] diff --git a/slayer/search/cypher_naive.py b/slayer/search/cypher_naive.py new file mode 100644 index 00000000..0dada6fc --- /dev/null +++ b/slayer/search/cypher_naive.py @@ -0,0 +1,74 @@ +"""Naive Cypher label-filter parser for the no-advanced_search fallback (DEV-1532). + +Supports only: MATCH (var:Label1:Label2:...) RETURN var.id AS id +(case-insensitive, whitespace-tolerant, no WHERE clause, no relationships). + +Used by SearchService.search() when cypher_filter is supplied but LadybugDB +is not installed. Complex Cypher raises SlayerError pointing at the +advanced_search extra. +""" + +from __future__ import annotations + +import re +from typing import Set + +from slayer.core.errors import SlayerError + + +_LABEL_TO_KIND: dict[str, str] = { + "memory": "memory", + "datasource": "datasource", + "model": "model", + "modelcolumn": "column", + "column": "column", + "measure": "measure", + "aggregation": "aggregation", +} + +# Structured pattern: one label word optionally followed by (: word)* pairs. +# \s* is only used as a delimiter between fixed tokens (never inside a +# quantified character class that can also match \s), which avoids +# polynomial backtracking on non-matching inputs (Sonar S5852). +_NAIVE_PATTERN = re.compile( + pattern=r"^\s*MATCH\s*\(\s*(?P\w+)\s*:\s*(?P\w+(?:\s*:\s*\w+)*)\s*\)\s*RETURN\s+(?P=var)\.id\s+AS\s+id\s*$", # NOSONAR(S5843) — structural complexity is load-bearing; each token maps to a distinct syntactic Cypher element + flags=re.IGNORECASE, +) + +_AS_ID_RE = re.compile(pattern=r"\bAS\s+id\b", flags=re.IGNORECASE) + + +def parse_naive_label_filter(cypher: str) -> Set[str]: + """Parse a simple MATCH (n:Label1:Label2) RETURN n.id AS id expression + and return the set of kind strings to filter search results on. + + Raises SlayerError: + - Missing 'AS id' alias → generic validation message. + - Pattern doesn't match (WHERE, relationship, etc.) → + message mentions advanced_search requirement. + - Unknown label → message says "unknown". + """ + if not _AS_ID_RE.search(cypher): + raise SlayerError( + "cypher_filter must return exactly one column aliased 'id' " + "(e.g. 'RETURN n.id AS id')." + ) + match = _NAIVE_PATTERN.match(cypher) + if not match: + raise SlayerError( + "cypher_filter expression is too complex for the naive fallback; " + "install the advanced_search extra: " + "pip install motley-slayer[advanced_search]" + ) + labels_str = match.group("labels") + labels = [lb.strip() for lb in re.split(r"\s*:\s*", labels_str) if lb.strip()] + kinds: Set[str] = set() + for label in labels: + kind = _LABEL_TO_KIND.get(label.lower()) + if kind is None: + raise SlayerError( + f"unknown entity type {label!r} in cypher_filter; " + f"known types: {sorted(_LABEL_TO_KIND)!r}." + ) + kinds.add(kind) + return kinds diff --git a/slayer/search/graph.py b/slayer/search/graph.py index 82e5f6d5..0e8588cf 100644 --- a/slayer/search/graph.py +++ b/slayer/search/graph.py @@ -16,13 +16,16 @@ Memory id STRING (canonical ``memory:`` form), learning STRING Datasource id STRING, name STRING Model id STRING, name STRING, description STRING - Column id STRING, name STRING, data_type STRING, description STRING + ModelColumn id STRING, name STRING, data_type STRING, description STRING Measure id STRING, name STRING, description STRING Aggregation id STRING, name STRING +Note: the node table for schema columns is named ``ModelColumn`` (not ``Column``) +because ``Column`` is a reserved keyword in LadybugDB ≥ 0.15. + Relationship tables: - MENTIONS Memory → {Datasource, Model, Column, Measure, Aggregation, Memory} - CONTAINS Datasource → Model, Model → {Column, Measure, Aggregation} + MENTIONS Memory → {Datasource, Model, ModelColumn, Measure, Aggregation, Memory} + CONTAINS Datasource → Model, Model → {ModelColumn, Measure, Aggregation} JOINS Model → Model All queries must be read-only ``MATCH … RETURN … AS id`` statements. @@ -136,7 +139,7 @@ def _create_schema(conn: Any) -> None: "id STRING, name STRING, description STRING, PRIMARY KEY(id))" ) conn.execute( - "CREATE NODE TABLE Column(" + "CREATE NODE TABLE ModelColumn(" "id STRING, name STRING, data_type STRING, description STRING, PRIMARY KEY(id))" ) conn.execute( @@ -151,7 +154,7 @@ def _create_schema(conn: Any) -> None: "CREATE REL TABLE MENTIONS(" "FROM Memory TO Datasource, " "FROM Memory TO Model, " - "FROM Memory TO Column, " + "FROM Memory TO ModelColumn, " "FROM Memory TO Measure, " "FROM Memory TO Aggregation, " "FROM Memory TO Memory" @@ -160,7 +163,7 @@ def _create_schema(conn: Any) -> None: conn.execute( "CREATE REL TABLE CONTAINS(" "FROM Datasource TO Model, " - "FROM Model TO Column, " + "FROM Model TO ModelColumn, " "FROM Model TO Measure, " "FROM Model TO Aggregation" ")" @@ -169,30 +172,30 @@ def _create_schema(conn: Any) -> None: def _insert_model_child_nodes(conn: Any, canonical_model: str, model: Any) -> None: - """Insert Column, Measure, and Aggregation nodes for one model.""" + """Insert ModelColumn, Measure, and Aggregation nodes for one model.""" for col in model.columns: if col.hidden: continue conn.execute( - "CREATE (:Column {" - "id: $id, name: $name, data_type: $dt, description: $desc" + "CREATE (:ModelColumn {" + "id: $id, name: $name, data_type: $dt, description: $descr" "})", { "id": f"{canonical_model}.{col.name}", "name": col.name, "dt": col.type.value if col.type is not None else "", - "desc": col.description or "", + "descr": col.description or "", }, ) for measure in model.measures: if not measure.name: continue conn.execute( - "CREATE (:Measure {id: $id, name: $name, description: $desc})", + "CREATE (:Measure {id: $id, name: $name, description: $descr})", { "id": f"{canonical_model}.{measure.name}", "name": measure.name, - "desc": measure.description or "", + "descr": measure.description or "", }, ) for agg in model.aggregations: @@ -218,11 +221,11 @@ def _insert_nodes( for canonical_model, model in visible_models.items(): _, model_name = canonical_model.split(".", 1) conn.execute( - "CREATE (:Model {id: $id, name: $name, description: $desc})", + "CREATE (:Model {id: $id, name: $name, description: $descr})", { "id": canonical_model, "name": model_name, - "desc": model.description or "", + "descr": model.description or "", }, ) _insert_model_child_nodes(conn, canonical_model, model) @@ -242,7 +245,7 @@ def _insert_contains_edges( datasource_names: list[str], visible_models: dict, ) -> None: - """Insert CONTAINS edges: Datasource→Model and Model→{Column/Measure/Agg}.""" + """Insert CONTAINS edges: Datasource→Model and Model→{ModelColumn/Measure/Agg}.""" ds_set = set(datasource_names) for canonical_model, model in visible_models.items(): ds = canonical_model.split(".", 1)[0] @@ -256,7 +259,7 @@ def _insert_contains_edges( if col.hidden: continue conn.execute( - "MATCH (m:Model {id: $model}), (c:Column {id: $col}) " + "MATCH (m:Model {id: $model}), (c:ModelColumn {id: $col}) " "CREATE (m)-[:CONTAINS]->(c)", {"model": canonical_model, "col": f"{canonical_model}.{col.name}"}, ) @@ -335,7 +338,7 @@ def _connect_entity_mention( ) elif entity in valid_columns: conn.execute( - "MATCH (m:Memory {id: $src}), (c:Column {id: $tgt}) " + "MATCH (m:Memory {id: $src}), (c:ModelColumn {id: $tgt}) " "CREATE (m)-[:MENTIONS]->(c)", {"src": src, "tgt": entity}, ) @@ -368,7 +371,7 @@ def _insert_mentions_edges( visible_models: dict, datasource_names: list[str], ) -> None: - """Insert MENTIONS edges: Memory → {Datasource, Model, Column, Measure, Agg, Memory}.""" + """Insert MENTIONS edges: Memory → {Datasource, Model, ModelColumn, Measure, Agg, Memory}.""" ds_set = set(datasource_names) valid_models: set[str] = set(visible_models) valid_columns, valid_measures, valid_aggs, valid_memory_canonicals = ( @@ -395,7 +398,7 @@ async def build_graph(storage: StorageBackend) -> tuple[Any, Any]: # No-argument Database() creates an ephemeral in-memory instance; # no files are written to the working directory. db = mod.Database() - conn = db.connect() + conn = mod.Connection(db) _create_schema(conn) datasource_names = await storage.list_datasources() diff --git a/slayer/search/service.py b/slayer/search/service.py index 896e6173..39b06f31 100644 --- a/slayer/search/service.py +++ b/slayer/search/service.py @@ -22,22 +22,20 @@ the memory ranking, non-memory rows feed the entity ranking. Each partition is ranked in full. -Memory and entity rankings from every active channel are fused via RRF -(``k = 60``). Channel 1's entity ranking is the surviving canonical -inputs in supplied order (DEV-1513); channels 2 and 3 contribute fuzzy -hits. +All rankings from every active channel are fused via RRF (``k = 60``) +into a single flat ``results: List[SearchHit]`` list capped at +``max_results`` (DEV-1532). Channel 1's entity ranking is the surviving +canonical inputs in supplied order (DEV-1513); channels 2 and 3 +contribute fuzzy hits. -Per-bucket invariance (DEV-1414): because each channel produces a full +Ranking stability (DEV-1414): because each channel produces a full per-kind ranking — never truncated by a shared candidate-pool budget — -the membership and order of every output bucket (``memories``, -``example_queries``, ``entities``) is a pure function of the corpus, -the question, the datasource filter, and that bucket's own cap. Varying -the other two caps cannot move ids in or out of the returned list nor -reorder it. +the relative order of any subset of the flat list is stable. Changing +only ``max_results`` never reorders existing entries nor causes an +entry to appear or disappear unless the cap boundary moves past it. Empty input (no entities, no query, no question) falls back to recency: -newest ``max_memories`` learning-only memories + newest -``max_example_queries`` query-bearing memories, with a warning. +newest memories capped at ``max_results``, with a warning. """ from __future__ import annotations @@ -60,6 +58,7 @@ resolve_entity, ) from slayer.search import graph as _search_graph +from slayer.search.cypher_naive import parse_naive_label_filter as _parse_naive_cypher from slayer.search.index import ( Corpus, IndexHit, @@ -82,49 +81,23 @@ # --------------------------------------------------------------------------- -class MemoryHit(BaseModel): - """A learning-only memory result (``Memory.query is None``). ``id`` is - the string memory id (suitable for ``forget_memory(id=hit.id)``). - ``score`` is always the Reciprocal-Rank-Fusion score - (``Σ 1 / (k + rank)``, ``k=60``); even single-channel searches go - through RRF, so the value is comparable across channels but is not - directly the raw BM25 / tantivy / cosine score.""" +class SearchHit(BaseModel): + """A unified search result. ``kind`` is ``"memory"`` for memories, + or the entity kind string (``"datasource"``, ``"model"``, + ``"column"``, ``"measure"``, ``"aggregation"``) for entity hits. - id: str - score: float - text: str - matched_entities: List[str] = Field(default_factory=list) - - -class ExampleQueryHit(BaseModel): - """A query-bearing memory result (``Memory.query`` is set). Same id / - score / text shape as ``MemoryHit`` but always carries the attached - ``SlayerQuery``. Surfaces in ``SearchResponse.example_queries`` — - bulky reference material, capped independently from learning-only - memories so it cannot crowd them out.""" + ``id`` is the raw storage id for memories (suitable for + ``forget_memory(id=hit.id)``) and the canonical entity string for + entity hits. ``score`` is the RRF-fused score (``Σ 1/(k+rank)``, + ``k=60``). ``matched_entities`` and ``query`` are populated for + memory hits only; entity hits carry empty defaults.""" + kind: str id: str score: float text: str matched_entities: List[str] = Field(default_factory=list) - query: SlayerQuery - - -class EntityHit(BaseModel): - """An entity result. ``id`` is the canonical entity string - (``""``, ``"."``, or ``".."``). - ``score`` is the RRF-fused score across channels 1, 2, and 3 (or the - single-channel raw score when only one channel contributed). - - DEV-1513: channel 1 contributes named-entity surfacing via the - implicit self-reference model — each entity is conceptually tagged - with itself, so a user-supplied ref in ``entities=`` ranks at the - top of the entities bucket alongside any fuzzy hits.""" - - id: str - kind: str # "datasource" | "model" | "column" | "measure" | "aggregation" - score: float - text: str + query: Optional[SlayerQuery] = None # --------------------------------------------------------------------------- @@ -159,9 +132,7 @@ class LookupMissing(BaseModel): class SearchResponse(BaseModel): - memories: List[MemoryHit] = Field(default_factory=list) - example_queries: List[ExampleQueryHit] = Field(default_factory=list) - entities: List[EntityHit] = Field(default_factory=list) + results: List[SearchHit] = Field(default_factory=list) resolved_input_entities: List[str] = Field(default_factory=list) warnings: List[str] = Field(default_factory=list) @@ -220,11 +191,9 @@ def _build_memory_hit( index_hits_by_memory_id: dict, canonical_input_entities: List[str], valid_canonicals: Optional[set] = None, -) -> Union["MemoryHit", "ExampleQueryHit"]: - """Build the appropriate hit type for ``mem``: ``MemoryHit`` for - learning-only memories (``query is None``), ``ExampleQueryHit`` for - query-bearing ones. ``text`` falls back to ``mem.learning`` when the - memory wasn't reached via tantivy. +) -> "SearchHit": + """Build a SearchHit for a memory. ``text`` falls back to + ``mem.learning`` when the memory wasn't reached via tantivy. DEV-1428: ``matched_entities`` is computed against the LIVE canonical set when ``valid_canonicals`` is supplied, so stale tags @@ -232,8 +201,7 @@ def _build_memory_hit( DEV-1513: every memory has an implicit ``memory:`` self-reference; it appears in ``matched_entities`` only when the - user explicitly named that ref (so the surfaced memory honestly - shows the reason it was returned).""" + user explicitly named that ref.""" if valid_canonicals is not None: live_entities = [e for e in mem.entities if e in valid_canonicals] else: @@ -248,13 +216,13 @@ def _build_memory_hit( if memory_id in index_hits_by_memory_id else mem.learning ) - if mem.query is None: - return MemoryHit( - id=memory_id, score=score, text=text, matched_entities=matched, - ) - return ExampleQueryHit( - id=memory_id, score=score, text=text, - matched_entities=matched, query=mem.query, + return SearchHit( + kind="memory", + id=memory_id, + score=score, + text=text, + matched_entities=matched, + query=mem.query, ) @@ -274,48 +242,96 @@ def _filter_memories_entities( return out -def _fuse_memory_hits( +def _query_bearing_memory_hits(hits: List["SearchHit"]) -> List["SearchHit"]: + """Return hits that are query-bearing memories (kind=='memory', query set).""" + return [h for h in hits if h.kind == "memory" and h.query is not None] + + +def _build_hit_from_fused_key( + *, + key: str, + score: float, + memory_by_id: dict, + index_hits_by_memory_id: dict, + canonical_input_entities: List[str], + corpus: Optional["Corpus"], + named_kind_text: Optional[Dict[str, Tuple[str, str]]], + valid_canonicals: Optional[set], + kind_filter: Optional[Set[str]], +) -> Optional["SearchHit"]: + """Build one SearchHit from a fused (key, score) pair, or return None to skip.""" + if key.startswith(_MEMORY_PREFIX): + memory_id = key[len(_MEMORY_PREFIX):] + mem = memory_by_id.get(memory_id) + if mem is None or (kind_filter is not None and "memory" not in kind_filter): + return None + return _build_memory_hit( + mem=mem, + memory_id=memory_id, + score=score, + index_hits_by_memory_id=index_hits_by_memory_id, + canonical_input_entities=canonical_input_entities, + valid_canonicals=valid_canonicals, + ) + resolved = _resolve_entity_hit_kind_text( + canonical=key, + corpus=corpus, + named_kind_text=named_kind_text, + ) + if resolved is None: + return None + kind, text = resolved + if kind_filter is not None and kind not in kind_filter: + return None + return SearchHit(id=key, kind=kind, score=score, text=text) + + +def _fuse_all_hits( *, - rankings: List[List[str]], + memory_rankings: List[List[str]], + entity_rankings: List[List[str]], memory_by_id: dict, index_hits_by_memory_id: dict, canonical_input_entities: List[str], - max_memories: int, - max_example_queries: int, + corpus: Optional["Corpus"], + named_kind_text: Optional[Dict[str, Tuple[str, str]]], + max_results: int, valid_canonicals: Optional[set] = None, -) -> Tuple[List["MemoryHit"], List["ExampleQueryHit"]]: - """RRF-fuse the supplied memory rankings and partition into - learning-only (``MemoryHit``) vs query-bearing (``ExampleQueryHit``) - lists, each capped independently. Empty inner rankings are filtered - out so single-channel results still flow through RRF normalisation.""" - non_empty = [r for r in rankings if r] + kind_filter: Optional[Set[str]] = None, +) -> List["SearchHit"]: + """RRF-fuse memory and entity rankings into a single flat list. + + Memory IDs are prefixed with the canonical memory prefix so the + unified pool contains no key collisions. Kind filter (naive Cypher + fallback) is applied BEFORE the max_results cap so the caller always + gets up to max_results matching items.""" + prefixed_memory_rankings = [ + [f"{_MEMORY_PREFIX}{mid}" for mid in ranking] + for ranking in memory_rankings + ] + all_rankings = prefixed_memory_rankings + entity_rankings + non_empty = [r for r in all_rankings if r] fused = rrf_fuse(rankings=non_empty, k=_RRF_K) if non_empty else {} fused_sorted = sorted(fused.items(), key=lambda kv: kv[1], reverse=True) - learnings: List[MemoryHit] = [] - examples: List[ExampleQueryHit] = [] - for memory_id, score in fused_sorted: - mem = memory_by_id.get(memory_id) - if mem is None: - continue - hit = _build_memory_hit( - mem=mem, - memory_id=memory_id, + results: List[SearchHit] = [] + for key, score in fused_sorted: + hit = _build_hit_from_fused_key( + key=key, score=score, + memory_by_id=memory_by_id, index_hits_by_memory_id=index_hits_by_memory_id, canonical_input_entities=canonical_input_entities, + corpus=corpus, + named_kind_text=named_kind_text, valid_canonicals=valid_canonicals, + kind_filter=kind_filter, ) - if isinstance(hit, MemoryHit) and len(learnings) < max_memories: - learnings.append(hit) - elif isinstance(hit, ExampleQueryHit) and len(examples) < max_example_queries: - examples.append(hit) - if ( - len(learnings) >= max_memories - and len(examples) >= max_example_queries - ): - break - return learnings, examples + if hit is not None: + results.append(hit) + if len(results) >= max_results: + break + return results def _filter_memories_by_datasource( @@ -445,37 +461,6 @@ def _resolve_entity_hit_kind_text( return None -def _fuse_entity_hits( - *, - rankings: List[List[str]], - corpus: Optional[Corpus], - named_kind_text: Optional[Dict[str, Tuple[str, str]]], - max_entities: int, -) -> List[EntityHit]: - """RRF-fuse the entity rankings and look text/kind up via - ``_resolve_entity_hit_kind_text`` (corpus first, then channel-1 - named-entity fallback, DEV-1513). Returns at most ``max_entities`` - hits.""" - non_empty = [r for r in rankings if r] - fused = rrf_fuse(rankings=non_empty, k=_RRF_K) if non_empty else {} - fused_sorted = sorted(fused.items(), key=lambda kv: kv[1], reverse=True) - out: List[EntityHit] = [] - for canonical, score in fused_sorted: - if len(out) >= max_entities: - break - resolved = _resolve_entity_hit_kind_text( - canonical=canonical, - corpus=corpus, - named_kind_text=named_kind_text, - ) - if resolved is None: - continue - kind, text = resolved - out.append(EntityHit( - id=canonical, kind=kind, score=score, text=text, - )) - return out - # --------------------------------------------------------------------------- # DEV-1513: implicit-self-reference helpers @@ -659,18 +644,10 @@ async def search( question: Optional[str] = None, datasource: Optional[str] = None, cypher_filter: Optional[str] = None, - max_memories: int = 5, - max_example_queries: int = 2, - max_entities: int = 5, + max_results: int = 10, ) -> SearchResponse: - if max_memories < 0: - raise ValueError(f"max_memories must be >= 0; got {max_memories}.") - if max_example_queries < 0: - raise ValueError( - f"max_example_queries must be >= 0; got {max_example_queries}." - ) - if max_entities < 0: - raise ValueError(f"max_entities must be >= 0; got {max_entities}.") + if max_results < 1: + raise ValueError(f"max_results must be >= 1; got {max_results}.") await self._validate_datasource_known(datasource) canonical_input_entities, warnings = await self._resolve_inputs( @@ -680,31 +657,32 @@ async def search( question_active = bool(question and question.strip()) # Cypher pre-filter: run before channels, short-circuit on empty result. - candidate_ids: Optional[FrozenSet[str]] = None - if cypher_filter is not None: - candidate_ids = await _search_graph.get_filtered_ids( - cypher_filter, self._storage, - ) - if not candidate_ids: - warnings.append( - "cypher_filter returned no matching nodes; " - "search returned no results." - ) - return SearchResponse( - memories=[], - example_queries=[], - entities=[], - resolved_input_entities=canonical_input_entities, - warnings=_dedup(warnings), - ) + # When advanced_search is absent, attempt naive label-filter parsing. + candidate_ids, kind_filter, early = await self._apply_cypher_filter( + cypher_filter=cypher_filter, + canonical_input_entities=canonical_input_entities, + warnings=warnings, + ) + if early is not None: + return early + # Naive kind_filter parity with graph path: warn when a named + # memory: ref would be excluded by the kind filter so the + # caller knows why it doesn't appear in results. + if kind_filter is not None and "memory" not in kind_filter: + for canonical in canonical_input_entities: + if canonical.startswith(_MEMORY_PREFIX): + warnings.append( + f"{canonical} excluded by cypher_filter kind filter " + f"(allowed kinds: {sorted(kind_filter)!r})." + ) # Recency fallback for the all-empty case. if not channel_1_active and not question_active: return await self._recency_fallback( datasource=datasource, candidate_ids=candidate_ids, - max_memories=max_memories, - max_example_queries=max_example_queries, + kind_filter=kind_filter, + max_results=max_results, warnings=warnings, ) @@ -826,49 +804,44 @@ async def search( mem_ids=channel_3_memory_ranking, ) - memory_hits, example_query_hits = _fuse_memory_hits( - rankings=[ + all_hits = _fuse_all_hits( + memory_rankings=[ channel_1_memory_ranking, channel_2_memory_ranking, channel_3_memory_ranking, ], + entity_rankings=[ + channel_1_entity_ranking, + channel_2_entity_ranking, + channel_3_entity_ranking, + ], memory_by_id=memory_by_id, index_hits_by_memory_id=index_hits_by_memory_id, canonical_input_entities=canonical_input_entities, - max_memories=max_memories, - max_example_queries=max_example_queries, + corpus=corpus, + named_kind_text=named_kind_text, + max_results=max_results, valid_canonicals=valid_canonicals, + kind_filter=kind_filter, ) - # DEV-1428: stale Memory.query warnings — surface example_queries - # whose attached query references entities that no longer resolve. + query_bearing_hits = _query_bearing_memory_hits(all_hits) + # DEV-1428: stale Memory.query warnings — surface memories whose + # attached query references entities that no longer resolve. # DEV-1513: ALSO emit the warning for any explicitly-named # ``memory:`` ref whose memory carries a stale query, even - # when ``max_example_queries`` suppressed the hit — the user - # explicitly asked for that memory. + # when max_results suppressed the hit. warnings = _dedup( warnings + await self._stale_query_warnings( - example_query_hits=example_query_hits, + query_bearing_hits=query_bearing_hits, memory_by_id=memory_by_id, ) + await self._stale_query_warnings_for_named_memory_refs( canonical_input_entities=canonical_input_entities, all_memories=all_memories, - already_warned_ids={h.id for h in example_query_hits}, + already_warned_ids={h.id for h in query_bearing_hits}, ) ) - entity_hits = _fuse_entity_hits( - rankings=[ - channel_1_entity_ranking, - channel_2_entity_ranking, - channel_3_entity_ranking, - ], - corpus=corpus, - named_kind_text=named_kind_text, - max_entities=max_entities, - ) return SearchResponse( - memories=memory_hits, - example_queries=example_query_hits, - entities=entity_hits, + results=all_hits, resolved_input_entities=canonical_input_entities, warnings=warnings, ) @@ -921,23 +894,55 @@ async def _resolve_inputs( warnings.extend(extraction.warnings) return _dedup(canonical), _dedup(warnings) + async def _apply_cypher_filter( + self, + *, + cypher_filter: Optional[str], + canonical_input_entities: List[str], + warnings: List[str], + ) -> Tuple[Optional[FrozenSet[str]], Optional[Set[str]], Optional[SearchResponse]]: + """Resolve the optional cypher_filter into (candidate_ids, kind_filter). + + Returns a 3-tuple: + - candidate_ids: non-None when the full graph path ran (advanced_search). + - kind_filter: non-None when the naive fallback ran. + - early: a short-circuit SearchResponse when the graph returned no ids. + """ + if cypher_filter is None: + return None, None, None + if _search_graph.is_available(): + candidate_ids = await _search_graph.get_filtered_ids( + cypher_filter, self._storage, + ) + if not candidate_ids: + early_warnings = _dedup( + warnings + [ + "cypher_filter returned no matching nodes; " + "search returned no results." + ] + ) + return candidate_ids, None, SearchResponse( + results=[], + resolved_input_entities=canonical_input_entities, + warnings=early_warnings, + ) + return candidate_ids, None, None + return None, _parse_naive_cypher(cypher_filter), None + async def _recency_fallback( self, *, - max_memories: int, - max_example_queries: int, + max_results: int, warnings: List[str], datasource: Optional[str] = None, candidate_ids: Optional[FrozenSet[str]] = None, + kind_filter: Optional[Set[str]] = None, ) -> SearchResponse: - """Empty-input branch: partition all memories by recency into the - learning-only bucket (``memories``, capped by ``max_memories``) - and the query-bearing bucket (``example_queries``, capped by - ``max_example_queries``). + """Empty-input branch: return the newest memories (both learning-only + and query-bearing) as a flat list, capped by max_results. DEV-1409: when ``datasource`` is set, the same memory pre-filter - used by the main search path applies — only memories with at - least one entity rooted at the requested datasource are eligible. + used by the main search path applies. """ warnings.append( "no entities, query, or question supplied; returning " @@ -952,47 +957,35 @@ async def _recency_fallback( m for m in recency_memories if f"memory:{m.id}" in candidate_ids ] + if kind_filter is not None and "memory" not in kind_filter: + recency_memories = [] recency_memories.sort(key=lambda m: m.created_at, reverse=True) valid_canonicals = await self._valid_canonical_set( all_memories=recency_memories, datasource=datasource, ) - memory_hits: List[MemoryHit] = [] - example_query_hits: List[ExampleQueryHit] = [] + hits: List[SearchHit] = [] for m in recency_memories: - hit = _build_memory_hit( + if len(hits) >= max_results: + break + hits.append(_build_memory_hit( mem=m, memory_id=m.id, score=0.0, index_hits_by_memory_id={}, canonical_input_entities=[], valid_canonicals=valid_canonicals, - ) - if isinstance(hit, MemoryHit) and len(memory_hits) < max_memories: - memory_hits.append(hit) - elif ( - isinstance(hit, ExampleQueryHit) - and len(example_query_hits) < max_example_queries - ): - example_query_hits.append(hit) - if ( - len(memory_hits) >= max_memories - and len(example_query_hits) >= max_example_queries - ): - break - # DEV-1428: emit stale-Memory.query warnings on the recency path - # too; otherwise an empty-input search would silently return - # example_queries whose attached queries no longer resolve. + )) + # DEV-1428: emit stale-Memory.query warnings on the recency path too. memory_by_id = {m.id: m for m in recency_memories} + query_bearing = [h for h in hits if h.query is not None] warnings = _dedup( warnings + await self._stale_query_warnings( - example_query_hits=example_query_hits, + query_bearing_hits=query_bearing, memory_by_id=memory_by_id, ) ) return SearchResponse( - memories=memory_hits, - example_queries=example_query_hits, - entities=[], + results=hits, resolved_input_entities=[], warnings=warnings, ) @@ -1370,15 +1363,15 @@ async def _collect_model_subtree_canonicals( async def _stale_query_warnings( self, *, - example_query_hits: List["ExampleQueryHit"], + query_bearing_hits: List["SearchHit"], memory_by_id: Dict[str, Memory], ) -> List[str]: - """DEV-1428: emit one warning per example_queries hit whose + """DEV-1428: emit one warning per query-bearing memory hit whose attached ``Memory.query`` references entities that no longer resolve. The query is NOT rewritten — agents who notice the warning can re-save the memory to clean it.""" out: List[str] = [] - for hit in example_query_hits: + for hit in query_bearing_hits: mem = memory_by_id.get(hit.id) if mem is None or mem.query is None: continue diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 63bdc67c..dbd3d6ea 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -3123,26 +3123,25 @@ async def test_search_ingest_populates_sampled(search_env): async def test_search_question_finds_column(search_env): - """``search(question="amount")`` returns a column EntityHit pointing at + """``search(question="amount")`` returns a column hit pointing at one of the seeded ``amount``-named columns.""" from slayer.search.service import SearchService _engine, storage = search_env response = await SearchService(storage=storage).search( question="amount", - max_entities=10, - max_memories=0, + max_results=20, ) - column_hits = [e for e in response.entities if e.kind == "column"] + column_hits = [h for h in response.results if h.kind == "column"] assert column_hits, ( - "expected at least one column EntityHit; got entities=" - f"{[(e.id, e.kind) for e in response.entities]}" + "expected at least one column hit; got results=" + f"{[(h.id, h.kind) for h in response.results]}" ) # The question is "amount" — only ``*.amount`` columns are # acceptable. Accepting any column hit (e.g. ``customers.region``) # would mask a relevance regression in the search ranker. assert any(h.id.endswith(".amount") for h in column_hits), ( - "expected an `.amount` column EntityHit; got column hits=" + "expected an `.amount` column hit; got column hits=" f"{[h.id for h in column_hits]}" ) @@ -3155,10 +3154,11 @@ async def test_search_entity_filter_finds_memory(search_env): _engine, storage = search_env response = await SearchService(storage=storage).search( entities=["test_sqlite.orders"], - max_memories=5, + max_results=10, ) - assert len(response.memories) >= 1 - hit = response.memories[0] + memory_hits = [h for h in response.results if h.kind == "memory"] + assert len(memory_hits) >= 1 + hit = memory_hits[0] assert "test_sqlite.orders" in hit.matched_entities assert "refunds" in hit.text diff --git a/tests/search_helpers.py b/tests/search_helpers.py new file mode 100644 index 00000000..d9e4ac48 --- /dev/null +++ b/tests/search_helpers.py @@ -0,0 +1,60 @@ +"""Shared helpers for search-related tests.""" + +from __future__ import annotations + +from typing import Any + +from slayer.core.enums import DataType +from slayer.core.models import Column, DatasourceConfig, SlayerModel +from slayer.storage.base import StorageBackend + + +async def seed_warehouse_models(storage: StorageBackend) -> None: + """Seed the standard warehouse datasource + orders + customers models. + + Centralised here to avoid Sonar duplication-density failures across + tests that each need the same model corpus. + """ + await storage.save_datasource( + DatasourceConfig(name="warehouse", type="sqlite", database=":memory:") + ) + await storage.save_model(SlayerModel( + name="orders", + sql_table="orders", + data_source="warehouse", + description="Checkout orders.", + columns=[ + Column(name="id", type=DataType.INT, primary_key=True), + Column(name="amount_paid", type=DataType.DOUBLE, + description="Net paid in USD."), + Column(name="status", type=DataType.TEXT, + description="paid|refunded|cancelled."), + ], + )) + await storage.save_model(SlayerModel( + name="customers", + sql_table="customers", + data_source="warehouse", + description="Customer master data.", + columns=[ + Column(name="id", type=DataType.INT, primary_key=True), + Column(name="email", type=DataType.TEXT), + ], + )) + + +async def call_mcp_tool(*, mcp: Any, name: str, arguments: dict) -> str: + """Invoke an MCP tool and return its text result.""" + result = await mcp.call_tool(name, arguments) + if isinstance(result, tuple): + candidates: list = list(result[0]) if result else [] + elif isinstance(result, list): + candidates = result + elif hasattr(result, "content"): + candidates = list(result.content) + else: + return str(result) + for block in candidates: + if hasattr(block, "text"): + return block.text + return str(result) diff --git a/tests/test_cypher_naive.py b/tests/test_cypher_naive.py new file mode 100644 index 00000000..618510c0 --- /dev/null +++ b/tests/test_cypher_naive.py @@ -0,0 +1,163 @@ +"""Unit tests for the naive Cypher label-filter parser (DEV-1532). + +``slayer.search.cypher_naive.parse_naive_label_filter`` parses +``MATCH (var:Label1:Label2:...) RETURN var.id AS id`` patterns (case-insensitive, +whitespace-tolerant) and returns the set of kind strings to filter on. + +When the pattern does not match (complex Cypher), it raises ``SlayerError`` +with an install hint for the ``advanced_search`` extra. +When a label is unrecognised, it raises ``SlayerError`` with "unknown". +""" + +from __future__ import annotations + +import pytest + +from slayer.core.errors import SlayerError +from slayer.search.cypher_naive import parse_naive_label_filter + + +# --------------------------------------------------------------------------- +# Happy paths: recognised labels → kind set +# --------------------------------------------------------------------------- + + +def test_single_label_model() -> None: + kinds = parse_naive_label_filter("MATCH (n:Model) RETURN n.id AS id") + assert kinds == {"model"} + + +def test_single_label_memory() -> None: + kinds = parse_naive_label_filter("MATCH (n:Memory) RETURN n.id AS id") + assert kinds == {"memory"} + + +def test_single_label_column() -> None: + kinds = parse_naive_label_filter("MATCH (n:Column) RETURN n.id AS id") + assert kinds == {"column"} + + +def test_single_label_datasource() -> None: + kinds = parse_naive_label_filter("MATCH (n:Datasource) RETURN n.id AS id") + assert kinds == {"datasource"} + + +def test_single_label_measure() -> None: + kinds = parse_naive_label_filter("MATCH (n:Measure) RETURN n.id AS id") + assert kinds == {"measure"} + + +def test_single_label_aggregation() -> None: + kinds = parse_naive_label_filter("MATCH (n:Aggregation) RETURN n.id AS id") + assert kinds == {"aggregation"} + + +def test_multi_label_colon_separated_two() -> None: + kinds = parse_naive_label_filter( + "MATCH (n:Model:Column) RETURN n.id AS id" + ) + assert kinds == {"model", "column"} + + +def test_multi_label_colon_separated_three() -> None: + kinds = parse_naive_label_filter( + "MATCH (n:Memory:Model:Column) RETURN n.id AS id" + ) + assert kinds == {"memory", "model", "column"} + + +def test_all_six_labels_together() -> None: + kinds = parse_naive_label_filter( + "MATCH (n:Memory:Datasource:Model:Column:Measure:Aggregation) " + "RETURN n.id AS id" + ) + assert kinds == {"memory", "datasource", "model", "column", "measure", "aggregation"} + + +# --------------------------------------------------------------------------- +# Case-insensitivity and whitespace tolerance +# --------------------------------------------------------------------------- + + +def test_case_insensitive_keyword() -> None: + kinds = parse_naive_label_filter("match (n:Model) return n.id as id") + assert kinds == {"model"} + + +def test_case_insensitive_label() -> None: + kinds = parse_naive_label_filter("MATCH (n:model) RETURN n.id AS id") + assert kinds == {"model"} + + +def test_whitespace_around_labels() -> None: + kinds = parse_naive_label_filter( + "MATCH ( n : Model : Column ) RETURN n.id AS id" + ) + assert kinds == {"model", "column"} + + +def test_extra_whitespace_in_return() -> None: + kinds = parse_naive_label_filter("MATCH (n:Model) RETURN n.id AS id") + assert kinds == {"model"} + + +# --------------------------------------------------------------------------- +# Error cases: unknown label +# --------------------------------------------------------------------------- + + +def test_unknown_label_raises_slayer_error() -> None: + with pytest.raises(SlayerError, match="(?i)unknown"): + parse_naive_label_filter("MATCH (n:Foo) RETURN n.id AS id") + + +def test_unknown_label_in_multi_raises_slayer_error() -> None: + with pytest.raises(SlayerError, match="(?i)unknown"): + parse_naive_label_filter("MATCH (n:Model:Foo) RETURN n.id AS id") + + +# --------------------------------------------------------------------------- +# Error cases: complex Cypher → requires advanced_search +# --------------------------------------------------------------------------- + + +def test_complex_cypher_with_where_raises_advanced_search_error() -> None: + with pytest.raises(SlayerError, match="(?i)advanced_search"): + parse_naive_label_filter( + "MATCH (n:Model) WHERE n.name = 'orders' RETURN n.id AS id" + ) + + +def test_complex_cypher_with_relationship_raises() -> None: + with pytest.raises(SlayerError, match="(?i)advanced_search"): + parse_naive_label_filter( + "MATCH (m:Memory)-[:MENTIONS]->(e:Model) RETURN m.id AS id" + ) + + +def test_complex_cypher_multi_clause_raises() -> None: + with pytest.raises(SlayerError, match="(?i)advanced_search"): + parse_naive_label_filter( + "MATCH (m:Memory) MATCH (e:Model) RETURN m.id AS id" + ) + + +def test_bare_match_no_label_raises_advanced_search_error() -> None: + with pytest.raises(SlayerError, match="(?i)advanced_search"): + parse_naive_label_filter("MATCH (n) RETURN n.id AS id") + + +# --------------------------------------------------------------------------- +# Error cases: missing AS id alias +# --------------------------------------------------------------------------- + + +def test_missing_as_id_raises_slayer_error() -> None: + """Queries without 'AS id' are invalid even for the naive path.""" + with pytest.raises(SlayerError): + parse_naive_label_filter("MATCH (n:Model) RETURN n.id") + + +def test_wrong_alias_raises_slayer_error() -> None: + with pytest.raises(SlayerError): + parse_naive_label_filter("MATCH (n:Model) RETURN n.id AS entity_id") diff --git a/tests/test_lightning_talk_notebook.py b/tests/test_lightning_talk_notebook.py index 031af3d1..f7018f27 100644 --- a/tests/test_lightning_talk_notebook.py +++ b/tests/test_lightning_talk_notebook.py @@ -415,7 +415,7 @@ def test_search_demo_has_three_calls(): def test_search_question_cell_targets_brooklyn(): """Cell 18: search(question=Brooklyn-related). Must also assert at - runtime that resp.memories contains the Brooklyn memory.""" + runtime that the Brooklyn memory is in the flat results list.""" nb = _load_notebook() for src in _all_search_cells(nb["cells"]): if "search" not in src: @@ -432,10 +432,10 @@ def test_search_question_cell_targets_brooklyn(): # tantivy retrieve via the learning text). assert BROOKLYN_MEMORY_ID in src, ( f"Brooklyn search cell must reference {BROOKLYN_MEMORY_ID!r} " - "in a runtime assertion on resp.memories" + "in a runtime assertion on the results" ) - assert "resp.memories" in src or ".memories" in src, ( - "Brooklyn search cell must assert against resp.memories" + assert "resp.results" in src or ".results" in src, ( + "Brooklyn search cell must assert against resp.results" ) return pytest.fail("Expected a search(question=...) cell referencing Brooklyn") @@ -443,7 +443,7 @@ def test_search_question_cell_targets_brooklyn(): def test_search_entities_cell_uses_order_total(): """Cell 19: search(entities=[order_total]). Must also assert at - runtime that resp.memories contains the Brooklyn memory.""" + runtime that the Brooklyn memory is in results.""" nb = _load_notebook() for src in _all_search_cells(nb["cells"]): try: @@ -454,7 +454,7 @@ def test_search_entities_cell_uses_order_total(): if isinstance(ents, list) and any("order_total" in e for e in ents): assert BROOKLYN_MEMORY_ID in src, ( f"Entities search cell must reference {BROOKLYN_MEMORY_ID!r} " - "in a runtime assertion on resp.memories" + "in a runtime assertion on results" ) return pytest.fail( @@ -464,43 +464,41 @@ def test_search_entities_cell_uses_order_total(): def test_search_discovery_cell_caps_memories_zero_and_lifts_entities(): """The third search call demonstrates entity-discovery + example_queries: - ``max_memories=0`` and ``max_entities >= 1`` and ``max_example_queries >= 1`` - so it surfaces both entity hits and the query-bearing memory. + uses ``max_results`` large enough to surface both entity hits and + query-bearing memories. Must also assert at runtime that the top-customers memory appears in - resp.example_queries and resp.entities is non-empty.""" + the example_queries local variable and entities is non-empty.""" nb = _load_notebook() for src in _all_search_cells(nb["cells"]): try: kwargs = _find_kwargs_for_call(src, callee_suffix="search") except AssertionError: continue - if ( - kwargs.get("max_memories") == 0 - and (kwargs.get("max_entities") or 0) >= 1 - and (kwargs.get("max_example_queries") or 0) >= 1 - ): - q = kwargs.get("question") - assert isinstance(q, str) and ( - "customer" in q.lower() or "lifetime" in q.lower() - ), ( - "Discovery search question should be about customer " - "lifetime spend; got: " + repr(q) - ) - # Runtime assertion: top-customers memory in resp.example_queries - assert TOP_CUSTOMERS_MEMORY_ID in src, ( - f"Discovery cell must reference {TOP_CUSTOMERS_MEMORY_ID!r} " - "in a runtime assertion on resp.example_queries" - ) - assert "example_queries" in src, ( - "Discovery cell must assert against resp.example_queries" - ) - assert "entities" in src and ".entities" in src.replace(" ", ""), ( - "Discovery cell must also assert resp.entities is non-empty" - ) - return + # The new API uses max_results instead of separate caps. + # Look for the discovery cell: uses a question about customer lifetime + # and surfaces both example_queries and entities from the flat results. + q = kwargs.get("question") + if not (isinstance(q, str) and ( + "customer" in q.lower() or "lifetime" in q.lower() + )): + continue + if "max_results" not in src: + continue + # Runtime assertion: top-customers memory in example_queries + assert TOP_CUSTOMERS_MEMORY_ID in src, ( + f"Discovery cell must reference {TOP_CUSTOMERS_MEMORY_ID!r} " + "in a runtime assertion on example_queries" + ) + assert "example_queries" in src, ( + "Discovery cell must assert against example_queries local variable" + ) + assert "entities" in src, ( + "Discovery cell must also assert entities is non-empty" + ) + return pytest.fail( - "Expected a search() call with max_memories=0, " - "max_entities>=1, max_example_queries>=1" + "Expected a search() call with a customer/lifetime question, " + "max_results=..., asserting example_queries and entities" ) diff --git a/tests/test_memory_id_in_responses.py b/tests/test_memory_id_in_responses.py index 5d251f6d..33992592 100644 --- a/tests/test_memory_id_in_responses.py +++ b/tests/test_memory_id_in_responses.py @@ -1,7 +1,7 @@ """DEV-1428: all response models carry str ids. ``SaveMemoryResponse.memory_id`` and ``ForgetMemoryResponse.deleted_id`` -flip to ``str``; ``MemoryHit.id`` / ``ExampleQueryHit.id`` also flip. +flip to ``str``; ``SearchHit.id`` also carries a str id for memory hits. """ from __future__ import annotations @@ -49,9 +49,10 @@ async def test_memory_hit_id_is_str( learning="orders revenue", entities=["mydb.orders.amount"], ) svc = SearchService(storage=storage) - resp = await svc.search(entities=["mydb.orders.amount"]) - assert resp.memories - assert isinstance(resp.memories[0].id, str) + resp = await svc.search(entities=["mydb.orders.amount"], max_results=20) + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + assert memory_hits + assert isinstance(memory_hits[0].id, str) async def test_example_query_hit_id_is_str( self, storage: StorageBackend, @@ -66,6 +67,7 @@ async def test_example_query_hit_id_is_str( query=attached, ) svc = SearchService(storage=storage) - resp = await svc.search(entities=["mydb.orders.amount"]) - assert resp.example_queries - assert isinstance(resp.example_queries[0].id, str) + resp = await svc.search(entities=["mydb.orders.amount"], max_results=20) + example_query_hits = [h for h in resp.results if h.kind == "memory" and h.query is not None] + assert example_query_hits + assert isinstance(example_query_hits[0].id, str) diff --git a/tests/test_search_datasource_filter.py b/tests/test_search_datasource_filter.py index ea66dfe5..a60aa149 100644 --- a/tests/test_search_datasource_filter.py +++ b/tests/test_search_datasource_filter.py @@ -92,9 +92,10 @@ async def test_filter_keeps_memory_tagged_only_at_datasource( response = await service.search( entities=["prod.orders.amount"], datasource="prod", - max_memories=10, + max_results=20, ) - learnings = {h.text for h in response.memories} + memory_hits = [h for h in response.results if h.kind == "memory"] + learnings = {h.text for h in memory_hits} assert "prod-only: amount excludes tax" in learnings assert "staging-only: amount includes tax" not in learnings @@ -108,17 +109,17 @@ async def test_filter_keeps_cross_datasource_memory( response_prod = await service.search( entities=["prod.orders.amount"], datasource="prod", - max_memories=10, + max_results=20, ) - learnings_prod = {h.text for h in response_prod.memories} + learnings_prod = {h.text for h in response_prod.results if h.kind == "memory"} assert "cross: amount is gross" in learnings_prod response_staging = await service.search( entities=["staging.orders.amount"], datasource="staging", - max_memories=10, + max_results=20, ) - learnings_staging = {h.text for h in response_staging.memories} + learnings_staging = {h.text for h in response_staging.results if h.kind == "memory"} assert "cross: amount is gross" in learnings_staging @@ -132,9 +133,9 @@ async def test_filter_drops_other_datasource_memory( response = await service.search( entities=["prod.orders.amount"], datasource="prod", - max_memories=10, + max_results=20, ) - learnings = {h.text for h in response.memories} + learnings = {h.text for h in response.results if h.kind == "memory"} assert "staging-only: amount includes tax" not in learnings @@ -147,9 +148,9 @@ async def test_filter_drops_untagged_memory( response = await service.search( entities=["prod.orders.amount"], datasource="prod", - max_memories=10, + max_results=20, ) - learnings = {h.text for h in response.memories} + learnings = {h.text for h in response.results if h.kind == "memory"} assert "free-floating note" not in learnings @@ -168,9 +169,9 @@ async def test_filter_excludes_other_datasource_entity_hits( response = await service.search( question="orders amount", datasource="prod", - max_entities=10, + max_results=20, ) - canonical_ids = {h.id for h in response.entities} + canonical_ids = {h.id for h in response.results if h.kind != "memory"} # All entity hits must be rooted at 'prod' (exact or dotted descendant). for cid in canonical_ids: assert cid == "prod" or cid.startswith("prod.") @@ -186,9 +187,9 @@ async def test_no_filter_returns_both_datasources( """Sanity: without the filter, both datasources' entities are eligible.""" response = await service.search( - question="orders amount", max_entities=10, + question="orders amount", max_results=20, ) - canonical_ids = {h.id for h in response.entities} + canonical_ids = {h.id for h in response.results if h.kind != "memory"} has_prod = any(cid == "prod" or cid.startswith("prod.") for cid in canonical_ids) has_staging = any( cid == "staging" or cid.startswith("staging.") for cid in canonical_ids @@ -210,7 +211,7 @@ async def test_unknown_datasource_raises( await service.search( question="anything", datasource="does_not_exist", - max_memories=5, + max_results=5, ) @@ -239,9 +240,9 @@ async def test_filter_with_recency_fallback( """No entities, no query, no question → recency fallback. The datasource filter still applies to which memories are eligible.""" response = await service.search( - datasource="prod", max_memories=10, + datasource="prod", max_results=20, ) - learnings = {h.text for h in response.memories} + learnings = {h.text for h in response.results if h.kind == "memory"} # The prod-only and cross memories surface; staging-only and untagged don't. assert "prod-only: amount excludes tax" in learnings assert "cross: amount is gross" in learnings @@ -258,9 +259,9 @@ async def test_none_datasource_is_no_filter( response_none = await service.search( entities=["prod.orders.amount", "staging.orders.amount"], datasource=None, - max_memories=10, + max_results=20, ) - learnings = {h.text for h in response_none.memories} + learnings = {h.text for h in response_none.results if h.kind == "memory"} # All three tagged memories should be eligible. assert "prod-only: amount excludes tax" in learnings assert "staging-only: amount includes tax" in learnings @@ -287,13 +288,13 @@ async def test_empty_datasource_returns_empty_response( response = await svc.search( question="anything", datasource="empty_ds", - max_memories=5, - max_entities=5, + max_results=10, ) - assert response.memories == [] - assert response.example_queries == [] + memory_hits = [h for h in response.results if h.kind == "memory"] + assert memory_hits == [] # Entity ranking may include the datasource doc itself, depending # on whether build_in_memory_corpus indexes empty-model datasources. # Either way: zero hits from a non-existent datasource. - for hit in response.entities: - assert hit.id == "empty_ds" or hit.id.startswith("empty_ds.") + for hit in response.results: + if hit.kind != "memory": + assert hit.id == "empty_ds" or hit.id.startswith("empty_ds.") diff --git a/tests/test_search_example_query_stale_warning.py b/tests/test_search_example_query_stale_warning.py index 01fc39e1..e9fdb154 100644 --- a/tests/test_search_example_query_stale_warning.py +++ b/tests/test_search_example_query_stale_warning.py @@ -45,16 +45,18 @@ async def test_stale_query_emits_warning_but_surfaces_memory( await storage.save_model(updated) svc = SearchService(storage=storage) - resp = await svc.search(question="paid revenue") - if not resp.example_queries: - resp = await svc.search(question="revenue") + resp = await svc.search(question="paid revenue", max_results=20) + example_queries = [h for h in resp.results if h.kind == "memory" and h.query is not None] + if not example_queries: + resp = await svc.search(question="revenue", max_results=20) + example_queries = [h for h in resp.results if h.kind == "memory" and h.query is not None] # The query-bearing memory must still surface. - eq_ids = [eq.id for eq in resp.example_queries] + eq_ids = [eq.id for eq in example_queries] assert str(seed.id) in eq_ids, ( f"expected memory {seed.id!r} in example_queries; got {eq_ids}" ) # And its attached query must be unchanged (we don't rewrite). - eq = next(e for e in resp.example_queries if e.id == str(seed.id)) + eq = next(e for e in example_queries if e.id == str(seed.id)) assert eq.query.source_model == "orders" assert eq.query.measures is not None assert eq.query.measures[0].formula == "amount:sum" diff --git a/tests/test_search_graph.py b/tests/test_search_graph.py index 0e426b81..fcb39594 100644 --- a/tests/test_search_graph.py +++ b/tests/test_search_graph.py @@ -275,11 +275,11 @@ async def test_join_sync_storage_fingerprint_delegates( async def test_cypher_missing_id_column_friendly_error( rich_storage: YAMLStorage, ) -> None: - # Query passes _validate_cypher (AS id appears in WITH clause) + # Query passes _validate_cypher (AS id appears in the WITH clause) # but the final RETURN exposes 'name' instead of 'id'. with pytest.raises(ValueError, match="must return a column named 'id'"): await get_filtered_ids( - "MATCH (m:Memory) WITH m.id AS id RETURN m.learning AS name", + "MATCH (m:Memory) WITH m.id AS id RETURN id AS name", rich_storage, ) @@ -325,7 +325,7 @@ async def test_graph_model_node_count(rich_storage: YAMLStorage) -> None: @pytest.mark.asyncio async def test_graph_column_node_count(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (c:Column) RETURN c.id AS id", rich_storage + "MATCH (c:ModelColumn) RETURN c.id AS id", rich_storage ) # orders: 3, customers: 3, events: 2 → 8 total assert len(ids) == 8 @@ -372,7 +372,7 @@ async def test_memory_node_learning_property(rich_storage: YAMLStorage) -> None: @pytest.mark.asyncio async def test_column_node_data_type_property(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (c:Column {data_type: 'DOUBLE'}) RETURN c.id AS id", + "MATCH (c:ModelColumn {data_type: 'DOUBLE'}) RETURN c.id AS id", rich_storage, ) assert "shop.orders.amount" in ids @@ -399,7 +399,7 @@ async def test_model_description_property(rich_storage: YAMLStorage) -> None: @pytest.mark.asyncio async def test_mentions_memory_to_column(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (m:Memory)-[:MENTIONS]->(c:Column {id: 'shop.orders.amount'}) " + "MATCH (m:Memory)-[:MENTIONS]->(c:ModelColumn {id: 'shop.orders.amount'}) " "RETURN m.id AS id", rich_storage, ) @@ -506,7 +506,7 @@ async def test_contains_datasource_to_model(rich_storage: YAMLStorage) -> None: @pytest.mark.asyncio async def test_contains_model_to_column(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (m:Model {id: 'shop.orders'})-[:CONTAINS]->(c:Column) " + "MATCH (m:Model {id: 'shop.orders'})-[:CONTAINS]->(c:ModelColumn) " "RETURN c.id AS id", rich_storage, ) @@ -541,7 +541,7 @@ async def test_contains_multi_hop_datasource_to_column( rich_storage: YAMLStorage, ) -> None: ids = await get_filtered_ids( - "MATCH (d:Datasource {id: 'shop'})-[:CONTAINS*2]->(c:Column) " + "MATCH (d:Datasource {id: 'shop'})-[:CONTAINS*2]->(c:ModelColumn) " "RETURN c.id AS id", rich_storage, ) @@ -575,7 +575,7 @@ async def test_joins_model_to_model(rich_storage: YAMLStorage) -> None: @pytest.mark.asyncio async def test_multi_label_union_column_and_measure(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (n:Column:Measure) RETURN n.id AS id", rich_storage + "MATCH (n:ModelColumn:Measure) RETURN n.id AS id", rich_storage ) assert "shop.orders.amount" in ids # Column assert "shop.orders.total_revenue" in ids # Measure @@ -585,7 +585,7 @@ async def test_multi_label_union_column_and_measure(rich_storage: YAMLStorage) - @pytest.mark.asyncio async def test_multi_label_all_data_model_entities(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (n:Datasource:Model:Column:Measure:Aggregation) RETURN n.id AS id", + "MATCH (n:Datasource:Model:ModelColumn:Measure:Aggregation) RETURN n.id AS id", rich_storage, ) assert "shop" in ids # Datasource @@ -622,7 +622,7 @@ async def test_hidden_column_excluded_from_graph() -> None: ) clear_cache() all_cols = await get_filtered_ids( - "MATCH (c:Column) RETURN c.id AS id", storage + "MATCH (c:ModelColumn) RETURN c.id AS id", storage ) assert "ds.orders.internal_flag" not in all_cols assert "ds.orders.id" in all_cols @@ -788,10 +788,10 @@ def test_validate_cypher_accepts_match_return() -> None: _validate_cypher("MATCH (m:Memory) RETURN m.id AS id") _validate_cypher( - "MATCH (m:Memory)-[:MENTIONS]->(e:Column) WHERE e.id = 'x' RETURN m.id AS id" + "MATCH (m:Memory)-[:MENTIONS]->(e:ModelColumn) WHERE e.id = 'x' RETURN m.id AS id" ) _validate_cypher( - "MATCH (e:Column)<-[:CONTAINS*1..3]-(d:Datasource) RETURN e.id AS id" + "MATCH (e:ModelColumn)<-[:CONTAINS*1..3]-(d:Datasource) RETURN e.id AS id" ) @@ -826,7 +826,7 @@ def test_validate_cypher_rejects_invalid(bad_cypher: str) -> None: @pytest.mark.asyncio async def test_cypher_property_filter(rich_storage: YAMLStorage) -> None: ids = await get_filtered_ids( - "MATCH (c:Column {id: 'shop.orders.amount'}) RETURN c.id AS id", + "MATCH (c:ModelColumn {id: 'shop.orders.amount'}) RETURN c.id AS id", rich_storage, ) assert ids == {"shop.orders.amount"} @@ -838,7 +838,7 @@ async def test_cypher_path_traversal_memories_of_model( rich_storage: YAMLStorage, ) -> None: ids = await get_filtered_ids( - "MATCH (m:Memory)-[:MENTIONS]->(e:Column) " + "MATCH (m:Memory)-[:MENTIONS]->(e:ModelColumn) " "WHERE e.id STARTS WITH 'shop.orders.' " "RETURN m.id AS id", rich_storage, @@ -879,9 +879,7 @@ async def test_zero_match_cypher_returns_empty_response_with_warning( "MATCH (m:Memory {id: 'memory:nonexistent-9999'}) RETURN m.id AS id" ), ) - assert response.memories == [] - assert response.example_queries == [] - assert response.entities == [] + assert response.results == [] assert any("zero" in w.lower() or "no" in w.lower() for w in response.warnings) @@ -946,16 +944,18 @@ async def test_no_cypher_filter_does_not_invoke_graph( @pytest.mark.asyncio -async def test_ladybug_not_installed_with_cypher_filter_raises( +async def test_ladybug_not_installed_with_cypher_filter_uses_naive_fallback( rich_storage: YAMLStorage, ) -> None: + """When ladybug is absent, cypher_filter falls back to the naive label parser + rather than raising — search still returns a valid SearchResponse.""" with patch("slayer.search.graph.is_available", return_value=False): service = SearchService(storage=rich_storage) - with pytest.raises(ValueError, match="(?i)ladybug|graph|not installed"): - await service.search( - entities=["shop.orders.amount"], - cypher_filter="MATCH (m:Memory) RETURN m.id AS id", - ) + response = await service.search( + entities=["shop.orders.amount"], + cypher_filter="MATCH (m:Memory) RETURN m.id AS id", + ) + assert isinstance(response, SearchResponse) @pytest.mark.asyncio @@ -1031,7 +1031,7 @@ async def test_candidate_ids_results_are_exact_subset_of_cypher_output( Cypher — no memory outside the candidate set must surface.""" # Cypher: only the memory that mentions shop.orders.amount cypher = ( - "MATCH (m:Memory)-[:MENTIONS]->(c:Column {id: 'shop.orders.amount'}) " + "MATCH (m:Memory)-[:MENTIONS]->(c:ModelColumn {id: 'shop.orders.amount'}) " "RETURN m.id AS id" ) expected_candidates = await get_filtered_ids(cypher, rich_storage) @@ -1041,10 +1041,10 @@ async def test_candidate_ids_results_are_exact_subset_of_cypher_output( response = await service.search( entities=["shop.orders.amount"], cypher_filter=cypher, - max_memories=10, + max_results=20, ) - returned_ids = {h.id for h in response.memories} + returned_ids = {h.id for h in response.results if h.kind == "memory"} # Every returned memory id, when prefixed with "memory:", must be in candidates for bare_id in returned_ids: assert f"memory:{bare_id}" in expected_candidates, ( @@ -1070,12 +1070,12 @@ async def test_cypher_filter_and_datasource_are_intersected( entities=["shop.orders.amount"], datasource="analytics", cypher_filter=( - "MATCH (m:Memory)-[:MENTIONS]->(c:Column) " + "MATCH (m:Memory)-[:MENTIONS]->(c:ModelColumn) " "WHERE c.id STARTS WITH 'shop.' " "RETURN m.id AS id" ), ) - assert response.memories == [] + assert [h for h in response.results if h.kind == "memory"] == [] # --------------------------------------------------------------------------- @@ -1109,7 +1109,7 @@ async def test_stale_canonical_ids_in_candidate_set_cause_no_error( question="deleted memory", ) assert isinstance(response, SearchResponse) - result_ids = {h.id for h in response.memories} + result_ids = {h.id for h in response.results if h.kind == "memory"} assert mem.id not in result_ids @@ -1133,7 +1133,7 @@ def test_validate_cypher_accepts_mutation_keyword_in_string_literal() -> None: ) # Double-quoted literal with DROP. _validate_cypher( - 'MATCH (c:Column) WHERE c.description = "drop rate" RETURN c.id AS id' + 'MATCH (c:ModelColumn) WHERE c.description = "drop rate" RETURN c.id AS id' ) @@ -1177,6 +1177,6 @@ async def test_cypher_filter_applied_to_recency_fallback( cypher_filter="MATCH (m:Memory) RETURN m.id AS id", ) - result_ids = {h.id for h in response.memories} + result_ids = {h.id for h in response.results if h.kind == "memory"} assert mem_in.id in result_ids assert mem_out.id not in result_ids diff --git a/tests/test_search_invariance.py b/tests/test_search_invariance.py index 1a5d82be..239cf447 100644 --- a/tests/test_search_invariance.py +++ b/tests/test_search_invariance.py @@ -1,18 +1,15 @@ -"""Per-bucket ranking invariance (DEV-1414). - -For a fixed ``(question, datasource, max_X)``, the user-visible list of -``X`` (``memories`` / ``example_queries`` / ``entities``) must be a pure -function of the corpus + question + that one cap. Changing the OTHER -two caps must not move any id in or out of the returned ``X`` list, -nor reorder it. - -These tests exercise the bug reported in DEV-1414: the previous -``over_fetch_budget = max(max_memories + max_example_queries, -max_entities) * 5`` shared one candidate-pool cap across all three -channels, so changing ``max_entities`` or ``max_example_queries`` would -push memories in or out of the bottom of each channel's per-kind -ranking — and the membership/order at the top of the fused memory list -would shift even though the question and ``max_memories`` were fixed. +"""Per-result ranking invariance. + +For a fixed ``(question, datasource)``, the ranked order within each +kind-category (memories / example_queries / entities) must not change +when ``max_results`` is increased. Items surfaced at a smaller cap must +appear in the same order at a larger cap — only new items may be appended +at the bottom. + +These tests replace the pre-flat-API per-bucket invariance suite. The old +tests exercised that changing ``max_entities`` didn't perturb ``memories``, +etc.; that property is a consequence of independent per-kind ranking and +is validated here by checking order-stability as ``max_results`` grows. """ from __future__ import annotations @@ -222,202 +219,100 @@ async def service_invariance( # --------------------------------------------------------------------------- -async def _ids(service: SearchService, **kwargs) -> dict[str, list]: +async def _ids_by_kind(service: SearchService, **kwargs) -> dict[str, list]: + """Return per-kind id lists from a search response.""" response = await service.search(**kwargs) return { - "memories": [h.id for h in response.memories], - "example_queries": [h.id for h in response.example_queries], - "entities": [h.id for h in response.entities], + "memories": [h.id for h in response.results if h.kind == "memory" and h.query is None], + "example_queries": [h.id for h in response.results if h.kind == "memory" and h.query is not None], + "entities": [h.id for h in response.results if h.kind != "memory"], } +def _is_prefix(shorter: list, longer: list) -> bool: + """Return True if ``shorter`` is a prefix of ``longer``.""" + return longer[:len(shorter)] == shorter + + # --------------------------------------------------------------------------- -# Memory-bucket invariance under entity / example-query caps +# Memory-bucket order stability as max_results grows # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_memories_invariant_under_max_entities( +async def test_memories_order_stable_as_max_results_grows( service_invariance: SearchService, ) -> None: - """Varying ``max_entities`` (with question + datasource + - max_memories + max_example_queries fixed) must not change the - `memories` id list or its order. Tight caps exercise the bottom - cliff in the legacy ``over_fetch_budget``.""" - base = await _ids( + """Items surfaced at a smaller max_results must appear in the same order + at a larger cap — only new items may be appended at the bottom. + + With a flat ranked list the invariant is: the ids returned at cap N are + a prefix of those returned at cap N+k (within the same kind).""" + small = await _ids_by_kind( service_invariance, question="amount paid refund revenue customer email warehouse", datasource="warehouse", - max_memories=3, - max_example_queries=0, - max_entities=2, + max_results=5, ) - for max_entities in (0, 1, 5, 50, 200): - other = await _ids( - service_invariance, - question="amount paid refund revenue customer email warehouse", - datasource="warehouse", - max_memories=3, - max_example_queries=0, - max_entities=max_entities, - ) - assert other["memories"] == base["memories"], ( - f"memories order changed when max_entities went 2 -> " - f"{max_entities}: {base['memories']} vs {other['memories']}" - ) - - -@pytest.mark.asyncio -async def test_memories_invariant_under_max_example_queries( - service_invariance: SearchService, -) -> None: - base = await _ids( + large = await _ids_by_kind( service_invariance, question="amount paid refund revenue customer email warehouse", datasource="warehouse", - max_memories=3, - max_example_queries=0, - max_entities=2, + max_results=30, + ) + assert _is_prefix(small["memories"], large["memories"]), ( + f"memory order changed as max_results grew: " + f"{small['memories']} is not a prefix of {large['memories']}" ) - for max_example_queries in (0, 1, 5, 20, 100): - other = await _ids( - service_invariance, - question="amount paid refund revenue customer email warehouse", - datasource="warehouse", - max_memories=3, - max_example_queries=max_example_queries, - max_entities=2, - ) - assert other["memories"] == base["memories"], ( - f"memories order changed when max_example_queries went 0 -> " - f"{max_example_queries}: {base['memories']} vs " - f"{other['memories']}" - ) - - -# --------------------------------------------------------------------------- -# example_queries-bucket invariance under memory / entity caps -# --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_example_queries_invariant_under_max_memories( +async def test_example_queries_order_stable_as_max_results_grows( service_invariance: SearchService, ) -> None: - base = await _ids( + small = await _ids_by_kind( service_invariance, question="revenue rollup amount paid", datasource="warehouse", - max_memories=5, - max_example_queries=5, - max_entities=5, + max_results=5, ) - for max_memories in (0, 1, 10, 50): - other = await _ids( - service_invariance, - question="revenue rollup amount paid", - datasource="warehouse", - max_memories=max_memories, - max_example_queries=5, - max_entities=5, - ) - assert other["example_queries"] == base["example_queries"], ( - f"example_queries order changed when max_memories went 5 -> " - f"{max_memories}: {base['example_queries']} vs " - f"{other['example_queries']}" - ) - - -@pytest.mark.asyncio -async def test_example_queries_invariant_under_max_entities( - service_invariance: SearchService, -) -> None: - base = await _ids( + large = await _ids_by_kind( service_invariance, question="revenue rollup amount paid", datasource="warehouse", - max_memories=5, - max_example_queries=5, - max_entities=5, + max_results=30, + ) + assert _is_prefix(small["example_queries"], large["example_queries"]), ( + f"example_queries order changed as max_results grew: " + f"{small['example_queries']} is not a prefix of {large['example_queries']}" ) - for max_entities in (0, 1, 20, 100): - other = await _ids( - service_invariance, - question="revenue rollup amount paid", - datasource="warehouse", - max_memories=5, - max_example_queries=5, - max_entities=max_entities, - ) - assert other["example_queries"] == base["example_queries"], ( - f"example_queries order changed when max_entities went 5 -> " - f"{max_entities}: {base['example_queries']} vs " - f"{other['example_queries']}" - ) - - -# --------------------------------------------------------------------------- -# entities-bucket invariance under memory / example-query caps -# --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_entities_invariant_under_max_memories( +async def test_entities_order_stable_as_max_results_grows( service_invariance: SearchService, ) -> None: - base = await _ids( + small = await _ids_by_kind( service_invariance, question="amount paid refund customer email warehouse shipping", datasource="warehouse", - max_memories=2, - max_example_queries=0, - max_entities=3, + max_results=5, ) - for max_memories in (0, 1, 20, 100): - other = await _ids( - service_invariance, - question="amount paid refund customer email warehouse shipping", - datasource="warehouse", - max_memories=max_memories, - max_example_queries=0, - max_entities=3, - ) - assert other["entities"] == base["entities"], ( - f"entities order changed when max_memories went 2 -> " - f"{max_memories}: {base['entities']} vs {other['entities']}" - ) - - -@pytest.mark.asyncio -async def test_entities_invariant_under_max_example_queries( - service_invariance: SearchService, -) -> None: - base = await _ids( + large = await _ids_by_kind( service_invariance, question="amount paid refund customer email warehouse shipping", datasource="warehouse", - max_memories=2, - max_example_queries=0, - max_entities=3, + max_results=30, + ) + assert _is_prefix(small["entities"], large["entities"]), ( + f"entities order changed as max_results grew: " + f"{small['entities']} is not a prefix of {large['entities']}" ) - for max_example_queries in (0, 1, 5, 30): - other = await _ids( - service_invariance, - question="amount paid refund customer email warehouse shipping", - datasource="warehouse", - max_memories=2, - max_example_queries=max_example_queries, - max_entities=3, - ) - assert other["entities"] == base["entities"], ( - f"entities order changed when max_example_queries went 0 -> " - f"{max_example_queries}: {base['entities']} vs " - f"{other['entities']}" - ) # --------------------------------------------------------------------------- -# DEV-1414 repro tuples +# DEV-1414 repro: same question, same question, different max_results +# Top items must be stable across calls # --------------------------------------------------------------------------- @@ -425,40 +320,34 @@ async def test_entities_invariant_under_max_example_queries( async def test_dev_1414_repro_tuples_yield_same_top_memories( service_invariance: SearchService, ) -> None: - """The exact three call shapes from DEV-1414 (max_memories fixed at - the smaller of the two values, varying entity / example-query caps) - must yield identical top-``min(max_memories)`` memory ids. - - Original repro held max_memories=10 across A and B, then bumped to - 15 in C. Compare the prefix of length 10 across all three.""" - call_a = await _ids( + """With a flat list, three calls with increasing max_results must + yield the same top-N ids in the same order (prefix property).""" + call_a = await _ids_by_kind( service_invariance, question="amount paid refund revenue customer email", datasource="warehouse", - max_memories=10, - max_entities=10, - max_example_queries=5, + max_results=10, ) - call_b = await _ids( + call_b = await _ids_by_kind( service_invariance, question="amount paid refund revenue customer email", datasource="warehouse", - max_memories=10, - max_entities=0, - max_example_queries=0, + max_results=20, ) - call_c = await _ids( + call_c = await _ids_by_kind( service_invariance, question="amount paid refund revenue customer email", datasource="warehouse", - max_memories=15, - max_entities=5, - max_example_queries=2, + max_results=30, + ) + # A is a prefix of B. + assert _is_prefix(call_a["memories"], call_b["memories"]), ( + f"memories prefix violated: {call_a['memories']} vs {call_b['memories']}" + ) + # A is a prefix of C. + assert _is_prefix(call_a["memories"], call_c["memories"]), ( + f"memories prefix violated: {call_a['memories']} vs {call_c['memories']}" ) - # A and B share max_memories=10 → full equality. - assert call_a["memories"] == call_b["memories"] - # C asks for 15 memories; the first 10 must match A and B. - assert call_c["memories"][:10] == call_a["memories"] # --------------------------------------------------------------------------- @@ -537,136 +426,93 @@ async def service_with_embeddings( @pytest.mark.asyncio -async def test_memories_invariant_under_max_entities_with_channel_3_active( +async def test_memories_order_stable_with_channel_3_active( service_with_embeddings: SearchService, ) -> None: - base = await _ids( + small = await _ids_by_kind( service_with_embeddings, question="amount paid refund revenue customer email", datasource="warehouse", - max_memories=10, - max_example_queries=2, - max_entities=5, + max_results=10, + ) + large = await _ids_by_kind( + service_with_embeddings, + question="amount paid refund revenue customer email", + datasource="warehouse", + max_results=30, + ) + assert _is_prefix(small["memories"], large["memories"]), ( + "channel-3 active: memories order changed as max_results grew" ) - for max_entities in (0, 20, 50): - other = await _ids( - service_with_embeddings, - question="amount paid refund revenue customer email", - datasource="warehouse", - max_memories=10, - max_example_queries=2, - max_entities=max_entities, - ) - assert other["memories"] == base["memories"], ( - f"channel-3 active: memories changed when max_entities went " - f"5 -> {max_entities}" - ) @pytest.mark.asyncio -async def test_entities_invariant_under_max_memories_with_channel_3_active( +async def test_entities_order_stable_with_channel_3_active( service_with_embeddings: SearchService, ) -> None: - base = await _ids( + small = await _ids_by_kind( service_with_embeddings, question="amount paid refund customer email warehouse", datasource="warehouse", - max_memories=5, - max_example_queries=2, - max_entities=10, + max_results=10, + ) + large = await _ids_by_kind( + service_with_embeddings, + question="amount paid refund customer email warehouse", + datasource="warehouse", + max_results=30, + ) + assert _is_prefix(small["entities"], large["entities"]), ( + "channel-3 active: entities order changed as max_results grew" ) - for max_memories in (0, 20, 50): - other = await _ids( - service_with_embeddings, - question="amount paid refund customer email warehouse", - datasource="warehouse", - max_memories=max_memories, - max_example_queries=2, - max_entities=10, - ) - assert other["entities"] == base["entities"], ( - f"channel-3 active: entities changed when max_memories went " - f"5 -> {max_memories}" - ) # --------------------------------------------------------------------------- -# DEV-1513: channel-1 entity ranking — per-bucket invariance must hold -# when named entity surfacing is active +# DEV-1513: channel-1 entity ranking — order stability with named entities # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_memories_invariant_under_max_entities_with_channel_1_named( +async def test_memories_order_stable_with_channel_1_named( service_invariance: SearchService, ) -> None: """DEV-1513: with ``entities=[X]`` supplied (channel 1 entity ranking - active), varying ``max_entities`` must not perturb the memories list. - - Control: the entities bucket actually contains the named ref at the - baseline cap, proving channel-1 entity surfacing is active in this - fixture configuration.""" - # No ``question`` so channel 1 is the SOLE entity source — proving the - # invariance loop genuinely exercises the new code path rather than - # passing trivially via channel 2/3 contributions. - base = await _ids( + active), increasing ``max_results`` must not reorder the memories.""" + small = await _ids_by_kind( service_invariance, entities=["warehouse.orders.amount_paid"], datasource="warehouse", - max_memories=3, - max_example_queries=0, - max_entities=5, + max_results=5, ) - assert "warehouse.orders.amount_paid" in base["entities"], ( - "channel-1 named surfacing must be active to make this invariance " - "test meaningful" + large = await _ids_by_kind( + service_invariance, + entities=["warehouse.orders.amount_paid"], + datasource="warehouse", + max_results=20, + ) + assert _is_prefix(small["memories"], large["memories"]), ( + "channel-1 named active: memories order changed as max_results grew" ) - for max_entities in (0, 1, 5, 50, 200): - other = await _ids( - service_invariance, - entities=["warehouse.orders.amount_paid"], - datasource="warehouse", - max_memories=3, - max_example_queries=0, - max_entities=max_entities, - ) - assert other["memories"] == base["memories"], ( - f"channel-1 named active: memories changed when max_entities " - f"went 5 -> {max_entities}" - ) @pytest.mark.asyncio -async def test_entities_invariant_under_max_memories_with_channel_1_named( +async def test_entities_order_stable_with_channel_1_named( service_invariance: SearchService, ) -> None: - """DEV-1513: with ``entities=[X]`` supplied, varying ``max_memories`` - must not perturb the entities list. - - Control: the entities bucket contains the named ref at baseline.""" - # No ``question`` so channel 1 is the SOLE entity source. - base = await _ids( + """DEV-1513: with ``entities=[X]`` supplied, increasing ``max_results`` + must not reorder the entities returned at the smaller cap.""" + small = await _ids_by_kind( service_invariance, entities=["warehouse.orders.amount_paid"], datasource="warehouse", - max_memories=2, - max_example_queries=0, - max_entities=3, + max_results=3, ) - assert "warehouse.orders.amount_paid" in base["entities"], ( - "channel-1 named surfacing must be active to make this invariance " - "test meaningful" + large = await _ids_by_kind( + service_invariance, + entities=["warehouse.orders.amount_paid"], + datasource="warehouse", + max_results=15, + ) + assert _is_prefix(small["entities"], large["entities"]), ( + "channel-1 named active: entities order changed as max_results grew" ) - for max_memories in (0, 1, 20, 100): - other = await _ids( - service_invariance, - entities=["warehouse.orders.amount_paid"], - datasource="warehouse", - max_memories=max_memories, - max_example_queries=0, - max_entities=3, - ) - assert other["entities"] == base["entities"], ( - f"channel-1 named active: entities changed when max_memories " - f"went 2 -> {max_memories}" - ) diff --git a/tests/test_search_lazy_gc_in_memory.py b/tests/test_search_lazy_gc_in_memory.py index 8d7b1360..6f20d7e0 100644 --- a/tests/test_search_lazy_gc_in_memory.py +++ b/tests/test_search_lazy_gc_in_memory.py @@ -33,8 +33,9 @@ async def test_stale_tag_excluded_from_matched_entities( resp = await svc.search( entities=["mydb.orders.amount", "mydb.deleted_model"], ) - assert resp.memories, "expected the live-tag memory to surface" - for hit in resp.memories: + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + assert memory_hits, "expected the live-tag memory to surface" + for hit in memory_hits: assert "mydb.deleted_model" not in hit.matched_entities async def test_recency_fallback_filter_excludes_stale_in_matched( @@ -52,13 +53,14 @@ async def test_recency_fallback_filter_excludes_stale_in_matched( ) svc = SearchService(storage=storage) resp = await svc.search() - learnings = {m.text for m in resp.memories} + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + learnings = {m.text for m in memory_hits} # Both rows survive the recency fallback (no datasource filter, # no entity filter — just newest-N). assert "only stale" in learnings assert "has live" in learnings # No memory's matched_entities should ever name a stale entity. - for hit in resp.memories: + for hit in memory_hits: assert "mydb.does_not_exist" not in hit.matched_entities async def test_no_writeback_on_stale_filter( diff --git a/tests/test_search_lenient_validation.py b/tests/test_search_lenient_validation.py index 5e593f18..b29d994e 100644 --- a/tests/test_search_lenient_validation.py +++ b/tests/test_search_lenient_validation.py @@ -72,9 +72,10 @@ async def test_off_datasource_named_entity_warning_text( resp = await svc.search( entities=["otherdb.catalog"], datasource="mydb", - max_memories=0, max_example_queries=0, max_entities=5, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert entity_hits == [] assert any( "otherdb.catalog" in w and "not rooted at datasource 'mydb'" in w diff --git a/tests/test_search_named_entity_surfacing.py b/tests/test_search_named_entity_surfacing.py index e9abd632..bd262334 100644 --- a/tests/test_search_named_entity_surfacing.py +++ b/tests/test_search_named_entity_surfacing.py @@ -152,14 +152,13 @@ async def test_named_column_surfaces_in_entities_bucket( ) -> None: resp = await service.search( entities=["warehouse.orders.amount"], - max_memories=0, - max_example_queries=0, - max_entities=5, - ) - assert resp.memories == [] - assert resp.example_queries == [] - assert len(resp.entities) == 1 - hit = resp.entities[0] + max_results=20, + ) + memory_hits = [h for h in resp.results if h.kind == "memory"] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert memory_hits == [] + assert len(entity_hits) == 1 + hit = entity_hits[0] assert hit.id == "warehouse.orders.amount" assert hit.kind == "column" assert "amount" in hit.text @@ -170,12 +169,11 @@ async def test_named_model_surfaces_in_entities_bucket( ) -> None: resp = await service.search( entities=["warehouse.orders"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert len(resp.entities) == 1 - hit = resp.entities[0] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert len(entity_hits) == 1 + hit = entity_hits[0] assert hit.id == "warehouse.orders" assert hit.kind == "model" assert "orders" in hit.text @@ -186,12 +184,11 @@ async def test_named_datasource_surfaces_in_entities_bucket( ) -> None: resp = await service.search( entities=["warehouse"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert len(resp.entities) == 1 - hit = resp.entities[0] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert len(entity_hits) == 1 + hit = entity_hits[0] assert hit.id == "warehouse" assert hit.kind == "datasource" assert "warehouse" in hit.text @@ -202,12 +199,11 @@ async def test_named_measure_surfaces_in_entities_bucket( ) -> None: resp = await service.search( entities=["warehouse.orders.aov"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert len(resp.entities) == 1 - hit = resp.entities[0] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert len(entity_hits) == 1 + hit = entity_hits[0] assert hit.id == "warehouse.orders.aov" assert hit.kind == "measure" assert "aov" in hit.text @@ -218,12 +214,11 @@ async def test_named_aggregation_surfaces_in_entities_bucket( ) -> None: resp = await service.search( entities=["warehouse.orders.paid_only_sum"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert len(resp.entities) == 1 - hit = resp.entities[0] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert len(entity_hits) == 1 + hit = entity_hits[0] assert hit.id == "warehouse.orders.paid_only_sum" assert hit.kind == "aggregation" assert "paid_only_sum" in hit.text @@ -241,11 +236,9 @@ async def test_named_and_fuzzy_combine_in_entities_bucket( resp = await service.search( entities=["warehouse.orders.amount"], question="customer email", - max_memories=0, - max_example_queries=0, - max_entities=10, + max_results=20, ) - ids = [h.id for h in resp.entities] + ids = [h.id for h in resp.results if h.kind != "memory"] # Named ref always present. assert "warehouse.orders.amount" in ids # Fuzzy match for "customer email" surfaces customers / its email column. @@ -264,11 +257,10 @@ async def test_unknown_named_ref_is_warning_not_entity_hit( ) -> None: resp = await service.search( entities=["warehouse.orders.no_such_column"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert entity_hits == [] assert any("no_such_column" in w for w in resp.warnings) @@ -283,11 +275,10 @@ async def test_off_datasource_named_ref_drops_with_warning( resp = await service.search( entities=["other_db.products"], datasource="warehouse", - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert entity_hits == [] assert any( "other_db.products" in w and "not rooted at datasource 'warehouse'" in w @@ -305,11 +296,10 @@ async def test_hidden_model_named_ref_drops_with_warning( ) -> None: resp = await service.search( entities=["warehouse.internal_audit"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert entity_hits == [] assert any( "warehouse.internal_audit" in w and "hidden" in w @@ -327,11 +317,10 @@ async def test_hidden_column_named_ref_drops_with_warning( ) -> None: resp = await service.search( entities=["warehouse.orders.internal_token"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert entity_hits == [] assert any( "warehouse.orders.internal_token" in w and "hidden" in w @@ -353,14 +342,15 @@ async def test_memory_id_in_entities_surfaces_memory_in_memories_bucket( ) resp = await service.search( entities=[f"memory:{seed.id}"], - max_memories=5, - max_example_queries=2, - max_entities=0, + max_results=20, ) - ids = [h.id for h in resp.memories] + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + example_query_hits = [h for h in resp.results if h.kind == "memory" and h.query is not None] + entity_hits = [h for h in resp.results if h.kind != "memory"] + ids = [h.id for h in memory_hits] assert seed.id in ids - assert resp.example_queries == [] - assert resp.entities == [] + assert example_query_hits == [] + assert entity_hits == [] async def test_memory_id_in_entities_surfaces_query_bearing_in_example_queries( @@ -376,14 +366,15 @@ async def test_memory_id_in_entities_surfaces_query_bearing_in_example_queries( ) resp = await service.search( entities=[f"memory:{seed.id}"], - max_memories=5, - max_example_queries=5, - max_entities=0, + max_results=20, ) - eq_ids = [h.id for h in resp.example_queries] + example_query_hits = [h for h in resp.results if h.kind == "memory" and h.query is not None] + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + entity_hits = [h for h in resp.results if h.kind != "memory"] + eq_ids = [h.id for h in example_query_hits] assert seed.id in eq_ids - assert resp.memories == [] - assert resp.entities == [] + assert memory_hits == [] + assert entity_hits == [] async def test_memory_id_off_datasource_drops_with_warning( @@ -396,12 +387,10 @@ async def test_memory_id_off_datasource_drops_with_warning( resp = await service.search( entities=[f"memory:{seed.id}"], datasource="warehouse", - max_memories=5, - max_example_queries=2, - max_entities=0, + max_results=20, ) - assert resp.memories == [] - assert resp.example_queries == [] + memory_hits = [h for h in resp.results if h.kind == "memory"] + assert memory_hits == [] assert any( f"memory:{seed.id}" in w and "not rooted at datasource 'warehouse'" in w @@ -417,13 +406,21 @@ async def test_memory_id_off_datasource_drops_with_warning( async def test_max_entities_zero_suppresses_channel_1_named_output( service: SearchService, ) -> None: + # With max_results=1, a single memory hit (if any) would fill the slot, + # but since no memories are tagged on this entity and no question is provided, + # only the named entity hit would appear. Use max_results=20 and filter. + # The original test checked max_entities=0 suppressed entity output. + # With the flat API, we just verify that entity hits are present when max_results>0. + # To mirror the old behavior of max_entities=0, we verify the named entity + # surfaces when max_results is large enough. resp = await service.search( entities=["warehouse.orders.amount"], - max_memories=0, - max_example_queries=0, - max_entities=0, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + # Named entity should surface (unlike old max_entities=0 which suppressed it). + # This test is now a positive assertion: the entity IS present. + assert any(h.id == "warehouse.orders.amount" for h in entity_hits) # --------------------------------------------------------------------------- @@ -440,11 +437,9 @@ async def test_query_arg_entities_also_surface_in_entities_bucket( ) resp = await service.search( query=q, - max_memories=0, - max_example_queries=0, - max_entities=10, + max_results=20, ) - ids = [h.id for h in resp.entities] + ids = [h.id for h in resp.results if h.kind != "memory"] # The query references orders and orders.amount; both should surface. assert "warehouse.orders" in ids assert "warehouse.orders.amount" in ids @@ -464,12 +459,11 @@ async def test_named_ref_at_top_of_entities_when_no_fuzzy_overlap( resp = await service.search( entities=["warehouse.orders.status"], question="completely unrelated phrase", - max_memories=0, - max_example_queries=0, - max_entities=10, + max_results=20, ) - assert resp.entities, "expected at least one entity hit" - assert resp.entities[0].id == "warehouse.orders.status" + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert entity_hits, "expected at least one entity hit" + assert entity_hits[0].id == "warehouse.orders.status" # --------------------------------------------------------------------------- @@ -485,39 +479,39 @@ async def test_rrf_fuses_named_and_fuzzy_on_same_canonical( either channel alone (proves both channels contribute).""" resp_named_only = await service.search( entities=["warehouse.orders.amount"], - max_memories=0, max_example_queries=0, max_entities=5, + max_results=20, ) resp_fuzzy_only = await service.search( question="amount", - max_memories=0, max_example_queries=0, max_entities=5, + max_results=20, ) resp_combined = await service.search( entities=["warehouse.orders.amount"], question="amount", - max_memories=0, max_example_queries=0, max_entities=5, + max_results=20, ) # Sanity: the fuzzy channel alone surfaces the same canonical so the # "combined" call really has two channels contributing the same ref. assert "warehouse.orders.amount" in [ - h.id for h in resp_fuzzy_only.entities + h.id for h in resp_fuzzy_only.results if h.kind != "memory" ] named_only_score = next( - h.score for h in resp_named_only.entities - if h.id == "warehouse.orders.amount" + h.score for h in resp_named_only.results + if h.kind != "memory" and h.id == "warehouse.orders.amount" ) fuzzy_only_score = next( - h.score for h in resp_fuzzy_only.entities - if h.id == "warehouse.orders.amount" + h.score for h in resp_fuzzy_only.results + if h.kind != "memory" and h.id == "warehouse.orders.amount" ) combined_score = next( - h.score for h in resp_combined.entities - if h.id == "warehouse.orders.amount" + h.score for h in resp_combined.results + if h.kind != "memory" and h.id == "warehouse.orders.amount" ) assert combined_score > named_only_score assert combined_score > fuzzy_only_score # No duplicate. - ids = [h.id for h in resp_combined.entities] - assert ids.count("warehouse.orders.amount") == 1 + entity_ids = [h.id for h in resp_combined.results if h.kind != "memory"] + assert entity_ids.count("warehouse.orders.amount") == 1 # --------------------------------------------------------------------------- @@ -534,11 +528,10 @@ async def test_matched_entities_includes_memory_self_ref_when_named( ) resp = await service.search( entities=[f"memory:{seed.id}"], - max_memories=5, - max_example_queries=2, - max_entities=0, + max_results=20, ) - hit = next(h for h in resp.memories if h.id == seed.id) + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + hit = next(h for h in memory_hits if h.id == seed.id) assert f"memory:{seed.id}" in hit.matched_entities @@ -556,16 +549,16 @@ async def test_mixed_memory_id_and_entity_refs_split_across_buckets( ) resp = await service.search( entities=[f"memory:{seed.id}", "warehouse.orders.status"], - max_memories=5, - max_example_queries=2, - max_entities=5, + max_results=20, ) - assert seed.id in [h.id for h in resp.memories] - assert "warehouse.orders.status" in [h.id for h in resp.entities] + memory_hits = [h for h in resp.results if h.kind == "memory"] + entity_hits = [h for h in resp.results if h.kind != "memory"] + assert seed.id in [h.id for h in memory_hits] + assert "warehouse.orders.status" in [h.id for h in entity_hits] # The memory ref does NOT leak into the entities bucket. - assert f"memory:{seed.id}" not in [h.id for h in resp.entities] + assert f"memory:{seed.id}" not in [h.id for h in entity_hits] # The entity ref does NOT leak into the memories bucket. - assert "warehouse.orders.status" not in [h.id for h in resp.memories] + assert "warehouse.orders.status" not in [h.id for h in memory_hits] # --------------------------------------------------------------------------- @@ -581,11 +574,10 @@ async def test_pure_named_datasource_text_includes_visible_excludes_hidden( ``render_datasource_text`` visibility filter).""" resp = await service.search( entities=["warehouse"], - max_memories=0, - max_example_queries=0, - max_entities=5, + max_results=20, ) - hit = next(h for h in resp.entities if h.id == "warehouse") + entity_hits = [h for h in resp.results if h.kind != "memory"] + hit = next(h for h in entity_hits if h.id == "warehouse") assert "orders" in hit.text assert "customers" in hit.text assert "internal_audit" not in hit.text @@ -616,10 +608,11 @@ def fail_if_called(*args, **kwargs): # noqa: ANN001 — test-only sentinel resp = await service.search( entities=["warehouse.orders.amount"], question="amount in cents", - max_memories=0, max_example_queries=0, max_entities=5, + max_results=20, ) + entity_hits = [h for h in resp.results if h.kind != "memory"] hit = next( - h for h in resp.entities if h.id == "warehouse.orders.amount" + h for h in entity_hits if h.id == "warehouse.orders.amount" ) assert hit.kind == "column" assert "amount" in hit.text @@ -640,15 +633,15 @@ async def test_hidden_model_drops_entity_but_memories_still_surface( ) resp = await service.search( entities=["warehouse.internal_audit"], - max_memories=5, - max_example_queries=2, - max_entities=5, + max_results=20, ) + entity_hits = [h for h in resp.results if h.kind != "memory"] + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] # Entity bucket: dropped with warning. - assert resp.entities == [] + assert entity_hits == [] assert any("internal_audit" in w and "hidden" in w for w in resp.warnings) # Memory bucket: still surfaces — BM25 over original tags unaffected. - assert seed.id in [h.id for h in resp.memories] + assert seed.id in [h.id for h in memory_hits] async def test_hidden_column_drops_entity_but_memories_still_surface( @@ -660,16 +653,16 @@ async def test_hidden_column_drops_entity_but_memories_still_surface( ) resp = await service.search( entities=["warehouse.orders.internal_token"], - max_memories=5, - max_example_queries=2, - max_entities=5, + max_results=20, ) - assert resp.entities == [] + entity_hits = [h for h in resp.results if h.kind != "memory"] + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + assert entity_hits == [] assert any( "internal_token" in w and "hidden" in w for w in resp.warnings ) - assert seed.id in [h.id for h in resp.memories] + assert seed.id in [h.id for h in memory_hits] # --------------------------------------------------------------------------- @@ -708,12 +701,11 @@ async def test_stale_query_warning_for_named_memory_id_default_cap( await _make_amount_stale(storage) resp = await service.search( entities=[f"memory:{seed.id}"], - max_memories=0, - max_example_queries=5, - max_entities=0, + max_results=20, ) + example_query_hits = [h for h in resp.results if h.kind == "memory" and h.query is not None] # Hit is surfaced and a stale-query warning fires. - assert seed.id in [h.id for h in resp.example_queries] + assert seed.id in [h.id for h in example_query_hits] assert any( f"memory:{seed.id}" in w and "stale" in w for w in resp.warnings @@ -723,9 +715,10 @@ async def test_stale_query_warning_for_named_memory_id_default_cap( async def test_stale_query_warning_for_named_memory_id_cap_zero( storage: YAMLStorage, service: SearchService, ) -> None: - """Even with ``max_example_queries=0`` (no hit surfaced), an - explicitly-named ``memory:`` that points at a query-bearing - memory with stale refs still emits the stale-query warning.""" + """Even with ``max_results=1`` (forces the hit to be surfaced but the + stale warning should still fire), an explicitly-named ``memory:`` + that points at a query-bearing memory with stale refs still emits the + stale-query warning.""" seed = await storage.save_memory( learning="legacy revenue lookup", entities=["warehouse.orders.amount"], @@ -737,11 +730,8 @@ async def test_stale_query_warning_for_named_memory_id_cap_zero( await _make_amount_stale(storage) resp = await service.search( entities=[f"memory:{seed.id}"], - max_memories=0, - max_example_queries=0, - max_entities=0, + max_results=20, ) - assert resp.example_queries == [] assert any( f"memory:{seed.id}" in w and "stale" in w for w in resp.warnings @@ -836,8 +826,7 @@ async def test_memory_self_ref_preserved_when_all_stored_tags_are_stale( await _make_amount_stale(storage) resp = await service.search( entities=[f"memory:{seed.id}"], - max_memories=5, - max_example_queries=2, - max_entities=0, + max_results=20, ) - assert seed.id in [h.id for h in resp.memories] + memory_hits = [h for h in resp.results if h.kind == "memory" and h.query is None] + assert seed.id in [h.id for h in memory_hits] diff --git a/tests/test_search_service.py b/tests/test_search_service.py index 01bbf661..b0f80b58 100644 --- a/tests/test_search_service.py +++ b/tests/test_search_service.py @@ -26,48 +26,24 @@ import pytest import pytest_asyncio -from slayer.core.enums import DataType -from slayer.core.models import Column, DatasourceConfig, ModelMeasure, SlayerModel +from slayer.core.models import ModelMeasure from slayer.core.query import SlayerQuery from slayer.search.service import ( - EntityHit, - ExampleQueryHit, - MemoryHit, + SearchHit, SearchResponse, SearchService, ) from slayer.storage.base import StorageBackend, resolve_storage +from tests.search_helpers import seed_warehouse_models + @pytest_asyncio.fixture async def storage_with_corpus() -> AsyncIterator[StorageBackend]: """A small fixture corpus: 1 datasource, 2 models, 4 memories.""" with tempfile.TemporaryDirectory() as tmpdir: storage = resolve_storage(tmpdir) - await storage.save_datasource(DatasourceConfig(name="warehouse", type="sqlite", database=":memory:")) - await storage.save_model(SlayerModel( - name="orders", - sql_table="orders", - data_source="warehouse", - description="Checkout orders.", - columns=[ - Column(name="id", type=DataType.INT, primary_key=True), - Column(name="amount_paid", type=DataType.DOUBLE, - description="Net paid in USD."), - Column(name="status", type=DataType.TEXT, - description="paid|refunded|cancelled."), - ], - )) - await storage.save_model(SlayerModel( - name="customers", - sql_table="customers", - data_source="warehouse", - description="Customer master data.", - columns=[ - Column(name="id", type=DataType.INT, primary_key=True), - Column(name="email", type=DataType.TEXT), - ], - )) + await seed_warehouse_models(storage) # 4 memories: 2 tagged on orders.amount_paid, 1 on customers, 1 untagged await storage.save_memory( learning="amount_paid is gross of refunds.", @@ -105,16 +81,17 @@ async def test_entities_and_question_both_set_runs_both_channels( response = await service.search( entities=["warehouse.orders.amount_paid"], question="paid revenue", - max_memories=5, - max_entities=5, + max_results=20, ) assert isinstance(response, SearchResponse) # Channel 1 should surface the memory tagged on amount_paid. - learnings = [h.text for h in response.memories] + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + learnings = [h.text for h in memory_hits] assert any("gross of refunds" in lm for lm in learnings) # Channel 2 should surface entity hits. - assert response.entities, "expected channel 2 to surface at least one entity hit" - assert all(isinstance(h, EntityHit) for h in response.entities) + entity_hits = [h for h in response.results if h.kind != "memory"] + assert entity_hits, "expected channel 2 to surface at least one entity hit" + assert all(isinstance(h, SearchHit) for h in entity_hits) @pytest.mark.asyncio @@ -123,15 +100,16 @@ async def test_entities_only_runs_channel_1_only(service: SearchService) -> None entity hits (DEV-1513 implicit self-reference).""" response = await service.search( entities=["warehouse.orders.amount_paid"], - max_memories=5, - max_entities=5, + max_results=20, ) + entity_hits = [h for h in response.results if h.kind != "memory"] + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] # DEV-1513: the named ref itself surfaces in the entities bucket. assert any( - h.id == "warehouse.orders.amount_paid" for h in response.entities + h.id == "warehouse.orders.amount_paid" for h in entity_hits ) # Memories include the two tagged on amount_paid (both have query=None). - learnings = [h.text for h in response.memories] + learnings = [h.text for h in memory_hits] assert any("gross of refunds" in lm for lm in learnings) @@ -146,15 +124,15 @@ async def test_query_only_runs_channel_1_via_extracted_entities( "source_model": "orders", "measures": [{"formula": "amount_paid:sum"}], }, - max_memories=5, - max_entities=5, + max_results=20, ) # DEV-1513: the query's source model and referenced column surface # in the entities bucket. - entity_ids = {h.id for h in response.entities} + entity_ids = {h.id for h in response.results if h.kind != "memory"} assert "warehouse.orders" in entity_ids assert "warehouse.orders.amount_paid" in entity_ids - learnings = [h.text for h in response.memories] + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + learnings = [h.text for h in memory_hits] assert any("gross of refunds" in lm for lm in learnings) @@ -162,21 +140,23 @@ async def test_query_only_runs_channel_1_via_extracted_entities( async def test_question_only_runs_channel_2_only(service: SearchService) -> None: response = await service.search( question="anonymous checkouts", - max_memories=5, - max_entities=5, + max_results=20, ) # Channel 1 was skipped → memories come only from tantivy memory subset. - learnings = [h.text for h in response.memories] + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + learnings = [h.text for h in memory_hits] assert any("anonymous" in lm for lm in learnings) @pytest.mark.asyncio async def test_all_empty_falls_back_to_recency(service: SearchService) -> None: - response = await service.search(max_memories=2, max_entities=5) - assert response.entities == [] + response = await service.search(max_results=2) + entity_hits = [h for h in response.results if h.kind != "memory"] + memory_hits = [h for h in response.results if h.kind == "memory"] + assert entity_hits == [] # Newest first: the 4th saved memory should appear before the 1st. - assert len(response.memories) == 2 - assert any("free-floating" in h.text for h in response.memories) + assert len(memory_hits) == 2 + assert any("free-floating" in h.text for h in memory_hits) # Warning explains the fallback. assert any("recency" in w.lower() for w in response.warnings) @@ -190,30 +170,26 @@ async def test_all_empty_falls_back_to_recency(service: SearchService) -> None: async def test_max_memories_caps_memory_list(service: SearchService) -> None: response = await service.search( entities=["warehouse.orders.amount_paid", "warehouse.orders.status"], - max_memories=1, - max_entities=5, + max_results=1, ) - assert len(response.memories) <= 1 + assert len(response.results) <= 1 @pytest.mark.asyncio async def test_max_entities_caps_entity_list(service: SearchService) -> None: response = await service.search( question="orders amount status customer email id", - max_memories=5, - max_entities=2, + max_results=2, ) - assert len(response.entities) <= 2 + assert len(response.results) <= 2 @pytest.mark.asyncio async def test_negative_caps_rejected(service: SearchService) -> None: with pytest.raises(ValueError): - await service.search(question="x", max_memories=-1) - with pytest.raises(ValueError): - await service.search(question="x", max_entities=-1) + await service.search(question="x", max_results=-1) with pytest.raises(ValueError): - await service.search(question="x", max_example_queries=-1) + await service.search(question="x", max_results=0) # --------------------------------------------------------------------------- @@ -239,17 +215,19 @@ async def test_unknown_entity_becomes_warning(service: SearchService) -> None: @pytest.mark.asyncio async def test_memory_hit_id_is_str(service: SearchService) -> None: - """DEV-1428: ``MemoryHit.id`` is the str memory id.""" - response = await service.search(entities=["warehouse.orders.amount_paid"]) - for hit in response.memories: + """DEV-1428: memory SearchHit.id is the str memory id.""" + response = await service.search(entities=["warehouse.orders.amount_paid"], max_results=20) + memory_hits = [h for h in response.results if h.kind == "memory"] + for hit in memory_hits: assert isinstance(hit.id, str) assert hit.id != "" @pytest.mark.asyncio async def test_entity_hit_id_is_canonical_string(service: SearchService) -> None: - response = await service.search(question="amount_paid status") - for hit in response.entities: + response = await service.search(question="amount_paid status", max_results=20) + entity_hits = [h for h in response.results if h.kind != "memory"] + for hit in entity_hits: assert isinstance(hit.id, str) assert hit.kind in {"datasource", "model", "column", "measure", "aggregation"} @@ -257,16 +235,18 @@ async def test_entity_hit_id_is_canonical_string(service: SearchService) -> None @pytest.mark.asyncio async def test_memory_hit_text_is_full_indexed_text(service: SearchService) -> None: """`text` must be the full indexed text — no truncation.""" - response = await service.search(entities=["warehouse.orders.amount_paid"]) - assert all(isinstance(h.text, str) and len(h.text) > 0 for h in response.memories) + response = await service.search(entities=["warehouse.orders.amount_paid"], max_results=20) + memory_hits = [h for h in response.results if h.kind == "memory"] + assert all(isinstance(h.text, str) and len(h.text) > 0 for h in memory_hits) @pytest.mark.asyncio async def test_memory_matched_entities_populated_from_channel_1( service: SearchService, ) -> None: - response = await service.search(entities=["warehouse.orders.amount_paid"]) - for hit in response.memories: + response = await service.search(entities=["warehouse.orders.amount_paid"], max_results=20) + memory_hits = [h for h in response.results if h.kind == "memory"] + for hit in memory_hits: assert "warehouse.orders.amount_paid" in hit.matched_entities @@ -281,8 +261,7 @@ async def test_empty_corpus_returns_empty_with_warning() -> None: storage = resolve_storage(tmpdir) service = SearchService(storage=storage) response = await service.search(question="anything") - assert response.memories == [] - assert response.entities == [] + assert response.results == [] # --------------------------------------------------------------------------- @@ -300,9 +279,10 @@ async def test_memory_appearing_in_both_channels_outranks_single_channel( response = await service.search( entities=["warehouse.orders.amount_paid"], question="amount_paid gross refunds", - max_memories=5, + max_results=20, ) - learnings_in_order = [h.text for h in response.memories] + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + learnings_in_order = [h.text for h in memory_hits] # Memory 1 ("amount_paid is gross of refunds") matches both channels. # Memory 2 ("Filter status='paid' for net revenue.") matches only via # entity overlap on amount_paid — tantivy doesn't pick it up on the @@ -380,7 +360,7 @@ async def test_resolved_input_entities_combined_input_dedupes( async def test_resolved_input_entities_empty_on_recency_fallback( service: SearchService, ) -> None: - response = await service.search(max_memories=2) + response = await service.search(max_results=2) assert response.resolved_input_entities == [] @@ -419,55 +399,57 @@ async def test_query_bearing_memories_go_to_example_queries( ) -> None: response = await service_with_query_memories.search( entities=["warehouse.orders.amount_paid"], - max_memories=10, - max_example_queries=10, + max_results=20, ) - # No query-bearing memory should leak into `memories`. - assert all(isinstance(h, MemoryHit) for h in response.memories) - # All three query-bearing memories surface in `example_queries`. - assert len(response.example_queries) == 3 - assert all(isinstance(h, ExampleQueryHit) for h in response.example_queries) - assert all(h.query is not None for h in response.example_queries) + # All memory hits are SearchHit instances. + assert all(isinstance(h, SearchHit) for h in response.results) + # All three query-bearing memories surface with query set. + example_query_hits = [h for h in response.results if h.kind == "memory" and h.query is not None] + assert len(example_query_hits) == 3 + assert all(h.query is not None for h in example_query_hits) @pytest.mark.asyncio async def test_max_example_queries_default_is_two( service_with_query_memories: SearchService, ) -> None: + # With max_results=10 (default), the flat list may include up to 10 items. + # We check that query-bearing memories surface. response = await service_with_query_memories.search( entities=["warehouse.orders.amount_paid"], + max_results=20, ) - assert len(response.example_queries) == 2 + example_query_hits = [h for h in response.results if h.kind == "memory" and h.query is not None] + assert len(example_query_hits) >= 1 @pytest.mark.asyncio async def test_max_example_queries_caps_independently( service_with_query_memories: SearchService, ) -> None: + # With max_results=1, at most 1 hit surfaces total. response = await service_with_query_memories.search( entities=["warehouse.orders.amount_paid"], - max_memories=10, - max_example_queries=1, + max_results=1, ) - assert len(response.example_queries) == 1 + assert len(response.results) <= 1 @pytest.mark.asyncio async def test_bulky_example_does_not_evict_small_learning( service_with_query_memories: SearchService, ) -> None: - """An agent setting low caps still receives both kinds of memory. - With three query-bearing memories all matching the same entity, the - learning-only memories must still surface in `memories` because the two - kinds have independent caps.""" + """With max_results large enough, both learning-only and query-bearing + memories surface in the flat list.""" response = await service_with_query_memories.search( entities=["warehouse.orders.amount_paid"], - max_memories=2, - max_example_queries=1, + max_results=20, ) - assert len(response.memories) == 2 - assert len(response.example_queries) == 1 - learning_texts = [h.text for h in response.memories] + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + example_query_hits = [h for h in response.results if h.kind == "memory" and h.query is not None] + assert len(memory_hits) >= 1 + assert len(example_query_hits) >= 1 + learning_texts = [h.text for h in memory_hits] assert any("gross of refunds" in t for t in learning_texts) @@ -476,18 +458,17 @@ async def test_recency_fallback_fills_both_buckets( service_with_query_memories: SearchService, ) -> None: response = await service_with_query_memories.search( - max_memories=10, - max_example_queries=10, + max_results=20, ) - # All learning-only memories from the base fixture (4) surface in - # `memories`; all query-bearing (3) in `example_queries`. - assert len(response.memories) == 4 - assert len(response.example_queries) == 3 + # All learning-only memories from the base fixture (4) and all + # query-bearing (3) surface in the flat list. + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + example_query_hits = [h for h in response.results if h.kind == "memory" and h.query is not None] + assert len(memory_hits) == 4 + assert len(example_query_hits) == 3 @pytest.mark.asyncio -async def test_memory_hit_no_longer_carries_query_field() -> None: - """`MemoryHit` is reserved for learning-only memories; the `query` - field has moved to `ExampleQueryHit`.""" - assert "query" not in MemoryHit.model_fields - assert "query" in ExampleQueryHit.model_fields +async def test_memory_hit_query_field_is_on_searchhit() -> None: + """`SearchHit` carries a ``query`` field for query-bearing memories.""" + assert "query" in SearchHit.model_fields diff --git a/tests/test_search_surfaces.py b/tests/test_search_surfaces.py index b71b0111..b00ba9eb 100644 --- a/tests/test_search_surfaces.py +++ b/tests/test_search_surfaces.py @@ -21,6 +21,8 @@ from slayer.core.models import Column, DatasourceConfig, SlayerModel from slayer.storage.base import StorageBackend, resolve_storage +from tests.search_helpers import call_mcp_tool as _call_mcp_tool + @pytest_asyncio.fixture async def storage_with_corpus() -> AsyncIterator[StorageBackend]: @@ -50,23 +52,6 @@ async def storage_with_corpus() -> AsyncIterator[StorageBackend]: # --------------------------------------------------------------------------- -async def _call_mcp_tool(*, mcp, name: str, arguments: dict) -> str: # NOSONAR(S3776) — small test helper that branches over three FastMCP result shapes; splitting hurts readability - """Invoke an MCP tool and return its text result.""" - result = await mcp.call_tool(name, arguments) - if isinstance(result, tuple): - # Some FastMCP versions return (content_list, structured_result). - for block in result[0]: - if hasattr(block, "text"): - return block.text - if isinstance(result, list): - for block in result: - if hasattr(block, "text"): - return block.text - if hasattr(result, "content"): - for block in result.content: - if hasattr(block, "text"): - return block.text - return str(result) @pytest.mark.asyncio @@ -85,15 +70,11 @@ async def test_mcp_search_tool_returns_json_with_three_lists( arguments={ "entities": ["warehouse.orders.amount_paid"], "question": "gross refunds", - "max_memories": 5, - "max_example_queries": 2, - "max_entities": 5, + "max_results": 20, }, ) payload = json.loads(result_text) - assert "memories" in payload - assert "example_queries" in payload - assert "entities" in payload + assert "results" in payload assert "resolved_input_entities" in payload assert "warehouse.orders.amount_paid" in payload["resolved_input_entities"] @@ -143,15 +124,11 @@ async def _seed(): res = client.post("/search", json={ "entities": ["warehouse.orders.amount_paid"], "question": "refunds", - "max_memories": 5, - "max_example_queries": 2, - "max_entities": 5, + "max_results": 20, }) assert res.status_code == 200 body = res.json() - assert "memories" in body - assert "example_queries" in body - assert "entities" in body + assert "results" in body assert "resolved_input_entities" in body # Recall endpoint is gone (FastAPI returns 405 because /memories/{id} # captures the path with the wrong method, or 404 if no route matches). @@ -238,15 +215,14 @@ def test_cli_search_runs_against_storage(tmp_path, monkeypatch, capsys) -> None: "search", "--storage", storage_dir, "--entity", "warehouse.orders.amount_paid", - "--max-example-queries", "1", + "--max-results", "10", "--format", "json", ], monkeypatch, capsys, ) assert code == 0 payload = json.loads(out) - assert "memories" in payload - assert "example_queries" in payload + assert "results" in payload assert "resolved_input_entities" in payload @@ -269,9 +245,7 @@ async def test_client_search_round_trip( in-process ``SearchService`` and returns a populated ``SearchResponse``.""" from slayer.client.slayer_client import SlayerClient from slayer.search.service import ( - EntityHit, - ExampleQueryHit, - MemoryHit, + SearchHit, SearchResponse, ) @@ -283,19 +257,16 @@ async def test_client_search_round_trip( response = await client.search( entities=["warehouse.orders.amount_paid"], question="refunds", - max_example_queries=2, + max_results=20, ) assert isinstance(response, SearchResponse) - assert isinstance(response.memories, list) - assert isinstance(response.example_queries, list) - assert isinstance(response.entities, list) + assert isinstance(response.results, list) assert isinstance(response.warnings, list) - assert all(isinstance(m, MemoryHit) for m in response.memories) - assert all(isinstance(e, ExampleQueryHit) for e in response.example_queries) - assert all(isinstance(e, EntityHit) for e in response.entities) - assert len(response.memories) >= 1 - assert "warehouse.orders.amount_paid" in response.memories[0].matched_entities + assert all(isinstance(h, SearchHit) for h in response.results) + memory_hits = [h for h in response.results if h.kind == "memory" and h.query is None] + assert len(memory_hits) >= 1 + assert "warehouse.orders.amount_paid" in memory_hits[0].matched_entities assert "warehouse.orders.amount_paid" in response.resolved_input_entities @@ -398,7 +369,7 @@ def test_cli_search_accepts_datasource_flag(tmp_path, monkeypatch, capsys) -> No ) assert code == 0 payload = json.loads(out) - assert "memories" in payload + assert "results" in payload @pytest.mark.asyncio @@ -413,6 +384,6 @@ async def test_client_search_accepts_datasource( response = await client.search( entities=["warehouse.orders.amount_paid"], datasource="warehouse", - max_example_queries=2, + max_results=20, ) assert "warehouse.orders.amount_paid" in response.resolved_input_entities diff --git a/tests/test_search_three_channel.py b/tests/test_search_three_channel.py index 617369df..ecfb43db 100644 --- a/tests/test_search_three_channel.py +++ b/tests/test_search_three_channel.py @@ -157,8 +157,9 @@ async def stub_embed_query( # NOSONAR(S7503) — stub matches embed_query async service = SearchService(storage=storage) response = await service.search(question="purchase total in dollars") - assert response.entities - assert response.entities[0].id == "dsx.orders.amount" + entity_hits = [h for h in response.results if h.kind != "memory"] + assert entity_hits + assert entity_hits[0].id == "dsx.orders.amount" # --------------------------------------------------------------------------- @@ -200,12 +201,13 @@ async def stub_embed_query(*_a, **_kw) -> List[float]: # NOSONAR(S7503) — stu service = SearchService(storage=storage) response = await service.search(question="orders") - if response.entities: + entity_hits = [h for h in response.results if h.kind != "memory"] + if entity_hits: # Any entity ranked #1 in *one* channel through RRF has # score = 1/(60+1) ≈ 0.0164. If both channels hit it #1, # score ≈ 0.0328. Both are well under the raw tantivy BM25 # band that the old surface emitted (5+). - assert response.entities[0].score < 0.1 + assert entity_hits[0].score < 0.1 # --------------------------------------------------------------------------- @@ -299,4 +301,4 @@ async def test_recency_fallback_when_all_inputs_empty( service = SearchService(storage=storage) response = await service.search() assert any("returning" in w for w in response.warnings) - assert response.entities == [] + assert [h for h in response.results if h.kind != "memory"] == [] diff --git a/tests/test_search_unified.py b/tests/test_search_unified.py new file mode 100644 index 00000000..b29ca0d0 --- /dev/null +++ b/tests/test_search_unified.py @@ -0,0 +1,733 @@ +"""DEV-1532: Unified flat-list search interface. + +Covers the full spec: +* SearchResponse.results replaces three separate buckets. +* SearchHit is the single unified hit model (kind, id, score, text, + matched_entities, query). +* max_results is the single total cap; old max_* params are hard-removed. +* Query-bearing memories surface in the flat list (kind="memory", query=...). +* Naive Cypher fallback: MATCH (n:Label1:Label2) RETURN n.id AS id works + without the advanced_search extra; complex Cypher raises SlayerError. +* REST, MCP, CLI, and client all expose the new signature. +""" + +from __future__ import annotations + +import json +import sys +import tempfile +from typing import AsyncIterator + +import pytest +import pytest_asyncio + +from tests.search_helpers import call_mcp_tool, seed_warehouse_models + +from slayer.core.enums import DataType +from slayer.core.models import Column, DatasourceConfig, ModelMeasure, SlayerModel +from slayer.core.query import SlayerQuery +from slayer.search import graph as _search_graph +from slayer.search.service import ( + SearchHit, + SearchResponse, + SearchService, +) +from slayer.storage.base import StorageBackend, resolve_storage + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def storage_with_corpus() -> AsyncIterator[StorageBackend]: + """1 datasource, 2 models, 5 memories (3 learning-only, 2 query-bearing).""" + with tempfile.TemporaryDirectory() as tmpdir: + storage = resolve_storage(tmpdir) + await seed_warehouse_models(storage) + # learning-only memories + await storage.save_memory( + learning="amount_paid is gross of refunds.", + entities=["warehouse.orders.amount_paid"], + ) + await storage.save_memory( + learning="Filter status='paid' for net revenue.", + entities=["warehouse.orders.amount_paid", "warehouse.orders.status"], + ) + await storage.save_memory( + learning="Customer email may be NULL for anonymous checkouts.", + entities=["warehouse.customers.email"], + ) + # query-bearing memories + await storage.save_memory( + learning="Example: sum of paid amounts.", + entities=["warehouse.orders.amount_paid"], + query=SlayerQuery( + source_model="orders", + measures=[ModelMeasure(formula="amount_paid:sum")], + ), + ) + await storage.save_memory( + learning="Example: count of customers.", + entities=["warehouse.customers.email"], + query=SlayerQuery( + source_model="customers", + measures=[ModelMeasure(formula="*:count")], + ), + ) + yield storage + + +@pytest_asyncio.fixture +async def service(storage_with_corpus: StorageBackend) -> SearchService: + return SearchService(storage=storage_with_corpus) + + +# --------------------------------------------------------------------------- +# Response model shape +# --------------------------------------------------------------------------- + + +def test_search_response_has_results_field() -> None: + """SearchResponse.results is the unified flat list.""" + assert "results" in SearchResponse.model_fields + + +def test_search_response_does_not_have_old_bucket_fields() -> None: + """Old memories / example_queries / entities fields are removed.""" + fields = SearchResponse.model_fields + assert "memories" not in fields + assert "example_queries" not in fields + assert "entities" not in fields + + +def test_search_hit_model_has_required_fields() -> None: + """SearchHit carries kind, id, score, text, matched_entities, query.""" + fields = SearchHit.model_fields + assert "kind" in fields + assert "id" in fields + assert "score" in fields + assert "text" in fields + assert "matched_entities" in fields + assert "query" in fields + + +def test_search_hit_query_and_matched_entities_are_optional_with_defaults() -> None: + """matched_entities defaults to [] and query defaults to None.""" + hit = SearchHit(kind="memory", id="1", score=0.5, text="hi") + assert hit.matched_entities == [] + assert hit.query is None + + +def test_search_hit_memory_hit_class_removed() -> None: + """MemoryHit, ExampleQueryHit, EntityHit are no longer exported.""" + import slayer.search.service as svc_mod + assert not hasattr(svc_mod, "MemoryHit") + assert not hasattr(svc_mod, "ExampleQueryHit") + assert not hasattr(svc_mod, "EntityHit") + + +# --------------------------------------------------------------------------- +# max_results — new cap +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_max_results_caps_total_flat_list(service: SearchService) -> None: + response = await service.search( + question="orders amount customers email", + max_results=3, + ) + assert len(response.results) <= 3 + + +@pytest.mark.asyncio +async def test_max_results_default_allows_at_least_one_result( + service: SearchService, +) -> None: + response = await service.search( + entities=["warehouse.orders.amount_paid"], + ) + assert len(response.results) >= 1 + + +@pytest.mark.asyncio +async def test_max_results_zero_raises_value_error(service: SearchService) -> None: + with pytest.raises(ValueError, match="max_results"): + await service.search(question="x", max_results=0) + + +@pytest.mark.asyncio +async def test_max_results_negative_raises_value_error(service: SearchService) -> None: + with pytest.raises(ValueError, match="max_results"): + await service.search(question="x", max_results=-1) + + +# --------------------------------------------------------------------------- +# Old max_* params are hard-removed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_max_memories_param_removed_raises_type_error( + service: SearchService, +) -> None: + with pytest.raises(TypeError): + await service.search(**{"question": "x", "max_memories": 5}) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_max_example_queries_param_removed_raises_type_error( + service: SearchService, +) -> None: + with pytest.raises(TypeError): + await service.search(**{"question": "x", "max_example_queries": 2}) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_max_entities_param_removed_raises_type_error( + service: SearchService, +) -> None: + with pytest.raises(TypeError): + await service.search(**{"question": "x", "max_entities": 5}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Flat list contents +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_query_bearing_memory_appears_in_results_with_query( + service: SearchService, +) -> None: + """Query-bearing memories appear in results with kind='memory' and query set.""" + response = await service.search( + entities=["warehouse.orders.amount_paid"], + max_results=20, + ) + hits_with_query = [h for h in response.results if h.query is not None] + assert len(hits_with_query) >= 1 + for h in hits_with_query: + assert h.kind == "memory" + assert isinstance(h.query, SlayerQuery) + + +@pytest.mark.asyncio +async def test_learning_only_memory_appears_in_results_without_query( + service: SearchService, +) -> None: + """Learning-only memories have kind='memory' and query=None.""" + response = await service.search( + entities=["warehouse.orders.amount_paid"], + max_results=20, + ) + hits_without_query = [h for h in response.results if h.kind == "memory" and h.query is None] + assert len(hits_without_query) >= 1 + texts = [h.text for h in hits_without_query] + assert any("gross of refunds" in t for t in texts) + + +@pytest.mark.asyncio +async def test_entity_hits_appear_in_flat_list(service: SearchService) -> None: + """Entities surface in results with their canonical kind.""" + response = await service.search( + question="amount paid orders revenue", + max_results=20, + ) + entity_kinds = {h.kind for h in response.results if h.kind != "memory"} + valid_kinds = {"datasource", "model", "column", "measure", "aggregation"} + assert entity_kinds & valid_kinds, "expected at least one entity hit" + + +@pytest.mark.asyncio +async def test_flat_list_mixes_memories_and_entities(service: SearchService) -> None: + """Both memory hits and entity hits appear in one results list.""" + response = await service.search( + entities=["warehouse.orders.amount_paid"], + question="amount paid orders", + max_results=20, + ) + kinds = {h.kind for h in response.results} + assert "memory" in kinds + non_memory = kinds - {"memory"} + assert non_memory, "expected at least one entity hit alongside memories" + + +@pytest.mark.asyncio +async def test_all_results_are_search_hit_instances(service: SearchService) -> None: + response = await service.search( + question="orders customers amount email", + max_results=20, + ) + for hit in response.results: + assert isinstance(hit, SearchHit) + + +# --------------------------------------------------------------------------- +# SearchHit.id semantics +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_memory_hit_id_is_raw_string_not_canonical_prefixed( + service: SearchService, +) -> None: + """Memory hit IDs are the raw storage ID, not 'memory:'.""" + response = await service.search( + entities=["warehouse.orders.amount_paid"], + max_results=20, + ) + memory_hits = [h for h in response.results if h.kind == "memory"] + assert memory_hits, "expected at least one memory hit" + for h in memory_hits: + assert not h.id.startswith("memory:"), ( + f"memory hit id should be raw, not canonical-prefixed; got {h.id!r}" + ) + assert isinstance(h.id, str) + assert h.id != "" + + +@pytest.mark.asyncio +async def test_entity_hit_id_is_canonical_string(service: SearchService) -> None: + """Entity hit IDs are canonical dotted-path strings.""" + response = await service.search( + question="orders amount paid customers", + max_results=20, + ) + entity_hits = [h for h in response.results if h.kind != "memory"] + assert entity_hits, "expected at least one entity hit" + for h in entity_hits: + assert isinstance(h.id, str) + # Canonical entity IDs contain at least one segment (datasource name). + assert "." in h.id or h.id == "warehouse" + + +# --------------------------------------------------------------------------- +# Empty-input recency fallback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_recency_fallback_returns_memories_including_query_bearing( + service: SearchService, +) -> None: + """Empty input → recency fallback; both learning-only and query-bearing + memories appear in the flat results list.""" + response = await service.search(max_results=20) + assert any("recency" in w.lower() for w in response.warnings) + memory_hits = [h for h in response.results if h.kind == "memory"] + # Corpus has 3 learning-only + 2 query-bearing = 5 total. + assert len(memory_hits) >= 1 + # At least one query-bearing memory should appear. + query_hits = [h for h in memory_hits if h.query is not None] + assert query_hits, "recency fallback should include query-bearing memories" + + +@pytest.mark.asyncio +async def test_recency_fallback_capped_by_max_results( + service: SearchService, +) -> None: + response = await service.search(max_results=2) + assert len(response.results) <= 2 + + +# --------------------------------------------------------------------------- +# Naive Cypher fallback (advanced_search absent) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_naive_cypher_single_model_label_filters_results( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """MATCH (n:Model) RETURN n.id AS id filters to model-kind entities only.""" + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + response = await service.search( + question="orders amount customers", + cypher_filter="MATCH (n:Model) RETURN n.id AS id", + max_results=20, + ) + for hit in response.results: + assert hit.kind == "model", ( + f"expected only model hits; got kind={hit.kind!r}, id={hit.id!r}" + ) + + +@pytest.mark.asyncio +async def test_naive_cypher_multi_label_colon_separated( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """MATCH (n:Memory:Model) RETURN n.id AS id filters to memory + model.""" + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + response = await service.search( + question="orders amount customers", + cypher_filter="MATCH (n:Memory:Model) RETURN n.id AS id", + max_results=20, + ) + for hit in response.results: + assert hit.kind in ("memory", "model"), ( + f"expected only memory/model hits; got kind={hit.kind!r}" + ) + + +@pytest.mark.asyncio +async def test_naive_cypher_memory_label_filters_to_memories_only( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + response = await service.search( + entities=["warehouse.orders.amount_paid"], + question="gross refunds", + cypher_filter="MATCH (n:Memory) RETURN n.id AS id", + max_results=20, + ) + for hit in response.results: + assert hit.kind == "memory" + + +@pytest.mark.asyncio +async def test_naive_cypher_case_insensitive( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + response = await service.search( + question="orders", + cypher_filter="match (n:model) return n.id as id", + max_results=20, + ) + for hit in response.results: + assert hit.kind == "model" + + +@pytest.mark.asyncio +async def test_naive_cypher_filters_before_max_results_cap( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Kind filter is applied before the top-N cap, not after. + With 5 memories and max_results=2, a Memory filter must return + at most 2 memory hits (not fewer than possible due to post-filter).""" + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + response = await service.search( + question="amount refunds customers email status", + cypher_filter="MATCH (n:Memory) RETURN n.id AS id", + max_results=2, + ) + # All returned hits must be memories, and the count must be <= max_results. + for hit in response.results: + assert hit.kind == "memory" + assert len(response.results) <= 2 + + +@pytest.mark.asyncio +async def test_naive_cypher_unknown_label_raises_slayer_error( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + from slayer.core.errors import SlayerError + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + with pytest.raises(SlayerError, match="(?i)unknown"): + await service.search( + question="x", + cypher_filter="MATCH (n:UnknownType) RETURN n.id AS id", + ) + + +@pytest.mark.asyncio +async def test_naive_cypher_complex_query_raises_slayer_error_with_install_hint( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Complex Cypher raises SlayerError explaining the advanced_search requirement.""" + from slayer.core.errors import SlayerError + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + with pytest.raises(SlayerError, match="(?i)advanced_search"): + await service.search( + question="x", + cypher_filter=( + "MATCH (m:Memory)-[:MENTIONS]->(e:Model) " + "RETURN m.id AS id" + ), + ) + + +@pytest.mark.asyncio +async def test_naive_cypher_where_clause_raises_slayer_error( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + from slayer.core.errors import SlayerError + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + with pytest.raises(SlayerError, match="(?i)advanced_search"): + await service.search( + question="x", + cypher_filter=( + "MATCH (n:Model) WHERE n.name = 'orders' RETURN n.id AS id" + ), + ) + + +@pytest.mark.asyncio +async def test_naive_cypher_missing_as_id_raises_slayer_error( + service: SearchService, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Without 'AS id' the Cypher is invalid even for the naive path.""" + from slayer.core.errors import SlayerError + monkeypatch.setattr(_search_graph, "is_available", lambda: False) + with pytest.raises(SlayerError): + await service.search( + question="x", + cypher_filter="MATCH (n:Model) RETURN n.id", + ) + + +# --------------------------------------------------------------------------- +# REST surface +# --------------------------------------------------------------------------- + + +def _make_rest_client(tmp_path): + from slayer.api.server import create_app + import asyncio + storage = resolve_storage(str(tmp_path / "storage")) + + async def _seed(): + await storage.save_datasource( + DatasourceConfig(name="warehouse", type="sqlite", database=":memory:") + ) + await storage.save_model(SlayerModel( + name="orders", + sql_table="orders", + data_source="warehouse", + columns=[Column(name="amount_paid", type=DataType.DOUBLE)], + )) + await storage.save_memory( + learning="amount_paid is net of refunds.", + entities=["warehouse.orders.amount_paid"], + ) + + asyncio.run(_seed()) + from fastapi.testclient import TestClient + return TestClient(create_app(storage=storage)) + + +def test_rest_search_response_has_results_not_buckets(tmp_path) -> None: + """POST /search response body has 'results', not 'memories'/'example_queries'/'entities'.""" + client = _make_rest_client(tmp_path) + res = client.post("/search", json={ + "entities": ["warehouse.orders.amount_paid"], + "max_results": 5, + }) + assert res.status_code == 200 + body = res.json() + assert "results" in body + assert "memories" not in body + assert "example_queries" not in body + assert "entities" not in body + + +def test_rest_search_max_results_accepted(tmp_path) -> None: + client = _make_rest_client(tmp_path) + res = client.post("/search", json={ + "question": "refunds", + "max_results": 3, + }) + assert res.status_code == 200 + body = res.json() + assert len(body["results"]) <= 3 + + +def test_rest_search_old_max_memories_param_rejected(tmp_path) -> None: + """Passing old max_memories should fail with 422 (extra fields forbidden).""" + client = _make_rest_client(tmp_path) + res = client.post("/search", json={ + "question": "refunds", + "max_memories": 5, + }) + assert res.status_code == 422 + + +def test_rest_search_old_max_example_queries_param_rejected(tmp_path) -> None: + client = _make_rest_client(tmp_path) + res = client.post("/search", json={ + "question": "refunds", + "max_example_queries": 2, + }) + assert res.status_code == 422 + + +def test_rest_search_old_max_entities_param_rejected(tmp_path) -> None: + client = _make_rest_client(tmp_path) + res = client.post("/search", json={ + "question": "refunds", + "max_entities": 5, + }) + assert res.status_code == 422 + + +# --------------------------------------------------------------------------- +# MCP tool surface +# --------------------------------------------------------------------------- + + +_call_mcp_tool = call_mcp_tool + + +@pytest.mark.asyncio +async def test_mcp_search_tool_schema_has_max_results( + storage_with_corpus: StorageBackend, +) -> None: + from slayer.mcp.server import create_mcp_server + mcp = create_mcp_server(storage=storage_with_corpus) + tools = await mcp.list_tools() + search_tool = next(t for t in tools if t.name == "search") + schema = search_tool.inputSchema + props = schema.get("properties", {}) + assert "max_results" in props + + +@pytest.mark.asyncio +async def test_mcp_search_tool_schema_does_not_have_old_params( + storage_with_corpus: StorageBackend, +) -> None: + from slayer.mcp.server import create_mcp_server + mcp = create_mcp_server(storage=storage_with_corpus) + tools = await mcp.list_tools() + search_tool = next(t for t in tools if t.name == "search") + schema = search_tool.inputSchema + props = schema.get("properties", {}) + assert "max_memories" not in props + assert "max_example_queries" not in props + assert "max_entities" not in props + + +@pytest.mark.asyncio +async def test_mcp_search_response_has_results_key( + storage_with_corpus: StorageBackend, +) -> None: + from slayer.mcp.server import create_mcp_server + mcp = create_mcp_server(storage=storage_with_corpus) + result_text = await _call_mcp_tool( + mcp=mcp, + name="search", + arguments={ + "entities": ["warehouse.orders.amount_paid"], + "question": "gross refunds", + "max_results": 5, + }, + ) + payload = json.loads(result_text) + assert "results" in payload + assert "memories" not in payload + assert "example_queries" not in payload + + +# --------------------------------------------------------------------------- +# CLI surface +# --------------------------------------------------------------------------- + + +def _run_cli(args: list[str], monkeypatch, capsys) -> tuple[int, str]: + from slayer.cli import main + monkeypatch.setattr(sys, "argv", ["slayer"] + args) + try: + main() + code = 0 + except SystemExit as e: # NOSONAR(S5754) + code = int(e.code or 0) + captured = capsys.readouterr() + return code, captured.out + + +def _seed_cli_storage(tmp_path) -> str: + import asyncio + storage_dir = str(tmp_path / "storage") + storage = resolve_storage(storage_dir) + + async def _seed(): + await storage.save_datasource( + DatasourceConfig(name="warehouse", type="sqlite", database=":memory:") + ) + await storage.save_model(SlayerModel( + name="orders", + sql_table="orders", + data_source="warehouse", + columns=[Column(name="amount_paid", type=DataType.DOUBLE)], + )) + await storage.save_memory( + learning="amount_paid is net of refunds.", + entities=["warehouse.orders.amount_paid"], + ) + + asyncio.run(_seed()) + return storage_dir + + +def test_cli_search_help_has_max_results_flag(monkeypatch, capsys) -> None: + code, out = _run_cli(["search", "--help"], monkeypatch, capsys) + assert code == 0 + assert "--max-results" in out + + +def test_cli_search_help_does_not_have_old_flags(monkeypatch, capsys) -> None: + code, out = _run_cli(["search", "--help"], monkeypatch, capsys) + assert code == 0 + assert "--max-memories" not in out + assert "--max-example-queries" not in out + assert "--max-entities" not in out + + +def test_cli_search_json_output_has_results_key(tmp_path, monkeypatch, capsys) -> None: + storage_dir = _seed_cli_storage(tmp_path) + code, out = _run_cli( + ["search", "--storage", storage_dir, + "--entity", "warehouse.orders.amount_paid", + "--format", "json"], + monkeypatch, capsys, + ) + assert code == 0 + payload = json.loads(out) + assert "results" in payload + assert "memories" not in payload + assert "example_queries" not in payload + + +# --------------------------------------------------------------------------- +# Python client surface +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_client_search_returns_unified_response( + storage_with_corpus: StorageBackend, +) -> None: + from slayer.client.slayer_client import SlayerClient + + client = SlayerClient(storage=storage_with_corpus) + response = await client.search( + entities=["warehouse.orders.amount_paid"], + question="refunds", + max_results=5, + ) + assert isinstance(response, SearchResponse) + assert hasattr(response, "results") + assert isinstance(response.results, list) + for hit in response.results: + assert isinstance(hit, SearchHit) + + +@pytest.mark.asyncio +async def test_client_search_old_max_params_raise_type_error( + storage_with_corpus: StorageBackend, +) -> None: + from slayer.client.slayer_client import SlayerClient + client = SlayerClient(storage=storage_with_corpus) + with pytest.raises(TypeError): + await client.search(**{"question": "x", "max_memories": 5}) # type: ignore[arg-type]