diff --git a/AGENTS.md b/AGENTS.md index dcc3b9a4..a89c7000 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,7 @@ AGENTS.md. In all cases show what you added to AGENTS.md. - Don't use '!' on the command line, it's some bash magic (even inside single quotes) - When running 'make' commands, do not use the venv (the Makefile uses 'uv run') +- If a term can refer to OS behavior or repository code behavior (for example, 'force quit'), prefer the in-repo meaning first and verify by searching the code. - To get API keys in ad-hoc code, call `load_dotenv()` - Use `pytest test` to run tests in test/ - Use `pyright` to check type annotations in src/, tools/, tests/, examples/ @@ -29,7 +30,21 @@ AGENTS.md. In all cases show what you added to AGENTS.md. - Use `make check test` to run `make check` and if it passes also run `make test` - Use `make format` to format all files using `black`. Do this before reporting success. - When validating changes, first run `pytest` only on new/modified test files, then run `make format check test` once at the end. +- While building `add_messages.py` before dedicated tests exist, skip running the full test suite; run full tests after those tests are added. - Keep ad-hoc and performance benchmarks under `tools/`, not `tests/`, so `make test` does not run them. +- In add-messages pipeline chunk processing, compute chunk-text embeddings with uncached model calls and related-term embeddings with cached model calls. +- In add-messages pipeline flow, lower stop_at_message_id to min(existing, failing_message_id), and always enqueue queue-1 sentinels even when the input iterator fails so workers can drain and exit cleanly. +- In add-messages pipeline data structures, use `TextLocation` as the chunk identifier instead of a formatted string chunk ID. +- In add-messages reassembler validation, prefer explicit guard checks over wrapping validation-only logic in `try/except` blocks. +- In add-messages reassembler validation, prefer a single `validation_error` variable with consistent `if/elif` checks over helper functions for simple message-only validation. +- When adding precomputed-embedding write paths, expose explicit `*_with_embeddings` methods and have existing methods compute embeddings then delegate to those methods. +- In asyncio code, avoid locks for in-memory state updates that do not `await` between read/modify/write; use locks only when a critical section spans `await` points. +- Name returned summary/value objects as `*Result`; reserve `*State` for mutable shared/internal state. +- Keep internal helper type naming consistent within a module; avoid mixing underscored and non-underscored helper class names without a clear API-boundary reason. +- Prefer variable names that reflect role rather than lifecycle; for accumulators like message assemblies, use neutral names (e.g., `assembly`) instead of state-qualified names (e.g., `existing`). +- Avoid potential import cycles between conversation orchestration and pipeline modules by using neutral payload protocols/arguments instead of importing concrete pipeline result classes across modules. +- Prefer ordinal type aliases (e.g., `MessageOrdinal`, `ChunkOrdinal`) over raw `int` in pipeline code for readability. +- When the user asks to "fix the test only", update tests/mocks first and avoid adding production compatibility fallbacks unless explicitly requested. ## Package Management with uv @@ -55,8 +70,12 @@ please follow these guidelines: * Assume Python 3.12 +* `from __future__ import annotations` is not allowed. + * Always strip trailing spaces +* Keep docstrings in sync with code when changing implementation. + * Keep class and type names in `PascalCase` * Use `python_case` for variable/field and function/method names diff --git a/Makefile b/Makefile index c2aeb33b..83616532 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,8 @@ format: venv .PHONY: check check: venv - uv run pyright src tests tools examples + uv run pyright --pythonversion 3.12 src tests tools examples + uv run pyright --pythonversion 3.14 src tests tools examples .PHONY: test test: venv @@ -21,10 +22,10 @@ test: venv .PHONY: coverage coverage: venv - coverage erase + uv run coverage erase COVERAGE_PROCESS_START=.coveragerc uv run coverage run -m pytest $(FLAGS) - coverage combine - coverage report + uv run coverage combine + uv run coverage report .PHONY: demo demo: venv diff --git a/src/typeagent/knowpro/add_messages.py b/src/typeagent/knowpro/add_messages.py new file mode 100644 index 00000000..4418b972 --- /dev/null +++ b/src/typeagent/knowpro/add_messages.py @@ -0,0 +1,679 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""New modular implementation of add_messages_streaming with pipelined architecture.""" + +import asyncio +from collections.abc import AsyncIterable, Awaitable, Callable +from dataclasses import dataclass +from itertools import chain +from typing import TYPE_CHECKING + +import typechat + +from . import knowledge_schema as kplib +from ..aitools.embeddings import IEmbeddingModel, NormalizedEmbedding +from ..storage.memory.semrefindex import collect_action_terms, collect_entity_terms +from .interfaces import AddMessagesResult +from .interfaces_core import IKnowledgeExtractor, IMessage, MessageOrdinal, TextLocation + +__all__ = ["add_messages_streaming"] + +if TYPE_CHECKING: + from .conversation_base import ConversationBase + +type ChunkOrdinal = int + +_EMPTY_KNOWLEDGE = kplib.KnowledgeResponse( + entities=[], actions=[], inverse_actions=[], topics=[] +) + + +class NoOpKnowledgeExtractor: + """No-op extractor used when auto_extract_knowledge is False.""" + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + return typechat.Success(_EMPTY_KNOWLEDGE) + + +@dataclass +class PipelineStopState: + """Shared stop marker for pipeline stages. + + A message ordinal greater than or equal to ``stop_at_message_id`` is + considered out-of-scope for further processing. + + ``exception`` holds the error from the lowest-ordinal message that + caused the stop, so the orchestrator can re-raise it after the pipeline + drains. + """ + + stop_at_message_id: int = 10**100 + exception: Exception | None = None + + +@dataclass +class ProducerState: + """Mutable producer state shared with orchestrator/reporting.""" + + next_message_id: MessageOrdinal + produced_messages: int = 0 + produced_chunks: int = 0 + exception: Exception | None = None + + +@dataclass +class ChunkWorkItem[TMessage: IMessage]: + """One chunk scheduled by the producer for worker processing.""" + + chunk_id: TextLocation + chunk_count: int + chunk_text: str + message: TMessage + + +async def _producer_task[TMessage: IMessage]( + messages: AsyncIterable[TMessage], + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], + stop_state: PipelineStopState, + producer_state: ProducerState, + result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"], + shutdown_event: asyncio.Event | None, +) -> None: + """Read input messages and enqueue chunk work items. + + The producer stops enqueueing once it reaches ``stop_at_message_id``. + It always sends a sentinel to shut down the dispatcher, even if the + input iterator raises. + """ + try: + async for message in messages: + message_id = producer_state.next_message_id + if message_id >= stop_state.stop_at_message_id: + break + if shutdown_event is not None and shutdown_event.is_set(): + break + + chunk_count = len(message.text_chunks) + if chunk_count == 0: + # Zero-chunk message: nothing for the dispatcher to process. + # Emit a zero-chunk result directly to the reassembler. + await result_queue.put( + ChunkProcessingResult[TMessage]( + chunk_id=TextLocation(message_id, 0), + chunk_count=0, + message=message, + ) + ) + producer_state.produced_messages += 1 + producer_state.next_message_id += 1 + continue + + for chunk_ordinal, chunk_text in enumerate(message.text_chunks): + if message_id >= stop_state.stop_at_message_id: + break + await chunk_queue.put( + ChunkWorkItem[TMessage]( + chunk_id=TextLocation(message_id, chunk_ordinal), + chunk_count=chunk_count, + chunk_text=chunk_text, + message=message, + ) + ) + producer_state.produced_chunks += 1 + + producer_state.produced_messages += 1 + producer_state.next_message_id += 1 + except Exception as exc: + producer_state.exception = exc + finally: + await chunk_queue.put(None) + + +async def _dispatcher_task[TMessage: IMessage]( + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None], + result_queue: asyncio.Queue["ChunkProcessingResult[TMessage] | None"], + stop_state: PipelineStopState, + knowledge_extractor: IKnowledgeExtractor, + embedding_model: IEmbeddingModel, + concurrency: int, + skip_failed_messages: bool, +) -> None: + """Dispatch chunk work items to bounded per-item worker tasks. + + Reads work items from ``chunk_queue`` until it receives a ``None`` + sentinel, then awaits all in-flight tasks via a TaskGroup and puts a + ``None`` sentinel on ``result_queue`` to signal the reassembler. + + Concurrency is bounded by a semaphore so at most ``concurrency`` worker + tasks run simultaneously. Chunks at or beyond ``stop_at_message_id`` are + skipped and reported as error results so the reassembler can account for + them deterministically. + + Args: + skip_failed_messages: If True, don't halt producer on extraction/embedding + failures; continue processing. If False, halt on first failure. + """ + sem = asyncio.Semaphore(concurrency) + + async def _process_one(work_item: ChunkWorkItem[TMessage]) -> None: + try: + stop_at = stop_state.stop_at_message_id + if work_item.chunk_id.message_ordinal >= stop_at: + result: "ChunkProcessingResult[TMessage]" = ChunkProcessingResult( + chunk_id=work_item.chunk_id, + chunk_count=work_item.chunk_count, + message=work_item.message, + error=RuntimeError( + "Chunk skipped because stop_at_message_id " + f"is {stop_at} and message_id is " + f"{work_item.chunk_id.message_ordinal}" + ), + ) + else: + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=work_item.chunk_id, + chunk_text=work_item.chunk_text, + chunk_count=work_item.chunk_count, + message=work_item.message, + knowledge_extractor=knowledge_extractor, + embedding_model=embedding_model, + ) + if result.error is not None: + if not skip_failed_messages: + new_stop = min( + stop_state.stop_at_message_id, + work_item.chunk_id.message_ordinal, + ) + if new_stop < stop_state.stop_at_message_id: + stop_state.stop_at_message_id = new_stop + if stop_state.exception is None: + stop_state.exception = result.error + finally: + sem.release() + + await result_queue.put(result) + + async with asyncio.TaskGroup() as tg: + while True: + item = await chunk_queue.get() + if item is None: + break + await sem.acquire() + tg.create_task(_process_one(item)) + + await result_queue.put(None) + + +@dataclass +class ChunkProcessingResult[TMessage: IMessage]: + """Result of processing a single chunk through extraction and embeddings. + + Attributes: + chunk_id: Message/chunk location for the processed chunk. + chunk_count: Total number of chunks in the message that owns this chunk. + message: Original message object containing this chunk. + extracted_knowledge: Extracted KnowledgeResponse if extraction succeeded, else None. + chunk_embedding: Normalized embedding vector for the message chunk, or None if extraction or embedding failed. + related_terms: Lowercased, deduplicated related-term texts extracted from knowledge. + related_term_embeddings: Embeddings for related_terms in the same order, or [] when there are no related terms. + error: Exception from the first failing operation, or None if extraction and embedding succeeded. + + The ``success`` property is True only when extraction succeeded, chunk embedding was + generated, related-term embeddings were generated, and no error occurred. + """ + + chunk_id: TextLocation + chunk_count: int + message: TMessage + extracted_knowledge: kplib.KnowledgeResponse | None = None + chunk_embedding: NormalizedEmbedding | None = None + related_terms: list[str] | None = None + related_term_embeddings: list[NormalizedEmbedding] | None = None + error: Exception | None = None + + +def _collect_related_terms_for_fuzzy_index( + knowledge: kplib.KnowledgeResponse, +) -> list[str]: + """Collect canonical related-term texts for the fuzzy related-terms index. + + These terms are derived from the same knowledge that feeds semantic refs. + We lowercase and deduplicate while preserving order to match index behavior. + """ + seen: set[str] = set() + related_terms: list[str] = [] + + def _add_term(term: str) -> None: + canonical = term.strip().lower() + if canonical and canonical not in seen: + seen.add(canonical) + related_terms.append(canonical) + + for entity in knowledge.entities: + for term in collect_entity_terms(entity): + _add_term(term) + + for action in chain(knowledge.actions, knowledge.inverse_actions): + for term in collect_action_terms(action): + _add_term(term) + + for topic in knowledge.topics: + _add_term(topic) + + return related_terms + + +# "Public", imported by tests +async def process_chunk_with_extraction_and_embeddings[TMessage: IMessage]( + chunk_id: TextLocation, + chunk_text: str, + chunk_count: int, + message: TMessage, + knowledge_extractor: IKnowledgeExtractor, + embedding_model: IEmbeddingModel, +) -> ChunkProcessingResult[TMessage]: + """Process a single text chunk through knowledge extraction and embeddings. + + Runs extraction/related-term embedding and chunk embedding concurrently, + capturing the first failure and stopping processing if an error occurs. + + Chunk embeddings are computed uncached; related-term embeddings use + cache-aware model calls on the same embedding model. + + Args: + chunk_id: Message/chunk location for this chunk. + chunk_text: Text content of the chunk (stripped). + chunk_count: Total number of chunks in the message. + message: Original message object containing this chunk. + knowledge_extractor: IKnowledgeExtractor instance for LLM extraction. + embedding_model: Embedding model for chunk and related-term embeddings. + + Returns: + ChunkProcessingResult with knowledge, chunk embedding, related-term + embeddings, or an error from the first failed operation. + """ + result = ChunkProcessingResult( + chunk_id=chunk_id, chunk_count=chunk_count, message=message + ) + sem = asyncio.Semaphore(1) # Avoid concurrent embedding requests + + async def _extract_knowledge_and_related_embeddings() -> None: + knowledge_result = await knowledge_extractor.extract(chunk_text) + if isinstance(knowledge_result, typechat.Failure): + raise RuntimeError( + f"Knowledge extraction failed: {knowledge_result.message}" + ) + result.extracted_knowledge = knowledge_result.value + + result.related_terms = _collect_related_terms_for_fuzzy_index( + result.extracted_knowledge + ) + if result.related_terms: + async with sem: + rel_embeddings = await embedding_model.get_embeddings( + result.related_terms + ) + result.related_term_embeddings = list(rel_embeddings) + else: + result.related_term_embeddings = [] + + async def _generate_chunk_embedding() -> None: + async with sem: + result.chunk_embedding = await embedding_model.get_embedding_nocache( + chunk_text + ) + + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(_extract_knowledge_and_related_embeddings()) + tg.create_task(_generate_chunk_embedding()) + except Exception as error: + while isinstance(error, ExceptionGroup) and len(error.exceptions) == 1: + error = error.exceptions[0] + result.error = error + + return result + + +@dataclass +class MessageAssembly[TMessage: IMessage]: + """In-memory chunk accumulation for one message.""" + + message_id: MessageOrdinal + chunk_count: int + message: TMessage + chunks: dict[ChunkOrdinal, ChunkProcessingResult[TMessage]] + has_error: bool = False + + def is_complete(self) -> bool: + return len(self.chunks) == self.chunk_count + + +@dataclass +class ReassemblerResult: + """Progress and counters produced by the reassembler stage.""" + + first_uncommitted_ordinal: MessageOrdinal + messages_committed: int = 0 + chunks_committed: int = 0 + chunk_failures: int = 0 + messages_skipped: int = 0 + buffered_messages: int = 0 + + +async def _reassembler_task[TMessage: IMessage]( + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None], + stop_state: PipelineStopState, + first_uncommitted_ordinal: MessageOrdinal, + target_commit_chunk_count: int, + commit_batch: Callable[ + [list[TMessage], list[ChunkProcessingResult[TMessage]]], Awaitable[None] + ], + skip_failed_messages: bool, +) -> ReassemblerResult: + """Reassemble chunks into messages and commit only consecutive complete ones. + + This stage consumes worker results until it sees a ``None`` sentinel. + It never commits out-of-order messages: if message N is incomplete or failed, + messages N+1 and later remain buffered. + + Args: + skip_failed_messages: If True, skip messages that fail extraction/embedding + and continue processing. If False, halt processing on first failure. + """ + state = ReassemblerResult(first_uncommitted_ordinal=first_uncommitted_ordinal) + assemblies: dict[MessageOrdinal, MessageAssembly[TMessage]] = {} + + staged_messages: list[TMessage] = [] + staged_results: list[ChunkProcessingResult[TMessage]] = [] + staged_chunks = 0 + + async def _commit_if_needed(force: bool = False) -> None: + nonlocal staged_chunks, staged_messages, staged_results + if not staged_messages: + return + if not force and staged_chunks < target_commit_chunk_count: + return + pending_messages = staged_messages + pending_results = staged_results + msg_count = len(pending_messages) + chunk_count = staged_chunks + + # Clear staged state before awaiting commit/callback paths so a post-commit + # exception cannot trigger a duplicate retry during final drain. + staged_messages = [] + staged_results = [] + staged_chunks = 0 + + await commit_batch(pending_messages, pending_results) + state.messages_committed += msg_count + state.chunks_committed += chunk_count + + async def _drain_consecutive_complete(force: bool = False) -> None: + nonlocal staged_chunks + while True: + assembly = assemblies.get(state.first_uncommitted_ordinal) + if assembly is None: + await _commit_if_needed(force) + return + if not assembly.is_complete(): + await _commit_if_needed(force) + return + if assembly.has_error: + if skip_failed_messages: + # Skip this failed message and continue + # Find the error from one of the chunks for logging + error_msg = "Unknown error" + for chunk_result in assembly.chunks.values(): + if chunk_result.error is not None: + error_msg = str(chunk_result.error) + break + print( + f"Skipping message {state.first_uncommitted_ordinal} " + f"due to chunk processing error: {error_msg}" + ) + del assemblies[state.first_uncommitted_ordinal] + state.first_uncommitted_ordinal += 1 + state.messages_skipped += 1 + # Continue to check the next message + continue + else: + # Stop at failed message; halt processing + await _commit_if_needed(force) + return + + # Pre-flush: if staging this message would push staged_chunks past + # the target, commit the current batch first. + if ( + staged_messages + and staged_chunks + assembly.chunk_count > target_commit_chunk_count + ): + await _commit_if_needed(force=True) + + ordered_chunk_ordinals = sorted(assembly.chunks) + ordered_results = [assembly.chunks[i] for i in ordered_chunk_ordinals] + staged_messages.append(assembly.message) + staged_results.extend(ordered_results) + staged_chunks += len(ordered_results) + + del assemblies[state.first_uncommitted_ordinal] + state.first_uncommitted_ordinal += 1 + await _commit_if_needed(force) + + try: + while True: + item = await result_queue.get() + if item is None: + break + + chunk_ordinal = item.chunk_id.chunk_ordinal + message_id = item.chunk_id.message_ordinal + + validation_error: str | None = None + assembly = assemblies.get(message_id) + if item.chunk_count == 0: + # Zero-chunk message: create an immediately-complete assembly. + if assembly is None: + assembly = MessageAssembly[TMessage]( + message_id=message_id, + chunk_count=0, + message=item.message, + chunks={}, + ) + assemblies[message_id] = assembly + elif chunk_ordinal < 0 or chunk_ordinal >= item.chunk_count: + validation_error = ( + f"Invalid chunk ordinal: message_id={message_id}, " + f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" + ) + elif assembly is None: + assembly = MessageAssembly[TMessage]( + message_id=message_id, + chunk_count=item.chunk_count, + message=item.message, + chunks={}, + ) + assemblies[message_id] = assembly + elif assembly.chunk_count != item.chunk_count: + validation_error = ( + f"Mismatched chunk count for message: message_id={message_id}, " + f"expected={assembly.chunk_count}, got={item.chunk_count}" + ) + elif chunk_ordinal in assembly.chunks: + validation_error = ( + f"Duplicate chunk: message_id={message_id}, " + f"chunk_ordinal={chunk_ordinal}, chunk_count={item.chunk_count}" + ) + + if validation_error is not None: + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, message_id + ) + raise RuntimeError(validation_error) + + assert assembly is not None + + if item.chunk_count > 0: + assembly.chunks[chunk_ordinal] = item + + if item.error is not None: + assembly.has_error = True + state.chunk_failures += 1 + if not skip_failed_messages: + stop_state.stop_at_message_id = min( + stop_state.stop_at_message_id, message_id + ) + + await _drain_consecutive_complete() + finally: + # Always drain and commit consecutive complete messages before raising + await _drain_consecutive_complete(force=True) + + state.buffered_messages = len(assemblies) + return state + + +async def add_messages_streaming[TMessage: IMessage]( + conv: "ConversationBase[TMessage]", + messages: AsyncIterable[TMessage], + *, + batch_size: int = 100, + on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + skip_failed_messages: bool = False, + shutdown_event: asyncio.Event | None = None, +) -> AddMessagesResult: + """Ingest messages through a producer/dispatcher/reassembler pipeline. + + The function preserves message commit order while processing chunk extraction + and embedding concurrently. Batches are committed only for consecutive, + complete, non-failing messages. + + Args: + conv: Conversation receiving the new messages. + messages: Async iterable of messages to ingest. + batch_size: Target number of chunks per commit batch. + on_batch_committed: Optional callback invoked after each committed batch + with that batch's AddMessagesResult. + skip_failed_messages: If True, skip messages that fail extraction or + embedding and continue processing. If False (default), halt on + first failure and raise an exception. + + Returns: + AddMessagesResult aggregating all committed batches, including count + of messages_skipped when skip_failed_messages is True. + + Raises: + Exception: If a single failure occurs during production, processing, + reassembly, or commit (when skip_failed_messages is False). + ExceptionGroup: If multiple distinct failures are observed across + pipeline stages (when skip_failed_messages is False). + """ + from . import convknowledge + + settings = conv.settings + sem_ref_settings = settings.semantic_ref_index_settings + storage = await settings.get_storage_provider() + if sem_ref_settings.auto_extract_knowledge: + knowledge_extractor: IKnowledgeExtractor = ( + sem_ref_settings.knowledge_extractor or convknowledge.KnowledgeExtractor() + ) + else: + knowledge_extractor = NoOpKnowledgeExtractor() + embedding_model = settings.embedding_model + + initial_message_id: MessageOrdinal = await conv.messages.size() + + total = AddMessagesResult() + + def _accumulate(result: AddMessagesResult) -> None: + total.messages_added += result.messages_added + total.semrefs_added += result.semrefs_added + total.chunks_added += result.chunks_added + if on_batch_committed: + on_batch_committed(result) + + async def _commit_batch( + messages_batch: list[TMessage], + chunk_results: list[ChunkProcessingResult[TMessage]], + ) -> None: + result = await conv._commit_batch_from_chunk_results( + storage, messages_batch, chunk_results + ) + _accumulate(result) + + chunk_queue: asyncio.Queue[ChunkWorkItem[TMessage] | None] = asyncio.Queue( + maxsize=sem_ref_settings.concurrency * 2 + ) + result_queue: asyncio.Queue[ChunkProcessingResult[TMessage] | None] = asyncio.Queue( + maxsize=sem_ref_settings.concurrency * 2 + ) + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=initial_message_id) + + task_exceptions: list[Exception] = [] + reassembler_task: asyncio.Task[ReassemblerResult] | None = None + try: + async with asyncio.TaskGroup() as tg: + tg.create_task( + _producer_task( + messages, + chunk_queue, + stop_state, + producer_state, + result_queue, + shutdown_event=shutdown_event, + ) + ) + tg.create_task( + _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + knowledge_extractor, + embedding_model, + concurrency=sem_ref_settings.concurrency, + skip_failed_messages=skip_failed_messages, + ) + ) + reassembler_task = tg.create_task( + _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=initial_message_id, + target_commit_chunk_count=batch_size, + commit_batch=_commit_batch, + skip_failed_messages=skip_failed_messages, + ) + ) + except ExceptionGroup as eg: + task_exceptions.extend(eg.exceptions) + except Exception as exc: + task_exceptions.append(exc) + + if producer_state.exception is not None: + task_exceptions.append(producer_state.exception) + + if stop_state.exception is not None and not skip_failed_messages: + task_exceptions.append(stop_state.exception) + + if task_exceptions: + distinct_exceptions: list[Exception] = [] + for exc in task_exceptions: + if exc not in distinct_exceptions: + distinct_exceptions.append(exc) + + if len(distinct_exceptions) == 1: + raise distinct_exceptions[0] + raise ExceptionGroup("add_messages_streaming failed", distinct_exceptions) + + # Collect messages_skipped from reassembler result if skip_failed_messages is True + if skip_failed_messages and reassembler_task is not None: + try: + reassembler_result = reassembler_task.result() + total.messages_skipped = reassembler_result.messages_skipped + except Exception: + # reassembler_task result may not be available if task group failed + pass + + return total diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 673695d2..ee3b0295 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -5,10 +5,9 @@ import asyncio from collections.abc import AsyncIterable, Callable, Sequence -import contextlib from dataclasses import dataclass from datetime import datetime, timezone -from typing import Generic, Self, TypeVar +from typing import Generic, Protocol, Self, TypeVar import typechat @@ -24,6 +23,7 @@ ) from . import knowledge_schema as kplib from ..aitools import model_adapters, utils +from ..aitools.embeddings import NormalizedEmbedding from ..storage.memory import semrefindex from .convsettings import ConversationSettings from .interfaces import ( @@ -40,19 +40,20 @@ Topic, ) from .interfaces_core import TextLocation -from .knowledge import extract_knowledge_from_text_batch from .messageutils import get_all_message_chunk_locations TMessage = TypeVar("TMessage", bound=IMessage) -@dataclass(frozen=True) -class _ExtractionResult: - """Pre-extracted knowledge for a batch, ready to commit.""" +class _ChunkCommitResult(Protocol): + """Neutral chunk commit payload shape used by pipeline batch commit.""" - messages: Sequence[IMessage] - text_locations: list[TextLocation] - knowledge_results: list[typechat.Result[kplib.KnowledgeResponse]] + chunk_id: TextLocation + chunk_count: int + extracted_knowledge: kplib.KnowledgeResponse | None + chunk_embedding: NormalizedEmbedding | None + related_terms: list[str] | None + related_term_embeddings: list[NormalizedEmbedding] | None @dataclass(init=False) @@ -217,209 +218,135 @@ async def add_messages_streaming( *, batch_size: int = 100, on_batch_committed: Callable[[AddMessagesResult], None] | None = None, + skip_failed_messages: bool = False, + shutdown_event: asyncio.Event | None = None, ) -> AddMessagesResult: - """Add messages from an async iterable, committing in batches. + """Delegate to the pipelined add_messages implementation.""" + from . import add_messages - Uses a two-stage pipeline: while batch N is being committed (DB writes, - embeddings, secondary indexes), batch N+1's LLM extraction runs - concurrently. LLM extraction is typically 95% of wall time, so this - nearly doubles throughput for multi-batch ingestions. - - **Source-ID tracking**: each message's ``source_id`` (if not ``None``) - is marked as ingested within the commit transaction. Callers are - responsible for filtering duplicates before yielding messages. - - **Extraction failures**: when knowledge extraction returns a - ``Failure`` for a chunk, the failure is recorded via - ``storage.record_chunk_failure`` and processing continues with the - remaining chunks. Raised exceptions (HTTP errors, timeouts, etc.) - are treated as systemic and stop the run immediately — the current - batch is rolled back and the exception propagates. - - Args: - messages: An async iterable of messages to ingest. - batch_size: Target number of text chunks per commit batch. - Messages are never split across batches, so the actual - chunk count may exceed ``batch_size`` if a single message - has more chunks than that. - on_batch_committed: Optional callback invoked after each batch is - committed, receiving the batch's ``AddMessagesResult``. - - Returns: - Cumulative ``AddMessagesResult`` across all committed batches. - """ - storage = await self.settings.get_storage_provider() - should_extract = ( - self.settings.semantic_ref_index_settings.auto_extract_knowledge + return await add_messages.add_messages_streaming( + self, + messages, + batch_size=batch_size, + on_batch_committed=on_batch_committed, + skip_failed_messages=skip_failed_messages, + shutdown_event=shutdown_event, ) - total = AddMessagesResult() - - def _accumulate(result: AddMessagesResult) -> None: - total.messages_added += result.messages_added - total.semrefs_added += result.semrefs_added - total.chunks_added += result.chunks_added - if on_batch_committed: - on_batch_committed(result) - - pending_commit: asyncio.Task[AddMessagesResult] | None = None - pending_extraction: asyncio.Task[_ExtractionResult | None] | None = None - - async def _drain_commit() -> None: - nonlocal pending_commit - if pending_commit is not None: - _accumulate(await pending_commit) - pending_commit = None - - async def _submit_batch(batch: list[TMessage]) -> None: - nonlocal pending_commit, pending_extraction - if not batch: - return - - if should_extract: - next_extraction = asyncio.create_task( - self._extract_knowledge_for_batch(batch) - ) - else: - next_extraction = None - pending_extraction = next_extraction - - await _drain_commit() - - extraction = await next_extraction if next_extraction is not None else None - pending_extraction = None - - pending_commit = asyncio.create_task( - self._commit_batch_streaming(storage, batch, extraction) - ) - try: - batch: list[TMessage] = [] - batch_chunks = 0 - async for msg in messages: - msg_chunks = len(msg.text_chunks) - if batch and batch_chunks + msg_chunks > batch_size: - await _submit_batch(batch) - batch = [] - batch_chunks = 0 - batch.append(msg) - batch_chunks += msg_chunks - if batch_chunks >= batch_size: - await _submit_batch(batch) - batch = [] - batch_chunks = 0 - - if batch: - await _submit_batch(batch) - - await _drain_commit() - except BaseException: - if pending_extraction is not None and not pending_extraction.done(): - pending_extraction.cancel() - with contextlib.suppress(asyncio.CancelledError): - await pending_extraction - if pending_commit is not None and not pending_commit.done(): - pending_commit.cancel() - with contextlib.suppress(asyncio.CancelledError): - await pending_commit - raise - - return total - - async def _extract_knowledge_for_batch( + async def _commit_batch_from_chunk_results( self, - messages: list[TMessage], - ) -> _ExtractionResult | None: - """Run LLM extraction on message texts — no DB access. - - Uses 0-based ordinals; the caller remaps to global ordinals at commit - time. Safe to run concurrently with a DB transaction on another batch. - """ - text_locations = get_all_message_chunk_locations(messages, 0) - if not text_locations: - return None - - settings = self.settings.semantic_ref_index_settings - knowledge_extractor = ( - settings.knowledge_extractor or convknowledge.KnowledgeExtractor() - ) - - text_batch = [ - messages[tl.message_ordinal].text_chunks[tl.chunk_ordinal].strip() - for tl in text_locations - ] + storage: IStorageProvider[TMessage], + messages_batch: list[TMessage], + chunk_results: Sequence[_ChunkCommitResult], + ) -> AddMessagesResult: + """Commit one pipeline batch using precomputed extraction and embeddings.""" + if not messages_batch: + return AddMessagesResult() + + # Process chunk results first to collect embeddings and knowledge items + knowledge_items: list[tuple[MessageOrdinal, int, kplib.KnowledgeResponse]] = [] + fuzzy_terms: list[str] = [] + fuzzy_term_embeddings: list[NormalizedEmbedding] = [] + chunk_embedding_map: dict[tuple[int, int], NormalizedEmbedding] = {} + + for result in chunk_results: + if result.chunk_count == 0: + continue - knowledge_results = await extract_knowledge_from_text_batch( - knowledge_extractor, - text_batch, - settings.concurrency, - ) - return _ExtractionResult( - messages=messages, - text_locations=text_locations, - knowledge_results=knowledge_results, - ) + if result.chunk_embedding is None: + raise ValueError( + "Chunk result missing chunk embedding for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) - async def _apply_extraction_results( - self, - storage: IStorageProvider[TMessage], - extraction: _ExtractionResult, - global_message_start: int, - ) -> None: - """Write pre-extracted knowledge into the DB. Must be inside a transaction.""" - bulk_items: list[tuple[int, int, kplib.KnowledgeResponse]] = [] - for i, knowledge_result in enumerate(extraction.knowledge_results): - tl = extraction.text_locations[i] - global_msg_ord = tl.message_ordinal + global_message_start - if isinstance(knowledge_result, typechat.Failure): - await storage.record_chunk_failure( - global_msg_ord, - tl.chunk_ordinal, - type(knowledge_result).__name__, - knowledge_result.message[:500], + if result.extracted_knowledge is None: + raise ValueError( + "Chunk result missing extracted knowledge for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + knowledge_items.append( + ( + result.chunk_id.message_ordinal, + result.chunk_id.chunk_ordinal, + result.extracted_knowledge, ) - continue - bulk_items.append( - (global_msg_ord, tl.chunk_ordinal, knowledge_result.value) - ) - if bulk_items: - await semrefindex.add_knowledge_batch_to_semantic_ref_index( - self, bulk_items ) - async def _commit_batch_streaming( - self, - storage: IStorageProvider[TMessage], - filtered: list[TMessage], - extraction: _ExtractionResult | None, - ) -> AddMessagesResult: - """Commit a single batch with pre-extracted knowledge.""" + if result.related_terms is None or result.related_term_embeddings is None: + raise ValueError( + "Chunk result missing related-term embeddings for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}" + ) + if len(result.related_terms) != len(result.related_term_embeddings): + raise ValueError( + "related_terms and related_term_embeddings length mismatch for " + f"message={result.chunk_id.message_ordinal}, " + f"chunk={result.chunk_id.chunk_ordinal}: " + f"{len(result.related_terms)} != " + f"{len(result.related_term_embeddings)}" + ) + fuzzy_terms.extend(result.related_terms) + fuzzy_term_embeddings.extend(result.related_term_embeddings) + # Store embedding for later retrieval in correct message/chunk order + chunk_embedding_map[ + (result.chunk_id.message_ordinal, result.chunk_id.chunk_ordinal) + ] = result.chunk_embedding + async with storage: start_points = IndexingStartPoints( message_count=await self.messages.size(), semref_count=await self.semantic_refs.size(), ) - await self.messages.extend(filtered) - - source_ids = [m.source_id for m in filtered if m.source_id is not None] + # Build chunk_embeddings in the correct order (matching message/chunk iteration) + chunk_embeddings: list[NormalizedEmbedding] = [] + for msg_ord, message in enumerate( + messages_batch, start_points.message_count + ): + for chunk_ord in range(len(message.text_chunks)): + embedding = chunk_embedding_map.get((msg_ord, chunk_ord)) + if embedding is None: + raise ValueError( + "Missing chunk embedding for staged message chunk: " + f"message={msg_ord}, chunk={chunk_ord}" + ) + chunk_embeddings.append(embedding) + + # Use precomputed embeddings to avoid redundant embedding work + await self.messages.extend( + messages_batch, chunk_embeddings=chunk_embeddings + ) + source_ids = [ + m.source_id for m in messages_batch if m.source_id is not None + ] if source_ids: await storage.mark_sources_ingested_batch(source_ids) await self._add_metadata_knowledge_incremental(start_points.message_count) - if extraction is not None: - await self._apply_extraction_results( - storage, extraction, start_points.message_count - ) + await semrefindex.add_knowledge_batch_to_semantic_ref_index( + self, + knowledge_items, + ) - await self._update_secondary_indexes_incremental(start_points) + await self._update_secondary_indexes_incremental_with_embeddings( + start_points, + messages_batch, + fuzzy_terms, + fuzzy_term_embeddings, + ) await storage.update_conversation_timestamps( updated_at=datetime.now(timezone.utc) ) messages_added = await self.messages.size() - start_points.message_count - chunks_added = sum(len(m.text_chunks) for m in filtered[:messages_added]) + chunks_added = sum( + len(message.text_chunks) for message in messages_batch[:messages_added] + ) return AddMessagesResult( messages_added=messages_added, chunks_added=chunks_added, @@ -427,6 +354,35 @@ async def _commit_batch_streaming( - start_points.semref_count, ) + async def _update_secondary_indexes_incremental_with_embeddings( + self, + start_points: IndexingStartPoints, + new_messages: list[TMessage], + related_terms: list[str], + related_term_embeddings: list[NormalizedEmbedding], + ) -> None: + """Update secondary indexes using precomputed embeddings when available.""" + if self.secondary_indexes is None: + return + + from ..storage.memory import propindex + + await propindex.add_to_property_index(self, start_points.semref_count) + + await self._add_timestamps_for_messages( + new_messages, + start_points.message_count, + ) + + term_to_related = self.secondary_indexes.term_to_related_terms_index + if term_to_related is not None: + fuzzy_index = term_to_related.fuzzy_index + if fuzzy_index is not None and related_terms: + await fuzzy_index.add_terms_with_embeddings( + related_terms, + related_term_embeddings, + ) + async def _add_metadata_knowledge_incremental( self, start_from_message_ordinal: int, diff --git a/src/typeagent/knowpro/interfaces_indexes.py b/src/typeagent/knowpro/interfaces_indexes.py index c872fa7c..02663f27 100644 --- a/src/typeagent/knowpro/interfaces_indexes.py +++ b/src/typeagent/knowpro/interfaces_indexes.py @@ -9,6 +9,7 @@ from pydantic.dataclasses import dataclass +from ..aitools.embeddings import NormalizedEmbedding from .interfaces_core import ( DateRange, IMessage, @@ -131,6 +132,12 @@ async def size(self) -> int: ... async def add_terms(self, texts: list[str]) -> None: ... + async def add_terms_with_embeddings( + self, + texts: list[str], + embeddings: list[NormalizedEmbedding], + ) -> None: ... + async def lookup_term( self, text: str, @@ -214,6 +221,13 @@ async def add_messages_starting_at( messages: list[TMessage], ) -> None: ... + async def add_messages_starting_at_with_embeddings( + self, + start_message_ordinal: int, + messages: list[TMessage], + chunk_embeddings: list[NormalizedEmbedding], + ) -> None: ... + async def lookup_messages( self, message_text: str, diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 9f17574d..ea696ce4 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -10,6 +10,7 @@ from pydantic.dataclasses import dataclass +from ..aitools.embeddings import NormalizedEmbedding from .interfaces_core import ( IMessage, ITermToSemanticRefIndex, @@ -112,6 +113,21 @@ class IMessageCollection[TMessage: IMessage]( ): """A collection of Messages.""" + async def extend( + self, + items: Iterable[TMessage], + chunk_embeddings: list[NormalizedEmbedding] | None = None, + index_messages: bool = True, + ) -> None: + """Append multiple items to the collection. + + Args: + items: Messages to append. + chunk_embeddings: Optional precomputed embeddings for text chunks. + index_messages: If False, skip updating the message text index. + """ + ... + class ISemanticRefCollection(ICollection[SemanticRef, SemanticRefOrdinal], Protocol): """A collection of SemanticRefs.""" diff --git a/src/typeagent/knowpro/textlocindex.py b/src/typeagent/knowpro/textlocindex.py index a6d95969..3c1082e6 100644 --- a/src/typeagent/knowpro/textlocindex.py +++ b/src/typeagent/knowpro/textlocindex.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import Protocol +import numpy as np + from ..aitools.embeddings import NormalizedEmbedding from ..aitools.vectorbase import TextEmbeddingIndexSettings from .fuzzyindex import EmbeddingIndex, ScoredInt @@ -27,6 +29,12 @@ async def add_text_locations( text_and_locations: list[tuple[str, TextLocation]], ) -> None: ... + async def add_text_locations_with_embeddings( + self, + text_locations: list[TextLocation], + embeddings: list[NormalizedEmbedding], + ) -> None: ... + async def lookup_text( self, text: str, @@ -71,6 +79,22 @@ async def add_text_locations( await self._embedding_index.add_texts([text for text, _ in text_and_locations]) self._text_locations.extend([loc for _, loc in text_and_locations]) + async def add_text_locations_with_embeddings( + self, + text_locations: list[TextLocation], + embeddings: list[NormalizedEmbedding], + ) -> None: + if len(text_locations) != len(embeddings): + raise ValueError( + "text_locations and embeddings must have the same length: " + f"{len(text_locations)} != {len(embeddings)}" + ) + if not text_locations: + return + embedding_array = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + self._embedding_index.push(embedding_array) + self._text_locations.extend(text_locations) + async def lookup_text( self, text: str, @@ -112,6 +136,16 @@ async def generate_embedding( ) -> NormalizedEmbedding: return await self._embedding_index.get_embedding(text, cache) + async def generate_embeddings( + self, texts: list[str], cache: bool = True + ) -> list[NormalizedEmbedding]: + if not texts: + return [] + embeddings = await self._embedding_index._vector_base.get_embeddings( + texts, cache=cache + ) + return [embedding for embedding in embeddings] + def lookup_by_embedding( self, text_embedding: NormalizedEmbedding, diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index 8038ea60..e0d96ae1 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -15,7 +15,6 @@ from ..knowpro.interfaces import ConversationDataWithIndexes, SemanticRef, Term from ..knowpro.universal_message import ConversationMessage, ConversationMessageMeta from ..storage.memory.convthreads import ConversationThreads -from ..storage.memory.messageindex import MessageTextIndex # Type aliases for backward compatibility PodcastMessage = ConversationMessage @@ -81,7 +80,9 @@ async def deserialize( self.name_tag = podcast_data["nameTag"] message_list = [PodcastMessage.deserialize(m) for m in podcast_data["messages"]] - await self.messages.extend(message_list) + # Message index data is deserialized later and replaces prior state, + # so skip incremental indexing while bulk-loading messages. + await self.messages.extend(message_list, index_messages=False) semantic_refs_data = podcast_data.get("semanticRefs") if semantic_refs_data is not None: @@ -119,16 +120,9 @@ async def deserialize( message_index_data = podcast_data.get("messageIndexData") if message_index_data is not None: secondary_indexes = self._get_secondary_indexes() - # Assert the message index is empty before deserializing assert ( secondary_indexes.message_index is not None ), "Message index should be initialized" - - if isinstance(secondary_indexes.message_index, MessageTextIndex): - index_size = await secondary_indexes.message_index.size() - assert ( - index_size == 0 - ), "Message index must be empty before deserializing" await secondary_indexes.message_index.deserialize(message_index_data) # Don't rebuild aliases/synonyms since they were deserialized from relatedTermsIndexData diff --git a/src/typeagent/storage/memory/collections.py b/src/typeagent/storage/memory/collections.py index 8a5b14eb..fb0e6b0e 100644 --- a/src/typeagent/storage/memory/collections.py +++ b/src/typeagent/storage/memory/collections.py @@ -5,9 +5,11 @@ from typing import Iterable +from ...aitools.embeddings import NormalizedEmbedding from ...knowpro.interfaces import ( ICollection, IMessage, + IMessageTextIndex, MessageOrdinal, SemanticRef, SemanticRefMetadata, @@ -81,3 +83,41 @@ class MemoryMessageCollection[TMessage: IMessage]( MemoryCollection[TMessage, MessageOrdinal] ): """A collection of messages.""" + + def __init__( + self, + items: list[TMessage] | None = None, + message_text_index: IMessageTextIndex[TMessage] | None = None, + ): + super().__init__(items) + self.message_text_index = message_text_index + + async def append(self, item: TMessage) -> None: + msg_id = await self.size() + self.items.append(item) + if self.message_text_index is not None: + await self.message_text_index.add_messages_starting_at(msg_id, [item]) + + async def extend( + self, + items: Iterable[TMessage], + chunk_embeddings: list[NormalizedEmbedding] | None = None, + index_messages: bool = True, + ) -> None: + items_list = list(items) + if not items_list: + return + current_size = await self.size() + self.items.extend(items_list) + if index_messages and self.message_text_index is not None: + if chunk_embeddings is not None: + await self.message_text_index.add_messages_starting_at_with_embeddings( + current_size, + items_list, + chunk_embeddings, + ) + else: + await self.message_text_index.add_messages_starting_at( + current_size, + items_list, + ) diff --git a/src/typeagent/storage/memory/messageindex.py b/src/typeagent/storage/memory/messageindex.py index efcc4ddf..b8eca68c 100644 --- a/src/typeagent/storage/memory/messageindex.py +++ b/src/typeagent/storage/memory/messageindex.py @@ -75,25 +75,67 @@ async def add_messages[TMessage: IMessage]( messages: Iterable[TMessage], ) -> None: base_message_ordinal: MessageOrdinal = await self.text_location_index.size() - all_chunks: list[tuple[str, TextLocation]] = [] - # Collect everything so we can batch efficiently. - for message_ordinal, message in enumerate(messages, base_message_ordinal): - for chunk_ordinal, chunk in enumerate(message.text_chunks): - all_chunks.append((chunk, TextLocation(message_ordinal, chunk_ordinal))) - await self.text_location_index.add_text_locations(all_chunks) - - async def add_messages_starting_at( + message_list = list(messages) + if not message_list: + return + + chunk_texts: list[str] = [] + for message in message_list: + chunk_texts.extend(message.text_chunks) + + chunk_embeddings = await self.text_location_index.generate_embeddings( + chunk_texts, + cache=False, + ) + await self.add_messages_starting_at_with_embeddings( + base_message_ordinal, + message_list, + chunk_embeddings, + ) + + async def add_messages_starting_at[TMessage: IMessage]( self, start_message_ordinal: int, - messages: list[IMessage], + messages: list[TMessage], ) -> None: """Add messages to the index starting at the given ordinal.""" - all_chunks: list[tuple[str, TextLocation]] = [] + chunk_texts: list[str] = [] + for message in messages: + chunk_texts.extend(message.text_chunks) + + chunk_embeddings = await self.text_location_index.generate_embeddings( + chunk_texts, + cache=False, + ) + await self.add_messages_starting_at_with_embeddings( + start_message_ordinal, + messages, + chunk_embeddings, + ) + + async def add_messages_starting_at_with_embeddings[TMessage: IMessage]( + self, + start_message_ordinal: int, + messages: list[TMessage], + chunk_embeddings: list[NormalizedEmbedding], + ) -> None: + """Add messages starting at an ordinal using precomputed chunk embeddings.""" + text_locations: list[TextLocation] = [] for idx, message in enumerate(messages): msg_ord = start_message_ordinal + idx - for chunk_ord, chunk in enumerate(message.text_chunks): - all_chunks.append((chunk, TextLocation(msg_ord, chunk_ord))) - await self.text_location_index.add_text_locations(all_chunks) + for chunk_ord, _chunk in enumerate(message.text_chunks): + text_locations.append(TextLocation(msg_ord, chunk_ord)) + + if len(text_locations) != len(chunk_embeddings): + raise ValueError( + "messages and chunk_embeddings produced different chunk counts: " + f"{len(text_locations)} != {len(chunk_embeddings)}" + ) + + await self.text_location_index.add_text_locations_with_embeddings( + text_locations, + chunk_embeddings, + ) async def lookup_messages( self, diff --git a/src/typeagent/storage/memory/provider.py b/src/typeagent/storage/memory/provider.py index e697fe01..230b5b0a 100644 --- a/src/typeagent/storage/memory/provider.py +++ b/src/typeagent/storage/memory/provider.py @@ -51,13 +51,15 @@ def __init__( ) -> None: """Create and initialize a MemoryStorageProvider with all indexes.""" self._metadata = metadata or ConversationMetadata() - self._message_collection = MemoryMessageCollection[TMessage]() + self._message_text_index = MessageTextIndex(message_text_settings) + self._message_collection = MemoryMessageCollection[TMessage]( + message_text_index=self._message_text_index + ) self._semantic_ref_collection = MemorySemanticRefCollection() self._conversation_index = TermToSemanticRefIndex() self._property_index = PropertyIndex() self._timestamp_index = TimestampToTextRangeIndex() - self._message_text_index = MessageTextIndex(message_text_settings) self._related_terms_index = RelatedTermsIndex(related_terms_settings) thread_settings = message_text_settings.embedding_index_settings self._conversation_threads = ConversationThreads(thread_settings) diff --git a/src/typeagent/storage/memory/reltermsindex.py b/src/typeagent/storage/memory/reltermsindex.py index d9e682c2..ec074bf9 100644 --- a/src/typeagent/storage/memory/reltermsindex.py +++ b/src/typeagent/storage/memory/reltermsindex.py @@ -4,6 +4,9 @@ from collections.abc import Callable from typing import Protocol, TYPE_CHECKING +import numpy as np + +from typeagent.aitools.embeddings import NormalizedEmbedding from typeagent.aitools.vectorbase import ( ScoredInt, TextEmbeddingIndexSettings, @@ -285,7 +288,25 @@ async def size(self) -> int: return len(self._vectorbase) async def add_terms(self, texts: list[str]) -> None: - await self._vectorbase.add_keys(texts) + if not texts: + return + embeddings = await self._vectorbase.get_embeddings(texts) + await self.add_terms_with_embeddings(texts, list(embeddings)) + + async def add_terms_with_embeddings( + self, + texts: list[str], + embeddings: list[NormalizedEmbedding], + ) -> None: + if len(texts) != len(embeddings): + raise ValueError( + "texts and embeddings must have the same length: " + f"{len(texts)} != {len(embeddings)}" + ) + if not texts: + return + embedding_array = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + self._vectorbase.add_embeddings(texts, embedding_array) self._texts.extend(texts) async def lookup_term( diff --git a/src/typeagent/storage/sqlite/collections.py b/src/typeagent/storage/sqlite/collections.py index fe394dcb..71a256df 100644 --- a/src/typeagent/storage/sqlite/collections.py +++ b/src/typeagent/storage/sqlite/collections.py @@ -7,6 +7,7 @@ import sqlite3 import typing +from ...aitools.embeddings import NormalizedEmbedding from ...knowpro import interfaces, serialization from .schema import ShreddedMessage, ShreddedSemanticRef @@ -186,7 +187,12 @@ async def append(self, item: TMessage) -> None: if self.message_text_index is not None: await self.message_text_index.add_messages_starting_at(msg_id, [item]) - async def extend(self, items: typing.Iterable[TMessage]) -> None: + async def extend( + self, + items: typing.Iterable[TMessage], + chunk_embeddings: list[NormalizedEmbedding] | None = None, + index_messages: bool = True, + ) -> None: items_list = list(items) # Convert to list to iterate twice if not items_list: return @@ -229,10 +235,17 @@ async def extend(self, items: typing.Iterable[TMessage]) -> None: ) # Also add to message text index if available - if self.message_text_index is not None: - await self.message_text_index.add_messages_starting_at( - current_size, items_list - ) + if index_messages and self.message_text_index is not None: + if chunk_embeddings is not None: + # Use precomputed embeddings (avoids redundant embedding work) + await self.message_text_index.add_messages_starting_at_with_embeddings( + current_size, items_list, chunk_embeddings + ) + else: + # Compute embeddings on the fly + await self.message_text_index.add_messages_starting_at( + current_size, items_list + ) class SqliteSemanticRefCollection(interfaces.ISemanticRefCollection): diff --git a/src/typeagent/storage/sqlite/messageindex.py b/src/typeagent/storage/sqlite/messageindex.py index d48a9761..457c1ae5 100644 --- a/src/typeagent/storage/sqlite/messageindex.py +++ b/src/typeagent/storage/sqlite/messageindex.py @@ -58,24 +58,56 @@ async def add_messages_starting_at( messages: list[interfaces.IMessage], ) -> None: """Add messages to the text index starting at the given ordinal.""" - chunks_to_embed: list[tuple[int, int, str]] = [] - for msg_ord, message in enumerate(messages, start_message_ordinal): - for chunk_ord, chunk in enumerate(message.text_chunks): - chunks_to_embed.append((msg_ord, chunk_ord, chunk)) + chunks_to_embed: list[str] = [] + for _msg_ord, message in enumerate(messages, start_message_ordinal): + for _chunk_ord, chunk in enumerate(message.text_chunks): + chunks_to_embed.append(chunk) if not chunks_to_embed: return embeddings = await self._vectorbase.get_embeddings( - [chunk for _, _, chunk in chunks_to_embed], cache=False + chunks_to_embed, + cache=False, + ) + + await self.add_messages_starting_at_with_embeddings( + start_message_ordinal, + messages, + [embedding for embedding in embeddings], ) + async def add_messages_starting_at_with_embeddings( + self, + start_message_ordinal: int, + messages: list[interfaces.IMessage], + chunk_embeddings: list[NormalizedEmbedding], + ) -> None: + """Add messages to the text index using precomputed chunk embeddings.""" + chunk_locations: list[tuple[int, int]] = [] + for msg_ord, message in enumerate(messages, start_message_ordinal): + for chunk_ord, _chunk in enumerate(message.text_chunks): + chunk_locations.append((msg_ord, chunk_ord)) + + if len(chunk_locations) != len(chunk_embeddings): + raise ValueError( + "messages and chunk_embeddings produced different chunk counts: " + f"{len(chunk_locations)} != {len(chunk_embeddings)}" + ) + + if not chunk_locations: + return + + current_size = len(self._vectorbase) + embedding_array = np.stack(chunk_embeddings, axis=0).astype( + np.float32, copy=False + ) + self._vectorbase.add_embeddings(None, embedding_array) + insertion_data: list[tuple[int, int, bytes, int]] = [] - for idx, ((msg_ord, chunk_ord, _), embedding) in enumerate( - zip(chunks_to_embed, embeddings) + for idx, ((msg_ord, chunk_ord), embedding) in enumerate( + zip(chunk_locations, chunk_embeddings) ): - # Get the current VectorBase size to determine the index position - current_size = len(self._vectorbase) index_position = current_size + idx insertion_data.append( (msg_ord, chunk_ord, serialize_embedding(embedding), index_position) diff --git a/src/typeagent/storage/sqlite/reltermsindex.py b/src/typeagent/storage/sqlite/reltermsindex.py index cf5b201b..1a27af71 100644 --- a/src/typeagent/storage/sqlite/reltermsindex.py +++ b/src/typeagent/storage/sqlite/reltermsindex.py @@ -213,15 +213,43 @@ async def add_terms(self, texts: list[str]) -> None: new_terms = [t for t in texts if t not in self._added_terms] if not new_terms: return + embeddings = await self._vector_base.get_embeddings(new_terms) + await self.add_terms_with_embeddings( + new_terms, + list(embeddings), + ) - embeddings = await self._vector_base.add_keys(new_terms) - assert embeddings is not None + async def add_terms_with_embeddings( + self, + texts: list[str], + embeddings: list[NormalizedEmbedding], + ) -> None: + if len(texts) != len(embeddings): + raise ValueError( + "texts and embeddings must have the same length: " + f"{len(texts)} != {len(embeddings)}" + ) + + new_term_embeddings = [ + (term, embedding) + for term, embedding in zip(texts, embeddings) + if term not in self._added_terms + ] + if not new_term_embeddings: + return + + new_terms = [term for term, _ in new_term_embeddings] + new_embeddings = [embedding for _, embedding in new_term_embeddings] + embedding_array = np.stack(new_embeddings, axis=0).astype( + np.float32, copy=False + ) + self._vector_base.add_embeddings(new_terms, embedding_array) cursor = self.db.cursor() cursor.executemany( "INSERT OR REPLACE INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", [ - (term, serialize_embedding(embeddings[i])) + (term, serialize_embedding(new_embeddings[i])) for i, term in enumerate(new_terms) ], ) diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 08c4fdae..a4733cdf 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -14,7 +14,6 @@ from ..knowpro.interfaces import ConversationDataWithIndexes, SemanticRef, Term from ..knowpro.universal_message import ConversationMessage, ConversationMessageMeta from ..storage.memory.convthreads import ConversationThreads -from ..storage.memory.messageindex import MessageTextIndex # Type aliases for backward compatibility TranscriptMessage = ConversationMessage @@ -82,7 +81,9 @@ async def deserialize( message_list = [ TranscriptMessage.deserialize(m) for m in transcript_data["messages"] ] - await self.messages.extend(message_list) + # Message index data is deserialized later and replaces prior state, + # so skip incremental indexing while bulk-loading messages. + await self.messages.extend(message_list, index_messages=False) semantic_refs_data = transcript_data.get("semanticRefs") if semantic_refs_data is not None: @@ -120,16 +121,9 @@ async def deserialize( message_index_data = transcript_data.get("messageIndexData") if message_index_data is not None: secondary_indexes = self._get_secondary_indexes() - # Assert the message index is empty before deserializing assert ( secondary_indexes.message_index is not None ), "Message index should be initialized" - - if isinstance(secondary_indexes.message_index, MessageTextIndex): - index_size = await secondary_indexes.message_index.size() - assert ( - index_size == 0 - ), "Message index must be empty before deserializing" await secondary_indexes.message_index.deserialize(message_index_data) # Don't rebuild aliases/synonyms since they were deserialized from relatedTermsIndexData @@ -180,9 +174,7 @@ def _read_conversation_data_from_file( @staticmethod async def read_from_file( - filename_prefix: str, - settings: ConversationSettings, - dbname: str | None = None, + filename_prefix: str, settings: ConversationSettings, dbname: str | None = None ) -> "Transcript": data = Transcript._read_conversation_data_from_file(filename_prefix) diff --git a/tests/test_add_messages_pipeline.py b/tests/test_add_messages_pipeline.py new file mode 100644 index 00000000..41be8d5c --- /dev/null +++ b/tests/test_add_messages_pipeline.py @@ -0,0 +1,819 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from dataclasses import dataclass, field + +import numpy as np +import pytest + +import typechat + +from typeagent.aitools.embeddings import NormalizedEmbedding, NormalizedEmbeddings +from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.add_messages import ( + _collect_related_terms_for_fuzzy_index, + _dispatcher_task, + _producer_task, + _reassembler_task, + ChunkProcessingResult, + ChunkWorkItem, + PipelineStopState, + process_chunk_with_extraction_and_embeddings, + ProducerState, +) +from typeagent.knowpro.interfaces_core import ( + DeletionInfo, + IMessageMetadata, + TextLocation, +) + + +@dataclass +class _Message: + text_chunks: list[str] + tags: list[str] = field(default_factory=list) + timestamp: str | None = None + deletion_info: DeletionInfo | None = None + metadata: IMessageMetadata | None = None + source_id: str | None = None + + def get_knowledge(self) -> kplib.KnowledgeResponse: + return _empty_knowledge() + + +class _SequenceExtractor: + def __init__( + self, outputs: list[typechat.Result[kplib.KnowledgeResponse] | Exception] + ) -> None: + self._outputs = outputs + self.calls: list[str] = [] + + async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: + self.calls.append(message) + output = self._outputs[len(self.calls) - 1] + if isinstance(output, Exception): + raise output + return output + + +class _StubEmbeddingModel: + def __init__( + self, + *, + chunk_error: Exception | None = None, + related_error: Exception | None = None, + ) -> None: + self.chunk_error = chunk_error + self.related_error = related_error + self.chunk_calls: list[str] = [] + self.related_calls: list[list[str]] = [] + self._cache: dict[str, NormalizedEmbedding] = {} + + @property + def model_name(self) -> str: + return "test-embedding" + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + self._cache[key] = embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + self.chunk_calls.append(input) + if self.chunk_error is not None: + raise self.chunk_error + return _embedding([1.0, 0.0]) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + self.related_calls.append(input) + if self.related_error is not None: + raise self.related_error + return np.array([_embedding([0.0, 1.0]) for _ in input], dtype=np.float32) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self.get_embedding_nocache(key) + self._cache[key] = embedding + return embedding + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + if not keys: + raise ValueError("Cannot embed an empty list") + output: list[NormalizedEmbedding] = [] + missing: list[str] = [] + for key in keys: + cached = self._cache.get(key) + if cached is None: + missing.append(key) + else: + output.append(cached) + if missing: + fresh = await self.get_embeddings_nocache(missing) + for index, key in enumerate(missing): + self._cache[key] = fresh[index] + return np.array([self._cache[key] for key in keys], dtype=np.float32) + + +class _FailingAsyncMessages: + def __init__(self, messages: list[_Message], error: Exception) -> None: + self.messages = messages + self.error = error + + def __aiter__(self) -> AsyncIterator[_Message]: + return self._iter() + + async def _iter(self) -> AsyncIterator[_Message]: + for message in self.messages: + yield message + raise self.error + + +class _StopMutatingChunks(list[str]): + """Iterable that lowers stop_at_message_id after yielding first chunk.""" + + def __init__(self, stop_state: PipelineStopState, chunks: list[str]) -> None: + super().__init__(chunks) + self._stop_state = stop_state + + def __iter__(self): + for index, chunk in enumerate(super().__iter__()): + if index == 1: + self._stop_state.stop_at_message_id = 0 + yield chunk + + +def _embedding(values: list[float]) -> NormalizedEmbedding: + return np.array(values, dtype=np.float32) + + +def _empty_knowledge() -> kplib.KnowledgeResponse: + return kplib.KnowledgeResponse( + entities=[], actions=[], inverse_actions=[], topics=[] + ) + + +def _knowledge_with_terms() -> kplib.KnowledgeResponse: + entity = kplib.ConcreteEntity( + name=" Alice ", + type=["Person"], + facets=[kplib.Facet(name="Role", value="Engineer")], + ) + action = kplib.Action( + verbs=["Mentors"], + verb_tense="present", + subject_entity_name="Alice", + object_entity_name="Bob", + ) + return kplib.KnowledgeResponse( + entities=[entity], + actions=[action], + inverse_actions=[], + topics=["ALICE", " Mentorship "], + ) + + +async def _drain_result_queue( + queue: asyncio.Queue[ChunkProcessingResult[_Message] | None], +) -> list[ChunkProcessingResult[_Message] | None]: + items: list[ChunkProcessingResult[_Message] | None] = [] + while True: + item = await queue.get() + items.append(item) + if item is None: + return items + + +@pytest.mark.asyncio +async def test_collect_related_terms_lowercases_dedupes_and_preserves_order() -> None: + knowledge = _knowledge_with_terms() + + terms = _collect_related_terms_for_fuzzy_index(knowledge) + + assert terms[0] == "alice" + assert "mentorship" in terms + assert len(terms) == len(set(terms)) + + +@pytest.mark.asyncio +async def test_process_chunk_success_with_related_terms() -> None: + extractor = _SequenceExtractor([typechat.Success(_knowledge_with_terms())]) + message_model = _StubEmbeddingModel() + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + embedding_model=message_model, + ) + + assert result.error is None + assert result.extracted_knowledge is not None + assert result.chunk_embedding is not None + assert result.related_terms is not None + assert result.related_term_embeddings is not None + assert len(result.related_terms) == len(result.related_term_embeddings) + assert message_model.chunk_calls == ["hello"] + assert len(message_model.related_calls) == 1 + + +@pytest.mark.asyncio +async def test_process_chunk_extraction_failure_returns_error() -> None: + """A Failure result from the extractor sets error. + + Chunk embedding may still run because extraction and chunk embedding are + launched concurrently. + """ + extractor = _SequenceExtractor([typechat.Failure("bad extraction")]) + message_model = _StubEmbeddingModel() + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "bad extraction" in str(result.error) + assert result.extracted_knowledge is None + assert message_model.chunk_calls == ["hello"] + + +@pytest.mark.asyncio +async def test_process_chunk_extraction_exception_returns_error() -> None: + extractor = _SequenceExtractor([RuntimeError("extract boom")]) + message_model = _StubEmbeddingModel() + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "extract boom" in str(result.error) + + +@pytest.mark.asyncio +async def test_process_chunk_chunk_embedding_exception_returns_error() -> None: + extractor = _SequenceExtractor([typechat.Success(_empty_knowledge())]) + message_model = _StubEmbeddingModel(chunk_error=RuntimeError("embed boom")) + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "embed boom" in str(result.error) + + +@pytest.mark.asyncio +async def test_process_chunk_related_term_embedding_exception_returns_error() -> None: + extractor = _SequenceExtractor([typechat.Success(_knowledge_with_terms())]) + message_model = _StubEmbeddingModel(related_error=RuntimeError("related boom")) + + result = await process_chunk_with_extraction_and_embeddings( + chunk_id=TextLocation(0, 0), + chunk_text="hello", + chunk_count=1, + message=_Message(["hello"]), + knowledge_extractor=extractor, + embedding_model=message_model, + ) + + assert isinstance(result.error, RuntimeError) + assert "related boom" in str(result.error) + + +@pytest.mark.asyncio +async def test_producer_enqueues_all_chunks_and_sentinel() -> None: + messages = [_Message(["a", "b"]), _Message(["c"])] + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=0) + + async def _iter_messages() -> AsyncIterator[_Message]: + for message in messages: + yield message + + result_queue = asyncio.Queue() + await _producer_task( + _iter_messages(), queue, stop_state, producer_state, result_queue, None + ) + + items: list[ChunkWorkItem[_Message] | None] = [ + await queue.get(), + await queue.get(), + await queue.get(), + await queue.get(), + ] + + assert [item.chunk_id for item in items[:-1] if item is not None] == [ + TextLocation(0, 0), + TextLocation(0, 1), + TextLocation(1, 0), + ] + assert items[-1] is None + assert producer_state.produced_messages == 2 + assert producer_state.produced_chunks == 3 + assert producer_state.exception is None + + +@pytest.mark.asyncio +async def test_producer_stops_at_stop_marker() -> None: + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState(stop_at_message_id=1) + producer_state = ProducerState(next_message_id=0) + + async def _iter_messages() -> AsyncIterator[_Message]: + yield _Message(["a"]) + yield _Message(["b"]) + + result_queue = asyncio.Queue() + await _producer_task( + _iter_messages(), queue, stop_state, producer_state, result_queue, None + ) + + first = await queue.get() + sentinel = await queue.get() + + assert first is not None + assert first.chunk_id == TextLocation(0, 0) + assert sentinel is None + assert producer_state.produced_messages == 1 + + +@pytest.mark.asyncio +async def test_producer_sets_exception_and_still_sends_sentinel() -> None: + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=0) + + failing_iter = _FailingAsyncMessages([_Message(["a"])], RuntimeError("input boom")) + + result_queue = asyncio.Queue() + await _producer_task( + failing_iter, queue, stop_state, producer_state, result_queue, None + ) + + first = await queue.get() + sentinel = await queue.get() + + assert first is not None + assert first.chunk_id == TextLocation(0, 0) + assert sentinel is None + assert isinstance(producer_state.exception, RuntimeError) + assert "input boom" in str(producer_state.exception) + + +@pytest.mark.asyncio +async def test_producer_breaks_inside_chunk_loop_when_stop_marker_changes() -> None: + queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + stop_state = PipelineStopState() + producer_state = ProducerState(next_message_id=0) + + message = _Message(["a", "b", "c"]) + message.text_chunks = _StopMutatingChunks(stop_state, ["a", "b", "c"]) + + async def _iter_messages() -> AsyncIterator[_Message]: + yield message + + result_queue = asyncio.Queue() + await _producer_task( + _iter_messages(), queue, stop_state, producer_state, result_queue, None + ) + + first = await queue.get() + sentinel = await queue.get() + + assert first is not None + assert first.chunk_id == TextLocation(0, 0) + assert sentinel is None + assert producer_state.produced_chunks == 1 + + +@pytest.mark.asyncio +async def test_dispatcher_stops_on_sentinel_and_emits_result_sentinel() -> None: + chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + message = _Message(["hello"]) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(0, 0), + chunk_count=1, + chunk_text="hello", + message=message, + ) + ) + await chunk_queue.put(None) + + extractor = _SequenceExtractor([typechat.Success(_empty_knowledge())]) + model = _StubEmbeddingModel() + + await _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + extractor, + model, + concurrency=2, + skip_failed_messages=False, + ) + + items = await _drain_result_queue(result_queue) + assert len(items) == 2 + assert items[-1] is None + assert items[0] is not None + assert items[0].error is None + + +@pytest.mark.asyncio +async def test_dispatcher_extraction_failure_lowers_stop() -> None: + """A Failure from the extractor sets error and lowers stop_at_message_id.""" + chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + m0 = _Message(["first"]) + m1 = _Message(["second"]) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(0, 0), chunk_count=1, chunk_text="first", message=m0 + ) + ) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(1, 0), chunk_count=1, chunk_text="second", message=m1 + ) + ) + await chunk_queue.put(None) + + extractor = _SequenceExtractor( + [typechat.Failure("first failed"), typechat.Success(_empty_knowledge())] + ) + model = _StubEmbeddingModel() + + await _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + extractor, + model, + concurrency=1, + skip_failed_messages=False, + ) + + items = await _drain_result_queue(result_queue) + first = items[0] + second = items[1] + + assert first is not None + assert isinstance(first.error, RuntimeError) + assert "first failed" in str(first.error) + + # Second chunk is skipped because stop_at_message_id was lowered to 0. + assert second is not None + assert second.error is not None + + assert stop_state.stop_at_message_id == 0 + assert extractor.calls == ["first"] + + +@pytest.mark.asyncio +async def test_dispatcher_extraction_failure_skips_and_keeps_processing() -> None: + chunk_queue: asyncio.Queue[ChunkWorkItem[_Message] | None] = asyncio.Queue() + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + m0 = _Message(["first"]) + m1 = _Message(["second"]) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(0, 0), chunk_count=1, chunk_text="first", message=m0 + ) + ) + await chunk_queue.put( + ChunkWorkItem( + chunk_id=TextLocation(1, 0), chunk_count=1, chunk_text="second", message=m1 + ) + ) + await chunk_queue.put(None) + + extractor = _SequenceExtractor( + [typechat.Failure("first failed"), typechat.Success(_empty_knowledge())] + ) + model = _StubEmbeddingModel() + + await _dispatcher_task( + chunk_queue, + result_queue, + stop_state, + extractor, + model, + concurrency=1, + skip_failed_messages=True, + ) + + items = await _drain_result_queue(result_queue) + first = items[0] + second = items[1] + + assert first is not None + assert isinstance(first.error, RuntimeError) + assert "first failed" in str(first.error) + + assert second is not None + assert second.error is None + + assert stop_state.stop_at_message_id == 10**100 + assert extractor.calls == ["first", "second"] + + +def _chunk_result( + message: _Message, + message_ordinal: int, + chunk_ordinal: int, + chunk_count: int, + *, + error: Exception | None = None, +) -> ChunkProcessingResult[_Message]: + return ChunkProcessingResult( + chunk_id=TextLocation(message_ordinal, chunk_ordinal), + chunk_count=chunk_count, + message=message, + error=error, + ) + + +@pytest.mark.asyncio +async def test_reassembler_commits_out_of_order_after_gap_is_filled() -> None: + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + m0 = _Message(["m0"]) + m1 = _Message(["m1"]) + await result_queue.put(_chunk_result(m1, 1, 0, 1)) + await result_queue.put(_chunk_result(m0, 0, 0, 1)) + await result_queue.put(None) + + committed_batches: list[tuple[int, int]] = [] + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + committed_batches.append((len(messages), len(results))) + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=10, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert committed_batches == [(2, 2)] + assert state.messages_committed == 2 + assert state.chunks_committed == 2 + assert state.first_uncommitted_ordinal == 2 + assert state.buffered_messages == 0 + + +@pytest.mark.asyncio +async def test_reassembler_marks_failure_and_blocks_later_commits() -> None: + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + m0 = _Message(["m0"]) + m1 = _Message(["m1"]) + await result_queue.put(_chunk_result(m0, 0, 0, 1, error=RuntimeError("boom"))) + await result_queue.put(_chunk_result(m1, 1, 0, 1)) + await result_queue.put(None) + + commit_calls = 0 + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + nonlocal commit_calls + commit_calls += 1 + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert commit_calls == 0 + assert state.chunk_failures == 1 + assert state.messages_committed == 0 + assert state.buffered_messages == 2 + assert stop_state.stop_at_message_id == 0 + + +@pytest.mark.asyncio +async def test_reassembler_skips_failed_message_and_commits_later_messages() -> None: + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + m0 = _Message(["m0"]) + m1 = _Message(["m1"]) + await result_queue.put(_chunk_result(m0, 0, 0, 1, error=RuntimeError("boom"))) + await result_queue.put(_chunk_result(m1, 1, 0, 1)) + await result_queue.put(None) + + committed_batches: list[tuple[int, int]] = [] + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + committed_batches.append((len(messages), len(results))) + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=True, + ) + + assert committed_batches == [(1, 1)] + assert state.chunk_failures == 1 + assert state.messages_skipped == 1 + assert state.messages_committed == 1 + assert state.chunks_committed == 1 + assert state.buffered_messages == 0 + assert state.first_uncommitted_ordinal == 2 + assert stop_state.stop_at_message_id == 10**100 + + +@pytest.mark.asyncio +async def test_reassembler_force_commits_small_staged_tail() -> None: + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + message = _Message(["m0"]) + await result_queue.put(_chunk_result(message, 0, 0, 1)) + await result_queue.put(None) + + commit_calls = 0 + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + nonlocal commit_calls + commit_calls += 1 + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=99, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert commit_calls == 1 + assert state.messages_committed == 1 + assert state.chunks_committed == 1 + + +@pytest.mark.asyncio +async def test_reassembler_raises_on_invalid_chunk_ordinal_and_sets_stop_marker() -> ( + None +): + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + message = _Message(["m0", "m0b"]) + await result_queue.put(_chunk_result(message, 3, 2, 2)) + await result_queue.put(None) + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + return None + + with pytest.raises(RuntimeError, match="Invalid chunk ordinal"): + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert stop_state.stop_at_message_id == 3 + + +@pytest.mark.asyncio +async def test_reassembler_raises_on_duplicate_chunk_and_sets_stop_marker() -> None: + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + message = _Message(["m1-a", "m1-b"]) + await result_queue.put(_chunk_result(message, 5, 0, 2)) + await result_queue.put(_chunk_result(message, 5, 0, 2)) + await result_queue.put(None) + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + return None + + with pytest.raises(RuntimeError, match="Duplicate chunk"): + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert stop_state.stop_at_message_id == 5 + + +@pytest.mark.asyncio +async def test_reassembler_raises_on_mismatched_chunk_count_and_sets_stop_marker() -> ( + None +): + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + message = _Message(["m0-a", "m0-b", "m0-c"]) + await result_queue.put(_chunk_result(message, 4, 0, 2)) + await result_queue.put(_chunk_result(message, 4, 1, 3)) + await result_queue.put(None) + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + return None + + with pytest.raises(RuntimeError, match="Mismatched chunk count"): + await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert stop_state.stop_at_message_id == 4 + + +@pytest.mark.asyncio +async def test_reassembler_handles_existing_assembly_non_duplicate_chunk() -> None: + result_queue = asyncio.Queue() + stop_state = PipelineStopState() + + message = _Message(["m0-a", "m0-b"]) + await result_queue.put(_chunk_result(message, 0, 0, 2)) + await result_queue.put(_chunk_result(message, 0, 1, 2)) + await result_queue.put(None) + + commit_calls = 0 + + async def _commit( + messages: list[_Message], results: list[ChunkProcessingResult[_Message]] + ) -> None: + nonlocal commit_calls + commit_calls += 1 + + state = await _reassembler_task( + result_queue, + stop_state, + first_uncommitted_ordinal=0, + target_commit_chunk_count=1, + commit_batch=_commit, + skip_failed_messages=False, + ) + + assert commit_calls == 1 + assert state.messages_committed == 1 + assert state.chunks_committed == 2 diff --git a/tests/test_add_messages_streaming.py b/tests/test_add_messages_streaming.py index 9a8b1fd6..cd1f15e2 100644 --- a/tests/test_add_messages_streaming.py +++ b/tests/test_add_messages_streaming.py @@ -3,7 +3,6 @@ """Tests for add_messages_streaming.""" -import asyncio from collections.abc import AsyncIterator import os import tempfile @@ -14,8 +13,9 @@ from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro import knowledge_schema as kplib +from typeagent.knowpro.add_messages import add_messages_streaming from typeagent.knowpro.convsettings import ConversationSettings -from typeagent.knowpro.interfaces_core import AddMessagesResult, IKnowledgeExtractor +from typeagent.knowpro.interfaces_core import IKnowledgeExtractor from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( Transcript, @@ -132,7 +132,7 @@ async def test_streaming_basic() -> None: transcript, storage = await _create_transcript(db_path) msgs = [_make_message(f"msg-{i}") for i in range(5)] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 5 assert await transcript.messages.size() == 5 @@ -148,8 +148,8 @@ async def test_streaming_batching() -> None: transcript, storage = await _create_transcript(db_path) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] - result = await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3 ) # 3 batches: [0,1,2], [3,4,5], [6] @@ -169,7 +169,7 @@ async def test_streaming_no_source_id_always_ingested() -> None: transcript, storage = await _create_transcript(db_path) msgs = [_make_message(f"msg-{i}") for i in range(3)] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 3 assert _ingested_count(storage) == 0 # no source IDs to track @@ -178,11 +178,11 @@ async def test_streaming_no_source_id_always_ingested() -> None: @pytest.mark.asyncio -async def test_streaming_records_chunk_failures() -> None: - """Extraction Failure results are recorded, not raised.""" +async def test_streaming_extraction_failure_stops_at_failing_message() -> None: + """Extraction Failure raises and stops processing; messages before the failure are committed.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - extractor = ControlledExtractor(fail_on={1}) # second chunk fails + extractor = ControlledExtractor(fail_on={1}) # second message fails transcript, storage = await _create_transcript( db_path, auto_extract=True, knowledge_extractor=extractor ) @@ -192,16 +192,10 @@ async def test_streaming_records_chunk_failures() -> None: _make_message("bad chunk 1"), _make_message("good chunk 2"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + with pytest.raises(RuntimeError): + await add_messages_streaming(transcript, _async_iter(msgs)) - assert result.messages_added == 3 - assert _failure_count(storage) == 1 - - failures = await storage.get_chunk_failures() - assert len(failures) == 1 - assert failures[0].message_ordinal == 1 - assert failures[0].chunk_ordinal == 0 - assert "Extraction failed" in failures[0].error_message + assert await transcript.messages.size() == 1 await storage.close() @@ -219,14 +213,8 @@ async def test_streaming_exception_stops_run() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - with pytest.raises(ExceptionGroup) as exc_info: - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) - - # Verify the wrapped exception is our RuntimeError - assert any( - isinstance(e, RuntimeError) and "Systemic failure" in str(e) - for e in exc_info.value.exceptions - ) + with pytest.raises(RuntimeError, match="Systemic failure"): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) # First batch (3 messages, 3 extract calls 0-2) committed assert await transcript.messages.size() == 3 @@ -242,7 +230,7 @@ async def test_streaming_empty_iterable() -> None: db_path = os.path.join(tmpdir, "test.db") transcript, storage = await _create_transcript(db_path) - result = await transcript.add_messages_streaming(_async_iter([])) + result = await add_messages_streaming(transcript, _async_iter([])) assert result.messages_added == 0 assert result.semrefs_added == 0 @@ -265,7 +253,8 @@ async def test_streaming_on_batch_committed_fires_per_batch() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(7)] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -289,8 +278,8 @@ async def test_streaming_extraction_with_multiple_batches() -> None: ) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - result = await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3 ) assert result.messages_added == 6 @@ -304,27 +293,21 @@ async def test_streaming_extraction_with_multiple_batches() -> None: @pytest.mark.asyncio async def test_streaming_extraction_failure_across_batches() -> None: - """Extraction failures are recorded with correct global ordinals across batches.""" + """Extraction failure in a later batch leaves earlier batches committed.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - # Fail on call index 1 (batch 0, msg 1) and 4 (batch 1, msg 1) - extractor = ControlledExtractor(fail_on={1, 4}) + # Fail on call index 3 (first message of second batch) + extractor = ControlledExtractor(fail_on={3}) transcript, storage = await _create_transcript( db_path, auto_extract=True, knowledge_extractor=extractor ) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - result = await transcript.add_messages_streaming( - _async_iter(msgs), batch_size=3 - ) - - assert result.messages_added == 6 - assert _failure_count(storage) == 2 + with pytest.raises(RuntimeError): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) - failures = await storage.get_chunk_failures() - failure_ordinals = sorted(f.message_ordinal for f in failures) - # msg 1 in batch 0 → global ordinal 1, msg 1 in batch 1 → global ordinal 4 - assert failure_ordinals == [1, 4] + # First batch (messages 0-2) committed; second batch stopped at message 3. + assert await transcript.messages.size() == 3 await storage.close() @@ -341,13 +324,8 @@ async def test_streaming_exception_in_later_batch_preserves_earlier() -> None: ) msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - with pytest.raises(ExceptionGroup) as exc_info: - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) - - assert any( - isinstance(e, RuntimeError) and "Systemic failure" in str(e) - for e in exc_info.value.exceptions - ) + with pytest.raises(RuntimeError, match="Systemic failure"): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) # Batch 0 committed (3 messages), batch 1 rolled back assert await transcript.messages.size() == 3 @@ -429,7 +407,7 @@ async def test_streaming_extraction_with_empty_text_chunks() -> None: ), _make_message("has content", source_id="has-content"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 2 # Only the message with content triggers extraction @@ -470,7 +448,7 @@ async def test_streaming_multi_chunk_extraction() -> None: _make_multi_chunk_message(["c0", "c1", "c2"], source_id="s-0"), _make_message("single chunk", source_id="s-1"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 2 assert result.chunks_added == 4 # 3 + 1 @@ -492,7 +470,8 @@ async def test_streaming_batch_size_counts_chunks() -> None: _make_message("d", source_id="s-1"), # 1 chunk ] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -519,7 +498,8 @@ async def test_streaming_large_message_exceeds_batch_size() -> None: _make_message("small", source_id="s-small"), ] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -549,7 +529,8 @@ async def test_streaming_mixed_chunk_sizes_batching() -> None: _make_message("e", source_id="s-4"), # 1 chunk, total=3 → flush ] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=3, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -564,32 +545,25 @@ async def test_streaming_mixed_chunk_sizes_batching() -> None: @pytest.mark.asyncio -async def test_streaming_multi_chunk_failure_ordinals() -> None: - """Extraction failures in multi-chunk messages record correct ordinals.""" +async def test_streaming_multi_chunk_failure_stops_message() -> None: + """Extraction failure in a chunk stops that message and all later ones.""" with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - # Fail on call index 1 (chunk 1 of first message) and 3 (chunk 0 of second message) - extractor = ControlledExtractor(fail_on={1, 3}) + # Fail on call index 1 (chunk 1 of first message) + extractor = ControlledExtractor(fail_on={1}) transcript, storage = await _create_transcript( db_path, auto_extract=True, knowledge_extractor=extractor ) msgs = [ - _make_multi_chunk_message( - ["c0", "c1", "c2"], source_id="s-0" - ), # calls 0,1,2 - _make_multi_chunk_message(["d0", "d1"], source_id="s-1"), # calls 3,4 + _make_multi_chunk_message(["c0", "c1", "c2"], source_id="s-0"), + _make_multi_chunk_message(["d0", "d1"], source_id="s-1"), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + with pytest.raises(RuntimeError): + await add_messages_streaming(transcript, _async_iter(msgs)) - assert result.messages_added == 2 - assert extractor.call_count == 5 - assert _failure_count(storage) == 2 - - failures = await storage.get_chunk_failures() - failure_locs = sorted((f.message_ordinal, f.chunk_ordinal) for f in failures) - # call 1 → msg 0, chunk 1; call 3 → msg 1, chunk 0 - assert failure_locs == [(0, 1), (1, 0)] + # Message 0 had a chunk failure so nothing is committed. + assert await transcript.messages.size() == 0 await storage.close() @@ -610,8 +584,8 @@ async def test_streaming_multi_chunk_exception_preserves_earlier_batch() -> None _make_multi_chunk_message(["d", "e"], source_id="s-1"), # batch 2 ] - with pytest.raises(ExceptionGroup): - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) + with pytest.raises(RuntimeError, match="Systemic failure"): + await add_messages_streaming(transcript, _async_iter(msgs), batch_size=3) # Batch 1 committed (1 message, 3 chunks), batch 2 rolled back assert await transcript.messages.size() == 1 @@ -629,7 +603,8 @@ async def test_streaming_batch_size_1_separates_all() -> None: msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(4)] batch_results: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=1, on_batch_committed=lambda r: batch_results.append(r.messages_added), @@ -660,7 +635,8 @@ async def test_streaming_preflush_avoids_oversized_batch() -> None: for i in range(4) ] batch_chunks: list[int] = [] - result = await transcript.add_messages_streaming( + result = await add_messages_streaming( + transcript, _async_iter(msgs), batch_size=10, on_batch_committed=lambda r: batch_chunks.append(r.chunks_added), @@ -674,116 +650,6 @@ async def test_streaming_preflush_avoids_oversized_batch() -> None: await storage.close() -# --------------------------------------------------------------------------- -# Coverage gap tests -# --------------------------------------------------------------------------- - - -class SlowExtractor: - """Extractor that blocks on an event, allowing tests to control timing.""" - - def __init__(self, block_from: int) -> None: - self.call_count = 0 - self.block_from = block_from - self.blocked = asyncio.Event() - self.cancelled = False - - async def extract(self, message: str) -> typechat.Result[kplib.KnowledgeResponse]: - idx = self.call_count - self.call_count += 1 - if idx >= self.block_from: - self.blocked.set() - try: - await asyncio.sleep(60) - except asyncio.CancelledError: - self.cancelled = True - raise - return typechat.Success(_EMPTY_RESPONSE) - - -@pytest.mark.asyncio -async def test_streaming_pending_extraction_cancelled_on_commit_failure() -> None: - """pending_extraction is cancelled when a prior commit raises during _drain_commit. - - Timeline: - 1. Batch 0: extraction succeeds (calls 0-2, fast), commit task created - (pending_commit = failing_commit) - 2. Batch 1: extraction task created (pending_extraction, calls 3+, slow), - _drain_commit awaits batch 0's pending_commit which raises - 3. except block: pending_extraction (batch 1's) is still in-flight → cancelled - """ - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - # Block extraction starting from call 3 (first call of batch 1) - # so that pending_extraction is still running when the except fires - extractor = SlowExtractor(block_from=3) - transcript, storage = await _create_transcript( - db_path, auto_extract=True, knowledge_extractor=extractor - ) - - async def failing_commit(*args, **kwargs): - raise RuntimeError("Simulated commit failure") - - transcript._commit_batch_streaming = failing_commit # type: ignore[assignment] - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - - with pytest.raises(RuntimeError, match="Simulated commit failure"): - await transcript.add_messages_streaming(_async_iter(msgs), batch_size=3) - - assert extractor.cancelled - - await storage.close() - - -@pytest.mark.asyncio -async def test_streaming_pending_commit_cancelled_on_iterator_error() -> None: - """pending_commit is cancelled when the message iterator raises. - - After batch 0 is submitted (pending_commit in flight), the async iterator - raises on the next message. The except block must cancel the still-running - pending_commit. - """ - - async def _error_after( - items: list[TranscriptMessage], error_after: int - ) -> AsyncIterator[TranscriptMessage]: - for i, item in enumerate(items): - if i == error_after: - # Yield to event loop so pending tasks start running - await asyncio.sleep(0) - raise ValueError("Iterator error") - yield item - - with tempfile.TemporaryDirectory() as tmpdir: - db_path = os.path.join(tmpdir, "test.db") - transcript, storage = await _create_transcript(db_path) - - commit_cancelled = False - - async def slow_commit(*args, **kwargs): - nonlocal commit_cancelled - try: - await asyncio.sleep(60) - except asyncio.CancelledError: - commit_cancelled = True - raise - return AddMessagesResult() - - transcript._commit_batch_streaming = slow_commit # type: ignore[assignment] - - msgs = [_make_message(f"msg-{i}", source_id=f"s-{i}") for i in range(6)] - - with pytest.raises(ValueError, match="Iterator error"): - await transcript.add_messages_streaming( - _error_after(msgs, error_after=4), batch_size=3 - ) - - assert commit_cancelled - - await storage.close() - - @pytest.mark.asyncio async def test_streaming_empty_iterator() -> None: """Streaming with an empty iterator returns zeros.""" @@ -793,11 +659,11 @@ async def test_streaming_empty_iterator() -> None: # Ingest one real message, then do a second call with an empty iterator msgs = [_make_message("msg-0", source_id="s-0")] - r1 = await transcript.add_messages_streaming(_async_iter(msgs)) + r1 = await add_messages_streaming(transcript, _async_iter(msgs)) assert r1.messages_added == 1 # Empty iterator → _submit_batch never called with content - r2 = await transcript.add_messages_streaming(_async_iter([])) + r2 = await add_messages_streaming(transcript, _async_iter([])) assert r2.messages_added == 0 assert r2.messages_skipped == 0 @@ -832,7 +698,7 @@ async def test_streaming_extraction_returns_none_for_empty_chunks() -> None: source_id="empty-1", ), ] - result = await transcript.add_messages_streaming(_async_iter(msgs)) + result = await add_messages_streaming(transcript, _async_iter(msgs)) assert result.messages_added == 2 assert result.chunks_added == 0 diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index 7e00cc45..54eea485 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -30,7 +30,16 @@ def mock_text_location_index() -> MagicMock: mock_index = MagicMock(spec=TextToTextLocationIndex) # Empty index, so first message starts at ordinal 0 mock_index.size = AsyncMock(return_value=0) + + def _fake_generate_embeddings( + texts: list[str], cache: bool = True + ) -> list[np.ndarray]: + del cache + return [np.array([1.0, 0.0, 0.0], dtype=np.float32) for _ in texts] + + mock_index.generate_embeddings = AsyncMock(side_effect=_fake_generate_embeddings) mock_index.add_text_locations = AsyncMock(return_value=None) + mock_index.add_text_locations_with_embeddings = AsyncMock(return_value=None) mock_index.lookup_text = AsyncMock(return_value=[]) mock_index.lookup_text_in_subset = AsyncMock(return_value=[]) mock_index.serialize = MagicMock(return_value={"mock": "data"}) @@ -79,25 +88,20 @@ async def test_add_messages( await message_text_index.add_messages(messages) - # Check that add_text_locations was called with the expected text and location data + # Check that add_text_locations_with_embeddings was called with expected locations + # and one embedding per chunk. mock_text_loc_index = cast( MagicMock, cast(MessageTextIndex, message_text_index).text_location_index ) - call_args = mock_text_loc_index.add_text_locations.call_args + call_args = mock_text_loc_index.add_text_locations_with_embeddings.call_args assert call_args is not None - text_and_locations = call_args[0][0] # First positional argument - assert ( - len(text_and_locations) == 3 - ) # Two chunks from first message, one from second - assert text_and_locations[0] == ( - "chunk1", - TextLocation(0, 0), - ) # First message starts at ordinal 0 - assert text_and_locations[1] == ("chunk2", TextLocation(0, 1)) - assert text_and_locations[2] == ( - "chunk3", - TextLocation(1, 0), - ) # Second message at ordinal 1 + text_locations = call_args[0][0] # First positional argument + embeddings = call_args[0][1] # Second positional argument + assert len(text_locations) == 3 # Two chunks from first message, one from second + assert text_locations[0] == TextLocation(0, 0) # First message starts at ordinal 0 + assert text_locations[1] == TextLocation(0, 1) + assert text_locations[2] == TextLocation(1, 0) # Second message at ordinal 1 + assert len(embeddings) == 3 @pytest.mark.asyncio diff --git a/tools/ingest_email.py b/tools/ingest_email.py index 34c4c6ce..0baa407f 100644 --- a/tools/ingest_email.py +++ b/tools/ingest_email.py @@ -21,7 +21,9 @@ import asyncio from collections.abc import AsyncIterator from datetime import datetime +import os from pathlib import Path +import signal import sys import time from typing import Iterable @@ -273,7 +275,9 @@ def _iter_emails( yield str(email_file.resolve()), email_file, label -def _print_email_verbose(email: EmailMessage) -> None: +def _print_email_verbose( + email: EmailMessage, original_chunk_count: int | None = None +) -> None: """Print verbose details for an email.""" print(f" From: {decode_encoded_words(email.metadata.sender)}") if email.metadata.recipients: @@ -289,7 +293,14 @@ def _print_email_verbose(email: EmailMessage) -> None: f" Subject: {decode_encoded_words(email.metadata.subject).replace('\n', '\\n')}" ) print(f" Date: {email.timestamp}") - print(f" Body chunks: {len(email.text_chunks)}") + if original_chunk_count is not None and original_chunk_count != len( + email.text_chunks + ): + print( + f" Body chunks: {len(email.text_chunks)} (clipped from {original_chunk_count})" + ) + else: + print(f" Body chunks: {len(email.text_chunks)}") MAIL_PREVIEW_LEN = 80 for chunk in email.text_chunks: preview = repr(chunk[: MAIL_PREVIEW_LEN + 1])[1:-1] @@ -341,15 +352,14 @@ async def _email_generator( print(f"{label} [Outside date range, skipping]") continue - if verbose: - print(label) - _print_email_verbose(email) - # Truncate chunks if --max-chunks is set + original_chunk_count: int | None = None if max_chunks is not None and len(email.text_chunks) > max_chunks: - if verbose: - print(f" Truncating {len(email.text_chunks)} chunks to {max_chunks}") + original_chunk_count = len(email.text_chunks) email.text_chunks = email.text_chunks[:max_chunks] + if verbose: + print(label) + _print_email_verbose(email, original_chunk_count) # Set source_id so streaming API handles dedup and tracking email.source_id = source_id @@ -452,10 +462,10 @@ def on_batch_committed(result: AddMessagesResult) -> None: f"+{result.semrefs_added} semrefs", ] print( - f"{' '.join(parts)} | " + f"\n{' '.join(parts)} | " f"{batch_secs:.1f}s ({per_chunk:.2f}s/chunk) | " f"{counters['ingested']} total ingested | " - f"{elapsed:.1f}s elapsed", + f"{elapsed:.1f}s elapsed\n", flush=True, ) @@ -471,14 +481,41 @@ def on_batch_committed(result: AddMessagesResult) -> None: result: AddMessagesResult | None = None interrupted = False + shutdown_event = asyncio.Event() + loop = asyncio.get_running_loop() + _sigint_count = 0 + _main_task = asyncio.current_task() + + def _on_sigint() -> None: + nonlocal _sigint_count + _sigint_count += 1 + if _sigint_count == 1: + print( + "\nInterrupt received; stopping after current batch completes " + "(press ^C again to force quit)...", + flush=True, + ) + shutdown_event.set() + else: + print("\nForce quit.", flush=True) + # Bypass cooperative cancellation on repeated Ctrl+C. + # Some pending async operations may not respond promptly and can hang. + os._exit(130) + + loop.add_signal_handler(signal.SIGINT, _on_sigint) try: result = await email_memory.add_messages_streaming( message_stream, batch_size=batch_size, on_batch_committed=on_batch_committed, + skip_failed_messages=True, + shutdown_event=shutdown_event, ) + interrupted = shutdown_event.is_set() except (KeyboardInterrupt, asyncio.CancelledError): interrupted = True + finally: + loop.remove_signal_handler(signal.SIGINT) # Final summary elapsed = time.time() - start_time @@ -540,20 +577,28 @@ def main() -> None: start_date = _parse_date(args.start_date) if args.start_date else None stop_date = _parse_date(args.stop_date) if args.stop_date else None - asyncio.run( - ingest_emails( - eml_paths=args.paths, - database=args.database, - verbose=args.verbose, - start_date=start_date, - stop_date=stop_date, - offset=args.offset, - limit=args.limit, - concurrency=args.concurrency, - batch_size=args.batch_size, - max_chunks=args.max_chunks, + try: + asyncio.run( + ingest_emails( + eml_paths=args.paths, + database=args.database, + verbose=args.verbose, + start_date=start_date, + stop_date=stop_date, + offset=args.offset, + limit=args.limit, + concurrency=args.concurrency, + batch_size=args.batch_size, + max_chunks=args.max_chunks, + ) ) - ) + except KeyboardInterrupt: + pass + except Exception as exc: + if args.verbose: + raise + print(f"Error: {exc}", file=sys.stderr) + sys.exit(1) if __name__ == "__main__":