From bb69f37ae139bd0d3d99d74bfd95cc89b73315f5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 9 May 2026 13:02:07 -0700 Subject: [PATCH 01/42] Implement chunk-level extraction and embedding. Split embedding strategy (uncached chunk, cached related terms). --- AGENTS.md | 1 + src/typeagent/knowpro/add_messages.py | 153 ++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 src/typeagent/knowpro/add_messages.py diff --git a/AGENTS.md b/AGENTS.md index 981fc974..f0691a38 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,6 +47,7 @@ In all cases show what you added to AGENTS.md. - Use `make format` to format all files using `black`. Do this before reporting success. - When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end. - Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them. +- In add-messages pipeline chunk processing, compute chunk-text embeddings with uncached model calls and related-term embeddings with cached model calls. ## Package Management with uv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py new file mode 100644 index 00000000..001ae2c4 --- /dev/null +++ b/src/typeagent/knowpro/add_messages.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""New modular implementation of add_messages_streaming with pipelined architecture.""" + +from dataclasses import dataclass + +import typechat + +from . import knowledge_schema as kplib +from ..aitools.embeddings import IEmbeddingModel, NormalizedEmbedding +from ..storage.memory.semrefindex import collect_action_terms, collect_entity_terms +from .interfaces_core import IKnowledgeExtractor + + +@dataclass +class ChunkProcessingResult: + """Result of processing a single chunk through extraction and embeddings. + + Attributes: + chunk_id: Identifier for the chunk being processed + message_id: Global message ordinal + extracted_knowledge: Extracted KnowledgeResponse, or None if extraction failed/wasn't run + chunk_embedding: Normalized embedding vector for the message chunk + related_terms: Lowercased related-term texts extracted from knowledge + related_term_embeddings: Embeddings for related_terms (same order) + error: Exception from the first failing operation, or None if successful + """ + + chunk_id: str + message_id: int + extracted_knowledge: kplib.KnowledgeResponse | None = None + chunk_embedding: NormalizedEmbedding | None = None + related_terms: list[str] | None = None + related_term_embeddings: list[NormalizedEmbedding] | None = None + error: Exception | None = None + + @property + def success(self) -> bool: + """True if both extraction and embedding succeeded.""" + return ( + self.extracted_knowledge is not None + and self.chunk_embedding is not None + and self.related_terms is not None + and self.related_term_embeddings is not None + and self.error is None + ) + + +def _collect_related_terms_for_fuzzy_index( + knowledge: kplib.KnowledgeResponse, +) -> list[str]: + """Collect canonical related-term texts for the fuzzy related-terms index. + + These terms are derived from the same knowledge that feeds semantic refs. + We lowercase and deduplicate while preserving order to match index behavior. + """ + seen: set[str] = set() + related_terms: list[str] = [] + + def _add_term(term: str) -> None: + canonical = term.strip().lower() + if canonical and canonical not in seen: + seen.add(canonical) + related_terms.append(canonical) + + for entity in knowledge.entities: + for term in collect_entity_terms(entity): + _add_term(term) + + for action in list(knowledge.actions) + list(knowledge.inverse_actions): + for term in collect_action_terms(action): + _add_term(term) + + for topic in knowledge.topics: + _add_term(topic) + + return related_terms + + +async def process_chunk_with_extraction_and_embeddings( + chunk_id: str, + message_id: int, + chunk_text: str, + knowledge_extractor: IKnowledgeExtractor, + message_embedding_model: IEmbeddingModel, + related_terms_embedding_model: IEmbeddingModel | None = None, +) -> ChunkProcessingResult: + """Process a single text chunk through knowledge extraction and embeddings. + + Runs both knowledge extraction and embedding in a single function call, + capturing the first failure and stopping processing if an error occurs. + + Extraction runs first; if it fails, embedding work is skipped. + + Chunk embeddings are computed uncached, while related-term embeddings are + computed using cache-aware model calls. + + Args: + chunk_id: Identifier for this chunk (e.g., for debugging/tracking) + message_id: Global message ordinal (1-based in SQLite context) + chunk_text: Text content of the chunk (stripped) + knowledge_extractor: IKnowledgeExtractor instance for LLM extraction + message_embedding_model: Embedding model for chunk text embeddings + related_terms_embedding_model: Optional embedding model for related-term + embeddings. If None, message_embedding_model is used. + + Returns: + ChunkProcessingResult with knowledge, chunk embedding, related-term + embeddings, or an error from the first failed operation. + """ + result = ChunkProcessingResult(chunk_id=chunk_id, message_id=message_id) + + # Step 1: Extract knowledge + try: + knowledge_result = await knowledge_extractor.extract(chunk_text) + if isinstance(knowledge_result, typechat.Success): + result.extracted_knowledge = knowledge_result.value + else: + # Extraction returned a Failure; treat as error and stop + result.error = RuntimeError( + f"Knowledge extraction failed: {knowledge_result.message}" + ) + return result + except Exception as e: + # Extraction raised an exception; stop processing + result.error = e + return result + + result.related_terms = _collect_related_terms_for_fuzzy_index( + result.extracted_knowledge + ) + + related_model = related_terms_embedding_model or message_embedding_model + + # Step 2: Generate embeddings (only if extraction succeeded) + try: + result.chunk_embedding = await message_embedding_model.get_embedding_nocache( + chunk_text + ) + if result.related_terms: + rel_embeddings = await related_model.get_embeddings( + result.related_terms, + ) + result.related_term_embeddings = [e for e in rel_embeddings] + else: + result.related_term_embeddings = [] + except Exception as e: + # Embedding failed; record error and return + result.error = e + return result + + return result From 5a60bad3e46617d2acd0318f5d8746b2ff60272a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 9 May 2026 21:03:07 -0700 Subject: [PATCH 02/42] Add producer task. Chunk ID is TextLocation. --- AGENTS.md | 2 + src/typeagent/knowpro/add_messages.py | 88 +++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f0691a38..447604a6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -48,6 +48,8 @@ In all cases show what you added to AGENTS.md. - When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end. - Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them. - In add-messages pipeline chunk processing, compute chunk-text embeddings with uncached model calls and related-term embeddings with cached model calls. +- In add-messages pipeline flow, lower stop_at_message_id to min(existing, failing_message_id), and always enqueue queue-1 sentinels even when the input iterator fails so workers can drain and exit cleanly. +- In add-messages pipeline data structures, use `TextLocation` as the chunk identifier instead of a formatted string chunk ID. ## Package Management with uv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 001ae2c4..b9c748de 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -3,6 +3,8 @@ """New modular implementation of add_messages_streaming with pipelined architecture.""" +import asyncio +from collections.abc import AsyncIterable from dataclasses import dataclass import typechat @@ -10,7 +12,83 @@ from . import knowledge_schema as kplib from ..aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from ..storage.memory.semrefindex import collect_action_terms, collect_entity_terms -from .interfaces_core import IKnowledgeExtractor +from .interfaces_core import IKnowledgeExtractor, IMessage, TextLocation + + +@dataclass +class PipelineStopState: + """Shared stop marker for pipeline stages. + + A message ordinal greater than or equal to ``stop_at_message_id`` is + considered out-of-scope for further processing. + """ + + stop_at_message_id: int = 10**100 + + +@dataclass +class ProducerState: + """Mutable producer state shared with orchestrator/reporting.""" + + next_message_id: int + produced_messages: int = 0 + produced_chunks: int = 0 + exception: Exception | None = None + + +@dataclass +class ChunkWorkItem[TMessage: IMessage]: + """One chunk scheduled by the producer for worker processing.""" + + chunk_id: TextLocation + message_id: int + chunk_ordinal: int + chunk_count: int + chunk_text: str + message: TMessage + + +async def _producer_task[TMessage: IMessage]( + messages: AsyncIterable[TMessage], + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], + worker_count: int, + stop_state: PipelineStopState, + producer_state: ProducerState, +) -> None: + """Read input messages and enqueue chunk work items. + + The producer stops enqueueing once it reaches ``stop_at_message_id``. + It always sends one sentinel per worker, even if the input iterator raises. + """ + try: + async for message in messages: + message_id = producer_state.next_message_id + if message_id >= stop_state.stop_at_message_id: + break + + chunk_count = len(message.text_chunks) + for chunk_ordinal, chunk_text in enumerate(message.text_chunks): + if message_id >= stop_state.stop_at_message_id: + break + await chunk_queue.put( + ChunkWorkItem[TMessage]( + chunk_id=TextLocation(message_id, chunk_ordinal), + message_id=message_id, + chunk_ordinal=chunk_ordinal, + chunk_count=chunk_count, + chunk_text=chunk_text, + message=message, + ) + ) + producer_state.produced_chunks += 1 + + producer_state.produced_messages += 1 + producer_state.next_message_id += 1 + except Exception as exc: + producer_state.exception = exc + finally: + for _ in range(worker_count): + await chunk_queue.put(None) @dataclass @@ -18,7 +96,7 @@ class ChunkProcessingResult: """Result of processing a single chunk through extraction and embeddings. Attributes: - chunk_id: Identifier for the chunk being processed + chunk_id: Message/chunk location for the processed chunk message_id: Global message ordinal extracted_knowledge: Extracted KnowledgeResponse, or None if extraction failed/wasn't run chunk_embedding: Normalized embedding vector for the message chunk @@ -27,7 +105,7 @@ class ChunkProcessingResult: error: Exception from the first failing operation, or None if successful """ - chunk_id: str + chunk_id: TextLocation message_id: int extracted_knowledge: kplib.KnowledgeResponse | None = None chunk_embedding: NormalizedEmbedding | None = None @@ -79,7 +157,7 @@ def _add_term(term: str) -> None: async def process_chunk_with_extraction_and_embeddings( - chunk_id: str, + chunk_id: TextLocation, message_id: int, chunk_text: str, knowledge_extractor: IKnowledgeExtractor, @@ -97,7 +175,7 @@ async def process_chunk_with_extraction_and_embeddings( computed using cache-aware model calls. Args: - chunk_id: Identifier for this chunk (e.g., for debugging/tracking) + chunk_id: Message/chunk location for this chunk message_id: Global message ordinal (1-based in SQLite context) chunk_text: Text content of the chunk (stripped) knowledge_extractor: IKnowledgeExtractor instance for LLM extraction From 45100043a753a6522891834e48b77eabd8fd650b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 9 May 2026 21:16:16 -0700 Subject: [PATCH 03/42] Add _worker_task(). --- AGENTS.md | 2 ++ src/typeagent/knowpro/add_messages.py | 52 +++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 447604a6..9fbbddf5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -46,10 +46,12 @@ In all cases show what you added to AGENTS.md. - Use `make check test` to run `make check` and if it passes also run `make test` - Use `make format` to format all files using `black`. Do this before reporting success. - When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end. +- While building `add_messages.py` before dedicated tests exist, skip running the full test suite; run full tests after those tests are added. - Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them. - In add-messages pipeline chunk processing, compute chunk-text embeddings with uncached model calls and related-term embeddings with cached model calls. - In add-messages pipeline flow, lower stop_at_message_id to min(existing, failing_message_id), and always enqueue queue-1 sentinels even when the input iterator fails so workers can drain and exit cleanly. - In add-messages pipeline data structures, use `TextLocation` as the chunk identifier instead of a formatted string chunk ID. +- In asyncio code, avoid locks for in-memory state updates that do not `await` between read/modify/write; use locks only when a critical section spans `await` points. ## Package Management with uv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index b9c748de..2c5d91eb 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -91,6 +91,58 @@ async def _producer_task[TMessage: IMessage]( await chunk_queue.put(None) +async def _worker_task[TMessage: IMessage]( + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], + result_queue: asyncio.Queue[ChunkProcessingResult], + stop_state: PipelineStopState, + knowledge_extractor: IKnowledgeExtractor, + message_embedding_model: IEmbeddingModel, + related_terms_embedding_model: IEmbeddingModel | None = None, +) -> None: + """Consume chunk work items and produce chunk processing results. + + Workers stop when they receive a ``None`` sentinel from ``chunk_queue``. + Chunks at or beyond ``stop_at_message_id`` are skipped and reported as + error results so downstream code can account for them deterministically. + """ + while True: + work_item = await chunk_queue.get() + if work_item is None: + return + + stop_at = stop_state.stop_at_message_id + + if work_item.message_id >= stop_at: + await result_queue.put( + ChunkProcessingResult( + chunk_id=work_item.chunk_id, + message_id=work_item.message_id, + error=RuntimeError( + "Chunk skipped because stop_at_message_id " + f"is {stop_at} and message_id is {work_item.message_id}" + ), + ) + ) + continue + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=work_item.chunk_id, + message_id=work_item.message_id, + chunk_text=work_item.chunk_text, + knowledge_extractor=knowledge_extractor, + message_embedding_model=message_embedding_model, + related_terms_embedding_model=related_terms_embedding_model, + ) + + if result.error is not None: + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, + work_item.message_id, + ) + + await result_queue.put(result) + + @dataclass class ChunkProcessingResult: """Result of processing a single chunk through extraction and embeddings. From c8db35dacdfeef3038cd7e517750ce760bda08dc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 9 May 2026 22:16:57 -0700 Subject: [PATCH 04/42] Add reassembler task and simplify some data structures. --- AGENTS.md | 3 + src/typeagent/knowpro/add_messages.py | 194 +++++++++++++++++++++++--- 2 files changed, 176 insertions(+), 21 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9fbbddf5..6b28c3b9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,6 +52,9 @@ In all cases show what you added to AGENTS.md. - In add-messages pipeline flow, lower stop_at_message_id to min(existing, failing_message_id), and always enqueue queue-1 sentinels even when the input iterator fails so workers can drain and exit cleanly. - In add-messages pipeline data structures, use `TextLocation` as the chunk identifier instead of a formatted string chunk ID. - In asyncio code, avoid locks for in-memory state updates that do not `await` between read/modify/write; use locks only when a critical section spans `await` points. +- Name returned summary/value objects as `*Result`; reserve `*State` for mutable shared/internal state. +- Keep internal helper type naming consistent within a module; avoid mixing underscored and non-underscored helper class names without a clear API-boundary reason. +- Prefer ordinal type aliases (e.g., `MessageOrdinal`, `ChunkOrdinal`) over raw `int` in pipeline code for readability. ## Package Management with uv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 2c5d91eb..78846ea3 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -4,7 +4,7 @@ """New modular implementation of add_messages_streaming with pipelined architecture.""" import asyncio -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Awaitable, Callable from dataclasses import dataclass import typechat @@ -12,7 +12,9 @@ from . import knowledge_schema as kplib from ..aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from ..storage.memory.semrefindex import collect_action_terms, collect_entity_terms -from .interfaces_core import IKnowledgeExtractor, IMessage, TextLocation +from .interfaces_core import IKnowledgeExtractor, IMessage, MessageOrdinal, TextLocation + +type ChunkOrdinal = int @dataclass @@ -30,7 +32,7 @@ class PipelineStopState: class ProducerState: """Mutable producer state shared with orchestrator/reporting.""" - next_message_id: int + next_message_id: MessageOrdinal produced_messages: int = 0 produced_chunks: int = 0 exception: Exception | None = None @@ -41,8 +43,6 @@ class ChunkWorkItem[TMessage: IMessage]: """One chunk scheduled by the producer for worker processing.""" chunk_id: TextLocation - message_id: int - chunk_ordinal: int chunk_count: int chunk_text: str message: TMessage @@ -73,8 +73,6 @@ async def _producer_task[TMessage: IMessage]( await chunk_queue.put( ChunkWorkItem[TMessage]( chunk_id=TextLocation(message_id, chunk_ordinal), - message_id=message_id, - chunk_ordinal=chunk_ordinal, chunk_count=chunk_count, chunk_text=chunk_text, message=message, @@ -93,7 +91,7 @@ async def _producer_task[TMessage: IMessage]( async def _worker_task[TMessage: IMessage]( chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], - result_queue: asyncio.Queue[ChunkProcessingResult], + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage]], stop_state: PipelineStopState, knowledge_extractor: IKnowledgeExtractor, message_embedding_model: IEmbeddingModel, @@ -112,14 +110,15 @@ async def _worker_task[TMessage: IMessage]( stop_at = stop_state.stop_at_message_id - if work_item.message_id >= stop_at: + if work_item.chunk_id.message_ordinal >= stop_at: await result_queue.put( ChunkProcessingResult( chunk_id=work_item.chunk_id, - message_id=work_item.message_id, + chunk_count=work_item.chunk_count, + message=work_item.message, error=RuntimeError( "Chunk skipped because stop_at_message_id " - f"is {stop_at} and message_id is {work_item.message_id}" + f"is {stop_at} and message_id is {work_item.chunk_id.message_ordinal}" ), ) ) @@ -127,8 +126,9 @@ async def _worker_task[TMessage: IMessage]( result = await process_chunk_with_extraction_and_embeddings( chunk_id=work_item.chunk_id, - message_id=work_item.message_id, chunk_text=work_item.chunk_text, + chunk_count=work_item.chunk_count, + message=work_item.message, knowledge_extractor=knowledge_extractor, message_embedding_model=message_embedding_model, related_terms_embedding_model=related_terms_embedding_model, @@ -137,19 +137,18 @@ async def _worker_task[TMessage: IMessage]( if result.error is not None: stop_state.stop_at_message_id = min( stop_state.stop_at_message_id, - work_item.message_id, + work_item.chunk_id.message_ordinal, ) await result_queue.put(result) @dataclass -class ChunkProcessingResult: +class ChunkProcessingResult[TMessage: IMessage]: """Result of processing a single chunk through extraction and embeddings. Attributes: chunk_id: Message/chunk location for the processed chunk - message_id: Global message ordinal extracted_knowledge: Extracted KnowledgeResponse, or None if extraction failed/wasn't run chunk_embedding: Normalized embedding vector for the message chunk related_terms: Lowercased related-term texts extracted from knowledge @@ -158,7 +157,8 @@ class ChunkProcessingResult: """ chunk_id: TextLocation - message_id: int + chunk_count: int + message: TMessage extracted_knowledge: kplib.KnowledgeResponse | None = None chunk_embedding: NormalizedEmbedding | None = None related_terms: list[str] | None = None @@ -208,14 +208,15 @@ def _add_term(term: str) -> None: return related_terms -async def process_chunk_with_extraction_and_embeddings( +async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( chunk_id: TextLocation, - message_id: int, chunk_text: str, + chunk_count: int, + message: TMessage, knowledge_extractor: IKnowledgeExtractor, message_embedding_model: IEmbeddingModel, related_terms_embedding_model: IEmbeddingModel | None = None, -) -> ChunkProcessingResult: +) -> ChunkProcessingResult[TMessage]: """Process a single text chunk through knowledge extraction and embeddings. Runs both knowledge extraction and embedding in a single function call, @@ -228,7 +229,6 @@ async def process_chunk_with_extraction_and_embeddings( Args: chunk_id: Message/chunk location for this chunk - message_id: Global message ordinal (1-based in SQLite context) chunk_text: Text content of the chunk (stripped) knowledge_extractor: IKnowledgeExtractor instance for LLM extraction message_embedding_model: Embedding model for chunk text embeddings @@ -239,7 +239,11 @@ async def process_chunk_with_extraction_and_embeddings( ChunkProcessingResult with knowledge, chunk embedding, related-term embeddings, or an error from the first failed operation. """ - result = ChunkProcessingResult(chunk_id=chunk_id, message_id=message_id) + result = ChunkProcessingResult( + chunk_id=chunk_id, + chunk_count=chunk_count, + message=message, + ) # Step 1: Extract knowledge try: @@ -281,3 +285,151 @@ async def process_chunk_with_extraction_and_embeddings( return result return result + + +@dataclass +class MessageAssembly[TMessage: IMessage]: + """In-memory chunk accumulation for one message.""" + + message_id: MessageOrdinal + chunk_count: int + message: TMessage + chunks: dict[ChunkOrdinal, ChunkProcessingResult[TMessage]] + has_error: bool = False + + def is_complete(self) -> bool: + return len(self.chunks) == self.chunk_count + + +@dataclass +class ReassemblerResult: + """Progress and counters produced by the reassembler stage.""" + + first_uncommitted_ordinal: MessageOrdinal + messages_committed: int = 0 + chunks_committed: int = 0 + chunk_failures: int = 0 + buffered_messages: int = 0 + + +async def _reassembler_task[TMessage: IMessage]( + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None], + stop_state: PipelineStopState, + first_uncommitted_ordinal: MessageOrdinal, + target_commit_chunk_count: int, + commit_batch: Callable[ + [list[TMessage], list[ChunkProcessingResult[TMessage]]], + Awaitable[None], + ], + on_batch_committed: Callable[[int, int], None] | None = None, +) -> ReassemblerResult: + """Reassemble chunks into messages and commit only consecutive complete ones. + + This stage consumes worker results until it sees a ``None`` sentinel. + It never commits out-of-order messages: if message N is incomplete or failed, + messages N+1 and later remain buffered. + """ + state = ReassemblerResult(first_uncommitted_ordinal=first_uncommitted_ordinal) + assemblies: dict[MessageOrdinal, MessageAssembly[TMessage]] = {} + + staged_messages: list[TMessage] = [] + staged_results: list[ChunkProcessingResult[TMessage]] = [] + staged_chunks = 0 + + async def _commit_if_needed(force: bool = False) -> None: + nonlocal staged_chunks + if not staged_messages: + return + if not force and staged_chunks < target_commit_chunk_count: + return + msg_count = len(staged_messages) + chunk_count = staged_chunks + await commit_batch(staged_messages, staged_results) + state.messages_committed += msg_count + state.chunks_committed += chunk_count + if on_batch_committed is not None: + on_batch_committed(msg_count, chunk_count) + staged_messages.clear() + staged_results.clear() + staged_chunks = 0 + + async def _drain_consecutive_complete() -> None: + nonlocal staged_chunks + while True: + assembly = assemblies.get(state.first_uncommitted_ordinal) + if assembly is None: + return + if not assembly.is_complete() or assembly.has_error: + return + + ordered_chunk_ordinals = sorted(assembly.chunks) + ordered_results = [assembly.chunks[i] for i in ordered_chunk_ordinals] + staged_messages.append(assembly.message) + staged_results.extend(ordered_results) + staged_chunks += len(ordered_results) + + del assemblies[state.first_uncommitted_ordinal] + state.first_uncommitted_ordinal += 1 + await _commit_if_needed() + + try: + while True: + item = await result_queue.get() + if item is None: + break + + chunk_ordinal = item.chunk_id.chunk_ordinal + message_id = item.chunk_id.message_ordinal + + try: + if chunk_ordinal < 0 or chunk_ordinal >= item.chunk_count: + raise RuntimeError( + f"Invalid chunk ordinal: message_id={message_id}, " + f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" + ) + + existing = assemblies.get(message_id) + if existing is None: + existing = MessageAssembly[TMessage]( + message_id=message_id, + chunk_count=item.chunk_count, + message=item.message, + chunks={}, + ) + assemblies[message_id] = existing + elif existing.chunk_count != item.chunk_count: + raise RuntimeError( + f"Mismatched chunk count for message: message_id={message_id}, " + f"expected={existing.chunk_count}, got={item.chunk_count}" + ) + + if chunk_ordinal in existing.chunks: + raise RuntimeError( + f"Duplicate chunk: message_id={message_id}, " + f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" + ) + + existing.chunks[chunk_ordinal] = item + except Exception: + # On validation error, set stop flag and re-raise + # The finally block will drain and commit consecutive complete messages + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, message_id + ) + raise + + if item.error is not None: + existing.has_error = True + state.chunk_failures += 1 + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, message_id + ) + + await _drain_consecutive_complete() + finally: + # Always drain and commit consecutive complete messages before raising + await _drain_consecutive_complete() + await _commit_if_needed(force=True) + + state.buffered_messages = len(assemblies) + return state From f4ece84df0023adc61bef82a4e96941a480c02a5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 9 May 2026 22:26:13 -0700 Subject: [PATCH 05/42] Refactor index writes to support precomputed embeddings. Add precomputed-embedding write paths for message and related-term indexes, introducing explicit *_with_embeddings methods in interfaces and both memory/SQLite implementations. Refactor existing add methods to compute embeddings once and delegate, enabling pipeline commit paths to reuse worker-generated embeddings without recomputation. --- AGENTS.md | 1 + src/typeagent/knowpro/interfaces_indexes.py | 14 ++++ src/typeagent/knowpro/textlocindex.py | 34 ++++++++++ src/typeagent/storage/memory/messageindex.py | 68 +++++++++++++++---- src/typeagent/storage/memory/reltermsindex.py | 25 ++++++- src/typeagent/storage/sqlite/messageindex.py | 50 +++++++++++--- src/typeagent/storage/sqlite/reltermsindex.py | 34 +++++++++- 7 files changed, 200 insertions(+), 26 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 6b28c3b9..0623e0d1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -51,6 +51,7 @@ In all cases show what you added to AGENTS.md. - In add-messages pipeline chunk processing, compute chunk-text embeddings with uncached model calls and related-term embeddings with cached model calls. - In add-messages pipeline flow, lower stop_at_message_id to min(existing, failing_message_id), and always enqueue queue-1 sentinels even when the input iterator fails so workers can drain and exit cleanly. - In add-messages pipeline data structures, use `TextLocation` as the chunk identifier instead of a formatted string chunk ID. +- When adding precomputed-embedding write paths, expose explicit `*_with_embeddings` methods and have existing methods compute embeddings then delegate to those methods. - In asyncio code, avoid locks for in-memory state updates that do not `await` between read/modify/write; use locks only when a critical section spans `await` points. - Name returned summary/value objects as `*Result`; reserve `*State` for mutable shared/internal state. - Keep internal helper type naming consistent within a module; avoid mixing underscored and non-underscored helper class names without a clear API-boundary reason. diff --git a/src/typeagent/knowpro/interfaces_indexes.py b/src/typeagent/knowpro/interfaces_indexes.py index c872fa7c..02663f27 100644 --- a/src/typeagent/knowpro/interfaces_indexes.py +++ b/src/typeagent/knowpro/interfaces_indexes.py @@ -9,6 +9,7 @@ from pydantic.dataclasses import dataclass +from ..aitools.embeddings import NormalizedEmbedding from .interfaces_core import ( DateRange, IMessage, @@ -131,6 +132,12 @@ async def size(self) -> int: ... async def add_terms(self, texts: list[str]) -> None: ... + async def add_terms_with_embeddings( + self, + texts: list[str], + embeddings: list[NormalizedEmbedding], + ) -> None: ... + async def lookup_term( self, text: str, @@ -214,6 +221,13 @@ async def add_messages_starting_at( messages: list[TMessage], ) -> None: ... + async def add_messages_starting_at_with_embeddings( + self, + start_message_ordinal: int, + messages: list[TMessage], + chunk_embeddings: list[NormalizedEmbedding], + ) -> None: ... + async def lookup_messages( self, message_text: str, diff --git a/src/typeagent/knowpro/textlocindex.py b/src/typeagent/knowpro/textlocindex.py index a6d95969..3c1082e6 100644 --- a/src/typeagent/knowpro/textlocindex.py +++ b/src/typeagent/knowpro/textlocindex.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import Protocol +import numpy as np + from ..aitools.embeddings import NormalizedEmbedding from ..aitools.vectorbase import TextEmbeddingIndexSettings from .fuzzyindex import EmbeddingIndex, ScoredInt @@ -27,6 +29,12 @@ async def add_text_locations( text_and_locations: list[tuple[str, TextLocation]], ) -> None: ... + async def add_text_locations_with_embeddings( + self, + text_locations: list[TextLocation], + embeddings: list[NormalizedEmbedding], + ) -> None: ... + async def lookup_text( self, text: str, @@ -71,6 +79,22 @@ async def add_text_locations( await self._embedding_index.add_texts([text for text, _ in text_and_locations]) self._text_locations.extend([loc for _, loc in text_and_locations]) + async def add_text_locations_with_embeddings( + self, + text_locations: list[TextLocation], + embeddings: list[NormalizedEmbedding], + ) -> None: + if len(text_locations) != len(embeddings): + raise ValueError( + "text_locations and embeddings must have the same length: " + f"{len(text_locations)} != {len(embeddings)}" + ) + if not text_locations: + return + embedding_array = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + self._embedding_index.push(embedding_array) + self._text_locations.extend(text_locations) + async def lookup_text( self, text: str, @@ -112,6 +136,16 @@ async def generate_embedding( ) -> NormalizedEmbedding: return await self._embedding_index.get_embedding(text, cache) + async def generate_embeddings( + self, texts: list[str], cache: bool = True + ) -> list[NormalizedEmbedding]: + if not texts: + return [] + embeddings = await self._embedding_index._vector_base.get_embeddings( + texts, cache=cache + ) + return [embedding for embedding in embeddings] + def lookup_by_embedding( self, text_embedding: NormalizedEmbedding, diff --git a/src/typeagent/storage/memory/messageindex.py b/src/typeagent/storage/memory/messageindex.py index efcc4ddf..f8da5377 100644 --- a/src/typeagent/storage/memory/messageindex.py +++ b/src/typeagent/storage/memory/messageindex.py @@ -75,25 +75,67 @@ async def add_messages[TMessage: IMessage]( messages: Iterable[TMessage], ) -> None: base_message_ordinal: MessageOrdinal = await self.text_location_index.size() - all_chunks: list[tuple[str, TextLocation]] = [] - # Collect everything so we can batch efficiently. - for message_ordinal, message in enumerate(messages, base_message_ordinal): - for chunk_ordinal, chunk in enumerate(message.text_chunks): - all_chunks.append((chunk, TextLocation(message_ordinal, chunk_ordinal))) - await self.text_location_index.add_text_locations(all_chunks) - - async def add_messages_starting_at( + message_list = list(messages) + if not message_list: + return + + chunk_texts: list[str] = [] + for message in message_list: + chunk_texts.extend(message.text_chunks) + + chunk_embeddings = await self.text_location_index.generate_embeddings( + chunk_texts, + cache=True, + ) + await self.add_messages_starting_at_with_embeddings( + base_message_ordinal, + message_list, + chunk_embeddings, + ) + + async def add_messages_starting_at[TMessage: IMessage]( self, start_message_ordinal: int, - messages: list[IMessage], + messages: list[TMessage], ) -> None: """Add messages to the index starting at the given ordinal.""" - all_chunks: list[tuple[str, TextLocation]] = [] + chunk_texts: list[str] = [] + for message in messages: + chunk_texts.extend(message.text_chunks) + + chunk_embeddings = await self.text_location_index.generate_embeddings( + chunk_texts, + cache=True, + ) + await self.add_messages_starting_at_with_embeddings( + start_message_ordinal, + messages, + chunk_embeddings, + ) + + async def add_messages_starting_at_with_embeddings[TMessage: IMessage]( + self, + start_message_ordinal: int, + messages: list[TMessage], + chunk_embeddings: list[NormalizedEmbedding], + ) -> None: + """Add messages starting at an ordinal using precomputed chunk embeddings.""" + text_locations: list[TextLocation] = [] for idx, message in enumerate(messages): msg_ord = start_message_ordinal + idx - for chunk_ord, chunk in enumerate(message.text_chunks): - all_chunks.append((chunk, TextLocation(msg_ord, chunk_ord))) - await self.text_location_index.add_text_locations(all_chunks) + for chunk_ord, _chunk in enumerate(message.text_chunks): + text_locations.append(TextLocation(msg_ord, chunk_ord)) + + if len(text_locations) != len(chunk_embeddings): + raise ValueError( + "messages and chunk_embeddings produced different chunk counts: " + f"{len(text_locations)} != {len(chunk_embeddings)}" + ) + + await self.text_location_index.add_text_locations_with_embeddings( + text_locations, + chunk_embeddings, + ) async def lookup_messages( self, diff --git a/src/typeagent/storage/memory/reltermsindex.py b/src/typeagent/storage/memory/reltermsindex.py index d9e682c2..8ad984ef 100644 --- a/src/typeagent/storage/memory/reltermsindex.py +++ b/src/typeagent/storage/memory/reltermsindex.py @@ -4,6 +4,9 @@ from collections.abc import Callable from typing import Protocol, TYPE_CHECKING +import numpy as np + +from typeagent.aitools.embeddings import NormalizedEmbedding from typeagent.aitools.vectorbase import ( ScoredInt, TextEmbeddingIndexSettings, @@ -285,7 +288,27 @@ async def size(self) -> int: return len(self._vectorbase) async def add_terms(self, texts: list[str]) -> None: - await self._vectorbase.add_keys(texts) + if not texts: + return + embeddings = await self._vectorbase.get_embeddings(texts) + await self.add_terms_with_embeddings( + texts, [embedding for embedding in embeddings] + ) + + async def add_terms_with_embeddings( + self, + texts: list[str], + embeddings: list[NormalizedEmbedding], + ) -> None: + if len(texts) != len(embeddings): + raise ValueError( + "texts and embeddings must have the same length: " + f"{len(texts)} != {len(embeddings)}" + ) + if not texts: + return + embedding_array = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + self._vectorbase.add_embeddings(texts, embedding_array) self._texts.extend(texts) async def lookup_term( diff --git a/src/typeagent/storage/sqlite/messageindex.py b/src/typeagent/storage/sqlite/messageindex.py index d48a9761..457c1ae5 100644 --- a/src/typeagent/storage/sqlite/messageindex.py +++ b/src/typeagent/storage/sqlite/messageindex.py @@ -58,24 +58,56 @@ async def add_messages_starting_at( messages: list[interfaces.IMessage], ) -> None: """Add messages to the text index starting at the given ordinal.""" - chunks_to_embed: list[tuple[int, int, str]] = [] - for msg_ord, message in enumerate(messages, start_message_ordinal): - for chunk_ord, chunk in enumerate(message.text_chunks): - chunks_to_embed.append((msg_ord, chunk_ord, chunk)) + chunks_to_embed: list[str] = [] + for _msg_ord, message in enumerate(messages, start_message_ordinal): + for _chunk_ord, chunk in enumerate(message.text_chunks): + chunks_to_embed.append(chunk) if not chunks_to_embed: return embeddings = await self._vectorbase.get_embeddings( - [chunk for _, _, chunk in chunks_to_embed], cache=False + chunks_to_embed, + cache=False, + ) + + await self.add_messages_starting_at_with_embeddings( + start_message_ordinal, + messages, + [embedding for embedding in embeddings], ) + async def add_messages_starting_at_with_embeddings( + self, + start_message_ordinal: int, + messages: list[interfaces.IMessage], + chunk_embeddings: list[NormalizedEmbedding], + ) -> None: + """Add messages to the text index using precomputed chunk embeddings.""" + chunk_locations: list[tuple[int, int]] = [] + for msg_ord, message in enumerate(messages, start_message_ordinal): + for chunk_ord, _chunk in enumerate(message.text_chunks): + chunk_locations.append((msg_ord, chunk_ord)) + + if len(chunk_locations) != len(chunk_embeddings): + raise ValueError( + "messages and chunk_embeddings produced different chunk counts: " + f"{len(chunk_locations)} != {len(chunk_embeddings)}" + ) + + if not chunk_locations: + return + + current_size = len(self._vectorbase) + embedding_array = np.stack(chunk_embeddings, axis=0).astype( + np.float32, copy=False + ) + self._vectorbase.add_embeddings(None, embedding_array) + insertion_data: list[tuple[int, int, bytes, int]] = [] - for idx, ((msg_ord, chunk_ord, _), embedding) in enumerate( - zip(chunks_to_embed, embeddings) + for idx, ((msg_ord, chunk_ord), embedding) in enumerate( + zip(chunk_locations, chunk_embeddings) ): - # Get the current VectorBase size to determine the index position - current_size = len(self._vectorbase) index_position = current_size + idx insertion_data.append( (msg_ord, chunk_ord, serialize_embedding(embedding), index_position) diff --git a/src/typeagent/storage/sqlite/reltermsindex.py b/src/typeagent/storage/sqlite/reltermsindex.py index cf5b201b..0b01e585 100644 --- a/src/typeagent/storage/sqlite/reltermsindex.py +++ b/src/typeagent/storage/sqlite/reltermsindex.py @@ -213,15 +213,43 @@ async def add_terms(self, texts: list[str]) -> None: new_terms = [t for t in texts if t not in self._added_terms] if not new_terms: return + embeddings = await self._vector_base.get_embeddings(new_terms) + await self.add_terms_with_embeddings( + new_terms, + [embedding for embedding in embeddings], + ) - embeddings = await self._vector_base.add_keys(new_terms) - assert embeddings is not None + async def add_terms_with_embeddings( + self, + texts: list[str], + embeddings: list[NormalizedEmbedding], + ) -> None: + if len(texts) != len(embeddings): + raise ValueError( + "texts and embeddings must have the same length: " + f"{len(texts)} != {len(embeddings)}" + ) + + new_term_embeddings = [ + (term, embedding) + for term, embedding in zip(texts, embeddings) + if term not in self._added_terms + ] + if not new_term_embeddings: + return + + new_terms = [term for term, _ in new_term_embeddings] + new_embeddings = [embedding for _, embedding in new_term_embeddings] + embedding_array = np.stack(new_embeddings, axis=0).astype( + np.float32, copy=False + ) + self._vector_base.add_embeddings(new_terms, embedding_array) cursor = self.db.cursor() cursor.executemany( "INSERT OR REPLACE INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", [ - (term, serialize_embedding(embeddings[i])) + (term, serialize_embedding(new_embeddings[i])) for i, term in enumerate(new_terms) ], ) From 423eb570a100b555e76dbe3d6ea7abd74ffaa283 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 9 May 2026 22:37:13 -0700 Subject: [PATCH 06/42] Fix test (add more mock methods) --- AGENTS.md | 1 + tests/test_messageindex.py | 34 +++++++++++++++++++--------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 0623e0d1..db73bea6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -56,6 +56,7 @@ In all cases show what you added to AGENTS.md. - Name returned summary/value objects as `*Result`; reserve `*State` for mutable shared/internal state. - Keep internal helper type naming consistent within a module; avoid mixing underscored and non-underscored helper class names without a clear API-boundary reason. - Prefer ordinal type aliases (e.g., `MessageOrdinal`, `ChunkOrdinal`) over raw `int` in pipeline code for readability. +- When the user asks to "fix the test only", update tests/mocks first and avoid adding production compatibility fallbacks unless explicitly requested. ## Package Management with uv diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index 7e00cc45..54eea485 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -30,7 +30,16 @@ def mock_text_location_index() -> MagicMock: mock_index = MagicMock(spec=TextToTextLocationIndex) # Empty index, so first message starts at ordinal 0 mock_index.size = AsyncMock(return_value=0) + + def _fake_generate_embeddings( + texts: list[str], cache: bool = True + ) -> list[np.ndarray]: + del cache + return [np.array([1.0, 0.0, 0.0], dtype=np.float32) for _ in texts] + + mock_index.generate_embeddings = AsyncMock(side_effect=_fake_generate_embeddings) mock_index.add_text_locations = AsyncMock(return_value=None) + mock_index.add_text_locations_with_embeddings = AsyncMock(return_value=None) mock_index.lookup_text = AsyncMock(return_value=[]) mock_index.lookup_text_in_subset = AsyncMock(return_value=[]) mock_index.serialize = MagicMock(return_value={"mock": "data"}) @@ -79,25 +88,20 @@ async def test_add_messages( await message_text_index.add_messages(messages) - # Check that add_text_locations was called with the expected text and location data + # Check that add_text_locations_with_embeddings was called with expected locations + # and one embedding per chunk. mock_text_loc_index = cast( MagicMock, cast(MessageTextIndex, message_text_index).text_location_index ) - call_args = mock_text_loc_index.add_text_locations.call_args + call_args = mock_text_loc_index.add_text_locations_with_embeddings.call_args assert call_args is not None - text_and_locations = call_args[0][0] # First positional argument - assert ( - len(text_and_locations) == 3 - ) # Two chunks from first message, one from second - assert text_and_locations[0] == ( - "chunk1", - TextLocation(0, 0), - ) # First message starts at ordinal 0 - assert text_and_locations[1] == ("chunk2", TextLocation(0, 1)) - assert text_and_locations[2] == ( - "chunk3", - TextLocation(1, 0), - ) # Second message at ordinal 1 + text_locations = call_args[0][0] # First positional argument + embeddings = call_args[0][1] # Second positional argument + assert len(text_locations) == 3 # Two chunks from first message, one from second + assert text_locations[0] == TextLocation(0, 0) # First message starts at ordinal 0 + assert text_locations[1] == TextLocation(0, 1) + assert text_locations[2] == TextLocation(1, 0) # Second message at ordinal 1 + assert len(embeddings) == 3 @pytest.mark.asyncio From 90091b0559863c2dbe2f5a9e08365386aecd5e42 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 10 May 2026 11:05:38 -0700 Subject: [PATCH 07/42] Use a dispatcher task that spawns one-shot workers. --- src/typeagent/knowpro/add_messages.py | 87 +++++++++++++++------------ 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 78846ea3..f30cee5a 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -51,14 +51,14 @@ class ChunkWorkItem[TMessage: IMessage]: async def _producer_task[TMessage: IMessage]( messages: AsyncIterable[TMessage], chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], - worker_count: int, stop_state: PipelineStopState, producer_state: ProducerState, ) -> None: """Read input messages and enqueue chunk work items. The producer stops enqueueing once it reaches ``stop_at_message_id``. - It always sends one sentinel per worker, even if the input iterator raises. + It always sends a sentinel to shut down the dispatcher, even if the + input iterator raises. """ try: async for message in messages: @@ -85,62 +85,73 @@ async def _producer_task[TMessage: IMessage]( except Exception as exc: producer_state.exception = exc finally: - for _ in range(worker_count): - await chunk_queue.put(None) + await chunk_queue.put(None) -async def _worker_task[TMessage: IMessage]( +async def _dispatcher_task[TMessage: IMessage]( chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], - result_queue: asyncio.Queue[ChunkProcessingResult[TMessage]], + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None], stop_state: PipelineStopState, knowledge_extractor: IKnowledgeExtractor, message_embedding_model: IEmbeddingModel, related_terms_embedding_model: IEmbeddingModel | None = None, + concurrency: int = 4, ) -> None: - """Consume chunk work items and produce chunk processing results. + """Dispatch chunk work items to bounded per-item worker tasks. - Workers stop when they receive a ``None`` sentinel from ``chunk_queue``. - Chunks at or beyond ``stop_at_message_id`` are skipped and reported as - error results so downstream code can account for them deterministically. - """ - while True: - work_item = await chunk_queue.get() - if work_item is None: - return + Reads work items from ``chunk_queue`` until it receives a ``None`` + sentinel, then awaits all in-flight tasks via a TaskGroup and puts a + ``None`` sentinel on ``result_queue`` to signal the reassembler. - stop_at = stop_state.stop_at_message_id + Concurrency is bounded by a semaphore so at most ``concurrency`` worker + tasks run simultaneously. Chunks at or beyond ``stop_at_message_id`` are + skipped and reported as error results so the reassembler can account for + them deterministically. + """ + sem = asyncio.Semaphore(concurrency) - if work_item.chunk_id.message_ordinal >= stop_at: - await result_queue.put( - ChunkProcessingResult( + async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: + try: + stop_at = stop_state.stop_at_message_id + if work_item.chunk_id.message_ordinal >= stop_at: + result: ChunkProcessingResult[TMessage] = ChunkProcessingResult( chunk_id=work_item.chunk_id, chunk_count=work_item.chunk_count, message=work_item.message, error=RuntimeError( "Chunk skipped because stop_at_message_id " - f"is {stop_at} and message_id is {work_item.chunk_id.message_ordinal}" + f"is {stop_at} and message_id is " + f"{work_item.chunk_id.message_ordinal}" ), ) - ) - continue - - result = await process_chunk_with_extraction_and_embeddings( - chunk_id=work_item.chunk_id, - chunk_text=work_item.chunk_text, - chunk_count=work_item.chunk_count, - message=work_item.message, - knowledge_extractor=knowledge_extractor, - message_embedding_model=message_embedding_model, - related_terms_embedding_model=related_terms_embedding_model, - ) + else: + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=work_item.chunk_id, + chunk_text=work_item.chunk_text, + chunk_count=work_item.chunk_count, + message=work_item.message, + knowledge_extractor=knowledge_extractor, + message_embedding_model=message_embedding_model, + related_terms_embedding_model=related_terms_embedding_model, + ) + if result.error is not None: + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, + work_item.chunk_id.message_ordinal, + ) + await result_queue.put(result) + finally: + sem.release() - if result.error is not None: - stop_state.stop_at_message_id = min( - stop_state.stop_at_message_id, - work_item.chunk_id.message_ordinal, - ) + async with asyncio.TaskGroup() as tg: + while True: + item = await chunk_queue.get() + if item is None: + break + await sem.acquire() + tg.create_task(_process_one(item)) - await result_queue.put(result) + await result_queue.put(None) @dataclass From 2b320dda63e96ddc61a265f390a392e6f556179e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 10 May 2026 11:48:39 -0700 Subject: [PATCH 08/42] Chunk validation without try/except. --- AGENTS.md | 3 ++ src/typeagent/knowpro/add_messages.py | 71 +++++++++++++-------------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index db73bea6..1abacc3b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -51,10 +51,13 @@ In all cases show what you added to AGENTS.md. - In add-messages pipeline chunk processing, compute chunk-text embeddings with uncached model calls and related-term embeddings with cached model calls. - In add-messages pipeline flow, lower stop_at_message_id to min(existing, failing_message_id), and always enqueue queue-1 sentinels even when the input iterator fails so workers can drain and exit cleanly. - In add-messages pipeline data structures, use `TextLocation` as the chunk identifier instead of a formatted string chunk ID. +- In add-messages reassembler validation, prefer explicit guard checks over wrapping validation-only logic in `try/except` blocks. +- In add-messages reassembler validation, prefer a single `validation_error` variable with consistent `if/elif` checks over helper functions for simple message-only validation. - When adding precomputed-embedding write paths, expose explicit `*_with_embeddings` methods and have existing methods compute embeddings then delegate to those methods. - In asyncio code, avoid locks for in-memory state updates that do not `await` between read/modify/write; use locks only when a critical section spans `await` points. - Name returned summary/value objects as `*Result`; reserve `*State` for mutable shared/internal state. - Keep internal helper type naming consistent within a module; avoid mixing underscored and non-underscored helper class names without a clear API-boundary reason. +- Prefer variable names that reflect role rather than lifecycle; for accumulators like message assemblies, use neutral names (e.g., `assembly`) instead of state-qualified names (e.g., `existing`). - Prefer ordinal type aliases (e.g., `MessageOrdinal`, `ChunkOrdinal`) over raw `int` in pipeline code for readability. - When the user asks to "fix the test only", update tests/mocks first and avoid adding production compatibility fallbacks unless explicitly requested. diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index f30cee5a..7fbe604a 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -364,7 +364,7 @@ async def _commit_if_needed(force: bool = False) -> None: staged_results.clear() staged_chunks = 0 - async def _drain_consecutive_complete() -> None: + async def _drain_consecutive_complete(force: bool = False) -> None: nonlocal staged_chunks while True: assembly = assemblies.get(state.first_uncommitted_ordinal) @@ -381,7 +381,7 @@ async def _drain_consecutive_complete() -> None: del assemblies[state.first_uncommitted_ordinal] state.first_uncommitted_ordinal += 1 - await _commit_if_needed() + await _commit_if_needed(force) try: while True: @@ -392,45 +392,43 @@ async def _drain_consecutive_complete() -> None: chunk_ordinal = item.chunk_id.chunk_ordinal message_id = item.chunk_id.message_ordinal - try: - if chunk_ordinal < 0 or chunk_ordinal >= item.chunk_count: - raise RuntimeError( - f"Invalid chunk ordinal: message_id={message_id}, " - f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" - ) - - existing = assemblies.get(message_id) - if existing is None: - existing = MessageAssembly[TMessage]( - message_id=message_id, - chunk_count=item.chunk_count, - message=item.message, - chunks={}, - ) - assemblies[message_id] = existing - elif existing.chunk_count != item.chunk_count: - raise RuntimeError( - f"Mismatched chunk count for message: message_id={message_id}, " - f"expected={existing.chunk_count}, got={item.chunk_count}" - ) - - if chunk_ordinal in existing.chunks: - raise RuntimeError( - f"Duplicate chunk: message_id={message_id}, " - f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" - ) + validation_error: str | None = None + assembly = assemblies.get(message_id) + if chunk_ordinal < 0 or chunk_ordinal >= item.chunk_count: + validation_error = ( + f"Invalid chunk ordinal: message_id={message_id}, " + f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" + ) + elif assembly is None: + assembly = MessageAssembly[TMessage]( + message_id=message_id, + chunk_count=item.chunk_count, + message=item.message, + chunks={}, + ) + assemblies[message_id] = assembly + elif assembly.chunk_count != item.chunk_count: + validation_error = ( + f"Mismatched chunk count for message: message_id={message_id}, " + f"expected={assembly.chunk_count}, got={item.chunk_count}" + ) + elif chunk_ordinal in assembly.chunks: + validation_error = ( + f"Duplicate chunk: message_id={message_id}, " + f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" + ) - existing.chunks[chunk_ordinal] = item - except Exception: - # On validation error, set stop flag and re-raise - # The finally block will drain and commit consecutive complete messages + if validation_error is not None: stop_state.stop_at_message_id = min( stop_state.stop_at_message_id, message_id ) - raise + raise RuntimeError(validation_error) + + assert assembly is not None + assembly.chunks[chunk_ordinal] = item if item.error is not None: - existing.has_error = True + assembly.has_error = True state.chunk_failures += 1 stop_state.stop_at_message_id = min( stop_state.stop_at_message_id, message_id @@ -439,8 +437,7 @@ async def _drain_consecutive_complete() -> None: await _drain_consecutive_complete() finally: # Always drain and commit consecutive complete messages before raising - await _drain_consecutive_complete() - await _commit_if_needed(force=True) + await _drain_consecutive_complete(force=True) state.buffered_messages = len(assemblies) return state From 9f8f51d7b3c589727e2129931271373c7a8a2065 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 10 May 2026 12:02:33 -0700 Subject: [PATCH 09/42] Add _commit_batch_from_chunk_results to ConversationBase. --- AGENTS.md | 1 + src/typeagent/knowpro/conversation_base.py | 148 ++++++++++++++++++++- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index 1abacc3b..dbac2349 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -58,6 +58,7 @@ In all cases show what you added to AGENTS.md. - Name returned summary/value objects as `*Result`; reserve `*State` for mutable shared/internal state. - Keep internal helper type naming consistent within a module; avoid mixing underscored and non-underscored helper class names without a clear API-boundary reason. - Prefer variable names that reflect role rather than lifecycle; for accumulators like message assemblies, use neutral names (e.g., `assembly`) instead of state-qualified names (e.g., `existing`). +- Avoid potential import cycles between conversation orchestration and pipeline modules by using neutral payload protocols/arguments instead of importing concrete pipeline result classes across modules. - Prefer ordinal type aliases (e.g., `MessageOrdinal`, `ChunkOrdinal`) over raw `int` in pipeline code for readability. - When the user asks to "fix the test only", update tests/mocks first and avoid adding production compatibility fallbacks unless explicitly requested. diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 673695d2..d724b93b 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -8,7 +8,7 @@ import contextlib from dataclasses import dataclass from datetime import datetime, timezone -from typing import Generic, Self, TypeVar +from typing import Generic, Protocol, Self, TypeVar import typechat @@ -24,6 +24,7 @@ ) from . import knowledge_schema as kplib from ..aitools import model_adapters, utils +from ..aitools.embeddings import NormalizedEmbedding from ..storage.memory import semrefindex from .convsettings import ConversationSettings from .interfaces import ( @@ -46,6 +47,16 @@ TMessage = TypeVar("TMessage", bound=IMessage) +class _ChunkCommitResult(Protocol): + """Neutral chunk commit payload shape used by pipeline batch commit.""" + + chunk_id: TextLocation + extracted_knowledge: kplib.KnowledgeResponse | None + chunk_embedding: NormalizedEmbedding | None + related_terms: list[str] | None + related_term_embeddings: list[NormalizedEmbedding] | None + + @dataclass(frozen=True) class _ExtractionResult: """Pre-extracted knowledge for a batch, ready to commit.""" @@ -427,6 +438,141 @@ async def _commit_batch_streaming( - start_points.semref_count, ) + async def _commit_batch_from_chunk_results( + self, + storage: IStorageProvider[TMessage], + messages_batch: list[TMessage], + chunk_results: list[_ChunkCommitResult], + ) -> AddMessagesResult: + """Commit one pipeline batch using precomputed extraction and embeddings.""" + if not messages_batch: + return AddMessagesResult() + + async with storage: + start_points = IndexingStartPoints( + message_count=await self.messages.size(), + semref_count=await self.semantic_refs.size(), + ) + + await self.messages.extend(messages_batch) + + source_ids = [m.source_id for m in messages_batch if m.source_id is not None] + if source_ids: + await storage.mark_sources_ingested_batch(source_ids) + + await self._add_metadata_knowledge_incremental(start_points.message_count) + + knowledge_items: list[tuple[MessageOrdinal, int, kplib.KnowledgeResponse]] = [] + chunk_embeddings: list[NormalizedEmbedding] = [] + fuzzy_terms: list[str] = [] + fuzzy_term_embeddings: list[NormalizedEmbedding] = [] + + for result in chunk_results: + if result.chunk_embedding is None: + raise ValueError( + "Chunk result missing chunk embedding for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + chunk_embeddings.append(result.chunk_embedding) + + if result.extracted_knowledge is None: + raise ValueError( + "Chunk result missing extracted knowledge for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + knowledge_items.append( + ( + result.chunk_id.message_ordinal, + result.chunk_id.chunk_ordinal, + result.extracted_knowledge, + ) + ) + + if result.related_terms is None or result.related_term_embeddings is None: + raise ValueError( + "Chunk result missing related-term embeddings for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + if len(result.related_terms) != len(result.related_term_embeddings): + raise ValueError( + "related_terms and related_term_embeddings length mismatch for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}: " + f"{len(result.related_terms)} != " + f"{len(result.related_term_embeddings)}" + ) + fuzzy_terms.extend(result.related_terms) + fuzzy_term_embeddings.extend(result.related_term_embeddings) + + await semrefindex.add_knowledge_batch_to_semantic_ref_index( + self, + knowledge_items, + ) + + await self._update_secondary_indexes_incremental_with_embeddings( + start_points, + messages_batch, + chunk_embeddings, + fuzzy_terms, + fuzzy_term_embeddings, + ) + + await storage.update_conversation_timestamps( + updated_at=datetime.now(timezone.utc) + ) + + messages_added = await self.messages.size() - start_points.message_count + chunks_added = sum( + len(message.text_chunks) for message in messages_batch[:messages_added] + ) + return AddMessagesResult( + messages_added=messages_added, + chunks_added=chunks_added, + semrefs_added=await self.semantic_refs.size() + - start_points.semref_count, + ) + + async def _update_secondary_indexes_incremental_with_embeddings( + self, + start_points: IndexingStartPoints, + new_messages: list[TMessage], + chunk_embeddings: list[NormalizedEmbedding], + related_terms: list[str], + related_term_embeddings: list[NormalizedEmbedding], + ) -> None: + """Update secondary indexes using precomputed embeddings when available.""" + if self.secondary_indexes is None: + return + + from ..storage.memory import propindex + + await propindex.add_to_property_index(self, start_points.semref_count) + + await self._add_timestamps_for_messages( + new_messages, + start_points.message_count, + ) + + term_to_related = self.secondary_indexes.term_to_related_terms_index + if term_to_related is not None: + fuzzy_index = term_to_related.fuzzy_index + if fuzzy_index is not None and related_terms: + await fuzzy_index.add_terms_with_embeddings( + related_terms, + related_term_embeddings, + ) + + message_index = self.secondary_indexes.message_index + if message_index is not None: + await message_index.add_messages_starting_at_with_embeddings( + start_points.message_count, + new_messages, + chunk_embeddings, + ) + async def _add_metadata_knowledge_incremental( self, start_from_message_ordinal: int, From 4cae56cc0e294fab607be052dd6afcd01ce990d4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 10 May 2026 17:48:12 -0700 Subject: [PATCH 10/42] Reformat conversation_base.py --- src/typeagent/knowpro/conversation_base.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index d724b93b..39c98734 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -456,13 +456,17 @@ async def _commit_batch_from_chunk_results( await self.messages.extend(messages_batch) - source_ids = [m.source_id for m in messages_batch if m.source_id is not None] + source_ids = [ + m.source_id for m in messages_batch if m.source_id is not None + ] if source_ids: await storage.mark_sources_ingested_batch(source_ids) await self._add_metadata_knowledge_incremental(start_points.message_count) - knowledge_items: list[tuple[MessageOrdinal, int, kplib.KnowledgeResponse]] = [] + knowledge_items: list[ + tuple[MessageOrdinal, int, kplib.KnowledgeResponse] + ] = [] chunk_embeddings: list[NormalizedEmbedding] = [] fuzzy_terms: list[str] = [] fuzzy_term_embeddings: list[NormalizedEmbedding] = [] @@ -490,7 +494,10 @@ async def _commit_batch_from_chunk_results( ) ) - if result.related_terms is None or result.related_term_embeddings is None: + if ( + result.related_terms is None + or result.related_term_embeddings is None + ): raise ValueError( "Chunk result missing related-term embeddings for " f"message={result.chunk_id.message_ordinal}, " From 7f484a6de5ce54a2186d97a554e2f7e2c50fd315 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 10 May 2026 18:27:56 -0700 Subject: [PATCH 11/42] Add tests, 100% coverage. Fix one thing in add_messages.py. --- Makefile | 6 +- src/typeagent/knowpro/add_messages.py | 16 +- tests/test_add_messages_pipeline.py | 771 ++++++++++++++++++++++++++ 3 files changed, 779 insertions(+), 14 deletions(-) create mode 100644 tests/test_add_messages_pipeline.py diff --git a/Makefile b/Makefile index c2aeb33b..84a609d7 100644 --- a/Makefile +++ b/Makefile @@ -21,10 +21,10 @@ test: venv .PHONY: coverage coverage: venv - coverage erase + uv run coverage erase COVERAGE_PROCESS_START=.coveragerc uv run coverage run -m pytest $(FLAGS) - coverage combine - coverage report + uv run coverage combine + uv run coverage report .PHONY: demo demo: venv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 7fbe604a..86eca178 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -251,9 +251,7 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( embeddings, or an error from the first failed operation. """ result = ChunkProcessingResult( - chunk_id=chunk_id, - chunk_count=chunk_count, - message=message, + chunk_id=chunk_id, chunk_count=chunk_count, message=message ) # Step 1: Extract knowledge @@ -284,9 +282,7 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( chunk_text ) if result.related_terms: - rel_embeddings = await related_model.get_embeddings( - result.related_terms, - ) + rel_embeddings = await related_model.get_embeddings(result.related_terms) result.related_term_embeddings = [e for e in rel_embeddings] else: result.related_term_embeddings = [] @@ -329,8 +325,7 @@ async def _reassembler_task[TMessage: IMessage]( first_uncommitted_ordinal: MessageOrdinal, target_commit_chunk_count: int, commit_batch: Callable[ - [list[TMessage], list[ChunkProcessingResult[TMessage]]], - Awaitable[None], + [list[TMessage], list[ChunkProcessingResult[TMessage]]], Awaitable[None] ], on_batch_committed: Callable[[int, int], None] | None = None, ) -> ReassemblerResult: @@ -368,9 +363,8 @@ async def _drain_consecutive_complete(force: bool = False) -> None: nonlocal staged_chunks while True: assembly = assemblies.get(state.first_uncommitted_ordinal) - if assembly is None: - return - if not assembly.is_complete() or assembly.has_error: + if assembly is None or not assembly.is_complete() or assembly.has_error: + await _commit_if_needed(force) return ordered_chunk_ordinals = sorted(assembly.chunks) diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py new file mode 100644 index 00000000..f33ba4d6 --- /dev/null +++ b/tests/test_add_messages_pipeline.py @@ -0,0 +1,771 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from dataclasses import dataclass, field + +import numpy as np +import pytest + +import typechat + +from typeagent.aitools.embeddings import ( + NormalizedEmbedding, + NormalizedEmbeddings, +) +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.add_messages import ( + _collect_related_terms_for_fuzzy_index, + _dispatcher_task, + _producer_task, + _reassembler_task, + ChunkProcessingResult, + ChunkWorkItem, + PipelineStopState, + process_chunk_with_extraction_and_embeddings, + ProducerState, +) +from typeagent.knowpro.interfaces_core import DeletionInfo, IMessageMetadata, TextLocation + + +@dataclass +class _Message: + text_chunks: list[str] + tags: list[str] = field(default_factory=list) + timestamp: str | None = None + deletion_info: DeletionInfo | None = None + metadata: IMessageMetadata | None = None + source_id: str | None = None + + def get_knowledge(self) -> kplib.KnowledgeResponse: + return _empty_knowledge() + + +class _SequenceExtractor: + def __init__( + self, + outputs: list[typechat.Result[kplib.KnowledgeResponse] | Exception], + ) -> None: + self._outputs = outputs + self.calls: list[str] = [] + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + self.calls.append(message) + output = self._outputs[len(self.calls) - 1] + if isinstance(output, Exception): + raise output + return output + + +class _StubEmbeddingModel: + def __init__( + self, + *, + chunk_error: Exception | None = None, + related_error: Exception | None = None, + ) -> None: + self.chunk_error = chunk_error + self.related_error = related_error + self.chunk_calls: list[str] = [] + self.related_calls: list[list[str]] = [] + self._cache: dict[str, NormalizedEmbedding] = {} + + @property + def model_name(self) -> str: + return "test-embedding" + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + self._cache[key] = embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + self.chunk_calls.append(input) + if self.chunk_error is not None: + raise self.chunk_error + return _embedding([1.0, 0.0]) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + self.related_calls.append(input) + if self.related_error is not None: + raise self.related_error + return np.array([_embedding([0.0, 1.0]) for _ in input], dtype=np.float32) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self.get_embedding_nocache(key) + self._cache[key] = embedding + return embedding + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + if not keys: + raise ValueError("Cannot embed an empty list") + output: list[NormalizedEmbedding] = [] + missing: list[str] = [] + for key in keys: + cached = self._cache.get(key) + if cached is None: + missing.append(key) + else: + output.append(cached) + if missing: + fresh = await self.get_embeddings_nocache(missing) + for index, key in enumerate(missing): + self._cache[key] = fresh[index] + return np.array([self._cache[key] for key in keys], dtype=np.float32) + + +class _FailingAsyncMessages: + def __init__(self, messages: list[_Message], error: Exception) -> None: + self.messages = messages + self.error = error + + def __aiter__(self) -> AsyncIterator[_Message]: + return self._iter() + + async def _iter(self) -> AsyncIterator[_Message]: + for message in self.messages: + yield message + raise self.error + + +class _StopMutatingChunks(list[str]): + """Iterable that lowers stop_at_message_id after yielding first chunk.""" + + def __init__(self, stop_state: PipelineStopState, chunks: list[str]) -> None: + super().__init__(chunks) + self._stop_state = stop_state + + def __iter__(self): + for index, chunk in enumerate(super().__iter__()): + if index == 1: + self._stop_state.stop_at_message_id = 0 + yield chunk + + +def _embedding(values: list[float]) -> NormalizedEmbedding: + return np.array(values, dtype=np.float32) + + +def _empty_knowledge() -> kplib.KnowledgeResponse: + return kplib.KnowledgeResponse( + entities=[], actions=[], inverse_actions=[], topics=[] + ) + + +def _knowledge_with_terms() -> kplib.KnowledgeResponse: + entity = kplib.ConcreteEntity( + name=" Alice ", + type=["Person"], + facets=[kplib.Facet(name="Role", value="Engineer")], + ) + action = kplib.Action( + verbs=["Mentors"], + verb_tense="present", + subject_entity_name="Alice", + object_entity_name="Bob", + ) + return kplib.KnowledgeResponse( + entities=[entity], + actions=[action], + inverse_actions=[], + topics=["ALICE", " Mentorship "], + ) + + +async def _drain_result_queue( + queue: asyncio.Queue[ChunkProcessingResult[_Message] | None], +) -> list[ChunkProcessingResult[_Message] | None]: + items: list[ChunkProcessingResult[_Message] | None] = [] + while True: + item = await queue.get() + items.append(item) + if item is None: + return items + + +@pytest.mark.asyncio +async def test_collect_related_terms_lowercases_dedupes_and_preserves_order() -> None: + knowledge = _knowledge_with_terms() + + terms = _collect_related_terms_for_fuzzy_index(knowledge) + + assert terms[0] == "alice" + assert "mentorship" in terms + assert len(terms) == len(set(terms)) + + +@pytest.mark.asyncio +async def test_process_chunk_success_with_related_terms() -> None: + extractor = _SequenceExtractor([typechat.Success(_knowledge_with_terms())]) + message_model = _StubEmbeddingModel() + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + message_embedding_model=message_model, + ) + + assert result.error is None + assert result.success + assert result.extracted_knowledge is not None + assert result.chunk_embedding is not None + assert result.related_terms is not None + assert result.related_term_embeddings is not None + assert len(result.related_terms) == len(result.related_term_embeddings) + assert message_model.chunk_calls == ["hello"] + assert len(message_model.related_calls) == 1 + + +@pytest.mark.asyncio +async def test_process_chunk_extraction_failure_skips_embedding_calls() -> None: + extractor = _SequenceExtractor([typechat.Failure("bad extraction")]) + message_model = _StubEmbeddingModel() + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + message_embedding_model=message_model, + ) + + assert result.error is not None + assert "Knowledge extraction failed" in str(result.error) + assert message_model.chunk_calls == [] + assert message_model.related_calls == [] + + +@pytest.mark.asyncio +async def test_process_chunk_extraction_exception_returns_error() -> None: + extractor = _SequenceExtractor([RuntimeError("extract boom")]) + message_model = _StubEmbeddingModel() + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + message_embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "extract boom" in str(result.error) + + +@pytest.mark.asyncio +async def test_process_chunk_chunk_embedding_exception_returns_error() -> None: + extractor = _SequenceExtractor([typechat.Success(_empty_knowledge())]) + message_model = _StubEmbeddingModel(chunk_error=RuntimeError("embed boom")) + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + message_embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "embed boom" in str(result.error) + + +@pytest.mark.asyncio +async def test_process_chunk_related_term_embedding_exception_returns_error() -> None: + extractor = _SequenceExtractor([typechat.Success(_knowledge_with_terms())]) + message_model = _StubEmbeddingModel(related_error=RuntimeError("related boom")) + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + message_embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "related boom" in str(result.error) + + +@pytest.mark.asyncio +async def test_producer_enqueues_all_chunks_and_sentinel() -> None: + messages = [_Message(["a", "b"]), _Message(["c"])] + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=0) + + async def _iter_messages() -> AsyncIterator[_Message]: + for message in messages: + yield message + + await _producer_task(_iter_messages(), queue, stop_state, producer_state) + + items: list[ChunkWorkItem[_Message] | None] = [ + await queue.get(), + await queue.get(), + await queue.get(), + await queue.get(), + ] + + assert [item.chunk_id for item in items[:-1] if item is not None] == [ + TextLocation(0, 0), + TextLocation(0, 1), + TextLocation(1, 0), + ] + assert items[-1] is None + assert producer_state.produced_messages == 2 + assert producer_state.produced_chunks == 3 + assert producer_state.exception is None + + +@pytest.mark.asyncio +async def test_producer_stops_at_stop_marker() -> None: + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState(stop_at_message_id=1) + producer_state = ProducerState(next_message_id=0) + + async def _iter_messages() -> AsyncIterator[_Message]: + yield _Message(["a"]) + yield _Message(["b"]) + + await _producer_task(_iter_messages(), queue, stop_state, producer_state) + + first = await queue.get() + sentinel = await queue.get() + + assert first is not None + assert first.chunk_id == TextLocation(0, 0) + assert sentinel is None + assert producer_state.produced_messages == 1 + + +@pytest.mark.asyncio +async def test_producer_sets_exception_and_still_sends_sentinel() -> None: + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=0) + + failing_iter = _FailingAsyncMessages( + [_Message(["a"])], + RuntimeError("input boom"), + ) + + await _producer_task(failing_iter, queue, stop_state, producer_state) + + first = await queue.get() + sentinel = await queue.get() + + assert first is not None + assert first.chunk_id == TextLocation(0, 0) + assert sentinel is None + assert isinstance(producer_state.exception, RuntimeError) + assert "input boom" in str(producer_state.exception) + + +@pytest.mark.asyncio +async def test_producer_breaks_inside_chunk_loop_when_stop_marker_changes() -> None: + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=0) + + message = _Message(["a", "b", "c"]) + message.text_chunks = _StopMutatingChunks(stop_state, ["a", "b", "c"]) + + async def _iter_messages() -> AsyncIterator[_Message]: + yield message + + await _producer_task(_iter_messages(), queue, stop_state, producer_state) + + first = await queue.get() + sentinel = await queue.get() + + assert first is not None + assert first.chunk_id == TextLocation(0, 0) + assert sentinel is None + assert producer_state.produced_chunks == 1 + + +@pytest.mark.asyncio +async def test_dispatcher_stops_on_sentinel_and_emits_result_sentinel() -> None: + chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["hello"]) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(0, 0), + chunk_count=1, + chunk_text="hello", + message=message, + ) + ) + await chunk_queue.put(None) + + extractor = _SequenceExtractor([typechat.Success(_empty_knowledge())]) + model = _StubEmbeddingModel() + + await _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + extractor, + model, + concurrency=2, + ) + + items = await _drain_result_queue(result_queue) + assert len(items) == 2 + assert items[-1] is None + assert items[0] is not None + assert items[0].error is None + + +@pytest.mark.asyncio +async def test_dispatcher_lowers_stop_and_skips_later_messages() -> None: + chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + m0 = _Message(["first"]) + m1 = _Message(["second"]) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(0, 0), + chunk_count=1, + chunk_text="first", + message=m0, + ) + ) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(1, 0), + chunk_count=1, + chunk_text="second", + message=m1, + ) + ) + await chunk_queue.put(None) + + extractor = _SequenceExtractor([typechat.Failure("first failed")]) + model = _StubEmbeddingModel() + + await _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + extractor, + model, + concurrency=1, + ) + + items = await _drain_result_queue(result_queue) + first = items[0] + second = items[1] + + assert first is not None + assert first.error is not None + assert "Knowledge extraction failed" in str(first.error) + + assert second is not None + assert second.error is not None + assert "Chunk skipped because stop_at_message_id is 0" in str(second.error) + + assert stop_state.stop_at_message_id == 0 + assert extractor.calls == ["first"] + + +def _chunk_result( + message: _Message, + message_ordinal: int, + chunk_ordinal: int, + chunk_count: int, + *, + error: Exception | None = None, +) -> ChunkProcessingResult[_Message]: + return ChunkProcessingResult( + chunk_id=TextLocation(message_ordinal, chunk_ordinal), + chunk_count=chunk_count, + message=message, + error=error, + ) + + +@pytest.mark.asyncio +async def test_reassembler_commits_out_of_order_after_gap_is_filled() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + m0 = _Message(["m0"]) + m1 = _Message(["m1"]) + await result_queue.put(_chunk_result(m1, 1, 0, 1)) + await result_queue.put(_chunk_result(m0, 0, 0, 1)) + await result_queue.put(None) + + committed_batches: list[tuple[int, int]] = [] + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + committed_batches.append((len(messages), len(results))) + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=10, + commit_batch=_commit, + ) + + assert committed_batches == [(2, 2)] + assert state.messages_committed == 2 + assert state.chunks_committed == 2 + assert state.first_uncommitted_ordinal == 2 + assert state.buffered_messages == 0 + + +@pytest.mark.asyncio +async def test_reassembler_marks_failure_and_blocks_later_commits() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + m0 = _Message(["m0"]) + m1 = _Message(["m1"]) + await result_queue.put(_chunk_result(m0, 0, 0, 1, error=RuntimeError("boom"))) + await result_queue.put(_chunk_result(m1, 1, 0, 1)) + await result_queue.put(None) + + commit_calls = 0 + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + nonlocal commit_calls + commit_calls += 1 + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + ) + + assert commit_calls == 0 + assert state.chunk_failures == 1 + assert state.messages_committed == 0 + assert state.buffered_messages == 2 + assert stop_state.stop_at_message_id == 0 + + +@pytest.mark.asyncio +async def test_reassembler_force_commits_small_staged_tail() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["m0"]) + await result_queue.put(_chunk_result(message, 0, 0, 1)) + await result_queue.put(None) + + commit_calls = 0 + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + nonlocal commit_calls + commit_calls += 1 + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=99, + commit_batch=_commit, + ) + + assert commit_calls == 1 + assert state.messages_committed == 1 + assert state.chunks_committed == 1 + + +@pytest.mark.asyncio +async def test_reassembler_raises_on_invalid_chunk_ordinal_and_sets_stop_marker() -> ( + None +): + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["m0", "m0b"]) + await result_queue.put(_chunk_result(message, 3, 2, 2)) + await result_queue.put(None) + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + return None + + with pytest.raises(RuntimeError, match="Invalid chunk ordinal"): + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + ) + + assert stop_state.stop_at_message_id == 3 + + +@pytest.mark.asyncio +async def test_reassembler_raises_on_duplicate_chunk_and_sets_stop_marker() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["m1-a", "m1-b"]) + await result_queue.put(_chunk_result(message, 5, 0, 2)) + await result_queue.put(_chunk_result(message, 5, 0, 2)) + await result_queue.put(None) + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + return None + + with pytest.raises(RuntimeError, match="Duplicate chunk"): + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + ) + + assert stop_state.stop_at_message_id == 5 + + +@pytest.mark.asyncio +async def test_reassembler_on_batch_committed_callback_is_invoked() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["m0"]) + await result_queue.put(_chunk_result(message, 0, 0, 1)) + await result_queue.put(None) + + callback_calls: list[tuple[int, int]] = [] + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + return None + + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + on_batch_committed=lambda msg_count, chunk_count: callback_calls.append( + (msg_count, chunk_count) + ), + ) + + assert callback_calls == [(1, 1)] + + +@pytest.mark.asyncio +async def test_reassembler_raises_on_mismatched_chunk_count_and_sets_stop_marker() -> ( + None +): + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["m0-a", "m0-b", "m0-c"]) + await result_queue.put(_chunk_result(message, 4, 0, 2)) + await result_queue.put(_chunk_result(message, 4, 1, 3)) + await result_queue.put(None) + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + return None + + with pytest.raises(RuntimeError, match="Mismatched chunk count"): + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + ) + + assert stop_state.stop_at_message_id == 4 + + +@pytest.mark.asyncio +async def test_reassembler_handles_existing_assembly_non_duplicate_chunk() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + message = _Message(["m0-a", "m0-b"]) + await result_queue.put(_chunk_result(message, 0, 0, 2)) + await result_queue.put(_chunk_result(message, 0, 1, 2)) + await result_queue.put(None) + + commit_calls = 0 + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + nonlocal commit_calls + commit_calls += 1 + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + ) + + assert commit_calls == 1 + assert state.messages_committed == 1 + assert state.chunks_committed == 2 From c45a81463ae1d032ef6f7e82f810efbb8dc04799 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 11 May 2026 14:14:34 -0700 Subject: [PATCH 12/42] Add the new add_messages_streaming() --- src/typeagent/knowpro/add_messages.py | 78 ++++++++++++++++++++++ src/typeagent/knowpro/conversation_base.py | 2 +- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 86eca178..88501841 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -6,14 +6,19 @@ import asyncio from collections.abc import AsyncIterable, Awaitable, Callable from dataclasses import dataclass +from typing import TYPE_CHECKING import typechat from . import knowledge_schema as kplib from ..aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from ..storage.memory.semrefindex import collect_action_terms, collect_entity_terms +from .interfaces import AddMessagesResult from .interfaces_core import IKnowledgeExtractor, IMessage, MessageOrdinal, TextLocation +if TYPE_CHECKING: + from .conversation_base import ConversationBase + type ChunkOrdinal = int @@ -435,3 +440,76 @@ async def _drain_consecutive_complete(force: bool = False) -> None: state.buffered_messages = len(assemblies) return state + + +async def add_messages_streaming[TMessage: IMessage]( + conv: "ConversationBase[TMessage]", + messages: AsyncIterable[TMessage], + *, + batch_size: int = 100, + on_batch_committed: Callable[[AddMessagesResult], None] | None = None, +) -> AddMessagesResult: + """Add messages using the pipelined extraction+embedding architecture.""" + from . import convknowledge + + settings = conv.settings + sem_ref_settings = settings.semantic_ref_index_settings + storage = await settings.get_storage_provider() + knowledge_extractor = ( + sem_ref_settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) + embedding_model = settings.embedding_model + + initial_message_id: MessageOrdinal = await conv.messages.size() + + total = AddMessagesResult() + + def _accumulate(result: AddMessagesResult) -> None: + total.messages_added += result.messages_added + total.semrefs_added += result.semrefs_added + total.chunks_added += result.chunks_added + if on_batch_committed: + on_batch_committed(result) + + async def _commit_batch( + messages_batch: list[TMessage], + chunk_results: list[ChunkProcessingResult[TMessage]], + ) -> None: + result = await conv._commit_batch_from_chunk_results( + storage, messages_batch, chunk_results + ) + _accumulate(result) + + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None] = asyncio.Queue() + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=initial_message_id) + + async with asyncio.TaskGroup() as tg: + tg.create_task( + _producer_task(messages, chunk_queue, stop_state, producer_state) + ) + tg.create_task( + _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + knowledge_extractor, + embedding_model, + concurrency=sem_ref_settings.concurrency, + ) + ) + tg.create_task( + _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=initial_message_id, + target_commit_chunk_count=batch_size, + commit_batch=_commit_batch, + ) + ) + + if producer_state.exception is not None: + raise producer_state.exception + + return total diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 39c98734..a5f9f45a 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -442,7 +442,7 @@ async def _commit_batch_from_chunk_results( self, storage: IStorageProvider[TMessage], messages_batch: list[TMessage], - chunk_results: list[_ChunkCommitResult], + chunk_results: Sequence[_ChunkCommitResult], ) -> AddMessagesResult: """Commit one pipeline batch using precomputed extraction and embeddings.""" if not messages_batch: From 4558ebfee58b155d5cf893c61eff5f875d62bccf Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 12 May 2026 09:36:24 -0700 Subject: [PATCH 13/42] Treat extraction Failure as a hard error, same as raised exceptions Previously typechat.Failure from the extractor was a soft error: the message was still committed (with no knowledge) and the failure recorded. Since LLM responses are non-deterministic, a Failure is just as unreliable as a raised exception, so both now stop the pipeline at the failing message and propagate the error. - Remove extraction_failure_msg from ChunkProcessingResult and _ChunkCommitResult; simplify _commit_batch_from_chunk_results - Keep stop_state.exception in sync with stop_at_message_id so it always reflects the lowest-ordinal failing message - Update tests accordingly Co-Authored-By: Claude Sonnet 4.6 --- src/typeagent/knowpro/add_messages.py | 89 +++++++- src/typeagent/knowpro/conversation_base.py | 13 +- tests/test_add_messages_pipeline.py | 28 ++- tests/test_add_messages_streaming.py | 244 +++++---------------- 4 files changed, 156 insertions(+), 218 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 88501841..9d822560 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -21,6 +21,17 @@ type ChunkOrdinal = int +_EMPTY_KNOWLEDGE = kplib.KnowledgeResponse( + entities=[], actions=[], inverse_actions=[], topics=[] +) + + +class _NoOpKnowledgeExtractor: + """No-op extractor used when auto_extract_knowledge is False.""" + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + return typechat.Success(_EMPTY_KNOWLEDGE) + @dataclass class PipelineStopState: @@ -28,9 +39,14 @@ class PipelineStopState: A message ordinal greater than or equal to ``stop_at_message_id`` is considered out-of-scope for further processing. + + ``exception`` holds the error from the lowest-ordinal message that + caused the stop, so the orchestrator can re-raise it after the pipeline + drains. """ stop_at_message_id: int = 10**100 + exception: Exception | None = None @dataclass @@ -58,6 +74,7 @@ async def _producer_task[TMessage: IMessage]( chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], stop_state: PipelineStopState, producer_state: ProducerState, + result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"] | None = None, ) -> None: """Read input messages and enqueue chunk work items. @@ -72,6 +89,21 @@ async def _producer_task[TMessage: IMessage]( break chunk_count = len(message.text_chunks) + if chunk_count == 0: + # Zero-chunk message: nothing for the dispatcher to process. + # Emit a zero-chunk result directly to the reassembler. + if result_queue is not None: + await result_queue.put( + ChunkProcessingResult[TMessage]( + chunk_id=TextLocation(message_id, 0), + chunk_count=0, + message=message, + ) + ) + producer_state.produced_messages += 1 + producer_state.next_message_id += 1 + continue + for chunk_ordinal, chunk_text in enumerate(message.text_chunks): if message_id >= stop_state.stop_at_message_id: break @@ -140,10 +172,15 @@ async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: related_terms_embedding_model=related_terms_embedding_model, ) if result.error is not None: - stop_state.stop_at_message_id = min( + new_stop = min( stop_state.stop_at_message_id, work_item.chunk_id.message_ordinal, ) + if new_stop < stop_state.stop_at_message_id: + stop_state.stop_at_message_id = new_stop + stop_state.exception = result.error + elif stop_state.exception is None: + stop_state.exception = result.error await result_queue.put(result) finally: sem.release() @@ -265,18 +302,16 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( if isinstance(knowledge_result, typechat.Success): result.extracted_knowledge = knowledge_result.value else: - # Extraction returned a Failure; treat as error and stop result.error = RuntimeError( f"Knowledge extraction failed: {knowledge_result.message}" ) return result except Exception as e: - # Extraction raised an exception; stop processing result.error = e return result result.related_terms = _collect_related_terms_for_fuzzy_index( - result.extracted_knowledge + result.extracted_knowledge # type: ignore[arg-type] ) related_model = related_terms_embedding_model or message_embedding_model @@ -372,6 +407,14 @@ async def _drain_consecutive_complete(force: bool = False) -> None: await _commit_if_needed(force) return + # Pre-flush: if staging this message would push staged_chunks past + # the target, commit the current batch first. + if ( + staged_messages + and staged_chunks + assembly.chunk_count > target_commit_chunk_count + ): + await _commit_if_needed(force=True) + ordered_chunk_ordinals = sorted(assembly.chunks) ordered_results = [assembly.chunks[i] for i in ordered_chunk_ordinals] staged_messages.append(assembly.message) @@ -393,7 +436,17 @@ async def _drain_consecutive_complete(force: bool = False) -> None: validation_error: str | None = None assembly = assemblies.get(message_id) - if chunk_ordinal < 0 or chunk_ordinal >= item.chunk_count: + if item.chunk_count == 0: + # Zero-chunk message: create an immediately-complete assembly. + if assembly is None: + assembly = MessageAssembly[TMessage]( + message_id=message_id, + chunk_count=0, + message=item.message, + chunks={}, + ) + assemblies[message_id] = assembly + elif chunk_ordinal < 0 or chunk_ordinal >= item.chunk_count: validation_error = ( f"Invalid chunk ordinal: message_id={message_id}, " f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" @@ -423,10 +476,12 @@ async def _drain_consecutive_complete(force: bool = False) -> None: ) raise RuntimeError(validation_error) - assert assembly is not None - assembly.chunks[chunk_ordinal] = item + if item.chunk_count > 0: + assert assembly is not None + assembly.chunks[chunk_ordinal] = item if item.error is not None: + assert assembly is not None assembly.has_error = True state.chunk_failures += 1 stop_state.stop_at_message_id = min( @@ -455,9 +510,12 @@ async def add_messages_streaming[TMessage: IMessage]( settings = conv.settings sem_ref_settings = settings.semantic_ref_index_settings storage = await settings.get_storage_provider() - knowledge_extractor = ( - sem_ref_settings.knowledge_extractor or convknowledge.KnowledgeExtractor() - ) + if sem_ref_settings.auto_extract_knowledge: + knowledge_extractor: IKnowledgeExtractor = ( + sem_ref_settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) + else: + knowledge_extractor = _NoOpKnowledgeExtractor() embedding_model = settings.embedding_model initial_message_id: MessageOrdinal = await conv.messages.size() @@ -481,13 +539,17 @@ async def _commit_batch( _accumulate(result) chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None] = asyncio.Queue() - result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None] = asyncio.Queue() + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None] = ( + asyncio.Queue() + ) stop_state = PipelineStopState() producer_state = ProducerState(next_message_id=initial_message_id) async with asyncio.TaskGroup() as tg: tg.create_task( - _producer_task(messages, chunk_queue, stop_state, producer_state) + _producer_task( + messages, chunk_queue, stop_state, producer_state, result_queue + ) ) tg.create_task( _dispatcher_task( @@ -512,4 +574,7 @@ async def _commit_batch( if producer_state.exception is not None: raise producer_state.exception + if stop_state.exception is not None: + raise stop_state.exception + return total diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index a5f9f45a..03cd400c 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -51,6 +51,7 @@ class _ChunkCommitResult(Protocol): """Neutral chunk commit payload shape used by pipeline batch commit.""" chunk_id: TextLocation + chunk_count: int extracted_knowledge: kplib.KnowledgeResponse | None chunk_embedding: NormalizedEmbedding | None related_terms: list[str] | None @@ -472,6 +473,9 @@ async def _commit_batch_from_chunk_results( fuzzy_term_embeddings: list[NormalizedEmbedding] = [] for result in chunk_results: + if result.chunk_count == 0: + continue + if result.chunk_embedding is None: raise ValueError( "Chunk result missing chunk embedding for " @@ -572,13 +576,8 @@ async def _update_secondary_indexes_incremental_with_embeddings( related_term_embeddings, ) - message_index = self.secondary_indexes.message_index - if message_index is not None: - await message_index.add_messages_starting_at_with_embeddings( - start_points.message_count, - new_messages, - chunk_embeddings, - ) + # The message text index is already populated by messages.extend() during + # the commit, so no explicit update is needed here. async def _add_metadata_knowledge_incremental( self, diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index f33ba4d6..b88cfe84 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -28,7 +28,11 @@ process_chunk_with_extraction_and_embeddings, ProducerState, ) -from typeagent.knowpro.interfaces_core import DeletionInfo, IMessageMetadata, TextLocation +from typeagent.knowpro.interfaces_core import ( + DeletionInfo, + IMessageMetadata, + TextLocation, +) @dataclass @@ -224,7 +228,8 @@ async def test_process_chunk_success_with_related_terms() -> None: @pytest.mark.asyncio -async def test_process_chunk_extraction_failure_skips_embedding_calls() -> None: +async def test_process_chunk_extraction_failure_returns_error() -> None: + """A Failure result from the extractor sets error and skips embedding.""" extractor = _SequenceExtractor([typechat.Failure("bad extraction")]) message_model = _StubEmbeddingModel() @@ -237,10 +242,10 @@ async def test_process_chunk_extraction_failure_skips_embedding_calls() -> None: message_embedding_model=message_model, ) - assert result.error is not None - assert "Knowledge extraction failed" in str(result.error) + assert isinstance(result.error, RuntimeError) + assert "bad extraction" in str(result.error) + assert result.extracted_knowledge is None assert message_model.chunk_calls == [] - assert message_model.related_calls == [] @pytest.mark.asyncio @@ -434,7 +439,8 @@ async def test_dispatcher_stops_on_sentinel_and_emits_result_sentinel() -> None: @pytest.mark.asyncio -async def test_dispatcher_lowers_stop_and_skips_later_messages() -> None: +async def test_dispatcher_extraction_failure_lowers_stop() -> None: + """A Failure from the extractor sets error and lowers stop_at_message_id.""" chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( asyncio.Queue() @@ -461,7 +467,9 @@ async def test_dispatcher_lowers_stop_and_skips_later_messages() -> None: ) await chunk_queue.put(None) - extractor = _SequenceExtractor([typechat.Failure("first failed")]) + extractor = _SequenceExtractor( + [typechat.Failure("first failed"), typechat.Success(_empty_knowledge())] + ) model = _StubEmbeddingModel() await _dispatcher_task( @@ -478,12 +486,12 @@ async def test_dispatcher_lowers_stop_and_skips_later_messages() -> None: second = items[1] assert first is not None - assert first.error is not None - assert "Knowledge extraction failed" in str(first.error) + assert isinstance(first.error, RuntimeError) + assert "first failed" in str(first.error) + # Second chunk is skipped because stop_at_message_id was lowered to 0. assert second is not None assert second.error is not None - assert "Chunk skipped because stop_at_message_id is 0" in str(second.error) assert stop_state.stop_at_message_id == 0 assert extractor.calls == ["first"] diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py index 9a8b1fd6..cd1f15e2 100644 --- a/tests/test_add_messages_streaming.py +++ b/tests/test_add_messages_streaming.py @@ -3,7 +3,6 @@ """Tests for add_messages_streaming.""" -import asyncio from collections.abc import AsyncIterator import os import tempfile @@ -14,8 +13,9 @@ from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.add_messages import add_messages_streaming from typeagent.knowpro.convsettings import ConversationSettings -from typeagent.knowpro.interfaces_core import AddMessagesResult, IKnowledgeExtractor +from typeagent.knowpro.interfaces_core import IKnowledgeExtractor from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( Transcript, @@ -132,7 +132,7 @@ async def test_streaming_basic() -> None: transcript, storage = await _create_transcript(db_path) msgs = [_make_message(f"msg-{i}") for i in range(5)] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 5 assert await transcript.messages.size() == 5 @@ -148,8 +148,8 @@ async def test_streaming_batching() -> None: transcript, storage = await _create_transcript(db_path) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] - result = await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3 ) # 3 batches: [0,1,2], [3,4,5], [6] @@ -169,7 +169,7 @@ async def test_streaming_no_source_id_always_ingested() -> None: transcript, storage = await _create_transcript(db_path) msgs = [_make_message(f"msg-{i}") for i in range(3)] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 3 assert _ingested_count(storage) == 0 # no source IDs to track @@ -178,11 +178,11 @@ async def test_streaming_no_source_id_always_ingested() -> None: @pytest.mark.asyncio -async def test_streaming_records_chunk_failures() -> None: - """Extraction Failure results are recorded, not raised.""" +async def test_streaming_extraction_failure_stops_at_failing_message() -> None: + """Extraction Failure raises and stops processing; messages before the failure are committed.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - extractor = ControlledExtractor(fail_on={1}) # second chunk fails + extractor = ControlledExtractor(fail_on={1}) # second message fails transcript, storage = await _create_transcript( db_path, auto_extract=True, knowledge_extractor=extractor ) @@ -192,16 +192,10 @@ async def test_streaming_records_chunk_failures() -> None: _make_message("bad chunk 1"), _make_message("good chunk 2"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + with pytest.raises(RuntimeError): + await add_messages_streaming(transcript, _async_iter(msgs)) - assert result.messages_added == 3 - assert _failure_count(storage) == 1 - - failures = await storage.get_chunk_failures() - assert len(failures) == 1 - assert failures[0].message_ordinal == 1 - assert failures[0].chunk_ordinal == 0 - assert "Extraction failed" in failures[0].error_message + assert await transcript.messages.size() == 1 await storage.close() @@ -219,14 +213,8 @@ async def test_streaming_exception_stops_run() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - with pytest.raises(ExceptionGroup) as exc_info: - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) - - # Verify the wrapped exception is our RuntimeError - assert any( - isinstance(e, RuntimeError) and "Systemic failure" in str(e) - for e in exc_info.value.exceptions - ) + with pytest.raises(RuntimeError, match="Systemic failure"): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) # First batch (3 messages, 3 extract calls 0-2) committed assert await transcript.messages.size() == 3 @@ -242,7 +230,7 @@ async def test_streaming_empty_iterable() -> None: db_path = os.path.join(tmpdir, "test.db") transcript, storage = await _create_transcript(db_path) - result = await transcript.add_messages_streaming(_async_iter([])) + result = await add_messages_streaming(transcript, _async_iter([])) assert result.messages_added == 0 assert result.semrefs_added == 0 @@ -265,7 +253,8 @@ async def test_streaming_on_batch_committed_fires_per_batch() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -289,8 +278,8 @@ async def test_streaming_extraction_with_multiple_batches() -> None: ) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - result = await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3 ) assert result.messages_added == 6 @@ -304,27 +293,21 @@ async def test_streaming_extraction_with_multiple_batches() -> None: @pytest.mark.asyncio async def test_streaming_extraction_failure_across_batches() -> None: - """Extraction failures are recorded with correct global ordinals across batches.""" + """Extraction failure in a later batch leaves earlier batches committed.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - # Fail on call index 1 (batch 0, msg 1) and 4 (batch 1, msg 1) - extractor = ControlledExtractor(fail_on={1, 4}) + # Fail on call index 3 (first message of second batch) + extractor = ControlledExtractor(fail_on={3}) transcript, storage = await _create_transcript( db_path, auto_extract=True, knowledge_extractor=extractor ) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - result = await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 - ) - - assert result.messages_added == 6 - assert _failure_count(storage) == 2 + with pytest.raises(RuntimeError): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) - failures = await storage.get_chunk_failures() - failure_ordinals = sorted(f.message_ordinal for f in failures) - # msg 1 in batch 0 → global ordinal 1, msg 1 in batch 1 → global ordinal 4 - assert failure_ordinals == [1, 4] + # First batch (messages 0-2) committed; second batch stopped at message 3. + assert await transcript.messages.size() == 3 await storage.close() @@ -341,13 +324,8 @@ async def test_streaming_exception_in_later_batch_preserves_earlier() -> None: ) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - with pytest.raises(ExceptionGroup) as exc_info: - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) - - assert any( - isinstance(e, RuntimeError) and "Systemic failure" in str(e) - for e in exc_info.value.exceptions - ) + with pytest.raises(RuntimeError, match="Systemic failure"): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) # Batch 0 committed (3 messages), batch 1 rolled back assert await transcript.messages.size() == 3 @@ -429,7 +407,7 @@ async def test_streaming_extraction_with_empty_text_chunks() -> None: ), _make_message("has content", source_id="has-content"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 2 # Only the message with content triggers extraction @@ -470,7 +448,7 @@ async def test_streaming_multi_chunk_extraction() -> None: _make_multi_chunk_message(["c0", "c1", "c2"], source_id="s-0"), _make_message("single chunk", source_id="s-1"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 2 assert result.chunks_added == 4 # 3 + 1 @@ -492,7 +470,8 @@ async def test_streaming_batch_size_counts_chunks() -> None: _make_message("d", source_id="s-1"), # 1 chunk ] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -519,7 +498,8 @@ async def test_streaming_large_message_exceeds_batch_size() -> None: _make_message("small", source_id="s-small"), ] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -549,7 +529,8 @@ async def test_streaming_mixed_chunk_sizes_batching() -> None: _make_message("e", source_id="s-4"), # 1 chunk, total=3 → flush ] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -564,32 +545,25 @@ async def test_streaming_mixed_chunk_sizes_batching() -> None: @pytest.mark.asyncio -async def test_streaming_multi_chunk_failure_ordinals() -> None: - """Extraction failures in multi-chunk messages record correct ordinals.""" +async def test_streaming_multi_chunk_failure_stops_message() -> None: + """Extraction failure in a chunk stops that message and all later ones.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - # Fail on call index 1 (chunk 1 of first message) and 3 (chunk 0 of second message) - extractor = ControlledExtractor(fail_on={1, 3}) + # Fail on call index 1 (chunk 1 of first message) + extractor = ControlledExtractor(fail_on={1}) transcript, storage = await _create_transcript( db_path, auto_extract=True, knowledge_extractor=extractor ) msgs = [ - _make_multi_chunk_message( - ["c0", "c1", "c2"], source_id="s-0" - ), # calls 0,1,2 - _make_multi_chunk_message(["d0", "d1"], source_id="s-1"), # calls 3,4 + _make_multi_chunk_message(["c0", "c1", "c2"], source_id="s-0"), + _make_multi_chunk_message(["d0", "d1"], source_id="s-1"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + with pytest.raises(RuntimeError): + await add_messages_streaming(transcript, _async_iter(msgs)) - assert result.messages_added == 2 - assert extractor.call_count == 5 - assert _failure_count(storage) == 2 - - failures = await storage.get_chunk_failures() - failure_locs = sorted((f.message_ordinal, f.chunk_ordinal) for f in failures) - # call 1 → msg 0, chunk 1; call 3 → msg 1, chunk 0 - assert failure_locs == [(0, 1), (1, 0)] + # Message 0 had a chunk failure so nothing is committed. + assert await transcript.messages.size() == 0 await storage.close() @@ -610,8 +584,8 @@ async def test_streaming_multi_chunk_exception_preserves_earlier_batch() -> None _make_multi_chunk_message(["d", "e"], source_id="s-1"), # batch 2 ] - with pytest.raises(ExceptionGroup): - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + with pytest.raises(RuntimeError, match="Systemic failure"): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) # Batch 1 committed (1 message, 3 chunks), batch 2 rolled back assert await transcript.messages.size() == 1 @@ -629,7 +603,8 @@ async def test_streaming_batch_size_1_separates_all() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(4)] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=1, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -660,7 +635,8 @@ async def test_streaming_preflush_avoids_oversized_batch() -> None: for i in range(4) ] batch_chunks: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=10, on_batch_committed=lambda r: batch_chunks.append(r.chunks_added), @@ -674,116 +650,6 @@ async def test_streaming_preflush_avoids_oversized_batch() -> None: await storage.close() -# --------------------------------------------------------------------------- -# Coverage gap tests -# --------------------------------------------------------------------------- - - -class SlowExtractor: - """Extractor that blocks on an event, allowing tests to control timing.""" - - def __init__(self, block_from: int) -> None: - self.call_count = 0 - self.block_from = block_from - self.blocked = asyncio.Event() - self.cancelled = False - - async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: - idx = self.call_count - self.call_count += 1 - if idx >= self.block_from: - self.blocked.set() - try: - await asyncio.sleep(60) - except asyncio.CancelledError: - self.cancelled = True - raise - return typechat.Success(_EMPTY_RESPONSE) - - -@pytest.mark.asyncio -async def test_streaming_pending_extraction_cancelled_on_commit_failure() -> None: - """pending_extraction is cancelled when a prior commit raises during _drain_commit. - - Timeline: - 1. Batch 0: extraction succeeds (calls 0-2, fast), commit task created - (pending_commit = failing_commit) - 2. Batch 1: extraction task created (pending_extraction, calls 3+, slow), - _drain_commit awaits batch 0's pending_commit which raises - 3. except block: pending_extraction (batch 1's) is still in-flight → cancelled - """ - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - # Block extraction starting from call 3 (first call of batch 1) - # so that pending_extraction is still running when the except fires - extractor = SlowExtractor(block_from=3) - transcript, storage = await _create_transcript( - db_path, auto_extract=True, knowledge_extractor=extractor - ) - - async def failing_commit(*args, **kwargs): - raise RuntimeError("Simulated commit failure") - - transcript._commit_batch_streaming = failing_commit # type: ignore[assignment] - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - - with pytest.raises(RuntimeError, match="Simulated commit failure"): - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) - - assert extractor.cancelled - - await storage.close() - - -@pytest.mark.asyncio -async def test_streaming_pending_commit_cancelled_on_iterator_error() -> None: - """pending_commit is cancelled when the message iterator raises. - - After batch 0 is submitted (pending_commit in flight), the async iterator - raises on the next message. The except block must cancel the still-running - pending_commit. - """ - - async def _error_after( - items: list[TranscriptMessage], error_after: int - ) -> AsyncIterator[TranscriptMessage]: - for i, item in enumerate(items): - if i == error_after: - # Yield to event loop so pending tasks start running - await asyncio.sleep(0) - raise ValueError("Iterator error") - yield item - - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - transcript, storage = await _create_transcript(db_path) - - commit_cancelled = False - - async def slow_commit(*args, **kwargs): - nonlocal commit_cancelled - try: - await asyncio.sleep(60) - except asyncio.CancelledError: - commit_cancelled = True - raise - return AddMessagesResult() - - transcript._commit_batch_streaming = slow_commit # type: ignore[assignment] - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - - with pytest.raises(ValueError, match="Iterator error"): - await transcript.add_messages_streaming( - _error_after(msgs, error_after=4), batch_size=3 - ) - - assert commit_cancelled - - await storage.close() - - @pytest.mark.asyncio async def test_streaming_empty_iterator() -> None: """Streaming with an empty iterator returns zeros.""" @@ -793,11 +659,11 @@ async def test_streaming_empty_iterator() -> None: # Ingest one real message, then do a second call with an empty iterator msgs = [_make_message("msg-0", source_id="s-0")] - r1 = await transcript.add_messages_streaming(_async_iter(msgs)) + r1 = await add_messages_streaming(transcript, _async_iter(msgs)) assert r1.messages_added == 1 # Empty iterator → _submit_batch never called with content - r2 = await transcript.add_messages_streaming(_async_iter([])) + r2 = await add_messages_streaming(transcript, _async_iter([])) assert r2.messages_added == 0 assert r2.messages_skipped == 0 @@ -832,7 +698,7 @@ async def test_streaming_extraction_returns_none_for_empty_chunks() -> None: source_id="empty-1", ), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 2 assert result.chunks_added == 0 From 208529f47a7fd11969c46ed44e1ec0978fc4ac4f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 12 May 2026 12:55:18 -0700 Subject: [PATCH 14/42] Fix forward reference and run pyright for 3.12/3.14 --- AGENTS.md | 2 ++ Makefile | 3 ++- src/typeagent/knowpro/add_messages.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index dbac2349..ffca67e2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -86,6 +86,8 @@ please follow these guidelines: * Assume Python 3.12 +* `from __future__ import annotations` is not allowed. + * Always strip trailing spaces * Keep class and type names in `PascalCase` diff --git a/Makefile b/Makefile index 84a609d7..83616532 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,8 @@ format: venv .PHONY: check check: venv - uv run pyright src tests tools examples + uv run pyright --pythonversion 3.12 src tests tools examples + uv run pyright --pythonversion 3.14 src tests tools examples .PHONY: test test: venv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 9d822560..4f403f74 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -127,7 +127,7 @@ async def _producer_task[TMessage: IMessage]( async def _dispatcher_task[TMessage: IMessage]( chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], - result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None], + result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"], stop_state: PipelineStopState, knowledge_extractor: IKnowledgeExtractor, message_embedding_model: IEmbeddingModel, @@ -151,7 +151,7 @@ async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: try: stop_at = stop_state.stop_at_message_id if work_item.chunk_id.message_ordinal >= stop_at: - result: ChunkProcessingResult[TMessage] = ChunkProcessingResult( + result: "ChunkProcessingResult[TMessage]" = ChunkProcessingResult( chunk_id=work_item.chunk_id, chunk_count=work_item.chunk_count, message=work_item.message, From 69fa64b18506da0555756890c1dd95a599970575 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 12 May 2026 15:15:34 -0700 Subject: [PATCH 15/42] Simplify exception aggregation in add_messages_streaming - Replace nested try/except* handling with ExceptionGroup handling. - Preserve producer_state and stop_state exceptions and raise a combined ExceptionGroup when multiple distinct failures occur. - Complete ChunkProcessingResult docstring with all class fields and clarify success semantics. --- src/typeagent/knowpro/add_messages.py | 88 +++++++++++++++++---------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 4f403f74..cd573aff 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -201,12 +201,17 @@ class ChunkProcessingResult[TMessage: IMessage]: """Result of processing a single chunk through extraction and embeddings. Attributes: - chunk_id: Message/chunk location for the processed chunk - extracted_knowledge: Extracted KnowledgeResponse, or None if extraction failed/wasn't run - chunk_embedding: Normalized embedding vector for the message chunk - related_terms: Lowercased related-term texts extracted from knowledge - related_term_embeddings: Embeddings for related_terms (same order) - error: Exception from the first failing operation, or None if successful + chunk_id: Message/chunk location for the processed chunk. + chunk_count: Total number of chunks in the message that owns this chunk. + message: Original message object containing this chunk. + extracted_knowledge: Extracted KnowledgeResponse if extraction succeeded, else None. + chunk_embedding: Normalized embedding vector for the message chunk, or None if extraction or embedding failed. + related_terms: Lowercased, deduplicated related-term texts extracted from knowledge. + related_term_embeddings: Embeddings for related_terms in the same order, or [] when there are no related terms. + error: Exception from the first failing operation, or None if extraction and embedding succeeded. + + The ``success`` property is True only when extraction succeeded, chunk embedding was + generated, related-term embeddings were generated, and no error occurred. """ chunk_id: TextLocation @@ -476,12 +481,12 @@ async def _drain_consecutive_complete(force: bool = False) -> None: ) raise RuntimeError(validation_error) + assert assembly is not None + if item.chunk_count > 0: - assert assembly is not None assembly.chunks[chunk_ordinal] = item if item.error is not None: - assert assembly is not None assembly.has_error = True state.chunk_failures += 1 stop_state.stop_at_message_id = min( @@ -545,36 +550,55 @@ async def _commit_batch( stop_state = PipelineStopState() producer_state = ProducerState(next_message_id=initial_message_id) - async with asyncio.TaskGroup() as tg: - tg.create_task( - _producer_task( - messages, chunk_queue, stop_state, producer_state, result_queue + task_exceptions: list[Exception] = [] + try: + async with asyncio.TaskGroup() as tg: + tg.create_task( + _producer_task( + messages, chunk_queue, stop_state, producer_state, result_queue + ) ) - ) - tg.create_task( - _dispatcher_task( - chunk_queue, - result_queue, - stop_state, - knowledge_extractor, - embedding_model, - concurrency=sem_ref_settings.concurrency, + tg.create_task( + _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + knowledge_extractor, + embedding_model, + concurrency=sem_ref_settings.concurrency, + ) ) - ) - tg.create_task( - _reassembler_task( - result_queue, - stop_state, - first_uncommitted_ordinal=initial_message_id, - target_commit_chunk_count=batch_size, - commit_batch=_commit_batch, + tg.create_task( + _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=initial_message_id, + target_commit_chunk_count=batch_size, + commit_batch=_commit_batch, + ) ) - ) + except ExceptionGroup as eg: + task_exceptions.extend(eg.exceptions) + except Exception as exc: + task_exceptions.append(exc) if producer_state.exception is not None: - raise producer_state.exception + task_exceptions.append(producer_state.exception) if stop_state.exception is not None: - raise stop_state.exception + task_exceptions.append(stop_state.exception) + + if task_exceptions: + distinct_exceptions: list[Exception] = [] + for exc in task_exceptions: + if exc not in distinct_exceptions: + distinct_exceptions.append(exc) + + if len(distinct_exceptions) == 1: + raise distinct_exceptions[0] + raise ExceptionGroup( + "add_messages_streaming failed", + distinct_exceptions, + ) return total From c8e9a671ee6364998d02a29c98034b7678f2a1c8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 09:06:28 -0700 Subject: [PATCH 16/42] A.md update --- AGENTS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index ffca67e2..b292be23 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -90,6 +90,8 @@ please follow these guidelines: * Always strip trailing spaces +* Keep docstrings in sync with code when changing implementation. + * Keep class and type names in `PascalCase` * Use `python_case` for variable/field and function/method names From 36eec7a61acfec4b04593968e200a49811ed8fb0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 11:11:39 -0700 Subject: [PATCH 17/42] Add maxsize=concurrency*2 to Queue(); use only one embedding model (per ConversationSettings) --- src/typeagent/knowpro/add_messages.py | 36 ++++++++++----------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index cd573aff..e8409059 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -130,9 +130,8 @@ async def _dispatcher_task[TMessage: IMessage]( result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"], stop_state: PipelineStopState, knowledge_extractor: IKnowledgeExtractor, - message_embedding_model: IEmbeddingModel, - related_terms_embedding_model: IEmbeddingModel | None = None, - concurrency: int = 4, + embedding_model: IEmbeddingModel, + concurrency: int, ) -> None: """Dispatch chunk work items to bounded per-item worker tasks. @@ -168,8 +167,7 @@ async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: chunk_count=work_item.chunk_count, message=work_item.message, knowledge_extractor=knowledge_extractor, - message_embedding_model=message_embedding_model, - related_terms_embedding_model=related_terms_embedding_model, + embedding_model=embedding_model, ) if result.error is not None: new_stop = min( @@ -272,8 +270,7 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( chunk_count: int, message: TMessage, knowledge_extractor: IKnowledgeExtractor, - message_embedding_model: IEmbeddingModel, - related_terms_embedding_model: IEmbeddingModel | None = None, + embedding_model: IEmbeddingModel, ) -> ChunkProcessingResult[TMessage]: """Process a single text chunk through knowledge extraction and embeddings. @@ -289,9 +286,7 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( chunk_id: Message/chunk location for this chunk chunk_text: Text content of the chunk (stripped) knowledge_extractor: IKnowledgeExtractor instance for LLM extraction - message_embedding_model: Embedding model for chunk text embeddings - related_terms_embedding_model: Optional embedding model for related-term - embeddings. If None, message_embedding_model is used. + embedding_model: Embedding model for both chunk and related-term embeddings Returns: ChunkProcessingResult with knowledge, chunk embedding, related-term @@ -319,15 +314,11 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( result.extracted_knowledge # type: ignore[arg-type] ) - related_model = related_terms_embedding_model or message_embedding_model - # Step 2: Generate embeddings (only if extraction succeeded) try: - result.chunk_embedding = await message_embedding_model.get_embedding_nocache( - chunk_text - ) + result.chunk_embedding = await embedding_model.get_embedding_nocache(chunk_text) if result.related_terms: - rel_embeddings = await related_model.get_embeddings(result.related_terms) + rel_embeddings = await embedding_model.get_embeddings(result.related_terms) result.related_term_embeddings = [e for e in rel_embeddings] else: result.related_term_embeddings = [] @@ -543,9 +534,11 @@ async def _commit_batch( ) _accumulate(result) - chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None] = asyncio.Queue() - result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None] = ( - asyncio.Queue() + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None] = asyncio.Queue( + maxsize=sem_ref_settings.concurrency * 2 + ) + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None] = asyncio.Queue( + maxsize=sem_ref_settings.concurrency * 2 ) stop_state = PipelineStopState() producer_state = ProducerState(next_message_id=initial_message_id) @@ -596,9 +589,6 @@ async def _commit_batch( if len(distinct_exceptions) == 1: raise distinct_exceptions[0] - raise ExceptionGroup( - "add_messages_streaming failed", - distinct_exceptions, - ) + raise ExceptionGroup("add_messages_streaming failed", distinct_exceptions) return total From 8cdbe68fceee827997b18f99c5dfbddc173413ee Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 11:15:01 -0700 Subject: [PATCH 18/42] Oops, fix tests --- tests/test_add_messages_pipeline.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index b88cfe84..7d76f80d 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -213,7 +213,7 @@ async def test_process_chunk_success_with_related_terms() -> None: chunk_count=1, message=_Message(["hello"]), knowledge_extractor=extractor, - message_embedding_model=message_model, + embedding_model=message_model, ) assert result.error is None @@ -239,7 +239,7 @@ async def test_process_chunk_extraction_failure_returns_error() -> None: chunk_count=1, message=_Message(["hello"]), knowledge_extractor=extractor, - message_embedding_model=message_model, + embedding_model=message_model, ) assert isinstance(result.error, RuntimeError) @@ -259,7 +259,7 @@ async def test_process_chunk_extraction_exception_returns_error() -> None: chunk_count=1, message=_Message(["hello"]), knowledge_extractor=extractor, - message_embedding_model=message_model, + embedding_model=message_model, ) assert isinstance(result.error, RuntimeError) @@ -277,7 +277,7 @@ async def test_process_chunk_chunk_embedding_exception_returns_error() -> None: chunk_count=1, message=_Message(["hello"]), knowledge_extractor=extractor, - message_embedding_model=message_model, + embedding_model=message_model, ) assert isinstance(result.error, RuntimeError) @@ -295,7 +295,7 @@ async def test_process_chunk_related_term_embedding_exception_returns_error() -> chunk_count=1, message=_Message(["hello"]), knowledge_extractor=extractor, - message_embedding_model=message_model, + embedding_model=message_model, ) assert isinstance(result.error, RuntimeError) From 83a252ddbe12f0f67e078764e82d990c94f2ac59 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 11:15:19 -0700 Subject: [PATCH 19/42] Good docstring for add_messages_streaming() --- src/typeagent/knowpro/add_messages.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index e8409059..e57e48ae 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -500,7 +500,28 @@ async def add_messages_streaming[TMessage: IMessage]( batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, ) -> AddMessagesResult: - """Add messages using the pipelined extraction+embedding architecture.""" + """Ingest messages through a producer/dispatcher/reassembler pipeline. + + The function preserves message commit order while processing chunk extraction + and embedding concurrently. Batches are committed only for consecutive, + complete, non-failing messages. + + Args: + conv: Conversation receiving the new messages. + messages: Async iterable of messages to ingest. + batch_size: Target number of chunks per commit batch. + on_batch_committed: Optional callback invoked after each committed batch + with that batch's AddMessagesResult. + + Returns: + AddMessagesResult aggregating all committed batches. + + Raises: + Exception: If a single failure occurs during production, processing, + reassembly, or commit. + ExceptionGroup: If multiple distinct failures are observed across + pipeline stages. + """ from . import convknowledge settings = conv.settings From bf55d96c740703bd7f884413f9aaebaffbee3267 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 11:24:05 -0700 Subject: [PATCH 20/42] Fix two more docstrings --- src/typeagent/knowpro/add_messages.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index e57e48ae..28cf63d8 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -279,14 +279,16 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( Extraction runs first; if it fails, embedding work is skipped. - Chunk embeddings are computed uncached, while related-term embeddings are - computed using cache-aware model calls. + Chunk embeddings are computed uncached; related-term embeddings use + cache-aware model calls on the same embedding model. Args: - chunk_id: Message/chunk location for this chunk - chunk_text: Text content of the chunk (stripped) - knowledge_extractor: IKnowledgeExtractor instance for LLM extraction - embedding_model: Embedding model for both chunk and related-term embeddings + chunk_id: Message/chunk location for this chunk. + chunk_text: Text content of the chunk (stripped). + chunk_count: Total number of chunks in the message. + message: Original message object containing this chunk. + knowledge_extractor: IKnowledgeExtractor instance for LLM extraction. + embedding_model: Embedding model for chunk and related-term embeddings. Returns: ChunkProcessingResult with knowledge, chunk embedding, related-term From ae844303090ef5da9023abb01199fb12f2b75793 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 11:31:58 -0700 Subject: [PATCH 21/42] Eliminate unused 'success' property; rename {_,}NoOpKnowledgeExtractor --- src/typeagent/knowpro/add_messages.py | 18 +++++------------- tests/test_add_messages_pipeline.py | 1 - 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 28cf63d8..47d47a39 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -16,6 +16,8 @@ from .interfaces import AddMessagesResult from .interfaces_core import IKnowledgeExtractor, IMessage, MessageOrdinal, TextLocation +__all__ = ["add_messages_streaming"] + if TYPE_CHECKING: from .conversation_base import ConversationBase @@ -26,7 +28,7 @@ ) -class _NoOpKnowledgeExtractor: +class NoOpKnowledgeExtractor: """No-op extractor used when auto_extract_knowledge is False.""" async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: @@ -221,17 +223,6 @@ class ChunkProcessingResult[TMessage: IMessage]: related_term_embeddings: list[NormalizedEmbedding] | None = None error: Exception | None = None - @property - def success(self) -> bool: - """True if both extraction and embedding succeeded.""" - return ( - self.extracted_knowledge is not None - and self.chunk_embedding is not None - and self.related_terms is not None - and self.related_term_embeddings is not None - and self.error is None - ) - def _collect_related_terms_for_fuzzy_index( knowledge: kplib.KnowledgeResponse, @@ -264,6 +255,7 @@ def _add_term(term: str) -> None: return related_terms +# "Public", imported by tests async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( chunk_id: TextLocation, chunk_text: str, @@ -534,7 +526,7 @@ async def add_messages_streaming[TMessage: IMessage]( sem_ref_settings.knowledge_extractor or convknowledge.KnowledgeExtractor() ) else: - knowledge_extractor = _NoOpKnowledgeExtractor() + knowledge_extractor = NoOpKnowledgeExtractor() embedding_model = settings.embedding_model initial_message_id: MessageOrdinal = await conv.messages.size() diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index 7d76f80d..dac2850f 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -217,7 +217,6 @@ async def test_process_chunk_success_with_related_terms() -> None: ) assert result.error is None - assert result.success assert result.extracted_knowledge is not None assert result.chunk_embedding is not None assert result.related_terms is not None From 1dea32b30f985294cca4827d0cf0fac381060cda Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 12:58:53 -0700 Subject: [PATCH 22/42] Swap in new add_messages_streaming. Update message text index when messages are added. --- src/typeagent/knowpro/conversation_base.py | 232 +------------------- src/typeagent/podcasts/podcast.py | 8 - src/typeagent/storage/memory/collections.py | 27 +++ src/typeagent/storage/memory/provider.py | 6 +- src/typeagent/transcripts/transcript.py | 12 +- 5 files changed, 39 insertions(+), 246 deletions(-) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 03cd400c..3d44f9d9 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -3,9 +3,7 @@ """Base class for conversations with incremental indexing support.""" -import asyncio from collections.abc import AsyncIterable, Callable, Sequence -import contextlib from dataclasses import dataclass from datetime import datetime, timezone from typing import Generic, Protocol, Self, TypeVar @@ -41,7 +39,6 @@ Topic, ) from .interfaces_core import TextLocation -from .knowledge import extract_knowledge_from_text_batch from .messageutils import get_all_message_chunk_locations TMessage = TypeVar("TMessage", bound=IMessage) @@ -58,15 +55,6 @@ class _ChunkCommitResult(Protocol): related_term_embeddings: list[NormalizedEmbedding] | None -@dataclass(frozen=True) -class _ExtractionResult: - """Pre-extracted knowledge for a batch, ready to commit.""" - - messages: Sequence[IMessage] - text_locations: list[TextLocation] - knowledge_results: list[typechat.Result[kplib.KnowledgeResponse]] - - @dataclass(init=False) class ConversationBase( Generic[TMessage], IConversation[TMessage, ITermToSemanticRefIndex] @@ -230,215 +218,16 @@ async def add_messages_streaming( batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, ) -> AddMessagesResult: - """Add messages from an async iterable, committing in batches. - - Uses a two-stage pipeline: while batch N is being committed (DB writes, - embeddings, secondary indexes), batch N+1's LLM extraction runs - concurrently. LLM extraction is typically 95% of wall time, so this - nearly doubles throughput for multi-batch ingestions. - - **Source-ID tracking**: each message's ``source_id`` (if not ``None``) - is marked as ingested within the commit transaction. Callers are - responsible for filtering duplicates before yielding messages. + """Delegate to the pipelined add_messages implementation.""" + from . import add_messages - **Extraction failures**: when knowledge extraction returns a - ``Failure`` for a chunk, the failure is recorded via - ``storage.record_chunk_failure`` and processing continues with the - remaining chunks. Raised exceptions (HTTP errors, timeouts, etc.) - are treated as systemic and stop the run immediately — the current - batch is rolled back and the exception propagates. - - Args: - messages: An async iterable of messages to ingest. - batch_size: Target number of text chunks per commit batch. - Messages are never split across batches, so the actual - chunk count may exceed ``batch_size`` if a single message - has more chunks than that. - on_batch_committed: Optional callback invoked after each batch is - committed, receiving the batch's ``AddMessagesResult``. - - Returns: - Cumulative ``AddMessagesResult`` across all committed batches. - """ - storage = await self.settings.get_storage_provider() - should_extract = ( - self.settings.semantic_ref_index_settings.auto_extract_knowledge - ) - total = AddMessagesResult() - - def _accumulate(result: AddMessagesResult) -> None: - total.messages_added += result.messages_added - total.semrefs_added += result.semrefs_added - total.chunks_added += result.chunks_added - if on_batch_committed: - on_batch_committed(result) - - pending_commit: asyncio.Task[AddMessagesResult] | None = None - pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None - - async def _drain_commit() -> None: - nonlocal pending_commit - if pending_commit is not None: - _accumulate(await pending_commit) - pending_commit = None - - async def _submit_batch(batch: list[TMessage]) -> None: - nonlocal pending_commit, pending_extraction - if not batch: - return - - if should_extract: - next_extraction = asyncio.create_task( - self._extract_knowledge_for_batch(batch) - ) - else: - next_extraction = None - pending_extraction = next_extraction - - await _drain_commit() - - extraction = await next_extraction if next_extraction is not None else None - pending_extraction = None - - pending_commit = asyncio.create_task( - self._commit_batch_streaming(storage, batch, extraction) - ) - - try: - batch: list[TMessage] = [] - batch_chunks = 0 - async for msg in messages: - msg_chunks = len(msg.text_chunks) - if batch and batch_chunks + msg_chunks > batch_size: - await _submit_batch(batch) - batch = [] - batch_chunks = 0 - batch.append(msg) - batch_chunks += msg_chunks - if batch_chunks >= batch_size: - await _submit_batch(batch) - batch = [] - batch_chunks = 0 - - if batch: - await _submit_batch(batch) - - await _drain_commit() - except BaseException: - if pending_extraction is not None and not pending_extraction.done(): - pending_extraction.cancel() - with contextlib.suppress(asyncio.CancelledError): - await pending_extraction - if pending_commit is not None and not pending_commit.done(): - pending_commit.cancel() - with contextlib.suppress(asyncio.CancelledError): - await pending_commit - raise - - return total - - async def _extract_knowledge_for_batch( - self, - messages: list[TMessage], - ) -> _ExtractionResult | None: - """Run LLM extraction on message texts — no DB access. - - Uses 0-based ordinals; the caller remaps to global ordinals at commit - time. Safe to run concurrently with a DB transaction on another batch. - """ - text_locations = get_all_message_chunk_locations(messages, 0) - if not text_locations: - return None - - settings = self.settings.semantic_ref_index_settings - knowledge_extractor = ( - settings.knowledge_extractor or convknowledge.KnowledgeExtractor() - ) - - text_batch = [ - messages[tl.message_ordinal].text_chunks[tl.chunk_ordinal].strip() - for tl in text_locations - ] - - knowledge_results = await extract_knowledge_from_text_batch( - knowledge_extractor, - text_batch, - settings.concurrency, - ) - return _ExtractionResult( - messages=messages, - text_locations=text_locations, - knowledge_results=knowledge_results, + return await add_messages.add_messages_streaming( + self, + messages, + batch_size=batch_size, + on_batch_committed=on_batch_committed, ) - async def _apply_extraction_results( - self, - storage: IStorageProvider[TMessage], - extraction: _ExtractionResult, - global_message_start: int, - ) -> None: - """Write pre-extracted knowledge into the DB. Must be inside a transaction.""" - bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] - for i, knowledge_result in enumerate(extraction.knowledge_results): - tl = extraction.text_locations[i] - global_msg_ord = tl.message_ordinal + global_message_start - if isinstance(knowledge_result, typechat.Failure): - await storage.record_chunk_failure( - global_msg_ord, - tl.chunk_ordinal, - type(knowledge_result).__name__, - knowledge_result.message[:500], - ) - continue - bulk_items.append( - (global_msg_ord, tl.chunk_ordinal, knowledge_result.value) - ) - if bulk_items: - await semrefindex.add_knowledge_batch_to_semantic_ref_index( - self, bulk_items - ) - - async def _commit_batch_streaming( - self, - storage: IStorageProvider[TMessage], - filtered: list[TMessage], - extraction: _ExtractionResult | None, - ) -> AddMessagesResult: - """Commit a single batch with pre-extracted knowledge.""" - async with storage: - start_points = IndexingStartPoints( - message_count=await self.messages.size(), - semref_count=await self.semantic_refs.size(), - ) - - await self.messages.extend(filtered) - - source_ids = [m.source_id for m in filtered if m.source_id is not None] - if source_ids: - await storage.mark_sources_ingested_batch(source_ids) - - await self._add_metadata_knowledge_incremental(start_points.message_count) - - if extraction is not None: - await self._apply_extraction_results( - storage, extraction, start_points.message_count - ) - - await self._update_secondary_indexes_incremental(start_points) - - await storage.update_conversation_timestamps( - updated_at=datetime.now(timezone.utc) - ) - - messages_added = await self.messages.size() - start_points.message_count - chunks_added = sum(len(m.text_chunks) for m in filtered[:messages_added]) - return AddMessagesResult( - messages_added=messages_added, - chunks_added=chunks_added, - semrefs_added=await self.semantic_refs.size() - - start_points.semref_count, - ) - async def _commit_batch_from_chunk_results( self, storage: IStorageProvider[TMessage], @@ -468,7 +257,6 @@ async def _commit_batch_from_chunk_results( knowledge_items: list[ tuple[MessageOrdinal, int, kplib.KnowledgeResponse] ] = [] - chunk_embeddings: list[NormalizedEmbedding] = [] fuzzy_terms: list[str] = [] fuzzy_term_embeddings: list[NormalizedEmbedding] = [] @@ -482,7 +270,6 @@ async def _commit_batch_from_chunk_results( f"message={result.chunk_id.message_ordinal}, " f"chunk={result.chunk_id.chunk_ordinal}" ) - chunk_embeddings.append(result.chunk_embedding) if result.extracted_knowledge is None: raise ValueError( @@ -526,7 +313,6 @@ async def _commit_batch_from_chunk_results( await self._update_secondary_indexes_incremental_with_embeddings( start_points, messages_batch, - chunk_embeddings, fuzzy_terms, fuzzy_term_embeddings, ) @@ -550,7 +336,6 @@ async def _update_secondary_indexes_incremental_with_embeddings( self, start_points: IndexingStartPoints, new_messages: list[TMessage], - chunk_embeddings: list[NormalizedEmbedding], related_terms: list[str], related_term_embeddings: list[NormalizedEmbedding], ) -> None: @@ -576,9 +361,6 @@ async def _update_secondary_indexes_incremental_with_embeddings( related_term_embeddings, ) - # The message text index is already populated by messages.extend() during - # the commit, so no explicit update is needed here. - async def _add_metadata_knowledge_incremental( self, start_from_message_ordinal: int, diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index 8038ea60..dc53e305 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -15,7 +15,6 @@ from ..knowpro.interfaces import ConversationDataWithIndexes, SemanticRef, Term from ..knowpro.universal_message import ConversationMessage, ConversationMessageMeta from ..storage.memory.convthreads import ConversationThreads -from ..storage.memory.messageindex import MessageTextIndex # Type aliases for backward compatibility PodcastMessage = ConversationMessage @@ -119,16 +118,9 @@ async def deserialize( message_index_data = podcast_data.get("messageIndexData") if message_index_data is not None: secondary_indexes = self._get_secondary_indexes() - # Assert the message index is empty before deserializing assert ( secondary_indexes.message_index is not None ), "Message index should be initialized" - - if isinstance(secondary_indexes.message_index, MessageTextIndex): - index_size = await secondary_indexes.message_index.size() - assert ( - index_size == 0 - ), "Message index must be empty before deserializing" await secondary_indexes.message_index.deserialize(message_index_data) # Don't rebuild aliases/synonyms since they were deserialized from relatedTermsIndexData diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index 8a5b14eb..f7279037 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -8,6 +8,7 @@ from ...knowpro.interfaces import ( ICollection, IMessage, + IMessageTextIndex, MessageOrdinal, SemanticRef, SemanticRefMetadata, @@ -81,3 +82,29 @@ class MemoryMessageCollection[TMessage: IMessage]( MemoryCollection[TMessage, MessageOrdinal] ): """A collection of messages.""" + + def __init__( + self, + items: list[TMessage] | None = None, + message_text_index: IMessageTextIndex[TMessage] | None = None, + ): + super().__init__(items) + self.message_text_index = message_text_index + + async def append(self, item: TMessage) -> None: + msg_id = await self.size() + self.items.append(item) + if self.message_text_index is not None: + await self.message_text_index.add_messages_starting_at(msg_id, [item]) + + async def extend(self, items: Iterable[TMessage]) -> None: + items_list = list(items) + if not items_list: + return + current_size = await self.size() + self.items.extend(items_list) + if self.message_text_index is not None: + await self.message_text_index.add_messages_starting_at( + current_size, + items_list, + ) diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index e697fe01..230b5b0a 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -51,13 +51,15 @@ def __init__( ) -> None: """Create and initialize a MemoryStorageProvider with all indexes.""" self._metadata = metadata or ConversationMetadata() - self._message_collection = MemoryMessageCollection[TMessage]() + self._message_text_index = MessageTextIndex(message_text_settings) + self._message_collection = MemoryMessageCollection[TMessage]( + message_text_index=self._message_text_index + ) self._semantic_ref_collection = MemorySemanticRefCollection() self._conversation_index = TermToSemanticRefIndex() self._property_index = PropertyIndex() self._timestamp_index = TimestampToTextRangeIndex() - self._message_text_index = MessageTextIndex(message_text_settings) self._related_terms_index = RelatedTermsIndex(related_terms_settings) thread_settings = message_text_settings.embedding_index_settings self._conversation_threads = ConversationThreads(thread_settings) diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 08c4fdae..2f889b70 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -14,7 +14,6 @@ from ..knowpro.interfaces import ConversationDataWithIndexes, SemanticRef, Term from ..knowpro.universal_message import ConversationMessage, ConversationMessageMeta from ..storage.memory.convthreads import ConversationThreads -from ..storage.memory.messageindex import MessageTextIndex # Type aliases for backward compatibility TranscriptMessage = ConversationMessage @@ -120,16 +119,9 @@ async def deserialize( message_index_data = transcript_data.get("messageIndexData") if message_index_data is not None: secondary_indexes = self._get_secondary_indexes() - # Assert the message index is empty before deserializing assert ( secondary_indexes.message_index is not None ), "Message index should be initialized" - - if isinstance(secondary_indexes.message_index, MessageTextIndex): - index_size = await secondary_indexes.message_index.size() - assert ( - index_size == 0 - ), "Message index must be empty before deserializing" await secondary_indexes.message_index.deserialize(message_index_data) # Don't rebuild aliases/synonyms since they were deserialized from relatedTermsIndexData @@ -180,9 +172,7 @@ def _read_conversation_data_from_file( @staticmethod async def read_from_file( - filename_prefix: str, - settings: ConversationSettings, - dbname: str | None = None, + filename_prefix: str, settings: ConversationSettings, dbname: str | None = None ) -> "Transcript": data = Transcript._read_conversation_data_from_file(filename_prefix) From 96e6934b218e72189c12479fcf4577bdfbfeecc9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 14:16:54 -0700 Subject: [PATCH 23/42] Add skip_failed_messages flag and use in ingest_email.py --- src/typeagent/knowpro/add_messages.py | 95 ++++++++++++++++---- src/typeagent/knowpro/conversation_base.py | 2 + tests/test_add_messages_pipeline.py | 100 +++++++++++++++++++++ tools/ingest_email.py | 1 + 4 files changed, 180 insertions(+), 18 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 47d47a39..001d8974 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -134,6 +134,7 @@ async def _dispatcher_task[TMessage: IMessage]( knowledge_extractor: IKnowledgeExtractor, embedding_model: IEmbeddingModel, concurrency: int, + skip_failed_messages: bool, ) -> None: """Dispatch chunk work items to bounded per-item worker tasks. @@ -145,6 +146,10 @@ async def _dispatcher_task[TMessage: IMessage]( tasks run simultaneously. Chunks at or beyond ``stop_at_message_id`` are skipped and reported as error results so the reassembler can account for them deterministically. + + Args: + skip_failed_messages: If True, don't halt producer on extraction/embedding + failures; continue processing. If False, halt on first failure. """ sem = asyncio.Semaphore(concurrency) @@ -172,15 +177,20 @@ async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: embedding_model=embedding_model, ) if result.error is not None: - new_stop = min( - stop_state.stop_at_message_id, - work_item.chunk_id.message_ordinal, - ) - if new_stop < stop_state.stop_at_message_id: - stop_state.stop_at_message_id = new_stop - stop_state.exception = result.error - elif stop_state.exception is None: - stop_state.exception = result.error + if skip_failed_messages: + print( + f"Skipping message {work_item.chunk_id.message_ordinal} " + f"due to extraction/embedding error: {result.error}" + ) + else: + new_stop = min( + stop_state.stop_at_message_id, + work_item.chunk_id.message_ordinal, + ) + if new_stop < stop_state.stop_at_message_id: + stop_state.stop_at_message_id = new_stop + if stop_state.exception is None: + stop_state.exception = result.error await result_queue.put(result) finally: sem.release() @@ -346,6 +356,7 @@ class ReassemblerResult: messages_committed: int = 0 chunks_committed: int = 0 chunk_failures: int = 0 + messages_skipped: int = 0 buffered_messages: int = 0 @@ -358,12 +369,17 @@ async def _reassembler_task[TMessage: IMessage]( [list[TMessage], list[ChunkProcessingResult[TMessage]]], Awaitable[None] ], on_batch_committed: Callable[[int, int], None] | None = None, + skip_failed_messages: bool = False, ) -> ReassemblerResult: """Reassemble chunks into messages and commit only consecutive complete ones. This stage consumes worker results until it sees a ``None`` sentinel. It never commits out-of-order messages: if message N is incomplete or failed, messages N+1 and later remain buffered. + + Args: + skip_failed_messages: If True, skip messages that fail extraction/embedding + and continue processing. If False, halt processing on first failure. """ state = ReassemblerResult(first_uncommitted_ordinal=first_uncommitted_ordinal) assemblies: dict[MessageOrdinal, MessageAssembly[TMessage]] = {} @@ -393,9 +409,34 @@ async def _drain_consecutive_complete(force: bool = False) -> None: nonlocal staged_chunks while True: assembly = assemblies.get(state.first_uncommitted_ordinal) - if assembly is None or not assembly.is_complete() or assembly.has_error: + if assembly is None: + await _commit_if_needed(force) + return + if not assembly.is_complete(): await _commit_if_needed(force) return + if assembly.has_error: + if skip_failed_messages: + # Skip this failed message and continue + # Find the error from one of the chunks for logging + error_msg = "Unknown error" + for chunk_result in assembly.chunks.values(): + if chunk_result.error is not None: + error_msg = str(chunk_result.error) + break + print( + f"Skipping message {state.first_uncommitted_ordinal} " + f"due to chunk processing error: {error_msg}" + ) + del assemblies[state.first_uncommitted_ordinal] + state.first_uncommitted_ordinal += 1 + state.messages_skipped += 1 + # Continue to check the next message + continue + else: + # Stop at failed message; halt processing + await _commit_if_needed(force) + return # Pre-flush: if staging this message would push staged_chunks past # the target, commit the current batch first. @@ -474,9 +515,10 @@ async def _drain_consecutive_complete(force: bool = False) -> None: if item.error is not None: assembly.has_error = True state.chunk_failures += 1 - stop_state.stop_at_message_id = min( - stop_state.stop_at_message_id, message_id - ) + if not skip_failed_messages: + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, message_id + ) await _drain_consecutive_complete() finally: @@ -493,6 +535,7 @@ async def add_messages_streaming[TMessage: IMessage]( *, batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + skip_failed_messages: bool = False, ) -> AddMessagesResult: """Ingest messages through a producer/dispatcher/reassembler pipeline. @@ -506,15 +549,19 @@ async def add_messages_streaming[TMessage: IMessage]( batch_size: Target number of chunks per commit batch. on_batch_committed: Optional callback invoked after each committed batch with that batch's AddMessagesResult. + skip_failed_messages: If True, skip messages that fail extraction or + embedding and continue processing. If False (default), halt on + first failure and raise an exception. Returns: - AddMessagesResult aggregating all committed batches. + AddMessagesResult aggregating all committed batches, including count + of messages_skipped when skip_failed_messages is True. Raises: Exception: If a single failure occurs during production, processing, - reassembly, or commit. + reassembly, or commit (when skip_failed_messages is False). ExceptionGroup: If multiple distinct failures are observed across - pipeline stages. + pipeline stages (when skip_failed_messages is False). """ from . import convknowledge @@ -559,6 +606,7 @@ async def _commit_batch( producer_state = ProducerState(next_message_id=initial_message_id) task_exceptions: list[Exception] = [] + reassembler_task: asyncio.Task[ReassemblerResult] | None = None try: async with asyncio.TaskGroup() as tg: tg.create_task( @@ -574,15 +622,17 @@ async def _commit_batch( knowledge_extractor, embedding_model, concurrency=sem_ref_settings.concurrency, + skip_failed_messages=skip_failed_messages, ) ) - tg.create_task( + reassembler_task = tg.create_task( _reassembler_task( result_queue, stop_state, first_uncommitted_ordinal=initial_message_id, target_commit_chunk_count=batch_size, commit_batch=_commit_batch, + skip_failed_messages=skip_failed_messages, ) ) except ExceptionGroup as eg: @@ -593,7 +643,7 @@ async def _commit_batch( if producer_state.exception is not None: task_exceptions.append(producer_state.exception) - if stop_state.exception is not None: + if stop_state.exception is not None and not skip_failed_messages: task_exceptions.append(stop_state.exception) if task_exceptions: @@ -606,4 +656,13 @@ async def _commit_batch( raise distinct_exceptions[0] raise ExceptionGroup("add_messages_streaming failed", distinct_exceptions) + # Collect messages_skipped from reassembler result if skip_failed_messages is True + if skip_failed_messages and reassembler_task is not None: + try: + reassembler_result = reassembler_task.result() + total.messages_skipped = reassembler_result.messages_skipped + except Exception: + # reassembler_task result may not be available if task group failed + pass + return total diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 3d44f9d9..6d67bf88 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -217,6 +217,7 @@ async def add_messages_streaming( *, batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + skip_failed_messages: bool = False, ) -> AddMessagesResult: """Delegate to the pipelined add_messages implementation.""" from . import add_messages @@ -226,6 +227,7 @@ async def add_messages_streaming( messages, batch_size=batch_size, on_batch_committed=on_batch_committed, + skip_failed_messages=skip_failed_messages, ) async def _commit_batch_from_chunk_results( diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index dac2850f..73261d7d 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -428,6 +428,7 @@ async def test_dispatcher_stops_on_sentinel_and_emits_result_sentinel() -> None: extractor, model, concurrency=2, + skip_failed_messages=False, ) items = await _drain_result_queue(result_queue) @@ -478,6 +479,7 @@ async def test_dispatcher_extraction_failure_lowers_stop() -> None: extractor, model, concurrency=1, + skip_failed_messages=False, ) items = await _drain_result_queue(result_queue) @@ -496,6 +498,64 @@ async def test_dispatcher_extraction_failure_lowers_stop() -> None: assert extractor.calls == ["first"] +@pytest.mark.asyncio +async def test_dispatcher_extraction_failure_skips_and_keeps_processing() -> None: + chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + m0 = _Message(["first"]) + m1 = _Message(["second"]) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(0, 0), + chunk_count=1, + chunk_text="first", + message=m0, + ) + ) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(1, 0), + chunk_count=1, + chunk_text="second", + message=m1, + ) + ) + await chunk_queue.put(None) + + extractor = _SequenceExtractor( + [typechat.Failure("first failed"), typechat.Success(_empty_knowledge())] + ) + model = _StubEmbeddingModel() + + await _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + extractor, + model, + concurrency=1, + skip_failed_messages=True, + ) + + items = await _drain_result_queue(result_queue) + first = items[0] + second = items[1] + + assert first is not None + assert isinstance(first.error, RuntimeError) + assert "first failed" in str(first.error) + + assert second is not None + assert second.error is None + + assert stop_state.stop_at_message_id == 10**100 + assert extractor.calls == ["first", "second"] + + def _chunk_result( message: _Message, message_ordinal: int, @@ -585,6 +645,46 @@ async def _commit( assert stop_state.stop_at_message_id == 0 +@pytest.mark.asyncio +async def test_reassembler_skips_failed_message_and_commits_later_messages() -> None: + result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( + asyncio.Queue() + ) + stop_state = PipelineStopState() + + m0 = _Message(["m0"]) + m1 = _Message(["m1"]) + await result_queue.put(_chunk_result(m0, 0, 0, 1, error=RuntimeError("boom"))) + await result_queue.put(_chunk_result(m1, 1, 0, 1)) + await result_queue.put(None) + + committed_batches: list[tuple[int, int]] = [] + + async def _commit( + messages: list[_Message], + results: list[ChunkProcessingResult[_Message]], + ) -> None: + committed_batches.append((len(messages), len(results))) + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=True, + ) + + assert committed_batches == [(1, 1)] + assert state.chunk_failures == 1 + assert state.messages_skipped == 1 + assert state.messages_committed == 1 + assert state.chunks_committed == 1 + assert state.buffered_messages == 0 + assert state.first_uncommitted_ordinal == 2 + assert stop_state.stop_at_message_id == 10**100 + + @pytest.mark.asyncio async def test_reassembler_force_commits_small_staged_tail() -> None: result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( diff --git a/tools/ingest_email.py b/tools/ingest_email.py index 34c4c6ce..6b601c41 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -476,6 +476,7 @@ def on_batch_committed(result: AddMessagesResult) -> None: message_stream, batch_size=batch_size, on_batch_committed=on_batch_committed, + skip_failed_messages=True, ) except (KeyboardInterrupt, asyncio.CancelledError): interrupted = True From 927fdae93e34e668921b9899e127485b7e471a5f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 14:28:20 -0700 Subject: [PATCH 24/42] Print chunk summaries after clipping --- tools/ingest_email.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tools/ingest_email.py b/tools/ingest_email.py index 6b601c41..3e6c27f2 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -273,7 +273,7 @@ def _iter_emails( yield str(email_file.resolve()), email_file, label -def _print_email_verbose(email: EmailMessage) -> None: +def _print_email_verbose(email: EmailMessage, original_chunk_count: int | None = None) -> None: """Print verbose details for an email.""" print(f" From: {decode_encoded_words(email.metadata.sender)}") if email.metadata.recipients: @@ -289,7 +289,10 @@ def _print_email_verbose(email: EmailMessage) -> None: f" Subject: {decode_encoded_words(email.metadata.subject).replace('\n', '\\n')}" ) print(f" Date: {email.timestamp}") - print(f" Body chunks: {len(email.text_chunks)}") + if original_chunk_count is not None and original_chunk_count != len(email.text_chunks): + print(f" Body chunks: {len(email.text_chunks)} (clipped from {original_chunk_count})") + else: + print(f" Body chunks: {len(email.text_chunks)}") MAIL_PREVIEW_LEN = 80 for chunk in email.text_chunks: preview = repr(chunk[: MAIL_PREVIEW_LEN + 1])[1:-1] @@ -341,15 +344,14 @@ async def _email_generator( print(f"{label} [Outside date range, skipping]") continue - if verbose: - print(label) - _print_email_verbose(email) - # Truncate chunks if --max-chunks is set + original_chunk_count: int | None = None if max_chunks is not None and len(email.text_chunks) > max_chunks: - if verbose: - print(f" Truncating {len(email.text_chunks)} chunks to {max_chunks}") + original_chunk_count = len(email.text_chunks) email.text_chunks = email.text_chunks[:max_chunks] + if verbose: + print(label) + _print_email_verbose(email, original_chunk_count) # Set source_id so streaming API handles dedup and tracking email.source_id = source_id @@ -452,10 +454,10 @@ def on_batch_committed(result: AddMessagesResult) -> None: f"+{result.semrefs_added} semrefs", ] print( - f"{' '.join(parts)} | " + f"\n{' '.join(parts)} | " f"{batch_secs:.1f}s ({per_chunk:.2f}s/chunk) | " f"{counters['ingested']} total ingested | " - f"{elapsed:.1f}s elapsed", + f"{elapsed:.1f}s elapsed\n", flush=True, ) From fd14b95b3f05ce4ce2949aa52d7174e2a3ce36f4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 14:41:17 -0700 Subject: [PATCH 25/42] [Incomplete] Handle ^C --- src/typeagent/knowpro/add_messages.py | 11 +++- src/typeagent/knowpro/conversation_base.py | 3 + tools/ingest_email.py | 73 +++++++++++++++++----- 3 files changed, 70 insertions(+), 17 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 001d8974..402c6310 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -77,6 +77,7 @@ async def _producer_task[TMessage: IMessage]( stop_state: PipelineStopState, producer_state: ProducerState, result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"] | None = None, + shutdown_event: asyncio.Event | None = None, ) -> None: """Read input messages and enqueue chunk work items. @@ -89,6 +90,8 @@ async def _producer_task[TMessage: IMessage]( message_id = producer_state.next_message_id if message_id >= stop_state.stop_at_message_id: break + if shutdown_event is not None and shutdown_event.is_set(): + break chunk_count = len(message.text_chunks) if chunk_count == 0: @@ -536,6 +539,7 @@ async def add_messages_streaming[TMessage: IMessage]( batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, skip_failed_messages: bool = False, + shutdown_event: asyncio.Event | None = None, ) -> AddMessagesResult: """Ingest messages through a producer/dispatcher/reassembler pipeline. @@ -611,7 +615,12 @@ async def _commit_batch( async with asyncio.TaskGroup() as tg: tg.create_task( _producer_task( - messages, chunk_queue, stop_state, producer_state, result_queue + messages, + chunk_queue, + stop_state, + producer_state, + result_queue, + shutdown_event=shutdown_event, ) ) tg.create_task( diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 6d67bf88..d7632dfb 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -3,6 +3,7 @@ """Base class for conversations with incremental indexing support.""" +import asyncio from collections.abc import AsyncIterable, Callable, Sequence from dataclasses import dataclass from datetime import datetime, timezone @@ -218,6 +219,7 @@ async def add_messages_streaming( batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, skip_failed_messages: bool = False, + shutdown_event: asyncio.Event | None = None, ) -> AddMessagesResult: """Delegate to the pipelined add_messages implementation.""" from . import add_messages @@ -228,6 +230,7 @@ async def add_messages_streaming( batch_size=batch_size, on_batch_committed=on_batch_committed, skip_failed_messages=skip_failed_messages, + shutdown_event=shutdown_event, ) async def _commit_batch_from_chunk_results( diff --git a/tools/ingest_email.py b/tools/ingest_email.py index 3e6c27f2..f2eb4e05 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -22,6 +22,7 @@ from collections.abc import AsyncIterator from datetime import datetime from pathlib import Path +import signal import sys import time from typing import Iterable @@ -273,7 +274,9 @@ def _iter_emails( yield str(email_file.resolve()), email_file, label -def _print_email_verbose(email: EmailMessage, original_chunk_count: int | None = None) -> None: +def _print_email_verbose( + email: EmailMessage, original_chunk_count: int | None = None +) -> None: """Print verbose details for an email.""" print(f" From: {decode_encoded_words(email.metadata.sender)}") if email.metadata.recipients: @@ -289,8 +292,12 @@ def _print_email_verbose(email: EmailMessage, original_chunk_count: int | None = f" Subject: {decode_encoded_words(email.metadata.subject).replace('\n', '\\n')}" ) print(f" Date: {email.timestamp}") - if original_chunk_count is not None and original_chunk_count != len(email.text_chunks): - print(f" Body chunks: {len(email.text_chunks)} (clipped from {original_chunk_count})") + if original_chunk_count is not None and original_chunk_count != len( + email.text_chunks + ): + print( + f" Body chunks: {len(email.text_chunks)} (clipped from {original_chunk_count})" + ) else: print(f" Body chunks: {len(email.text_chunks)}") MAIL_PREVIEW_LEN = 80 @@ -473,15 +480,41 @@ def on_batch_committed(result: AddMessagesResult) -> None: result: AddMessagesResult | None = None interrupted = False + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + _sigint_count = 0 + _main_task = asyncio.current_task() + + def _on_sigint() -> None: + nonlocal _sigint_count + _sigint_count += 1 + if _sigint_count == 1: + print( + "\nInterrupt received; stopping after current batch completes " + "(press ^C again to force quit)...", + flush=True, + ) + shutdown_event.set() + else: + print("\nForce quit.", flush=True) + loop.remove_signal_handler(signal.SIGINT) + if _main_task is not None: + _main_task.cancel() + + loop.add_signal_handler(signal.SIGINT, _on_sigint) try: result = await email_memory.add_messages_streaming( message_stream, batch_size=batch_size, on_batch_committed=on_batch_committed, skip_failed_messages=True, + shutdown_event=shutdown_event, ) + interrupted = shutdown_event.is_set() except (KeyboardInterrupt, asyncio.CancelledError): interrupted = True + finally: + loop.remove_signal_handler(signal.SIGINT) # Final summary elapsed = time.time() - start_time @@ -543,20 +576,28 @@ def main() -> None: start_date = _parse_date(args.start_date) if args.start_date else None stop_date = _parse_date(args.stop_date) if args.stop_date else None - asyncio.run( - ingest_emails( - eml_paths=args.paths, - database=args.database, - verbose=args.verbose, - start_date=start_date, - stop_date=stop_date, - offset=args.offset, - limit=args.limit, - concurrency=args.concurrency, - batch_size=args.batch_size, - max_chunks=args.max_chunks, + try: + asyncio.run( + ingest_emails( + eml_paths=args.paths, + database=args.database, + verbose=args.verbose, + start_date=start_date, + stop_date=stop_date, + offset=args.offset, + limit=args.limit, + concurrency=args.concurrency, + batch_size=args.batch_size, + max_chunks=args.max_chunks, + ) ) - ) + except KeyboardInterrupt: + pass + except Exception as exc: + if args.verbose: + raise + print(f"Error: {exc}", file=sys.stderr) + sys.exit(1) if __name__ == "__main__": From 766b655d44bc5e2053cda3f107983bba9f4e2e26 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 May 2026 16:00:04 -0700 Subject: [PATCH 26/42] Make second ^C a hard exit --- AGENTS.md | 1 + tools/ingest_email.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index b292be23..70822bad 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -36,6 +36,7 @@ In all cases show what you added to AGENTS.md. - Don't use '!' on the command line, it's some bash magic (even inside single quotes) - When running 'make' commands, do not use the venv (the Makefile uses 'uv run') +- If a term can refer to OS behavior or repository code behavior (for example, 'force quit'), prefer the in-repo meaning first and verify by searching the code. - To get API keys in ad-hoc code, call `load_dotenv()` - Use `pytest test` to run tests in test/ - Use `pyright` to check type annotations in src/, tools/, tests/, examples/ diff --git a/tools/ingest_email.py b/tools/ingest_email.py index f2eb4e05..0baa407f 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -21,6 +21,7 @@ import asyncio from collections.abc import AsyncIterator from datetime import datetime +import os from pathlib import Path import signal import sys @@ -497,9 +498,9 @@ def _on_sigint() -> None: shutdown_event.set() else: print("\nForce quit.", flush=True) - loop.remove_signal_handler(signal.SIGINT) - if _main_task is not None: - _main_task.cancel() + # Bypass cooperative cancellation on repeated Ctrl+C. + # Some pending async operations may not respond promptly and can hang. + os._exit(130) loop.add_signal_handler(signal.SIGINT, _on_sigint) try: From 35bb11ab03adf1b9841dacfe28e452b270708386 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 14 May 2026 10:16:42 -0700 Subject: [PATCH 27/42] No optional args for add_messages.py helper. Formatted test_add_messages_pipeline.py --- src/typeagent/knowpro/add_messages.py | 22 ++-- tests/test_add_messages_pipeline.py | 144 +++++++++++--------------- 2 files changed, 71 insertions(+), 95 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 402c6310..a1a5f683 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -76,8 +76,8 @@ async def _producer_task[TMessage: IMessage]( chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], stop_state: PipelineStopState, producer_state: ProducerState, - result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"] | None = None, - shutdown_event: asyncio.Event | None = None, + result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"], + shutdown_event: asyncio.Event | None, ) -> None: """Read input messages and enqueue chunk work items. @@ -97,14 +97,13 @@ async def _producer_task[TMessage: IMessage]( if chunk_count == 0: # Zero-chunk message: nothing for the dispatcher to process. # Emit a zero-chunk result directly to the reassembler. - if result_queue is not None: - await result_queue.put( - ChunkProcessingResult[TMessage]( - chunk_id=TextLocation(message_id, 0), - chunk_count=0, - message=message, - ) + await result_queue.put( + ChunkProcessingResult[TMessage]( + chunk_id=TextLocation(message_id, 0), + chunk_count=0, + message=message, ) + ) producer_state.produced_messages += 1 producer_state.next_message_id += 1 continue @@ -371,8 +370,8 @@ async def _reassembler_task[TMessage: IMessage]( commit_batch: Callable[ [list[TMessage], list[ChunkProcessingResult[TMessage]]], Awaitable[None] ], - on_batch_committed: Callable[[int, int], None] | None = None, - skip_failed_messages: bool = False, + on_batch_committed: Callable[[int, int], None] | None, + skip_failed_messages: bool, ) -> ReassemblerResult: """Reassemble chunks into messages and commit only consecutive complete ones. @@ -641,6 +640,7 @@ async def _commit_batch( first_uncommitted_ordinal=initial_message_id, target_commit_chunk_count=batch_size, commit_batch=_commit_batch, + on_batch_committed=None, skip_failed_messages=skip_failed_messages, ) ) diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index 73261d7d..11c36b09 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -12,10 +12,7 @@ import typechat -from typeagent.aitools.embeddings import ( - NormalizedEmbedding, - NormalizedEmbeddings, -) +from typeagent.aitools.embeddings import NormalizedEmbedding, NormalizedEmbeddings from typeagent.knowpro import knowledge_schema as kplib from typeagent.knowpro.add_messages import ( _collect_related_terms_for_fuzzy_index, @@ -50,8 +47,7 @@ def get_knowledge(self) -> kplib.KnowledgeResponse: class _SequenceExtractor: def __init__( - self, - outputs: list[typechat.Result[kplib.KnowledgeResponse] | Exception], + self, outputs: list[typechat.Result[kplib.KnowledgeResponse] | Exception] ) -> None: self._outputs = outputs self.calls: list[str] = [] @@ -312,7 +308,10 @@ async def _iter_messages() -> AsyncIterator[_Message]: for message in messages: yield message - await _producer_task(_iter_messages(), queue, stop_state, producer_state) + result_queue = asyncio.Queue() + await _producer_task( + _iter_messages(), queue, stop_state, producer_state, result_queue, None + ) items: list[ChunkWorkItem[_Message] | None] = [ await queue.get(), @@ -342,7 +341,10 @@ async def _iter_messages() -> AsyncIterator[_Message]: yield _Message(["a"]) yield _Message(["b"]) - await _producer_task(_iter_messages(), queue, stop_state, producer_state) + result_queue = asyncio.Queue() + await _producer_task( + _iter_messages(), queue, stop_state, producer_state, result_queue, None + ) first = await queue.get() sentinel = await queue.get() @@ -359,12 +361,12 @@ async def test_producer_sets_exception_and_still_sends_sentinel() -> None: stop_state = PipelineStopState() producer_state = ProducerState(next_message_id=0) - failing_iter = _FailingAsyncMessages( - [_Message(["a"])], - RuntimeError("input boom"), - ) + failing_iter = _FailingAsyncMessages([_Message(["a"])], RuntimeError("input boom")) - await _producer_task(failing_iter, queue, stop_state, producer_state) + result_queue = asyncio.Queue() + await _producer_task( + failing_iter, queue, stop_state, producer_state, result_queue, None + ) first = await queue.get() sentinel = await queue.get() @@ -388,7 +390,10 @@ async def test_producer_breaks_inside_chunk_loop_when_stop_marker_changes() -> N async def _iter_messages() -> AsyncIterator[_Message]: yield message - await _producer_task(_iter_messages(), queue, stop_state, producer_state) + result_queue = asyncio.Queue() + await _producer_task( + _iter_messages(), queue, stop_state, producer_state, result_queue, None + ) first = await queue.get() sentinel = await queue.get() @@ -402,9 +407,7 @@ async def _iter_messages() -> AsyncIterator[_Message]: @pytest.mark.asyncio async def test_dispatcher_stops_on_sentinel_and_emits_result_sentinel() -> None: chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["hello"]) @@ -442,27 +445,19 @@ async def test_dispatcher_stops_on_sentinel_and_emits_result_sentinel() -> None: async def test_dispatcher_extraction_failure_lowers_stop() -> None: """A Failure from the extractor sets error and lowers stop_at_message_id.""" chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() m0 = _Message(["first"]) m1 = _Message(["second"]) await chunk_queue.put( ChunkWorkItem( - chunk_id=TextLocation(0, 0), - chunk_count=1, - chunk_text="first", - message=m0, + chunk_id=TextLocation(0, 0), chunk_count=1, chunk_text="first", message=m0 ) ) await chunk_queue.put( ChunkWorkItem( - chunk_id=TextLocation(1, 0), - chunk_count=1, - chunk_text="second", - message=m1, + chunk_id=TextLocation(1, 0), chunk_count=1, chunk_text="second", message=m1 ) ) await chunk_queue.put(None) @@ -501,27 +496,19 @@ async def test_dispatcher_extraction_failure_lowers_stop() -> None: @pytest.mark.asyncio async def test_dispatcher_extraction_failure_skips_and_keeps_processing() -> None: chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() m0 = _Message(["first"]) m1 = _Message(["second"]) await chunk_queue.put( ChunkWorkItem( - chunk_id=TextLocation(0, 0), - chunk_count=1, - chunk_text="first", - message=m0, + chunk_id=TextLocation(0, 0), chunk_count=1, chunk_text="first", message=m0 ) ) await chunk_queue.put( ChunkWorkItem( - chunk_id=TextLocation(1, 0), - chunk_count=1, - chunk_text="second", - message=m1, + chunk_id=TextLocation(1, 0), chunk_count=1, chunk_text="second", message=m1 ) ) await chunk_queue.put(None) @@ -574,9 +561,7 @@ def _chunk_result( @pytest.mark.asyncio async def test_reassembler_commits_out_of_order_after_gap_is_filled() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() m0 = _Message(["m0"]) @@ -588,8 +573,7 @@ async def test_reassembler_commits_out_of_order_after_gap_is_filled() -> None: committed_batches: list[tuple[int, int]] = [] async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: committed_batches.append((len(messages), len(results))) @@ -599,6 +583,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=10, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert committed_batches == [(2, 2)] @@ -610,9 +596,7 @@ async def _commit( @pytest.mark.asyncio async def test_reassembler_marks_failure_and_blocks_later_commits() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() m0 = _Message(["m0"]) @@ -624,8 +608,7 @@ async def test_reassembler_marks_failure_and_blocks_later_commits() -> None: commit_calls = 0 async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: nonlocal commit_calls commit_calls += 1 @@ -636,6 +619,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert commit_calls == 0 @@ -647,9 +632,7 @@ async def _commit( @pytest.mark.asyncio async def test_reassembler_skips_failed_message_and_commits_later_messages() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() m0 = _Message(["m0"]) @@ -661,8 +644,7 @@ async def test_reassembler_skips_failed_message_and_commits_later_messages() -> committed_batches: list[tuple[int, int]] = [] async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: committed_batches.append((len(messages), len(results))) @@ -672,6 +654,7 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, + on_batch_committed=None, skip_failed_messages=True, ) @@ -687,9 +670,7 @@ async def _commit( @pytest.mark.asyncio async def test_reassembler_force_commits_small_staged_tail() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["m0"]) @@ -699,8 +680,7 @@ async def test_reassembler_force_commits_small_staged_tail() -> None: commit_calls = 0 async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: nonlocal commit_calls commit_calls += 1 @@ -711,6 +691,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=99, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert commit_calls == 1 @@ -722,9 +704,7 @@ async def _commit( async def test_reassembler_raises_on_invalid_chunk_ordinal_and_sets_stop_marker() -> ( None ): - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["m0", "m0b"]) @@ -732,8 +712,7 @@ async def test_reassembler_raises_on_invalid_chunk_ordinal_and_sets_stop_marker( await result_queue.put(None) async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: return None @@ -744,6 +723,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert stop_state.stop_at_message_id == 3 @@ -751,9 +732,7 @@ async def _commit( @pytest.mark.asyncio async def test_reassembler_raises_on_duplicate_chunk_and_sets_stop_marker() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["m1-a", "m1-b"]) @@ -762,8 +741,7 @@ async def test_reassembler_raises_on_duplicate_chunk_and_sets_stop_marker() -> N await result_queue.put(None) async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: return None @@ -774,6 +752,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert stop_state.stop_at_message_id == 5 @@ -781,9 +761,7 @@ async def _commit( @pytest.mark.asyncio async def test_reassembler_on_batch_committed_callback_is_invoked() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["m0"]) @@ -793,8 +771,7 @@ async def test_reassembler_on_batch_committed_callback_is_invoked() -> None: callback_calls: list[tuple[int, int]] = [] async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: return None @@ -807,6 +784,7 @@ async def _commit( on_batch_committed=lambda msg_count, chunk_count: callback_calls.append( (msg_count, chunk_count) ), + skip_failed_messages=False, ) assert callback_calls == [(1, 1)] @@ -816,9 +794,7 @@ async def _commit( async def test_reassembler_raises_on_mismatched_chunk_count_and_sets_stop_marker() -> ( None ): - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["m0-a", "m0-b", "m0-c"]) @@ -827,8 +803,7 @@ async def test_reassembler_raises_on_mismatched_chunk_count_and_sets_stop_marker await result_queue.put(None) async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: return None @@ -839,6 +814,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert stop_state.stop_at_message_id == 4 @@ -846,9 +823,7 @@ async def _commit( @pytest.mark.asyncio async def test_reassembler_handles_existing_assembly_non_duplicate_chunk() -> None: - result_queue: asyncio.Queue[ChunkProcessingResult[_Message] | None] = ( - asyncio.Queue() - ) + result_queue = asyncio.Queue() stop_state = PipelineStopState() message = _Message(["m0-a", "m0-b"]) @@ -859,8 +834,7 @@ async def test_reassembler_handles_existing_assembly_non_duplicate_chunk() -> No commit_calls = 0 async def _commit( - messages: list[_Message], - results: list[ChunkProcessingResult[_Message]], + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] ) -> None: nonlocal commit_calls commit_calls += 1 @@ -871,6 +845,8 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, + on_batch_committed=None, + skip_failed_messages=False, ) assert commit_calls == 1 From 90f8cf5fddc2f8503a98aa555008bc546608941c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 14 May 2026 10:59:19 -0700 Subject: [PATCH 28/42] Ensure pre-computed chunk embeddings are used and not recomputed --- src/typeagent/knowpro/conversation_base.py | 121 +++++++++++--------- src/typeagent/knowpro/interfaces_storage.py | 11 ++ src/typeagent/storage/memory/collections.py | 21 +++- src/typeagent/storage/sqlite/collections.py | 19 ++- 4 files changed, 109 insertions(+), 63 deletions(-) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index d7632dfb..1c937bfe 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -243,14 +243,78 @@ async def _commit_batch_from_chunk_results( if not messages_batch: return AddMessagesResult() + # Process chunk results first to collect embeddings and knowledge items + knowledge_items: list[tuple[MessageOrdinal, int, kplib.KnowledgeResponse]] = [] + fuzzy_terms: list[str] = [] + fuzzy_term_embeddings: list[NormalizedEmbedding] = [] + chunk_embedding_map: dict[tuple[int, int], NormalizedEmbedding] = {} + + for result in chunk_results: + if result.chunk_count == 0: + continue + + if result.chunk_embedding is None: + raise ValueError( + "Chunk result missing chunk embedding for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + + if result.extracted_knowledge is None: + raise ValueError( + "Chunk result missing extracted knowledge for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + knowledge_items.append( + ( + result.chunk_id.message_ordinal, + result.chunk_id.chunk_ordinal, + result.extracted_knowledge, + ) + ) + + if result.related_terms is None or result.related_term_embeddings is None: + raise ValueError( + "Chunk result missing related-term embeddings for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + if len(result.related_terms) != len(result.related_term_embeddings): + raise ValueError( + "related_terms and related_term_embeddings length mismatch for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}: " + f"{len(result.related_terms)} != " + f"{len(result.related_term_embeddings)}" + ) + fuzzy_terms.extend(result.related_terms) + fuzzy_term_embeddings.extend(result.related_term_embeddings) + # Store embedding for later retrieval in correct message/chunk order + chunk_embedding_map[ + (result.chunk_id.message_ordinal, result.chunk_id.chunk_ordinal) + ] = result.chunk_embedding + async with storage: start_points = IndexingStartPoints( message_count=await self.messages.size(), semref_count=await self.semantic_refs.size(), ) - await self.messages.extend(messages_batch) - + # Build chunk_embeddings in the correct order (matching message/chunk iteration) + chunk_embeddings: list[NormalizedEmbedding] = [] + for msg_ord, message in enumerate( + messages_batch, start_points.message_count + ): + for chunk_ord in range(len(message.text_chunks)): + embedding = chunk_embedding_map.get((msg_ord, chunk_ord)) + if embedding is not None: + chunk_embeddings.append(embedding) + + # Use precomputed embeddings to avoid redundant embedding work + await self.messages.extend( + messages_batch, chunk_embeddings=chunk_embeddings + ) source_ids = [ m.source_id for m in messages_batch if m.source_id is not None ] @@ -259,57 +323,6 @@ async def _commit_batch_from_chunk_results( await self._add_metadata_knowledge_incremental(start_points.message_count) - knowledge_items: list[ - tuple[MessageOrdinal, int, kplib.KnowledgeResponse] - ] = [] - fuzzy_terms: list[str] = [] - fuzzy_term_embeddings: list[NormalizedEmbedding] = [] - - for result in chunk_results: - if result.chunk_count == 0: - continue - - if result.chunk_embedding is None: - raise ValueError( - "Chunk result missing chunk embedding for " - f"message={result.chunk_id.message_ordinal}, " - f"chunk={result.chunk_id.chunk_ordinal}" - ) - - if result.extracted_knowledge is None: - raise ValueError( - "Chunk result missing extracted knowledge for " - f"message={result.chunk_id.message_ordinal}, " - f"chunk={result.chunk_id.chunk_ordinal}" - ) - knowledge_items.append( - ( - result.chunk_id.message_ordinal, - result.chunk_id.chunk_ordinal, - result.extracted_knowledge, - ) - ) - - if ( - result.related_terms is None - or result.related_term_embeddings is None - ): - raise ValueError( - "Chunk result missing related-term embeddings for " - f"message={result.chunk_id.message_ordinal}, " - f"chunk={result.chunk_id.chunk_ordinal}" - ) - if len(result.related_terms) != len(result.related_term_embeddings): - raise ValueError( - "related_terms and related_term_embeddings length mismatch for " - f"message={result.chunk_id.message_ordinal}, " - f"chunk={result.chunk_id.chunk_ordinal}: " - f"{len(result.related_terms)} != " - f"{len(result.related_term_embeddings)}" - ) - fuzzy_terms.extend(result.related_terms) - fuzzy_term_embeddings.extend(result.related_term_embeddings) - await semrefindex.add_knowledge_batch_to_semantic_ref_index( self, knowledge_items, @@ -320,6 +333,7 @@ async def _commit_batch_from_chunk_results( messages_batch, fuzzy_terms, fuzzy_term_embeddings, + chunk_embeddings, ) await storage.update_conversation_timestamps( @@ -343,6 +357,7 @@ async def _update_secondary_indexes_incremental_with_embeddings( new_messages: list[TMessage], related_terms: list[str], related_term_embeddings: list[NormalizedEmbedding], + chunk_embeddings: list[NormalizedEmbedding], ) -> None: """Update secondary indexes using precomputed embeddings when available.""" if self.secondary_indexes is None: diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 9f17574d..de74bfe9 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -112,6 +112,17 @@ class IMessageCollection[TMessage: IMessage]( ): """A collection of Messages.""" + async def extend( + self, items: Iterable[TMessage], chunk_embeddings: list[Any] | None = None + ) -> None: + """Append multiple items to the collection. + + Args: + items: Messages to append. + chunk_embeddings: Optional precomputed embeddings for text chunks. + """ + ... + class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol): """A collection of SemanticRefs.""" diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index f7279037..54d3de14 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -3,7 +3,7 @@ """Memory-based collection implementations.""" -from typing import Iterable +from typing import Any, Iterable from ...knowpro.interfaces import ( ICollection, @@ -97,14 +97,23 @@ async def append(self, item: TMessage) -> None: if self.message_text_index is not None: await self.message_text_index.add_messages_starting_at(msg_id, [item]) - async def extend(self, items: Iterable[TMessage]) -> None: + async def extend( + self, items: Iterable[TMessage], chunk_embeddings: list[Any] | None = None + ) -> None: items_list = list(items) if not items_list: return current_size = await self.size() self.items.extend(items_list) if self.message_text_index is not None: - await self.message_text_index.add_messages_starting_at( - current_size, - items_list, - ) + if chunk_embeddings is not None: + await self.message_text_index.add_messages_starting_at_with_embeddings( + current_size, + items_list, + chunk_embeddings, + ) + else: + await self.message_text_index.add_messages_starting_at( + current_size, + items_list, + ) diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index fe394dcb..6057e282 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -186,7 +186,11 @@ async def append(self, item: TMessage) -> None: if self.message_text_index is not None: await self.message_text_index.add_messages_starting_at(msg_id, [item]) - async def extend(self, items: typing.Iterable[TMessage]) -> None: + async def extend( + self, + items: typing.Iterable[TMessage], + chunk_embeddings: list[typing.Any] | None = None, + ) -> None: items_list = list(items) # Convert to list to iterate twice if not items_list: return @@ -230,9 +234,16 @@ async def extend(self, items: typing.Iterable[TMessage]) -> None: # Also add to message text index if available if self.message_text_index is not None: - await self.message_text_index.add_messages_starting_at( - current_size, items_list - ) + if chunk_embeddings is not None: + # Use precomputed embeddings (avoids redundant embedding work) + await self.message_text_index.add_messages_starting_at_with_embeddings( + current_size, items_list, chunk_embeddings + ) + else: + # Compute embeddings on the fly + await self.message_text_index.add_messages_starting_at( + current_size, items_list + ) class SqliteSemanticRefCollection(interfaces.ISemanticRefCollection): From db13cb052c26518d29d24d686369abd24c787789 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 12:13:50 -0700 Subject: [PATCH 29/42] Remove dead chunk_embeddings param from _update_secondary_indexes_incremental_with_embeddings Chunk embeddings are already consumed by messages.extend(); passing them through to this function served no purpose. Co-Authored-By: Claude Sonnet 4.6 --- src/typeagent/knowpro/conversation_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 1c937bfe..944287b6 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -333,7 +333,6 @@ async def _commit_batch_from_chunk_results( messages_batch, fuzzy_terms, fuzzy_term_embeddings, - chunk_embeddings, ) await storage.update_conversation_timestamps( @@ -357,7 +356,6 @@ async def _update_secondary_indexes_incremental_with_embeddings( new_messages: list[TMessage], related_terms: list[str], related_term_embeddings: list[NormalizedEmbedding], - chunk_embeddings: list[NormalizedEmbedding], ) -> None: """Update secondary indexes using precomputed embeddings when available.""" if self.secondary_indexes is None: From efe1ed2c8633fc8a28576d92e1562f9a773769a4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 12:14:32 -0700 Subject: [PATCH 30/42] Replace type: ignore with assert for extracted_knowledge narrowing After the success-check early return, extracted_knowledge is guaranteed non-None; an assert communicates this to pyright and catches programmer errors at runtime. Co-Authored-By: Claude Sonnet 4.6 --- src/typeagent/knowpro/add_messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index a1a5f683..9a7276f4 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -316,8 +316,9 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( result.error = e return result + assert result.extracted_knowledge is not None result.related_terms = _collect_related_terms_for_fuzzy_index( - result.extracted_knowledge # type: ignore[arg-type] + result.extracted_knowledge ) # Step 2: Generate embeddings (only if extraction succeeded) From cffb4c3876b13445fe3b723436bd69dfbffdf41f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 12:18:27 -0700 Subject: [PATCH 31/42] Tighten chunk_embeddings type from list[Any] to list[NormalizedEmbedding] Both the IMessageCollection protocol and the MemoryMessageCollection concrete class now use the specific type. Co-Authored-By: Claude Sonnet 4.6 --- src/typeagent/knowpro/interfaces_storage.py | 5 ++++- src/typeagent/storage/memory/collections.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index de74bfe9..35e5ace1 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass +from ..aitools.embeddings import NormalizedEmbedding from .interfaces_core import ( IMessage, ITermToSemanticRefIndex, @@ -113,7 +114,9 @@ class IMessageCollection[TMessage: IMessage]( """A collection of Messages.""" async def extend( - self, items: Iterable[TMessage], chunk_embeddings: list[Any] | None = None + self, + items: Iterable[TMessage], + chunk_embeddings: list[NormalizedEmbedding] | None = None, ) -> None: """Append multiple items to the collection. diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index 54d3de14..dc7bb7b8 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -3,8 +3,9 @@ """Memory-based collection implementations.""" -from typing import Any, Iterable +from typing import Iterable +from ...aitools.embeddings import NormalizedEmbedding from ...knowpro.interfaces import ( ICollection, IMessage, @@ -98,7 +99,9 @@ async def append(self, item: TMessage) -> None: await self.message_text_index.add_messages_starting_at(msg_id, [item]) async def extend( - self, items: Iterable[TMessage], chunk_embeddings: list[Any] | None = None + self, + items: Iterable[TMessage], + chunk_embeddings: list[NormalizedEmbedding] | None = None, ) -> None: items_list = list(items) if not items_list: From b2fa6ce2caa1910b53f1ca5e42a06535fd2df381 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 12:33:20 -0700 Subject: [PATCH 32/42] Avoid wasted embedding work when deserializing message collections Add index_messages=False flag to IMessageCollection.extend() so callers can skip message text index population. Deserialization uses it because message_index.deserialize() will replace the index from the sidecar file anyway. Also change generate_embeddings(cache=True) to cache=False since indexing embeddings should not be cached. Co-Authored-By: Claude Sonnet 4.6 --- src/typeagent/knowpro/interfaces_storage.py | 2 ++ src/typeagent/podcasts/podcast.py | 2 +- src/typeagent/storage/memory/collections.py | 3 ++- src/typeagent/storage/memory/messageindex.py | 4 ++-- src/typeagent/storage/sqlite/collections.py | 3 ++- src/typeagent/transcripts/transcript.py | 2 +- 6 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 35e5ace1..ea696ce4 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -117,12 +117,14 @@ async def extend( self, items: Iterable[TMessage], chunk_embeddings: list[NormalizedEmbedding] | None = None, + index_messages: bool = True, ) -> None: """Append multiple items to the collection. Args: items: Messages to append. chunk_embeddings: Optional precomputed embeddings for text chunks. + index_messages: If False, skip updating the message text index. """ ... diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index dc53e305..ac60f3ac 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -80,7 +80,7 @@ async def deserialize( self.name_tag = podcast_data["nameTag"] message_list = [PodcastMessage.deserialize(m) for m in podcast_data["messages"]] - await self.messages.extend(message_list) + await self.messages.extend(message_list, index_messages=False) semantic_refs_data = podcast_data.get("semanticRefs") if semantic_refs_data is not None: diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index dc7bb7b8..fb0e6b0e 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -102,13 +102,14 @@ async def extend( self, items: Iterable[TMessage], chunk_embeddings: list[NormalizedEmbedding] | None = None, + index_messages: bool = True, ) -> None: items_list = list(items) if not items_list: return current_size = await self.size() self.items.extend(items_list) - if self.message_text_index is not None: + if index_messages and self.message_text_index is not None: if chunk_embeddings is not None: await self.message_text_index.add_messages_starting_at_with_embeddings( current_size, diff --git a/src/typeagent/storage/memory/messageindex.py b/src/typeagent/storage/memory/messageindex.py index f8da5377..b8eca68c 100644 --- a/src/typeagent/storage/memory/messageindex.py +++ b/src/typeagent/storage/memory/messageindex.py @@ -85,7 +85,7 @@ async def add_messages[TMessage: IMessage]( chunk_embeddings = await self.text_location_index.generate_embeddings( chunk_texts, - cache=True, + cache=False, ) await self.add_messages_starting_at_with_embeddings( base_message_ordinal, @@ -105,7 +105,7 @@ async def add_messages_starting_at[TMessage: IMessage]( chunk_embeddings = await self.text_location_index.generate_embeddings( chunk_texts, - cache=True, + cache=False, ) await self.add_messages_starting_at_with_embeddings( start_message_ordinal, diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index 6057e282..25b979e8 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -190,6 +190,7 @@ async def extend( self, items: typing.Iterable[TMessage], chunk_embeddings: list[typing.Any] | None = None, + index_messages: bool = True, ) -> None: items_list = list(items) # Convert to list to iterate twice if not items_list: @@ -233,7 +234,7 @@ async def extend( ) # Also add to message text index if available - if self.message_text_index is not None: + if index_messages and self.message_text_index is not None: if chunk_embeddings is not None: # Use precomputed embeddings (avoids redundant embedding work) await self.message_text_index.add_messages_starting_at_with_embeddings( diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 2f889b70..08840102 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -81,7 +81,7 @@ async def deserialize( message_list = [ TranscriptMessage.deserialize(m) for m in transcript_data["messages"] ] - await self.messages.extend(message_list) + await self.messages.extend(message_list, index_messages=False) semantic_refs_data = transcript_data.get("semanticRefs") if semantic_refs_data is not None: From a71f93a28d18093f4e7521971f8dbf5a02db7819 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 17:33:19 -0700 Subject: [PATCH 33/42] Fix reassembler staged-state retry hazard on post-commit callback failure --- src/typeagent/knowpro/add_messages.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 9a7276f4..5ccfb54e 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -392,21 +392,27 @@ async def _reassembler_task[TMessage: IMessage]( staged_chunks = 0 async def _commit_if_needed(force: bool = False) -> None: - nonlocal staged_chunks + nonlocal staged_chunks, staged_messages, staged_results if not staged_messages: return if not force and staged_chunks < target_commit_chunk_count: return - msg_count = len(staged_messages) + pending_messages = staged_messages + pending_results = staged_results + msg_count = len(pending_messages) chunk_count = staged_chunks - await commit_batch(staged_messages, staged_results) + + # Clear staged state before awaiting commit/callback paths so a post-commit + # exception cannot trigger a duplicate retry during final drain. + staged_messages = [] + staged_results = [] + staged_chunks = 0 + + await commit_batch(pending_messages, pending_results) state.messages_committed += msg_count state.chunks_committed += chunk_count if on_batch_committed is not None: on_batch_committed(msg_count, chunk_count) - staged_messages.clear() - staged_results.clear() - staged_chunks = 0 async def _drain_consecutive_complete(force: bool = False) -> None: nonlocal staged_chunks From 72ceb905481e3c63f1a0dd7bcb334eadd977639e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 17:34:00 -0700 Subject: [PATCH 34/42] Fail fast when staged chunk embeddings are missing --- src/typeagent/knowpro/conversation_base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 944287b6..ee3b0295 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -308,8 +308,12 @@ async def _commit_batch_from_chunk_results( ): for chunk_ord in range(len(message.text_chunks)): embedding = chunk_embedding_map.get((msg_ord, chunk_ord)) - if embedding is not None: - chunk_embeddings.append(embedding) + if embedding is None: + raise ValueError( + "Missing chunk embedding for staged message chunk: " + f"message={msg_ord}, chunk={chunk_ord}" + ) + chunk_embeddings.append(embedding) # Use precomputed embeddings to avoid redundant embedding work await self.messages.extend( From 18561c34a69a2e261aa1c0ebc2b8e77eeb5490fb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 17:34:57 -0700 Subject: [PATCH 35/42] Align sqlite message extend embedding typing with protocol --- src/typeagent/storage/sqlite/collections.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index 25b979e8..71a256df 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -7,6 +7,7 @@ import sqlite3 import typing +from ...aitools.embeddings import NormalizedEmbedding from ...knowpro import interfaces, serialization from .schema import ShreddedMessage, ShreddedSemanticRef @@ -189,7 +190,7 @@ async def append(self, item: TMessage) -> None: async def extend( self, items: typing.Iterable[TMessage], - chunk_embeddings: list[typing.Any] | None = None, + chunk_embeddings: list[NormalizedEmbedding] | None = None, index_messages: bool = True, ) -> None: items_list = list(items) # Convert to list to iterate twice From 216ea03680f86a25c68cc32b404beb3e169dca6b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 17:35:47 -0700 Subject: [PATCH 36/42] Clarify deserialize message-index replacement semantics --- src/typeagent/podcasts/podcast.py | 2 ++ src/typeagent/transcripts/transcript.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index ac60f3ac..e0d96ae1 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -80,6 +80,8 @@ async def deserialize( self.name_tag = podcast_data["nameTag"] message_list = [PodcastMessage.deserialize(m) for m in podcast_data["messages"]] + # Message index data is deserialized later and replaces prior state, + # so skip incremental indexing while bulk-loading messages. await self.messages.extend(message_list, index_messages=False) semantic_refs_data = podcast_data.get("semanticRefs") diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 08840102..a4733cdf 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -81,6 +81,8 @@ async def deserialize( message_list = [ TranscriptMessage.deserialize(m) for m in transcript_data["messages"] ] + # Message index data is deserialized later and replaces prior state, + # so skip incremental indexing while bulk-loading messages. await self.messages.extend(message_list, index_messages=False) semantic_refs_data = transcript_data.get("semanticRefs") From fa513a5c91e4dcea4e33ce7074919868df411a23 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 17:39:32 -0700 Subject: [PATCH 37/42] Use itertools.chain for related-action term collection --- src/typeagent/knowpro/add_messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 5ccfb54e..2b1d9081 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -6,6 +6,7 @@ import asyncio from collections.abc import AsyncIterable, Awaitable, Callable from dataclasses import dataclass +from itertools import chain from typing import TYPE_CHECKING import typechat @@ -257,7 +258,7 @@ def _add_term(term: str) -> None: for term in collect_entity_terms(entity): _add_term(term) - for action in list(knowledge.actions) + list(knowledge.inverse_actions): + for action in chain(knowledge.actions, knowledge.inverse_actions): for term in collect_action_terms(action): _add_term(term) From 55feb8da4bb0eab5b249a3940344e554816e4946 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 17:40:22 -0700 Subject: [PATCH 38/42] Replace redundant embedding list-copy comprehensions --- src/typeagent/knowpro/add_messages.py | 2 +- src/typeagent/storage/memory/reltermsindex.py | 4 +--- src/typeagent/storage/sqlite/reltermsindex.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 2b1d9081..56c17fdf 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -327,7 +327,7 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( result.chunk_embedding = await embedding_model.get_embedding_nocache(chunk_text) if result.related_terms: rel_embeddings = await embedding_model.get_embeddings(result.related_terms) - result.related_term_embeddings = [e for e in rel_embeddings] + result.related_term_embeddings = list(rel_embeddings) else: result.related_term_embeddings = [] except Exception as e: diff --git a/src/typeagent/storage/memory/reltermsindex.py b/src/typeagent/storage/memory/reltermsindex.py index 8ad984ef..ec074bf9 100644 --- a/src/typeagent/storage/memory/reltermsindex.py +++ b/src/typeagent/storage/memory/reltermsindex.py @@ -291,9 +291,7 @@ async def add_terms(self, texts: list[str]) -> None: if not texts: return embeddings = await self._vectorbase.get_embeddings(texts) - await self.add_terms_with_embeddings( - texts, [embedding for embedding in embeddings] - ) + await self.add_terms_with_embeddings(texts, list(embeddings)) async def add_terms_with_embeddings( self, diff --git a/src/typeagent/storage/sqlite/reltermsindex.py b/src/typeagent/storage/sqlite/reltermsindex.py index 0b01e585..1a27af71 100644 --- a/src/typeagent/storage/sqlite/reltermsindex.py +++ b/src/typeagent/storage/sqlite/reltermsindex.py @@ -216,7 +216,7 @@ async def add_terms(self, texts: list[str]) -> None: embeddings = await self._vector_base.get_embeddings(new_terms) await self.add_terms_with_embeddings( new_terms, - [embedding for embedding in embeddings], + list(embeddings), ) async def add_terms_with_embeddings( From 88e812a0dfd99d7333b243a67f991b3edfd1f8a2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 21:20:32 -0700 Subject: [PATCH 39/42] Refactor process_chunk_with_extraction_and_embeddings: parallel chunk embedding and extraction+related-embeddings with semaphore --- src/typeagent/knowpro/add_messages.py | 54 ++++++++++++++------------- tests/test_add_messages_pipeline.py | 8 +++- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 56c17fdf..63a67a7b 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -279,11 +279,9 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( ) -> ChunkProcessingResult[TMessage]: """Process a single text chunk through knowledge extraction and embeddings. - Runs both knowledge extraction and embedding in a single function call, + Runs extraction/related-term embedding and chunk embedding concurrently, capturing the first failure and stopping processing if an error occurs. - Extraction runs first; if it fails, embedding work is skipped. - Chunk embeddings are computed uncached; related-term embeddings use cache-aware model calls on the same embedding model. @@ -302,38 +300,42 @@ async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( result = ChunkProcessingResult( chunk_id=chunk_id, chunk_count=chunk_count, message=message ) + sem = asyncio.Semaphore(1) # Avoid concurrent embedding requests - # Step 1: Extract knowledge - try: + async def _extract_knowledge_and_related_embeddings() -> None: knowledge_result = await knowledge_extractor.extract(chunk_text) - if isinstance(knowledge_result, typechat.Success): - result.extracted_knowledge = knowledge_result.value - else: - result.error = RuntimeError( + if isinstance(knowledge_result, typechat.Failure): + raise RuntimeError( f"Knowledge extraction failed: {knowledge_result.message}" ) - return result - except Exception as e: - result.error = e - return result - - assert result.extracted_knowledge is not None - result.related_terms = _collect_related_terms_for_fuzzy_index( - result.extracted_knowledge - ) + result.extracted_knowledge = knowledge_result.value - # Step 2: Generate embeddings (only if extraction succeeded) - try: - result.chunk_embedding = await embedding_model.get_embedding_nocache(chunk_text) + result.related_terms = _collect_related_terms_for_fuzzy_index( + result.extracted_knowledge + ) if result.related_terms: - rel_embeddings = await embedding_model.get_embeddings(result.related_terms) + async with sem: + rel_embeddings = await embedding_model.get_embeddings( + result.related_terms + ) result.related_term_embeddings = list(rel_embeddings) else: result.related_term_embeddings = [] - except Exception as e: - # Embedding failed; record error and return - result.error = e - return result + + async def _generate_chunk_embedding() -> None: + async with sem: + result.chunk_embedding = await embedding_model.get_embedding_nocache( + chunk_text + ) + + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(_extract_knowledge_and_related_embeddings()) + tg.create_task(_generate_chunk_embedding()) + except Exception as error: + while isinstance(error, ExceptionGroup) and len(error.exceptions) == 1: + error = error.exceptions[0] + result.error = error return result diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index 11c36b09..8e24a94f 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -224,7 +224,11 @@ async def test_process_chunk_success_with_related_terms() -> None: @pytest.mark.asyncio async def test_process_chunk_extraction_failure_returns_error() -> None: - """A Failure result from the extractor sets error and skips embedding.""" + """A Failure result from the extractor sets error. + + Chunk embedding may still run because extraction and chunk embedding are + launched concurrently. + """ extractor = _SequenceExtractor([typechat.Failure("bad extraction")]) message_model = _StubEmbeddingModel() @@ -240,7 +244,7 @@ async def test_process_chunk_extraction_failure_returns_error() -> None: assert isinstance(result.error, RuntimeError) assert "bad extraction" in str(result.error) assert result.extracted_knowledge is None - assert message_model.chunk_calls == [] + assert message_model.chunk_calls == ["hello"] @pytest.mark.asyncio From 72ab410acf3cf2006077caa0b82ed88d8abf2cc2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 21:40:31 -0700 Subject: [PATCH 40/42] Move semaphore release before result queue put --- src/typeagent/knowpro/add_messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 63a67a7b..8bc95b46 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -194,10 +194,11 @@ async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: stop_state.stop_at_message_id = new_stop if stop_state.exception is None: stop_state.exception = result.error - await result_queue.put(result) finally: sem.release() + await result_queue.put(result) + async with asyncio.TaskGroup() as tg: while True: item = await chunk_queue.get() From ad157f33c1805d58df09bce92850c4d7a21de11c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 22:46:48 -0700 Subject: [PATCH 41/42] Remove unused on_batch_committed from _reassembler_task --- src/typeagent/knowpro/add_messages.py | 4 --- tests/test_add_messages_pipeline.py | 39 --------------------------- 2 files changed, 43 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 8bc95b46..89a7a245 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -375,7 +375,6 @@ async def _reassembler_task[TMessage: IMessage]( commit_batch: Callable[ [list[TMessage], list[ChunkProcessingResult[TMessage]]], Awaitable[None] ], - on_batch_committed: Callable[[int, int], None] | None, skip_failed_messages: bool, ) -> ReassemblerResult: """Reassemble chunks into messages and commit only consecutive complete ones. @@ -415,8 +414,6 @@ async def _commit_if_needed(force: bool = False) -> None: await commit_batch(pending_messages, pending_results) state.messages_committed += msg_count state.chunks_committed += chunk_count - if on_batch_committed is not None: - on_batch_committed(msg_count, chunk_count) async def _drain_consecutive_complete(force: bool = False) -> None: nonlocal staged_chunks @@ -651,7 +648,6 @@ async def _commit_batch( first_uncommitted_ordinal=initial_message_id, target_commit_chunk_count=batch_size, commit_batch=_commit_batch, - on_batch_committed=None, skip_failed_messages=skip_failed_messages, ) ) diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py index 8e24a94f..41be8d5c 100644 --- a/tests/test_add_messages_pipeline.py +++ b/tests/test_add_messages_pipeline.py @@ -587,7 +587,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=10, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) @@ -623,7 +622,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) @@ -658,7 +656,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=True, ) @@ -695,7 +692,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=99, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) @@ -727,7 +723,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) @@ -756,44 +751,12 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) assert stop_state.stop_at_message_id == 5 -@pytest.mark.asyncio -async def test_reassembler_on_batch_committed_callback_is_invoked() -> None: - result_queue = asyncio.Queue() - stop_state = PipelineStopState() - - message = _Message(["m0"]) - await result_queue.put(_chunk_result(message, 0, 0, 1)) - await result_queue.put(None) - - callback_calls: list[tuple[int, int]] = [] - - async def _commit( - messages: list[_Message], results: list[ChunkProcessingResult[_Message]] - ) -> None: - return None - - await _reassembler_task( - result_queue, - stop_state, - first_uncommitted_ordinal=0, - target_commit_chunk_count=1, - commit_batch=_commit, - on_batch_committed=lambda msg_count, chunk_count: callback_calls.append( - (msg_count, chunk_count) - ), - skip_failed_messages=False, - ) - - assert callback_calls == [(1, 1)] - - @pytest.mark.asyncio async def test_reassembler_raises_on_mismatched_chunk_count_and_sets_stop_marker() -> ( None @@ -818,7 +781,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) @@ -849,7 +811,6 @@ async def _commit( first_uncommitted_ordinal=0, target_commit_chunk_count=1, commit_batch=_commit, - on_batch_committed=None, skip_failed_messages=False, ) From 619b3b6c3d0e32b4513a033935b4ea89e6fb3908 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 May 2026 22:50:41 -0700 Subject: [PATCH 42/42] Consolidate skip-failed logging to message level --- src/typeagent/knowpro/add_messages.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py index 89a7a245..4418b972 100644 --- a/src/typeagent/knowpro/add_messages.py +++ b/src/typeagent/knowpro/add_messages.py @@ -180,12 +180,7 @@ async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: embedding_model=embedding_model, ) if result.error is not None: - if skip_failed_messages: - print( - f"Skipping message {work_item.chunk_id.message_ordinal} " - f"due to extraction/embedding error: {result.error}" - ) - else: + if not skip_failed_messages: new_stop = min( stop_state.stop_at_message_id, work_item.chunk_id.message_ordinal,