From 8f60d1c95f21f2f4584007dd3b34824d2b81eafb Mon Sep 17 00:00:00 2001 From: jschweiz Date: Fri, 17 Apr 2026 15:40:53 +0200 Subject: [PATCH 1/4] [feat]: add checkpointing APIs --- docs/wayflowcore/source/core/changelog.rst | 11 +- .../core/code_examples/howto_checkpointing.py | 69 +++++++ .../core/howtoguides/howto_checkpointing.rst | 119 +++++++++++ .../source/core/howtoguides/index.rst | 1 + .../agentserver/_storagehelpers.py | 94 ++------- .../services/wayflowservice.py | 189 ++++++++---------- .../agentserver/serverstorageconfig.py | 47 +---- .../contextproviders/flowcontextprovider.py | 1 + .../wayflowcore/conversationalcomponent.py | 104 ++++++++++ .../wayflowcore/executors/_agentexecutor.py | 17 +- .../wayflowcore/executors/_flowexecutor.py | 2 +- .../executors/_managerworkersconversation.py | 3 +- .../tokenlimitexecutioninterrupt.py | 2 +- .../src/wayflowcore/models/llmmodel.py | 4 +- .../src/wayflowcore/serialization/context.py | 29 ++- .../wayflowcore/steps/agentexecutionstep.py | 4 +- .../src/wayflowcore/tools/servertools.py | 1 + wayflowcore/tests/test_managerworkers.py | 4 +- wayflowcore/tests/test_swarm.py | 4 +- 19 files changed, 449 insertions(+), 256 deletions(-) create mode 100644 docs/wayflowcore/source/core/code_examples/howto_checkpointing.py create mode 100644 docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 5e62fc465..c14515946 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -39,14 +39,23 @@ New features For more information read the :doc:`API Reference on LLM models ` and the guide on :doc:`how to use LLMs from different providers `. - * **Logprob support in `LlmGenerationConfig` and `PromptExecutionStep`** Add per-token log-probabilities support with the ``top_logprobs`` generation config parameter and support returning per-token log-probabilities in the ``PromptExecutionStep``. For more information please read the guide on :ref:`How to request per-token log-probabilities ` +* **First-class conversation checkpointing** + + Added shared conversation checkpointing for Agents, Flows, Swarms, ManagerWorkers, and A2A agents through + ``ConversationCheckpoint``, ``Checkpointer``, ``InMemoryCheckpointer``, ``PostgresCheckpointer``, and + ``OracleDatabaseCheckpointer``. Conversations can now resume from ``conversation_id``, load specific checkpoints for + time-travel debugging, and choose checkpoint save frequency with ``CheckpointingInterval``. + + The OpenAI Responses server path now uses this shared checkpointing subsystem as well, so persisted + ``previous_response_id`` and ``conversation`` behavior is handled through the same checkpoint model. + For more information, see :doc:`how to checkpoint and resume conversations `. Improvements ^^^^^^^^^^^^ diff --git a/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py b/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py new file mode 100644 index 000000000..bcfc88cc4 --- /dev/null +++ b/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py @@ -0,0 +1,69 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +# isort:skip_file +# fmt: off +# mypy: ignore-errors +# docs-title: Code Example - How to Checkpoint and Resume Conversations + +# .. start-##_Configure_your_LLM +from wayflowcore.models import VllmModel + +llm = VllmModel( + model_id="LLAMA_MODEL_ID", + host_port="LLAMA_API_URL", +) +# .. end-##_Configure_your_LLM + +llm: VllmModel # docs-skiprow +(llm,) = _update_globals(["llm_small"]) # docs-skiprow # type: ignore + +# .. start-##_Start_a_checkpointed_conversation +from wayflowcore import Agent +from wayflowcore.checkpointing import InMemoryCheckpointer + +agent = Agent(llm=llm) +checkpointer = InMemoryCheckpointer() + +conversation = agent.start_conversation( + conversation_id="support-thread-1", + checkpointer=checkpointer, +) + +status = conversation.execute() +# .. end-##_Start_a_checkpointed_conversation + +# .. start-##_Resume_the_latest_checkpoint +restored_conversation = agent.start_conversation( + conversation_id="support-thread-1", + checkpointer=checkpointer, +) + +restored_conversation.append_user_message("Continue from where you left off.") +status = restored_conversation.execute() +# .. end-##_Resume_the_latest_checkpoint + +# .. start-##_Load_a_specific_checkpoint +checkpoints = checkpointer.list_checkpoints("support-thread-1") + +previous_checkpoint = checkpoints[-2] +rewound_conversation = agent.start_conversation( + conversation_id="support-thread-1", + checkpoint_id=previous_checkpoint.checkpoint_id, + checkpointer=checkpointer, +) + +rewound_conversation.append_user_message("Try a different path from here.") +status = rewound_conversation.execute() +# .. end-##_Load_a_specific_checkpoint + +# .. start-##_Control_checkpoint_frequency +from wayflowcore.checkpointing import CheckpointingInterval, InMemoryCheckpointer + +checkpointer = InMemoryCheckpointer( + checkpointing_interval=CheckpointingInterval.ALL_INTERNAL_TURNS, +) +# .. end-##_Control_checkpoint_frequency diff --git a/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst b/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst new file mode 100644 index 000000000..e22d9a64e --- /dev/null +++ b/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst @@ -0,0 +1,119 @@ +.. _top-howtocheckpointing: + +========================================= +How to Checkpoint and Resume Conversations +========================================= + +.. admonition:: Prerequisites + + This guide assumes familiarity with: + + - :doc:`Agents <../tutorials/basic_agent>` + - :doc:`Flows <../tutorials/basic_flow>` + - :doc:`Serve Agents with WayFlow ` + +WayFlow can now checkpoint the runtime state of a conversation and restore it later by +conversation id. This is useful when you want to: + +- resume after a crash or restart +- pause and continue a long-running workflow +- inspect prior checkpoints for debugging +- reload an earlier state and branch from it + + +Choose a checkpointer +===================== + +WayFlow exposes a shared checkpointing subsystem in ``wayflowcore.checkpointing``. +You can use: + +- ``InMemoryCheckpointer`` for tests and local experimentation +- ``PostgresCheckpointer`` for PostgreSQL-backed persistence +- ``OracleDatabaseCheckpointer`` for Oracle-backed persistence + +All checkpointers share the same API for saving, loading, listing, and deleting checkpoints. + + +Start a checkpointed conversation +================================= + +Attach a checkpointer when you start the conversation. ``conversation_id`` becomes the durable key +used to look up the conversation later. + +.. literalinclude:: ../code_examples/howto_checkpointing.py + :language: python + :start-after: .. start-##_Start_a_checkpointed_conversation + :end-before: .. end-##_Start_a_checkpointed_conversation + +Once checkpointing is enabled, WayFlow saves the root conversation automatically at the configured +checkpoint boundaries. For nested execution lineage without checkpoint restore, pass +``root_conversation_id`` explicitly. + + +Resume the latest checkpoint +============================ + +To restore the latest saved state, call ``start_conversation()`` again with the same +``conversation_id`` +and checkpointer. + +.. literalinclude:: ../code_examples/howto_checkpointing.py + :language: python + :start-after: .. start-##_Resume_the_latest_checkpoint + :end-before: .. end-##_Resume_the_latest_checkpoint + +If no checkpoint exists for that id, WayFlow creates a new conversation instead. + + +Load a specific checkpoint +========================== + +You can inspect checkpoint history and reload an older checkpoint for replay or time-travel +debugging. + +.. literalinclude:: ../code_examples/howto_checkpointing.py + :language: python + :start-after: .. start-##_Load_a_specific_checkpoint + :end-before: .. end-##_Load_a_specific_checkpoint + +``list_checkpoints()`` returns ordered checkpoint metadata, including the checkpoint id, +creation timestamp, and save metadata recorded at the boundary. + + +Control checkpoint frequency +============================ + +Use ``CheckpointingInterval`` to decide how often WayFlow should persist state. + +.. literalinclude:: ../code_examples/howto_checkpointing.py + :language: python + :start-after: .. start-##_Control_checkpoint_frequency + :end-before: .. end-##_Control_checkpoint_frequency + +The available options are: + +- ``CONVERSATION_TURNS``: save after the outermost ``conversation.execute()`` call returns +- ``LLM_TURNS``: also save at internal turn boundaries after turns that used an LLM +- ``ALL_INTERNAL_TURNS``: also save at every internal agent/flow turn boundary + +Saving more frequently improves restart fidelity, but it also increases write volume. + + +Use checkpointing with the OpenAI Responses server +================================================== + +The OpenAI Responses server path now uses the shared checkpointing subsystem behind +``ServerStorageConfig``. That means the existing OpenAI-compatible features such as +``previous_response_id``, ``conversation``, ``get_response()``, ``delete_response()``, +and ``store=False`` all run through the same shared checkpoint model. + +If you are serving agents, keep using :doc:`Serve Agents with WayFlow ` to +configure the storage backend. The server will use the matching shared checkpointer internally. + + +Next steps +========== + +- :doc:`Serialize and Deserialize Conversations ` +- :doc:`Serve Agents with WayFlow ` +- :doc:`Build a Swarm of Agents ` diff --git a/docs/wayflowcore/source/core/howtoguides/index.rst b/docs/wayflowcore/source/core/howtoguides/index.rst index 234da7699..d8449b032 100644 --- a/docs/wayflowcore/source/core/howtoguides/index.rst +++ b/docs/wayflowcore/source/core/howtoguides/index.rst @@ -105,6 +105,7 @@ These guides demonstrate how to configure the components of assistants built wit :maxdepth: 1 Load and Execute an Agent Spec Configuration + Checkpoint and Resume Conversations Serialize and Deserialize Flows and Agents Serialize and Deserialize Conversations Build a New WayFlow Component diff --git a/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py b/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py index 58b12d97d..6a09bad18 100644 --- a/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py +++ b/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py @@ -4,19 +4,18 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import logging -from textwrap import dedent -from typing import Dict, Optional, cast +from typing import Dict, Optional from wayflowcore.agentserver.serverstorageconfig import ServerStorageConfig +from wayflowcore.checkpointing.datastore import ( + _prepare_oracle_checkpoint_datastore, + _prepare_postgres_checkpoint_datastore, +) +from wayflowcore.checkpointing.serialization import _deserialize_conversation_checkpoint_state from wayflowcore.component import Component from wayflowcore.conversation import Conversation -from wayflowcore.datastore.oracle import OracleDatabaseConnectionConfig, _execute_query_on_oracle_db -from wayflowcore.datastore.postgres import ( - PostgresDatabaseConnectionConfig, - _execute_query_on_postgres_db, -) -from wayflowcore.serialization import autodeserialize -from wayflowcore.serialization.context import DeserializationContext +from wayflowcore.datastore.oracle import OracleDatabaseConnectionConfig +from wayflowcore.datastore.postgres import PostgresDatabaseConnectionConfig from wayflowcore.tools import Tool logger = logging.getLogger(__name__) @@ -25,53 +24,13 @@ def _prepare_postgres_datastore( connection_config: PostgresDatabaseConnectionConfig, storage_config: ServerStorageConfig ) -> None: - from sqlalchemy.exc import ProgrammingError - - create_table_query = dedent(f""" - CREATE TABLE {storage_config.table_name} ( - {storage_config.turn_id_column_name} VARCHAR(255) PRIMARY KEY, - {storage_config.agent_id_column_name} VARCHAR(255) NOT NULL, - {storage_config.conversation_id_column_name} VARCHAR(255) NOT NULL, - {storage_config.created_at_column_name} INTEGER NOT NULL, - {storage_config.conversation_turn_state_column_name} TEXT NOT NULL, - {storage_config.is_last_turn_column_name} INTEGER NOT NULL, - {storage_config.extra_metadata_column_name} TEXT NOT NULL - ); - """) - try: - _execute_query_on_postgres_db(connection_config, create_table_query) - except ProgrammingError as e: - if f'relation "{storage_config.table_name}" already exists' in str(e): - raise ValueError( - f'The datastore is already setup. Either delete the existing "{storage_config.table_name}" table or start the server with `--setup-datastore=no`.' - ) from e - else: - raise e + _prepare_postgres_checkpoint_datastore(connection_config, storage_config) def _prepare_oracle_datastore( connection_config: OracleDatabaseConnectionConfig, storage_config: ServerStorageConfig ) -> None: - create_table_query = dedent(f""" - CREATE TABLE {storage_config.table_name} ( - {storage_config.turn_id_column_name} VARCHAR2(255) PRIMARY KEY, - {storage_config.agent_id_column_name} VARCHAR2(255) NOT NULL, - {storage_config.conversation_id_column_name} VARCHAR2(255) NOT NULL, - {storage_config.created_at_column_name} INTEGER NOT NULL, - {storage_config.conversation_turn_state_column_name} CLOB NOT NULL, - {storage_config.is_last_turn_column_name} INTEGER NOT NULL, - {storage_config.extra_metadata_column_name} CLOB NOT NULL - ); - """) - try: - _execute_query_on_oracle_db(connection_config, query=create_table_query) - except Exception as e: - if "already exists" in str(e): - raise ValueError( - f'The datastore is already setup. Either delete the existing "{storage_config.table_name}" table or start the server with `--setup-datastore=no`.' - ) from e - else: - raise e + _prepare_oracle_checkpoint_datastore(connection_config, storage_config) def _deserialize_conversation_safely( @@ -79,31 +38,8 @@ def _deserialize_conversation_safely( tool_registry: Optional[Dict[str, Tool]] = None, component: Optional[Component] = None, ) -> Conversation: - """ - Tries to deserialize the conversation. If it does not work, try to deserialize it by considering - the component as a disaggregated component, and will use the already instantiated agent instead of deserializing - it from scratch. - """ - deserialization_context = DeserializationContext() - deserialization_context.registered_tools = tool_registry.copy() if tool_registry else {} - try: - conversation = autodeserialize( - serialized_state, deserialization_context=deserialization_context - ) - except (TypeError, ValueError) as e: - if component is None: - raise e - # we try adding the ref to the agent itself, so that we fall back if - # something went wrong during agent deserialization - logger.warning( - "Failed to deserialize conversation by itself: %s. Using a fallback approach that leverages the provided agent as a disaggregated one to deserialize the conversation.", - e, - ) - deserialization_context = DeserializationContext() - deserialization_context.registered_tools = tool_registry.copy() if tool_registry else {} - deserialization_context._add_component_to_context(component) - - conversation = autodeserialize( - serialized_state, deserialization_context=deserialization_context - ) - return cast(Conversation, conversation) + return _deserialize_conversation_checkpoint_state( + serialized_state, + tool_registry=tool_registry, + component=component, + ) diff --git a/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py b/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py index c9d270573..b7a6f7987 100644 --- a/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py +++ b/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py @@ -4,8 +4,6 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - -import json import logging import time from typing import Any, AsyncIterable, Dict, List, Optional, Union, cast @@ -16,16 +14,18 @@ from fastapi import status as http_status_code from wayflowcore.agentserver.serverstorageconfig import ServerStorageConfig +from wayflowcore.checkpointing import ConversationCheckpoint, DatastoreCheckpointer +from wayflowcore.checkpointing.runtime import ( + _detach_checkpointer_from_conversation, + _set_conversation_final_checkpoint_overrides, +) from wayflowcore.conversation import Conversation from wayflowcore.conversationalcomponent import ConversationalComponent from wayflowcore.datastore import Datastore, InMemoryDatastore -from wayflowcore.datastore._relational import RelationalDatastore from wayflowcore.events import register_event_listeners from wayflowcore.executors.executionstatus import ExecutionStatus, ToolRequestStatus from wayflowcore.idgeneration import IdGenerator -from wayflowcore.serialization import serialize -from ..._storagehelpers import _deserialize_conversation_safely from ..models.openairesponsespydanticmodels import ( Conversation2, CreateResponse, @@ -66,7 +66,12 @@ def __init__( self.agents = agents self.storage_config = storage_config or ServerStorageConfig() self.storage = storage or InMemoryDatastore(schema=self.storage_config.to_schema()) + self.checkpointer = DatastoreCheckpointer( + datastore=self.storage, + storage_config=self.storage_config, + ) self.created_at = int(time.time()) + self._response_conversation_ids: Dict[str, str] = {} self.tool_registries = { agent_name: {t.name: t for t in agent._referenced_tools()} for agent_name, agent in self.agents.items() @@ -121,23 +126,23 @@ async def get_response( detail="Get endpoint for wayflow server only supports non-streaming requests", ) - try: - metadata = self._lookup_conversation( - where={self.storage_config.turn_id_column_name: response_id}, - what=self.storage_config.extra_metadata_column_name, + checkpoint = self._lookup_checkpoint_by_response_id(response_id) + if checkpoint is None: + raise HTTPException( + status_code=http_status_code.HTTP_404_NOT_FOUND, detail="Response not found" ) - except ValueError: + response_as_txt = checkpoint.metadata.get("response") + if not isinstance(response_as_txt, str): raise HTTPException( status_code=http_status_code.HTTP_404_NOT_FOUND, detail="Response not found" ) - response_as_txt = json.loads(metadata)["response"] return Response.model_validate_json(response_as_txt) async def delete_response(self, response_id: str) -> Optional[ResponseError]: - self.storage.delete( - collection_name=self.storage_config.table_name, - where={self.storage_config.turn_id_column_name: response_id}, - ) + checkpoint = self._lookup_checkpoint_by_response_id(response_id) + if checkpoint is not None: + self.checkpointer.delete(checkpoint.conversation_id, checkpoint.checkpoint_id) + self._response_conversation_ids.pop(response_id, None) return None async def cancel_response(self, response_id: str) -> Union[Response, ResponseError]: @@ -193,14 +198,19 @@ async def create_response(self, body: CreateResponse) -> AsyncIterable[ResponseS agent_id=model, ) + response_id = IdGenerator.get_or_generate_id() state = await self._create_state( agent=agent, state=state, request=body, ) + if body.store is None or body.store is True: + _set_conversation_final_checkpoint_overrides(state, checkpoint_id=response_id) + else: + _detach_checkpointer_from_conversation(state) current_response = Response( - id=IdGenerator.get_or_generate_id(), + id=response_id, created_at=int(time.time()), error=None, incomplete_details=None, @@ -295,11 +305,13 @@ async def runner(conversation: Conversation) -> None: token_usage_listener.usage ) - if body.store is None or body.store is True: - self._save_state( - state=state, - response=current_response, + if (body.store is None or body.store is True) and state.checkpointer is not None: + self.checkpointer.save_conversation( + state, + checkpoint_id=current_response.id, + metadata={"response": current_response.model_dump_json()}, ) + self._response_conversation_ids[current_response.id] = state.id if current_response.error is not None: yield ResponseFailedEvent( @@ -355,106 +367,58 @@ def _load_state( agent_id: str, ) -> Optional[Conversation]: if previous_response_id: - try: - serialized_conversation = self._lookup_conversation( - where={self.storage_config.turn_id_column_name: previous_response_id}, - what=self.storage_config.conversation_turn_state_column_name, - ) - except ValueError: + checkpoint = self._lookup_checkpoint_by_response_id(previous_response_id) + if checkpoint is None: raise HTTPException( status_code=http_status_code.HTTP_404_NOT_FOUND, detail=f"No previous response with id `{previous_response_id}` was found", ) - elif conversation_id: + self._response_conversation_ids[checkpoint.checkpoint_id] = checkpoint.conversation_id try: - serialized_conversation = self._lookup_conversation( - where={ - self.storage_config.conversation_id_column_name: conversation_id, - self.storage_config.is_last_turn_column_name: 1, # only latest round - }, - what=self.storage_config.conversation_turn_state_column_name, + return self.agents[agent_id].start_conversation( + conversation_id=checkpoint.conversation_id, + checkpoint_id=checkpoint.checkpoint_id, + checkpointer=self.checkpointer, ) - except ValueError: + except (TypeError, ValueError) as e: + raise HTTPException( + status_code=http_status_code.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Conversation state is corrupted, it cannot be de-serialized: {e}", + ) from e + elif conversation_id: + checkpoint = self.checkpointer.load_latest(conversation_id) + if checkpoint is None: raise HTTPException( status_code=http_status_code.HTTP_404_NOT_FOUND, detail=f"No conversation with id `{conversation_id}` was found", ) + self._response_conversation_ids[checkpoint.checkpoint_id] = checkpoint.conversation_id + try: + return self.agents[agent_id].start_conversation( + conversation_id=conversation_id, + checkpointer=self.checkpointer, + ) + except (TypeError, ValueError) as e: + raise HTTPException( + status_code=http_status_code.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Conversation state is corrupted, it cannot be de-serialized: {e}", + ) from e else: return None - try: - return _deserialize_conversation_safely( - serialized_state=serialized_conversation, - tool_registry=self.tool_registries[agent_id], - component=self.agents[agent_id], - ) - except (TypeError, ValueError) as e: - raise HTTPException( - status_code=http_status_code.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Conversation state is corrupted, it cannot be de-serialized: {e}", - ) - - def _save_state( - self, - response: Response, - state: Conversation, - ) -> None: - conversation_model = response.conversation - if conversation_model is None: - raise ValueError("Internal Error: Conversation should not be None") - conversation_id = conversation_model.id - if conversation_id is None: - raise ValueError("Internal Error: Conversation ID should not be None") - - updates = {self.storage_config.is_last_turn_column_name: 0} - updates_where = { - self.storage_config.conversation_id_column_name: conversation_id, - self.storage_config.is_last_turn_column_name: 1, - } - serialized_state = serialize(state) - new_entity = { - self.storage_config.agent_id_column_name: response.model, - self.storage_config.conversation_id_column_name: conversation_id, - self.storage_config.turn_id_column_name: response.id, - self.storage_config.created_at_column_name: int(time.time()), - self.storage_config.conversation_turn_state_column_name: serialized_state, - self.storage_config.is_last_turn_column_name: 1, - self.storage_config.extra_metadata_column_name: json.dumps( - {"response": response.model_dump_json()} - ), - } - if isinstance(self.storage, RelationalDatastore): - # for relational datastores, we prefer making a single - # transaction, to avoid corrupting the state of the DB - # if the process crashes between the update and the insert - data_table = self.storage.data_tables[self.storage_config.table_name] - sql_update_stmt = data_table._update_query( - where=updates_where, - update=updates, - ) - sql_create_stmt, new_entities = data_table._create_query([new_entity]) - with data_table.engine.connect() as connection: - connection.execute(sql_update_stmt) - connection.execute(sql_create_stmt, new_entities) - connection.commit() - - else: - self.storage.update( - collection_name=self.storage_config.table_name, - where=updates_where, - update=updates, - ) - self.storage.create( - collection_name=self.storage_config.table_name, - entities=[new_entity], - ) - def _lookup_conversation(self, where: Dict[str, Any], what: str) -> Any: - serialized_conversations = self.storage.list( - collection_name=self.storage_config.table_name, where=where - ) - if len(serialized_conversations) != 1: - raise ValueError(f"No conversation with: {where}") - return serialized_conversations[0][what] + def _lookup_checkpoint_by_response_id( + self, response_id: str + ) -> Optional[ConversationCheckpoint]: + conversation_id = self._response_conversation_ids.get(response_id) + if conversation_id is not None: + try: + return self.checkpointer.load(conversation_id, response_id) + except ValueError: + self._response_conversation_ids.pop(response_id, None) + checkpoint = self.checkpointer._find_checkpoint_by_id(response_id) + if checkpoint is not None: + self._response_conversation_ids[response_id] = checkpoint.conversation_id + return checkpoint async def _create_state( self, @@ -481,13 +445,22 @@ async def _create_state( detail="Agent should have an `instructions` input descriptor to be able to take instructions as input", ) inputs = {"instructions": instructions} - state = agent.start_conversation(inputs=inputs, messages=new_messages) + if request.store is None or request.store is True: + state = agent.start_conversation( + inputs=inputs, + messages=new_messages, + checkpointer=self.checkpointer, + ) + else: + state = agent.start_conversation(inputs=inputs, messages=new_messages) else: # later: implement context provider for custom instructions if instructions is not None: raise NotImplementedError( "Instructions are only supported when creating a conversation" ) + if request.store is False: + _detach_checkpointer_from_conversation(state) # Add the new messages to the conversation for message in new_messages: state.append_message(message) diff --git a/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py b/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py index 4a27344e7..f2701cacb 100644 --- a/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py +++ b/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py @@ -4,52 +4,11 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. - from dataclasses import dataclass -from typing import Dict, Optional -from wayflowcore.datastore import Datastore, Entity -from wayflowcore.property import IntegerProperty, StringProperty +from wayflowcore.checkpointing import StorageConfig @dataclass -class ServerStorageConfig: - """Configuration for server storage management.""" - - datastore: Optional[Datastore] = None - """Datastore to use for persistence""" - - table_name: str = "conversations" - """Name of the table in which the states are stored""" - agent_id_column_name: str = "agent_id" - """Name of the column where the agent id of the state is stored""" - conversation_id_column_name: str = "conversation_id" - """Name of the column where the id of the conversation is stored""" - turn_id_column_name: str = "turn_id" - """Name of the column where the turn id / response id is stored""" - created_at_column_name: str = "created_at" - """Name of the column where the creation timestamp is stored""" - conversation_turn_state_column_name: str = "conversation_turn_state" - """Name of the column where the serialized state of turn is store""" - is_last_turn_column_name: str = "is_last_turn" - """Name of the column where the marker for the most recent turn of a given conversation is stored""" - extra_metadata_column_name: str = "extra_metadata" - """Name of the column where the server stores its own attributes""" - - max_retention: Optional[int] = None - """Number of seconds for which to retain a conversation before discarding it""" - - def to_schema(self) -> Dict[str, Entity]: - return { - self.table_name: Entity( - properties={ - self.agent_id_column_name: StringProperty(), - self.conversation_id_column_name: StringProperty(), - self.turn_id_column_name: StringProperty(), - self.is_last_turn_column_name: IntegerProperty(), - self.conversation_turn_state_column_name: StringProperty(), - self.created_at_column_name: IntegerProperty(), - self.extra_metadata_column_name: StringProperty(), - } - ), - } +class ServerStorageConfig(StorageConfig): + """Configuration for agent-server conversation storage.""" diff --git a/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py b/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py index c98d1db59..bd622d0ca 100644 --- a/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py +++ b/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py @@ -95,6 +95,7 @@ async def call_async(self, conversation: "Conversation") -> Any: conversation = self.flow.start_conversation( inputs={}, messages=conversation.message_list, + root_conversation_id=conversation.root_conversation_id, ) status = await conversation.execute_async() if status._requires_yielding: diff --git a/wayflowcore/src/wayflowcore/conversationalcomponent.py b/wayflowcore/src/wayflowcore/conversationalcomponent.py index f6eef9d00..2d15b20c2 100644 --- a/wayflowcore/src/wayflowcore/conversationalcomponent.py +++ b/wayflowcore/src/wayflowcore/conversationalcomponent.py @@ -10,13 +10,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Type, TypeVar, Union from wayflowcore._metadata import MetadataType +from wayflowcore.checkpointing.runtime import _attach_checkpointer_to_conversation +from wayflowcore.checkpointing.serialization import _deserialize_conversation_checkpoint_state from wayflowcore.componentwithio import ComponentWithInputsOutputs +from wayflowcore.idgeneration import IdGenerator from wayflowcore.property import Property logger = logging.getLogger(__name__) if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.conversation import Conversation from wayflowcore.executors._executor import ConversationExecutor from wayflowcore.messagelist import Message, MessageList @@ -66,6 +70,11 @@ def start_conversation( self, inputs: Optional[Dict[str, Any]] = None, messages: Union[None, str, "Message", List["Message"], "MessageList"] = None, + conversation_id: Optional[str] = None, + *, + root_conversation_id: Optional[str] = None, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, ) -> "Conversation": pass @@ -115,6 +124,101 @@ def _update_internal_state(self) -> None: Method to update the attributes inside. """ + @staticmethod + def _messages_or_inputs_were_passed( + inputs: Optional[Dict[str, Any]], + messages: Union[None, str, "Message", List["Message"], "MessageList"], + ) -> bool: + if inputs: + return True + if messages is None: + return False + if isinstance(messages, str): + return len(messages) > 0 + if isinstance(messages, Message): + return True + return len(message) > 0 + + def _restore_or_prepare_checkpoint_conversation( + self, + *, + inputs: Optional[Dict[str, Any]], + messages: Union[None, str, "Message", List["Message"], "MessageList"], + conversation_id: Optional[str], + root_conversation_id: Optional[str], + checkpointer: Optional["Checkpointer"], + checkpoint_id: Optional[str], + ) -> tuple[Optional["Conversation"], Optional[str]]: + if checkpointer is None: + if checkpoint_id is not None: + raise ValueError("`checkpoint_id` requires a `checkpointer`.") + return None, conversation_id + + if ( + root_conversation_id is not None + and conversation_id is not None + and root_conversation_id != conversation_id + ): + raise ValueError( + "`root_conversation_id` and `conversation_id` cannot differ when checkpointing is enabled." + ) + + resolved_conversation_id = conversation_id or root_conversation_id + if resolved_conversation_id is None and checkpoint_id is not None: + raise ValueError("`checkpoint_id` requires a `conversation_id`.") + if resolved_conversation_id is None: + resolved_conversation_id = IdGenerator.get_or_generate_id() + + checkpoint = ( + checkpointer.load(resolved_conversation_id, checkpoint_id) + if checkpoint_id is not None + else checkpointer.load_latest(resolved_conversation_id) + ) + if checkpoint is None: + return None, resolved_conversation_id + + if self._messages_or_inputs_were_passed(inputs=inputs, messages=messages): + raise ValueError( + "Cannot restore a checkpoint while also passing new `inputs` or `messages`. " + "Load the conversation first, then append new user input explicitly." + ) + + conversation = _deserialize_conversation_checkpoint_state( + checkpoint.state, + tool_registry={tool.name: tool for tool in self._referenced_tools()}, + component=self, + ) + return ( + _attach_checkpointer_to_conversation( + conversation, + checkpointer=checkpointer, + checkpoint_id=checkpoint.checkpoint_id, + ), + resolved_conversation_id, + ) + + @staticmethod + def _resolve_runtime_and_root_conversation_ids( + *, + conversation_id: Optional[str], + root_conversation_id: Optional[str], + checkpointer: Optional["Checkpointer"], + restored_conversation_id: Optional[str], + ) -> tuple[str, str]: + if checkpointer is not None: + runtime_conversation_id = restored_conversation_id or IdGenerator.get_or_generate_id( + conversation_id or root_conversation_id + ) + resolved_root_conversation_id = root_conversation_id or runtime_conversation_id + if resolved_root_conversation_id != runtime_conversation_id: + raise ValueError( + "`root_conversation_id` and `conversation_id` cannot differ when checkpointing is enabled." + ) + return runtime_conversation_id, runtime_conversation_id + + runtime_conversation_id = IdGenerator.get_or_generate_id(conversation_id) + return runtime_conversation_id, root_conversation_id or runtime_conversation_id + # Define a TypeVar that represents the component's type ConversationalComponentTypeT = TypeVar( diff --git a/wayflowcore/src/wayflowcore/executors/_agentexecutor.py b/wayflowcore/src/wayflowcore/executors/_agentexecutor.py index f84f63070..e82bce27e 100644 --- a/wayflowcore/src/wayflowcore/executors/_agentexecutor.py +++ b/wayflowcore/src/wayflowcore/executors/_agentexecutor.py @@ -462,7 +462,9 @@ def _get_or_create_expert_agent_subconversation( init_messages.append_message(caller_request_message) sub_agent_conversation = expert_agent.start_conversation( - messages=init_messages, inputs=inputs + messages=init_messages, + inputs=inputs, + root_conversation_id=caller_conv.root_conversation_id, ) return sub_agent_conversation @@ -519,6 +521,7 @@ async def _execute_flow( messages: MessageList, flow: Flow, inputs: Dict[str, Any], + root_conversation_id: Optional[str], ) -> Tuple[Any, str, ExecutionStatus]: """ Execute a flow and return its outputs and its execution status. @@ -531,6 +534,7 @@ async def _execute_flow( state.current_flow_conversation = flow.start_conversation( inputs=inputs, messages=messages, + root_conversation_id=root_conversation_id, ) messages.append_message( Message( @@ -593,7 +597,7 @@ async def _execute_next_subcall( _descriptors_to_json_schema_map(flow.input_descriptors_dict.values()), ) return await AgentConversationExecutor._handle_flow_call( - config, state, flow, tool_request, messages + config, state, flow, tool_request, messages, conversation ) if state.current_retrieved_tools is None: @@ -746,6 +750,7 @@ async def _handle_flow_call( flow: Flow, tool_request: ToolRequest, messages: MessageList, + conversation: "AgentConversation", ) -> Optional[ExecutionStatus]: logger.debug( 'Agent executing flow "%s" (id=%s) with arguments: %s', @@ -754,7 +759,13 @@ async def _handle_flow_call( tool_request.args, ) output, serialized_output, flow_execution_status = ( - await AgentConversationExecutor._execute_flow(state, messages, flow, tool_request.args) + await AgentConversationExecutor._execute_flow( + state, + messages, + flow, + tool_request.args, + conversation.root_conversation_id, + ) ) logger.debug( diff --git a/wayflowcore/src/wayflowcore/executors/_flowexecutor.py b/wayflowcore/src/wayflowcore/executors/_flowexecutor.py index 8f164b8ea..8b57ba98a 100644 --- a/wayflowcore/src/wayflowcore/executors/_flowexecutor.py +++ b/wayflowcore/src/wayflowcore/executors/_flowexecutor.py @@ -288,7 +288,7 @@ def create_sub_conversation( sub_conversation = flow.start_conversation( inputs_not_from_context_providers, - conversation_id=conversation.conversation_id, + root_conversation_id=conversation.root_conversation_id, messages=conversation.message_list, nesting_level=conversation.state.nesting_level + 1, context_providers_from_parent_flow=all_context_provider_keys, diff --git a/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py b/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py index 980cebdff..a242caa65 100644 --- a/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py +++ b/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py @@ -26,11 +26,12 @@ class ManagerWorkersConversationExecutionState(ConversationExecutionState): current_agent_name: str subconversations: Dict[str, Union["AgentConversation", "ManagerWorkersConversation"]] + root_conversation_id: str = "" def _create_subconversation_for_agent( self, agent: Union[Agent, ManagerWorkers] ) -> Union["AgentConversation", "ManagerWorkersConversation"]: - subconv = agent.start_conversation() + subconv = agent.start_conversation(root_conversation_id=self.root_conversation_id or None) self.subconversations[agent.name] = subconv return subconv diff --git a/wayflowcore/src/wayflowcore/executors/interrupts/tokenlimitexecutioninterrupt.py b/wayflowcore/src/wayflowcore/executors/interrupts/tokenlimitexecutioninterrupt.py index e96035453..fa7ce580e 100644 --- a/wayflowcore/src/wayflowcore/executors/interrupts/tokenlimitexecutioninterrupt.py +++ b/wayflowcore/src/wayflowcore/executors/interrupts/tokenlimitexecutioninterrupt.py @@ -126,7 +126,7 @@ def _return_status_if_condition_is_met( self, state: ConversationExecutionState, conversation: "Conversation" ) -> Optional[InterruptedExecutionStatus]: - conversation_id = conversation.conversation_id + conversation_id = conversation.root_conversation_id # We first check the global token limit, then we go over the llm-wise limits # Note that we must do the checks separately, because the list of all models diff --git a/wayflowcore/src/wayflowcore/models/llmmodel.py b/wayflowcore/src/wayflowcore/models/llmmodel.py index 2b6530a69..8652682da 100644 --- a/wayflowcore/src/wayflowcore/models/llmmodel.py +++ b/wayflowcore/src/wayflowcore/models/llmmodel.py @@ -346,11 +346,11 @@ def _update_token_usage( if isinstance(conversation, FlowConversation): # generate with flow - self.token_usages_flow[conversation.conversation_id][ + self.token_usages_flow[conversation.root_conversation_id][ conversation.current_step_name ] += token_usage else: - self.token_usages_flexible[conversation.conversation_id] += token_usage + self.token_usages_flexible[conversation.root_conversation_id] += token_usage def get_total_token_consumption(self, conversation_id: str) -> TokenUsage: """Calculate and return the total token consumption for a given conversation. diff --git a/wayflowcore/src/wayflowcore/serialization/context.py b/wayflowcore/src/wayflowcore/serialization/context.py index a86d6cf17..232215778 100644 --- a/wayflowcore/src/wayflowcore/serialization/context.py +++ b/wayflowcore/src/wayflowcore/serialization/context.py @@ -312,6 +312,21 @@ def _add_component_to_context(self, component: "Component") -> None: """ from wayflowcore.component import Component + def _iter_nested_components(value: Any) -> List["Component"]: + if isinstance(value, Component): + return [value] + if isinstance(value, dict): + nested_components: List["Component"] = [] + for nested_value in value.values(): + nested_components.extend(_iter_nested_components(nested_value)) + return nested_components + if isinstance(value, (list, tuple, set)): + nested_components = [] + for nested_value in value: + nested_components.extend(_iter_nested_components(nested_value)) + return nested_components + return [] + component_ref = SerializationContext.get_reference(component) if component_ref in self._deserialized_objects: @@ -322,14 +337,6 @@ def _add_component_to_context(self, component: "Component") -> None: all_public_attrs = { name: value for name, value in vars(component).items() if not name.startswith("_") } - for attr_name, attr in all_public_attrs.items(): - if isinstance(attr, Component): - self._add_component_to_context(attr) - if isinstance(attr, dict): - for value in attr.values(): - if isinstance(value, Component): - self._add_component_to_context(value) - if isinstance(attr, list): - for value in attr: - if isinstance(value, Component): - self._add_component_to_context(value) + for attr in all_public_attrs.values(): + for nested_component in _iter_nested_components(attr): + self._add_component_to_context(nested_component) diff --git a/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py b/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py index 893877282..25c4efe58 100644 --- a/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py +++ b/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py @@ -294,7 +294,9 @@ def _get_or_create_agent_subconversation( caller_conv.message_list if self._share_conversation else MessageList.from_messages([]) ) agent_sub_conversation = self.agent.start_conversation( - inputs=inputs, messages=init_messages + inputs=inputs, + messages=init_messages, + root_conversation_id=caller_conv.root_conversation_id, ) return agent_sub_conversation diff --git a/wayflowcore/src/wayflowcore/tools/servertools.py b/wayflowcore/src/wayflowcore/tools/servertools.py index d6216c97f..3d6bd6daf 100644 --- a/wayflowcore/src/wayflowcore/tools/servertools.py +++ b/wayflowcore/src/wayflowcore/tools/servertools.py @@ -639,6 +639,7 @@ async def __call__(self, **inputs: Any) -> Any: conversation = self.flow.start_conversation( inputs=inputs, messages=self._parent_conversation.message_list, + root_conversation_id=self._parent_conversation.root_conversation_id, ) interrupts = self._parent_conversation._get_interrupts() diff --git a/wayflowcore/tests/test_managerworkers.py b/wayflowcore/tests/test_managerworkers.py index 9cbed15b7..8141e4d5c 100644 --- a/wayflowcore/tests/test_managerworkers.py +++ b/wayflowcore/tests/test_managerworkers.py @@ -184,14 +184,14 @@ def test_managerworkers_can_execute_with_initial_params_passed_in_start_conversa conversation = group.start_conversation( messages=[Message(content="Please compute 3*4 + 2", message_type=MessageType.USER)], inputs={"USER": "Iris"}, - conversation_id="12345", + root_conversation_id="12345", ) conversation.execute() # The first message must be not the default message as the init messages are passed. assert conversation.get_last_message().content != DEFAULT_INITIAL_MESSAGE - assert conversation.conversation_id == "12345" + assert conversation.root_conversation_id == "12345" @retry_test(max_attempts=2) diff --git a/wayflowcore/tests/test_swarm.py b/wayflowcore/tests/test_swarm.py index 29533d36c..330654872 100644 --- a/wayflowcore/tests/test_swarm.py +++ b/wayflowcore/tests/test_swarm.py @@ -226,14 +226,14 @@ def test_can_execute_swarm_with_initial_params_passed_in_start_conversation( ) ], inputs={"USER": "Iris"}, - conversation_id="12345", + root_conversation_id="12345", ) conversation.execute() # The first message must be not the default message as the init messages are passed. assert conversation.get_last_message().content != "Hi! How can I help you?" - assert conversation.conversation_id == "12345" + assert conversation.root_conversation_id == "12345" def test_can_create_swarm(example_medical_agents): From 390bb89e18c8eacf95fe8c0fc68c0067bd45f283 Mon Sep 17 00:00:00 2001 From: jschweiz Date: Tue, 21 Apr 2026 15:18:18 +0200 Subject: [PATCH 2/4] [fix]: fix tests --- wayflowcore/src/wayflowcore/a2a/a2aagent.py | 53 +- wayflowcore/src/wayflowcore/agent.py | 49 +- .../services/wayflowservice.py | 22 +- .../src/wayflowcore/checkpointing/__init__.py | 27 + .../wayflowcore/checkpointing/checkpointer.py | 155 +++++ .../checkpointing/checkpointeventlistener.py | 178 +++++ .../checkpointing/datastorecheckpointer.py | 399 +++++++++++ .../checkpointing/serialization.py | 143 ++++ .../contextproviders/flowcontextprovider.py | 2 +- wayflowcore/src/wayflowcore/conversation.py | 55 +- .../wayflowcore/conversationalcomponent.py | 132 ++-- wayflowcore/src/wayflowcore/events/event.py | 4 +- .../wayflowcore/executors/_agentexecutor.py | 4 +- .../wayflowcore/executors/_flowexecutor.py | 2 +- .../executors/_managerworkersconversation.py | 2 +- .../executors/_swarmconversation.py | 2 + wayflowcore/src/wayflowcore/flow.py | 58 +- wayflowcore/src/wayflowcore/managerworkers.py | 56 +- wayflowcore/src/wayflowcore/ociagent.py | 47 +- .../src/wayflowcore/serialization/context.py | 17 +- .../wayflowcore/serialization/serializer.py | 34 +- .../wayflowcore/steps/agentexecutionstep.py | 2 +- wayflowcore/src/wayflowcore/swarm.py | 57 +- .../src/wayflowcore/tools/servertools.py | 2 +- wayflowcore/src/wayflowcore/tracing/span.py | 2 +- .../steps/test_prompt_execution_step.py | 2 +- .../test_conversation_checkpointing.py | 654 ++++++++++++++++++ wayflowcore/tests/test_managerworkers.py | 2 +- wayflowcore/tests/test_swarm.py | 2 +- .../tracing/spans/test_conversation_span.py | 2 +- 30 files changed, 2015 insertions(+), 151 deletions(-) create mode 100644 wayflowcore/src/wayflowcore/checkpointing/__init__.py create mode 100644 wayflowcore/src/wayflowcore/checkpointing/checkpointer.py create mode 100644 wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py create mode 100644 wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py create mode 100644 wayflowcore/src/wayflowcore/checkpointing/serialization.py create mode 100644 wayflowcore/tests/serialization/test_conversation_checkpointing.py diff --git a/wayflowcore/src/wayflowcore/a2a/a2aagent.py b/wayflowcore/src/wayflowcore/a2a/a2aagent.py index 4dee6ae11..7a654ad3f 100644 --- a/wayflowcore/src/wayflowcore/a2a/a2aagent.py +++ b/wayflowcore/src/wayflowcore/a2a/a2aagent.py @@ -19,6 +19,7 @@ from wayflowcore.tools import Tool if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.executors._a2aagentconversation import A2AAgentConversation logger = logging.getLogger(__name__) @@ -248,43 +249,71 @@ def start_conversation( self, inputs: Optional[Dict[str, Any]] = None, messages: Union[None, str, Message, List[Message], MessageList] = None, + conversation_id: Optional[str] = None, + *, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "A2AAgentConversation": """ - Initiates a new conversation with the remote server agent. - - Creates and returns a conversation instance tied to this agent, optionally initialized - with input data and a message history. + Start a conversation with the remote A2A agent. Parameters ---------- inputs: - Optional dictionary of initial input data for the conversation. Defaults to an empty - dictionary if not provided. + Optional structured inputs stored on the conversation for interface compatibility. + The A2A runtime currently executes from messages rather than these inputs. messages: - Optional initial message list for the conversation. Can be either a ``MessageList`` - or a list of ``Message`` objects. Defaults to an empty ``MessageList`` if not provided. + Optional initial message history for the remote conversation. + conversation_id: + Optional identifier for this A2A conversation. + checkpointer: + Optional checkpoint backend used to restore and persist this conversation. + checkpoint_id: + Optional checkpoint identifier to restore. Requires ``checkpointer``. + _root_conversation_id: + Internal lineage identifier shared with nested or parent conversations. Returns ------- - Conversation: - A new conversation object associated with this agent. + A2AAgentConversation + A new or restored A2A agent conversation. """ from wayflowcore.executors._a2aagentconversation import A2AAgentConversation from wayflowcore.executors._a2aagentexecutor import A2AAgentState + restored_conversation, conversation_runtime_id, conversation_root_id = ( + self._prepare_conversation_start( + inputs=inputs, + messages=messages, + conversation_id=conversation_id, + checkpointer=checkpointer, + checkpoint_id=checkpoint_id, + _root_conversation_id=_root_conversation_id, + expected_conversation_type=A2AAgentConversation, + attach_checkpointer=_attach_checkpointer, + ) + ) + if restored_conversation is not None: + return restored_conversation + if not isinstance(messages, MessageList): messages = MessageList.from_messages(messages=messages) - return A2AAgentConversation( + conversation = A2AAgentConversation( component=self, state=A2AAgentState(last_message_idx=-1), inputs=inputs or {}, # Inputs are ignored in execution message_list=messages, status=None, - conversation_id=IdGenerator.get_or_generate_id(None), + id=conversation_runtime_id, + checkpointer=checkpointer, name="a2a_conversation", + root_conversation_id=conversation_root_id, __metadata_info__={}, ) + return conversation @property def agent_id(self) -> str: diff --git a/wayflowcore/src/wayflowcore/agent.py b/wayflowcore/src/wayflowcore/agent.py index eff51b387..a74d66fe0 100644 --- a/wayflowcore/src/wayflowcore/agent.py +++ b/wayflowcore/src/wayflowcore/agent.py @@ -26,6 +26,7 @@ from wayflowcore.transforms import MessageTransform if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.contextproviders import ContextProvider from wayflowcore.executors._agentconversation import AgentConversation from wayflowcore.flow import Flow @@ -386,29 +387,54 @@ def start_conversation( inputs: Optional[Dict[str, Any]] = None, messages: Union[None, str, "Message", List["Message"], "MessageList"] = None, conversation_id: Optional[str] = None, + *, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "AgentConversation": """ - Initializes a conversation with the agent. + Start a conversation with the agent. Parameters ---------- inputs: - This argument is not used. - It is included for compatibility with the Flow class. + Optional input values for the agent's declared input descriptors. messages: - Message list to which the agent will participate + Optional message history for the conversation. conversation_id: - Conversation id of the parent conversation. + Optional identifier for this agent conversation. + checkpointer: + Optional checkpoint backend used to restore and persist this conversation. + checkpoint_id: + Optional checkpoint identifier to restore. Requires ``checkpointer``. + _root_conversation_id: + Internal lineage identifier shared with nested or parent conversations. Returns ------- - Conversation: - The conversation object of the agent. + AgentConversation + A new or restored agent conversation. """ from wayflowcore.events.event import ConversationCreatedEvent from wayflowcore.events.eventlistener import record_event from wayflowcore.executors._agentconversation import AgentConversation + restored_conversation, conversation_runtime_id, conversation_root_id = ( + self._prepare_conversation_start( + inputs=inputs, + messages=messages, + conversation_id=conversation_id, + checkpointer=checkpointer, + checkpoint_id=checkpoint_id, + _root_conversation_id=_root_conversation_id, + expected_conversation_type=AgentConversation, + attach_checkpointer=_attach_checkpointer, + ) + ) + if restored_conversation is not None: + return restored_conversation + if not isinstance(messages, MessageList): messages = MessageList.from_messages(messages=messages) @@ -451,23 +477,26 @@ def start_conversation( conversational_component=self, inputs=inputs, messages=messages, - conversation_id=conversation_id, + conversation_id=conversation_runtime_id, nesting_level=None, ) ) from wayflowcore.executors._agentexecutor import AgentConversationExecutionState - return AgentConversation( + conversation = AgentConversation( component=self, message_list=messages, - conversation_id=IdGenerator.get_or_generate_id(conversation_id), + id=conversation_runtime_id, + checkpointer=checkpointer, inputs=inputs or {}, name="agent_conversation", state=AgentConversationExecutionState(), status=None, + root_conversation_id=conversation_root_id, __metadata_info__={}, ) + return conversation @property def llms(self) -> List["LlmModel"]: diff --git a/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py b/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py index b7a6f7987..58a4d13ea 100644 --- a/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py +++ b/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py @@ -15,10 +15,6 @@ from wayflowcore.agentserver.serverstorageconfig import ServerStorageConfig from wayflowcore.checkpointing import ConversationCheckpoint, DatastoreCheckpointer -from wayflowcore.checkpointing.runtime import ( - _detach_checkpointer_from_conversation, - _set_conversation_final_checkpoint_overrides, -) from wayflowcore.conversation import Conversation from wayflowcore.conversationalcomponent import ConversationalComponent from wayflowcore.datastore import Datastore, InMemoryDatastore @@ -192,10 +188,13 @@ async def create_response(self, body: CreateResponse) -> AsyncIterable[ResponseS if conversation_id is not None and not isinstance(conversation_id, str): conversation_id = conversation_id.id + should_store_response = body.store is None or body.store is True + state = self._load_state( previous_response_id=previous_response_id, conversation_id=conversation_id, agent_id=model, + attach_checkpointer=should_store_response, ) response_id = IdGenerator.get_or_generate_id() @@ -204,10 +203,6 @@ async def create_response(self, body: CreateResponse) -> AsyncIterable[ResponseS state=state, request=body, ) - if body.store is None or body.store is True: - _set_conversation_final_checkpoint_overrides(state, checkpoint_id=response_id) - else: - _detach_checkpointer_from_conversation(state) current_response = Response( id=response_id, @@ -261,7 +256,9 @@ async def runner(conversation: Conversation) -> None: nonlocal status try: with register_event_listeners([token_usage_listener, yielding_listener]): - status = await conversation.execute_async() + status = await conversation.execute_async( + _final_checkpoint_id=response_id if should_store_response else None, + ) except Exception as e: nonlocal raised_exception raised_exception = e @@ -305,7 +302,7 @@ async def runner(conversation: Conversation) -> None: token_usage_listener.usage ) - if (body.store is None or body.store is True) and state.checkpointer is not None: + if should_store_response and state.checkpointer is not None: self.checkpointer.save_conversation( state, checkpoint_id=current_response.id, @@ -365,6 +362,7 @@ def _load_state( previous_response_id: Optional[str], conversation_id: Optional[str], agent_id: str, + attach_checkpointer: bool = True, ) -> Optional[Conversation]: if previous_response_id: checkpoint = self._lookup_checkpoint_by_response_id(previous_response_id) @@ -379,6 +377,7 @@ def _load_state( conversation_id=checkpoint.conversation_id, checkpoint_id=checkpoint.checkpoint_id, checkpointer=self.checkpointer, + _attach_checkpointer=attach_checkpointer, ) except (TypeError, ValueError) as e: raise HTTPException( @@ -397,6 +396,7 @@ def _load_state( return self.agents[agent_id].start_conversation( conversation_id=conversation_id, checkpointer=self.checkpointer, + _attach_checkpointer=attach_checkpointer, ) except (TypeError, ValueError) as e: raise HTTPException( @@ -459,8 +459,6 @@ async def _create_state( raise NotImplementedError( "Instructions are only supported when creating a conversation" ) - if request.store is False: - _detach_checkpointer_from_conversation(state) # Add the new messages to the conversation for message in new_messages: state.append_message(message) diff --git a/wayflowcore/src/wayflowcore/checkpointing/__init__.py b/wayflowcore/src/wayflowcore/checkpointing/__init__.py new file mode 100644 index 000000000..7db08002e --- /dev/null +++ b/wayflowcore/src/wayflowcore/checkpointing/__init__.py @@ -0,0 +1,27 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from importlib import import_module +from typing import Any + +from .checkpointer import Checkpointer, CheckpointingInterval, ConversationCheckpoint, StorageConfig +from .datastorecheckpointer import ( + DatastoreCheckpointer, + InMemoryCheckpointer, + OracleDatabaseCheckpointer, + PostgresCheckpointer, +) + +__all__ = [ + "CheckpointingInterval", + "Checkpointer", + "ConversationCheckpoint", + "DatastoreCheckpointer", + "InMemoryCheckpointer", + "OracleDatabaseCheckpointer", + "PostgresCheckpointer", + "StorageConfig", +] diff --git a/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py b/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py new file mode 100644 index 000000000..0ecafb48f --- /dev/null +++ b/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py @@ -0,0 +1,155 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from wayflowcore.idgeneration import IdGenerator + +from .serialization import _serialize_conversation_checkpoint_state + +if TYPE_CHECKING: + from wayflowcore.conversation import Conversation + from wayflowcore.datastore import Datastore + + +@dataclass(frozen=True) +class ConversationCheckpoint: + """Durable snapshot of a conversation at a checkpoint boundary.""" + + checkpoint_id: str + conversation_id: str + component_id: str + created_at: int + state: str + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def id(self) -> str: + return self.checkpoint_id + + +class CheckpointingInterval(Enum): + """ + Configure when the conversation is saved during execution. + """ + + CONVERSATION_TURNS = "conversation_turns" + LLM_TURNS = "llm_turns" + ALL_INTERNAL_TURNS = "all_internal_turns" + + +@dataclass +class StorageConfig: + """Configuration for checkpoint storage.""" + + datastore: Optional["Datastore"] = None + table_name: str = "conversations" + agent_id_column_name: str = "agent_id" + conversation_id_column_name: str = "conversation_id" + turn_id_column_name: str = "turn_id" + created_at_column_name: str = "created_at" + remove_by_column_name: str = "remove_by" + conversation_turn_state_column_name: str = "conversation_turn_state" + is_last_turn_column_name: str = "is_last_turn" + extra_metadata_column_name: str = "extra_metadata" + max_retention: Optional[int] = None + + def to_schema(self) -> Dict[str, Any]: + from wayflowcore.datastore import Entity, nullable + from wayflowcore.property import IntegerProperty, StringProperty + + properties = { + self.agent_id_column_name: StringProperty(), + self.conversation_id_column_name: StringProperty(), + self.turn_id_column_name: StringProperty(), + self.is_last_turn_column_name: IntegerProperty(), + self.conversation_turn_state_column_name: StringProperty(), + self.created_at_column_name: IntegerProperty(), + self.extra_metadata_column_name: StringProperty(), + } + if self.max_retention is not None: + properties[self.remove_by_column_name] = nullable(IntegerProperty()) + return { + self.table_name: Entity( + properties=properties, + ), + } + + +class Checkpointer(ABC): + """Backend that can persist and restore checkpoints for conversations.""" + + def __init__( + self, + checkpointing_interval: CheckpointingInterval = CheckpointingInterval.CONVERSATION_TURNS, + ) -> None: + self.checkpointing_interval = checkpointing_interval + self._save_sequence_by_conversation: Dict[str, int] = {} + + @abstractmethod + def load_latest(self, conversation_id: str) -> Optional[ConversationCheckpoint]: + raise NotImplementedError() + + @abstractmethod + def load(self, conversation_id: str, checkpoint_id: str) -> ConversationCheckpoint: + raise NotImplementedError() + + def save(self, checkpoint: Any) -> None: + from wayflowcore.conversation import Conversation + + if isinstance(checkpoint, Conversation): + self.save_conversation(checkpoint) + return + if not isinstance(checkpoint, ConversationCheckpoint): + raise TypeError( + f"Expected a Conversation or ConversationCheckpoint, got {type(checkpoint).__name__}." + ) + self._save_checkpoint(checkpoint) + + async def save_async(self, checkpoint: Any) -> None: + self.save(checkpoint) + + def save_conversation( + self, + conversation: "Conversation", + *, + checkpoint_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> ConversationCheckpoint: + next_save_sequence = self._save_sequence_by_conversation.get(conversation.id, 0) + 1 + self._save_sequence_by_conversation[conversation.id] = next_save_sequence + checkpoint_metadata = {"save_sequence": next_save_sequence} + if metadata: + checkpoint_metadata.update(metadata) + checkpoint = ConversationCheckpoint( + checkpoint_id=checkpoint_id or IdGenerator.get_or_generate_id(), + conversation_id=conversation.id, + component_id=conversation.component.id, + created_at=int(time.time()), + state=_serialize_conversation_checkpoint_state(conversation), + metadata=checkpoint_metadata, + ) + self._save_checkpoint(checkpoint) + conversation.checkpoint_id = checkpoint.checkpoint_id + return checkpoint + + @abstractmethod + def _save_checkpoint(self, checkpoint: ConversationCheckpoint) -> None: + raise NotImplementedError() + + @abstractmethod + def list_checkpoints( + self, conversation_id: str, limit: Optional[int] = 50 + ) -> List[ConversationCheckpoint]: + raise NotImplementedError() + + @abstractmethod + def delete(self, conversation_id: str, checkpoint_id: str) -> None: + raise NotImplementedError() diff --git a/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py b/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py new file mode 100644 index 000000000..af750a3b2 --- /dev/null +++ b/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py @@ -0,0 +1,178 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from contextlib import contextmanager, nullcontext +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional + +from ..events import EventListener +from .checkpointer import CheckpointingInterval + +if TYPE_CHECKING: + from wayflowcore.conversation import Conversation + + +def _build_checkpoint_metadata( + conversation: "Conversation", + *, + save_reason: str, + event: Optional[Any] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + checkpoint_metadata: Dict[str, Any] = { + "save_reason": save_reason, + "current_step_name": conversation.current_step_name, + "message_count": len(conversation.message_list.messages), + } + if conversation.status is not None: + checkpoint_metadata["status_type"] = type(conversation.status).__name__ + if event is not None: + checkpoint_metadata["event_type"] = event.__class__.__name__ + execution_state = getattr(event, "execution_state", None) + if execution_state is not None: + if hasattr(execution_state, "curr_iter"): + checkpoint_metadata["agent_iteration"] = execution_state.curr_iter + if hasattr(execution_state, "current_step_name"): + checkpoint_metadata["flow_step_name"] = execution_state.current_step_name + if hasattr(execution_state, "nesting_level"): + checkpoint_metadata["nesting_level"] = execution_state.nesting_level + if metadata: + checkpoint_metadata.update(metadata) + return checkpoint_metadata + + +def _save_conversation_checkpoint( + conversation: "Conversation", + *, + save_reason: str, + event: Optional[Any] = None, + metadata: Optional[Dict[str, Any]] = None, + checkpoint_id: Optional[str] = None, +) -> None: + checkpointer = conversation.checkpointer + if checkpointer is None: + return + + checkpoint_metadata = _build_checkpoint_metadata( + conversation, + save_reason=save_reason, + event=event, + metadata=metadata, + ) + + checkpointer.save_conversation( + conversation, + checkpoint_id=checkpoint_id, + metadata=checkpoint_metadata, + ) + + +class _ConversationCheckpointEventListener(EventListener): + def __init__(self, conversation: "Conversation") -> None: + self.conversation = conversation + self._llm_was_used_since_last_internal_turn = False + self._last_internal_turn_start_event: Optional[Any] = None + + def __call__(self, event: Any) -> None: + from wayflowcore.events.event import ( + AgentExecutionIterationStartedEvent, + FlowExecutionIterationStartedEvent, + LlmGenerationResponseEvent, + ) + + checkpointer = self.conversation.checkpointer + if checkpointer is None: + return + + if isinstance(event, LlmGenerationResponseEvent): + self._llm_was_used_since_last_internal_turn = True + return + + if not isinstance( + event, (AgentExecutionIterationStartedEvent, FlowExecutionIterationStartedEvent) + ): + return + + checkpointing_interval = checkpointer.checkpointing_interval + + if checkpointing_interval == CheckpointingInterval.CONVERSATION_TURNS: + self._last_internal_turn_start_event = event + return + + if checkpointing_interval == CheckpointingInterval.ALL_INTERNAL_TURNS: + _save_conversation_checkpoint( + self.conversation, + save_reason="internal_turn_boundary", + event=event, + metadata={ + "llm_used_in_previous_turn": self._llm_was_used_since_last_internal_turn, + }, + ) + self._llm_was_used_since_last_internal_turn = False + self._last_internal_turn_start_event = event + return + + if self._llm_was_used_since_last_internal_turn: + _save_conversation_checkpoint( + self.conversation, + save_reason="internal_turn_boundary", + event=event, + metadata={ + "llm_used_in_previous_turn": self._llm_was_used_since_last_internal_turn, + }, + ) + self._llm_was_used_since_last_internal_turn = False + + self._last_internal_turn_start_event = event + + def flush_pending_checkpoint(self) -> None: + checkpointer = self.conversation.checkpointer + if checkpointer is None: + return + if checkpointer.checkpointing_interval != CheckpointingInterval.LLM_TURNS: + return + if not self._llm_was_used_since_last_internal_turn: + return + + _save_conversation_checkpoint( + self.conversation, + save_reason="internal_turn_boundary", + event=self._last_internal_turn_start_event, + metadata={ + "llm_used_in_previous_turn": True, + }, + ) + self._llm_was_used_since_last_internal_turn = False + + +@contextmanager +def get_conversation_checkpoint_execution_context( + conversation: "Conversation", + *, + is_outermost_execution: bool, + final_checkpoint_id: Optional[str] = None, + final_checkpoint_metadata: Optional[Dict[str, Any]] = None, +) -> Iterator[None]: + if conversation.checkpointer is None or not is_outermost_execution: + with nullcontext(): + yield + return + + from wayflowcore.events.eventlistener import register_event_listeners + + listener = _ConversationCheckpointEventListener(conversation) + with register_event_listeners([listener]): + try: + yield + except Exception: + raise + else: + listener.flush_pending_checkpoint() + _save_conversation_checkpoint( + conversation, + save_reason="conversation_turn", + checkpoint_id=final_checkpoint_id, + metadata=final_checkpoint_metadata, + ) diff --git a/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py b/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py new file mode 100644 index 000000000..27bb5a2bb --- /dev/null +++ b/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py @@ -0,0 +1,399 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import json +import warnings +from textwrap import dedent +from typing import Any, Dict, List, Optional, Sequence + +from wayflowcore.datastore import ( + Datastore, + InMemoryDatastore, + OracleDatabaseConnectionConfig, + OracleDatabaseDatastore, + PostgresDatabaseConnectionConfig, + PostgresDatabaseDatastore, +) +from wayflowcore.datastore._relational import RelationalDatastore +from wayflowcore.datastore.inmemory import _INMEMORY_USER_WARNING +from wayflowcore.datastore.oracle import _execute_query_on_oracle_db +from wayflowcore.datastore.postgres import _execute_query_on_postgres_db + +from .checkpointer import Checkpointer, CheckpointingInterval, ConversationCheckpoint, StorageConfig + + +def _build_checkpoint_create_table_columns( + storage_config: StorageConfig, + *, + is_oracle: bool, +) -> List[str]: + text_type = "CLOB" if is_oracle else "TEXT" + varchar_type = "VARCHAR2(255)" if is_oracle else "VARCHAR(255)" + + columns = [ + f"{storage_config.turn_id_column_name} {varchar_type} PRIMARY KEY", + f"{storage_config.agent_id_column_name} {varchar_type} NOT NULL", + f"{storage_config.conversation_id_column_name} {varchar_type} NOT NULL", + f"{storage_config.created_at_column_name} INTEGER NOT NULL", + f"{storage_config.conversation_turn_state_column_name} {text_type} NOT NULL", + f"{storage_config.is_last_turn_column_name} INTEGER NOT NULL", + f"{storage_config.extra_metadata_column_name} {text_type} NOT NULL", + ] + if storage_config.max_retention is not None: + columns.append(f"{storage_config.remove_by_column_name} INTEGER") + return columns + + +def _prepare_postgres_checkpoint_datastore( + connection_config: PostgresDatabaseConnectionConfig, + storage_config: StorageConfig, +) -> None: + from sqlalchemy.exc import ProgrammingError + + create_table_query = dedent( + f""" + CREATE TABLE {storage_config.table_name} ( + {", ".join(_build_checkpoint_create_table_columns(storage_config, is_oracle=False))} + ); + """ + ) + try: + _execute_query_on_postgres_db(connection_config, create_table_query) + except ProgrammingError as e: + if f'relation "{storage_config.table_name}" already exists' in str(e): + raise ValueError( + f'The datastore is already setup. Either delete the existing "{storage_config.table_name}" table or start the server with `--setup-datastore=no`.' + ) from e + raise + + +def _prepare_oracle_checkpoint_datastore( + connection_config: OracleDatabaseConnectionConfig, + storage_config: StorageConfig, +) -> None: + create_table_query = dedent( + f""" + CREATE TABLE {storage_config.table_name} ( + {", ".join(_build_checkpoint_create_table_columns(storage_config, is_oracle=True))} + ); + """ + ) + try: + _execute_query_on_oracle_db(connection_config, query=create_table_query) + except Exception as e: + if "already exists" in str(e): + raise ValueError( + f'The datastore is already setup. Either delete the existing "{storage_config.table_name}" table or start the server with `--setup-datastore=no`.' + ) from e + raise + + +class DatastoreCheckpointer(Checkpointer): + """Checkpointer backed by a WayFlow datastore.""" + + def __init__( + self, + datastore: Datastore, + storage_config: Optional[StorageConfig] = None, + checkpointing_interval: CheckpointingInterval = CheckpointingInterval.CONVERSATION_TURNS, + ) -> None: + super().__init__(checkpointing_interval=checkpointing_interval) + self.datastore = datastore + self.storage_config = storage_config or StorageConfig() + + def _entity_to_checkpoint(self, entity: Dict[str, Any]) -> ConversationCheckpoint: + raw_metadata = entity.get(self.storage_config.extra_metadata_column_name, "{}") + metadata = raw_metadata if isinstance(raw_metadata, dict) else json.loads(raw_metadata) + return ConversationCheckpoint( + checkpoint_id=str(entity[self.storage_config.turn_id_column_name]), + conversation_id=str(entity[self.storage_config.conversation_id_column_name]), + component_id=str(entity[self.storage_config.agent_id_column_name]), + created_at=int(entity[self.storage_config.created_at_column_name]), + state=str(entity[self.storage_config.conversation_turn_state_column_name]), + metadata=metadata, + ) + + def _checkpoint_to_entity(self, checkpoint: ConversationCheckpoint) -> Dict[str, Any]: + entity = { + self.storage_config.agent_id_column_name: checkpoint.component_id, + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, + self.storage_config.created_at_column_name: checkpoint.created_at, + self.storage_config.conversation_turn_state_column_name: checkpoint.state, + self.storage_config.is_last_turn_column_name: 1, + self.storage_config.extra_metadata_column_name: json.dumps(checkpoint.metadata), + } + if self.storage_config.max_retention is not None: + entity[self.storage_config.remove_by_column_name] = ( + checkpoint.created_at + self.storage_config.max_retention + ) + return entity + + @staticmethod + def _sort_checkpoints( + checkpoints: Sequence[ConversationCheckpoint], + ) -> List[ConversationCheckpoint]: + return sorted( + checkpoints, + key=lambda checkpoint: ( + checkpoint.created_at, + checkpoint.metadata.get("save_sequence", -1), + checkpoint.id, + ), + ) + + def _find_checkpoint( + self, + *, + conversation_id: str, + checkpoint_id: str, + ) -> Optional[ConversationCheckpoint]: + entities = self.datastore.list( + collection_name=self.storage_config.table_name, + where={ + self.storage_config.conversation_id_column_name: conversation_id, + self.storage_config.turn_id_column_name: checkpoint_id, + }, + limit=1, + ) + if len(entities) == 0: + return None + return self._entity_to_checkpoint(entities[0]) + + def _find_checkpoint_by_id(self, checkpoint_id: str) -> Optional[ConversationCheckpoint]: + entities = self.datastore.list( + collection_name=self.storage_config.table_name, + where={self.storage_config.turn_id_column_name: checkpoint_id}, + limit=1, + ) + if len(entities) == 0: + return None + return self._entity_to_checkpoint(entities[0]) + + def load_latest(self, conversation_id: str) -> Optional[ConversationCheckpoint]: + entities = self.datastore.list( + collection_name=self.storage_config.table_name, + where={ + self.storage_config.conversation_id_column_name: conversation_id, + self.storage_config.is_last_turn_column_name: 1, + }, + ) + if len(entities) == 0: + return None + checkpoints = self._sort_checkpoints( + [self._entity_to_checkpoint(entity) for entity in entities] + ) + return checkpoints[-1] + + def load(self, conversation_id: str, checkpoint_id: str) -> ConversationCheckpoint: + checkpoint = self._find_checkpoint( + conversation_id=conversation_id, checkpoint_id=checkpoint_id + ) + if checkpoint is None: + raise ValueError( + f"Checkpoint `{checkpoint_id}` was not found for conversation `{conversation_id}`." + ) + return checkpoint + + def _save_checkpoint(self, checkpoint: ConversationCheckpoint) -> None: + existing_checkpoint = self._find_checkpoint( + conversation_id=checkpoint.conversation_id, + checkpoint_id=checkpoint.checkpoint_id, + ) + if existing_checkpoint is not None: + checkpoint = ConversationCheckpoint( + checkpoint_id=checkpoint.checkpoint_id, + conversation_id=checkpoint.conversation_id, + component_id=checkpoint.component_id, + created_at=checkpoint.created_at, + state=checkpoint.state, + metadata=existing_checkpoint.metadata | checkpoint.metadata, + ) + + update_latest_where = { + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.is_last_turn_column_name: 1, + } + update_latest_values = {self.storage_config.is_last_turn_column_name: 0} + entity = self._checkpoint_to_entity(checkpoint) + + if isinstance(self.datastore, RelationalDatastore): + data_table = self.datastore.data_tables[self.storage_config.table_name] + with data_table.engine.connect() as connection: + connection.execute( + data_table._update_query( + where=update_latest_where, + update=update_latest_values, + ) + ) + if existing_checkpoint is None: + sql_create_stmt, new_entities = data_table._create_query([entity]) + connection.execute(sql_create_stmt, new_entities) + else: + update_checkpoint_where = { + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, + } + update_checkpoint_values = { + self.storage_config.agent_id_column_name: checkpoint.component_id, + self.storage_config.created_at_column_name: checkpoint.created_at, + self.storage_config.conversation_turn_state_column_name: checkpoint.state, + self.storage_config.is_last_turn_column_name: 1, + self.storage_config.extra_metadata_column_name: json.dumps( + checkpoint.metadata + ), + } + if self.storage_config.max_retention is not None: + update_checkpoint_values[self.storage_config.remove_by_column_name] = ( + checkpoint.created_at + self.storage_config.max_retention + ) + connection.execute( + data_table._update_query( + where=update_checkpoint_where, + update=update_checkpoint_values, + ) + ) + connection.commit() + else: + self.datastore.update( + collection_name=self.storage_config.table_name, + where=update_latest_where, + update=update_latest_values, + ) + if existing_checkpoint is None: + self.datastore.create( + collection_name=self.storage_config.table_name, + entities=[entity], + ) + else: + update_checkpoint_values = { + self.storage_config.agent_id_column_name: checkpoint.component_id, + self.storage_config.created_at_column_name: checkpoint.created_at, + self.storage_config.conversation_turn_state_column_name: checkpoint.state, + self.storage_config.is_last_turn_column_name: 1, + self.storage_config.extra_metadata_column_name: json.dumps(checkpoint.metadata), + } + if self.storage_config.max_retention is not None: + update_checkpoint_values[self.storage_config.remove_by_column_name] = ( + checkpoint.created_at + self.storage_config.max_retention + ) + self.datastore.update( + collection_name=self.storage_config.table_name, + where={ + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, + }, + update=update_checkpoint_values, + ) + + def list_checkpoints( + self, conversation_id: str, limit: Optional[int] = 50 + ) -> List[ConversationCheckpoint]: + checkpoints = self._sort_checkpoints( + [ + self._entity_to_checkpoint(entity) + for entity in self.datastore.list( + collection_name=self.storage_config.table_name, + where={self.storage_config.conversation_id_column_name: conversation_id}, + ) + ] + ) + if limit is not None and len(checkpoints) > limit: + checkpoints = checkpoints[-limit:] + return checkpoints + + def delete(self, conversation_id: str, checkpoint_id: str) -> None: + latest_checkpoint = self.load_latest(conversation_id) + checkpoints = self.list_checkpoints(conversation_id, limit=None) + checkpoint_to_promote: Optional[ConversationCheckpoint] = None + if latest_checkpoint is not None and latest_checkpoint.checkpoint_id == checkpoint_id: + remaining_checkpoints = [ + checkpoint + for checkpoint in checkpoints + if checkpoint.checkpoint_id != checkpoint_id + ] + checkpoint_to_promote = remaining_checkpoints[-1] if remaining_checkpoints else None + + self.datastore.delete( + collection_name=self.storage_config.table_name, + where={ + self.storage_config.conversation_id_column_name: conversation_id, + self.storage_config.turn_id_column_name: checkpoint_id, + }, + ) + + if checkpoint_to_promote is not None: + self.datastore.update( + collection_name=self.storage_config.table_name, + where={ + self.storage_config.conversation_id_column_name: conversation_id, + self.storage_config.turn_id_column_name: checkpoint_to_promote.checkpoint_id, + }, + update={self.storage_config.is_last_turn_column_name: 1}, + ) + + +class InMemoryCheckpointer(DatastoreCheckpointer): + """Checkpointer backed by an in-memory datastore.""" + + def __init__( + self, + storage_config: Optional[StorageConfig] = None, + checkpointing_interval: CheckpointingInterval = CheckpointingInterval.CONVERSATION_TURNS, + ) -> None: + resolved_storage_config = storage_config or StorageConfig() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=f"{_INMEMORY_USER_WARNING}*") + datastore = InMemoryDatastore(schema=resolved_storage_config.to_schema()) + super().__init__( + datastore=datastore, + storage_config=resolved_storage_config, + checkpointing_interval=checkpointing_interval, + ) + + +class PostgresCheckpointer(DatastoreCheckpointer): + """Checkpointer backed by PostgreSQL.""" + + def __init__( + self, + connection_config: PostgresDatabaseConnectionConfig, + storage_config: Optional[StorageConfig] = None, + checkpointing_interval: CheckpointingInterval = CheckpointingInterval.CONVERSATION_TURNS, + ) -> None: + resolved_storage_config = storage_config or StorageConfig() + datastore = PostgresDatabaseDatastore( + schema=resolved_storage_config.to_schema(), + connection_config=connection_config, + ) + super().__init__( + datastore=datastore, + storage_config=resolved_storage_config, + checkpointing_interval=checkpointing_interval, + ) + self.connection_config = connection_config + + +class OracleDatabaseCheckpointer(DatastoreCheckpointer): + """Checkpointer backed by Oracle Database.""" + + def __init__( + self, + connection_config: OracleDatabaseConnectionConfig, + storage_config: Optional[StorageConfig] = None, + checkpointing_interval: CheckpointingInterval = CheckpointingInterval.CONVERSATION_TURNS, + ) -> None: + resolved_storage_config = storage_config or StorageConfig() + datastore = OracleDatabaseDatastore( + schema=resolved_storage_config.to_schema(), + connection_config=connection_config, + ) + super().__init__( + datastore=datastore, + storage_config=resolved_storage_config, + checkpointing_interval=checkpointing_interval, + ) + self.connection_config = connection_config diff --git a/wayflowcore/src/wayflowcore/checkpointing/serialization.py b/wayflowcore/src/wayflowcore/checkpointing/serialization.py new file mode 100644 index 000000000..d757a3bc0 --- /dev/null +++ b/wayflowcore/src/wayflowcore/checkpointing/serialization.py @@ -0,0 +1,143 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast + +import yaml + +from wayflowcore.serialization import autodeserialize, serialize_to_dict +from wayflowcore.serialization.context import DeserializationContext, SerializationContext +from wayflowcore.serialization.serializer import autodeserialize_from_dict + +if TYPE_CHECKING: + from wayflowcore.component import Component + from wayflowcore.conversation import Conversation + + +_CHECKPOINT_ENVELOPE_FORMAT = "wayflow-conversation-checkpoint" +_CHECKPOINT_ENVELOPE_VERSION = 1 + + +def _iter_conversation_graph(root_conversation: "Conversation") -> Sequence["Conversation"]: + visited_conversation_ids: set[str] = set() + queue: List["Conversation"] = [root_conversation] + ordered_conversations: List["Conversation"] = [] + + while queue: + conversation = queue.pop() + if conversation.id in visited_conversation_ids: + continue + visited_conversation_ids.add(conversation.id) + ordered_conversations.append(conversation) + queue.extend(conversation._get_all_sub_conversations()) + + return ordered_conversations + + +def _ensure_checkpointing_supported(conversation: "Conversation") -> None: + from wayflowcore.ociagent import OciAgent + + for sub_conversation in _iter_conversation_graph(conversation): + if isinstance(sub_conversation.component, OciAgent): + raise NotImplementedError( + "Checkpointing conversations that contain `OciAgent` is not supported yet." + ) + + +def _iter_component_tree(component: "Component") -> Sequence["Component"]: + from wayflowcore.component import Component + + def _iter_nested_components(value: Any) -> List["Component"]: + if isinstance(value, Component): + return [value] + if isinstance(value, dict): + nested_components: List["Component"] = [] + for nested_value in value.values(): + nested_components.extend(_iter_nested_components(nested_value)) + return nested_components + if isinstance(value, (list, tuple, set)): + nested_components = [] + for nested_value in value: + nested_components.extend(_iter_nested_components(nested_value)) + return nested_components + return [] + + visited_component_ids: set[str] = set() + ordered_components: List["Component"] = [] + queue: List["Component"] = [component] + + while queue: + current_component = queue.pop() + current_component_ref = SerializationContext.get_reference(current_component) + if current_component_ref in visited_component_ids: + continue + visited_component_ids.add(current_component_ref) + ordered_components.append(current_component) + + all_public_attrs = { + name: value + for name, value in vars(current_component).items() + if not name.startswith("_") + } + for attr in all_public_attrs.values(): + queue.extend(_iter_nested_components(attr)) + + return ordered_components + + +def _build_checkpoint_serialization_context(conversation: "Conversation") -> SerializationContext: + serialization_context = SerializationContext(root=conversation) + for component in _iter_component_tree(conversation.component): + serialization_context.register_external_reference(component) + return serialization_context + + +def _serialize_conversation_checkpoint_state(conversation: "Conversation") -> str: + _ensure_checkpointing_supported(conversation) + + serialized_conversation = serialize_to_dict( + conversation, + serialization_context=_build_checkpoint_serialization_context(conversation), + ) + + envelope = { + "checkpoint_format": _CHECKPOINT_ENVELOPE_FORMAT, + "version": _CHECKPOINT_ENVELOPE_VERSION, + "conversation": serialized_conversation, + } + return yaml.safe_dump(envelope) + + +def _deserialize_conversation_checkpoint_state( + serialized_state: str, + *, + tool_registry: Optional[Dict[str, Any]] = None, + component: Optional["Component"] = None, +) -> "Conversation": + deserialization_context = DeserializationContext() + deserialization_context.registered_tools = tool_registry.copy() if tool_registry else {} + + if component is not None: + deserialization_context._add_component_to_context(component) + + state_payload = yaml.safe_load(serialized_state) + if ( + isinstance(state_payload, dict) + and state_payload.get("checkpoint_format") == _CHECKPOINT_ENVELOPE_FORMAT + and state_payload.get("version") == _CHECKPOINT_ENVELOPE_VERSION + and "conversation" in state_payload + ): + conversation = autodeserialize_from_dict( + state_payload["conversation"], + deserialization_context=deserialization_context, + ) + else: + conversation = autodeserialize( + serialized_state, + deserialization_context=deserialization_context, + ) + + return cast("Conversation", conversation) diff --git a/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py b/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py index bd622d0ca..0bc90f4d1 100644 --- a/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py +++ b/wayflowcore/src/wayflowcore/contextproviders/flowcontextprovider.py @@ -95,7 +95,7 @@ async def call_async(self, conversation: "Conversation") -> Any: conversation = self.flow.start_conversation( inputs={}, messages=conversation.message_list, - root_conversation_id=conversation.root_conversation_id, + _root_conversation_id=conversation.root_conversation_id, ) status = await conversation.execute_async() if status._requires_yielding: diff --git a/wayflowcore/src/wayflowcore/conversation.py b/wayflowcore/src/wayflowcore/conversation.py index 14fa440c2..c2c2dadd2 100644 --- a/wayflowcore/src/wayflowcore/conversation.py +++ b/wayflowcore/src/wayflowcore/conversation.py @@ -24,7 +24,6 @@ from wayflowcore._utils.async_helpers import run_async_in_sync from wayflowcore.component import DataclassComponent -from wayflowcore.conversationalcomponent import ConversationalComponent from wayflowcore.executors._events.event import Event from wayflowcore.executors.executionstatus import ( ExecutionStatus, @@ -37,7 +36,9 @@ from wayflowcore.tokenusage import TokenUsage if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.contextproviders import ContextProvider + from wayflowcore.conversationalcomponent import ConversationalComponent from wayflowcore.executors._executionstate import ConversationExecutionState from wayflowcore.executors.interrupts.executioninterrupt import ExecutionInterrupt from wayflowcore.models._requesthelpers import TaggedMessageChunkType @@ -62,6 +63,10 @@ def _get_active_conversations(return_copy: bool = True) -> List["Conversation"]: return copy(active_conversations) if return_copy else active_conversations +def is_outermost_execution() -> bool: + return len(_get_active_conversations(return_copy=False)) == 0 + + def _get_current_conversation_id() -> Optional[str]: active_conversations = _get_active_conversations(return_copy=True) if not active_conversations: @@ -85,13 +90,20 @@ def _register_conversation(conversation: "Conversation") -> Generator[None, Any, @dataclass class Conversation(DataclassComponent): - component: ConversationalComponent + component: "ConversationalComponent" state: "ConversationExecutionState" inputs: Dict[str, Any] message_list: MessageList status: Optional[ExecutionStatus] token_usage: TokenUsage = field(default_factory=TokenUsage, init=False) - conversation_id: str = "" # deprecated + root_conversation_id: str = "" + checkpointer: Optional["Checkpointer"] = field( + default=None, + repr=False, + compare=False, + metadata={"serialize": False}, + ) + checkpoint_id: Optional[str] = field(default=None, init=False, repr=False, compare=False) status_handled: bool = False """Whether the current status associated to this conversation was already handled or not @@ -100,6 +112,8 @@ class Conversation(DataclassComponent): def __post_init__(self) -> None: if self.inputs is None: self.inputs = {} + if not self.root_conversation_id: + self.root_conversation_id = self.id @property def plan(self) -> Optional[ExecutionPlan]: @@ -114,6 +128,9 @@ def _register_event(self, event: Event) -> None: def execute( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, + *, + _final_checkpoint_id: Optional[str] = None, + _final_checkpoint_metadata: Optional[Dict[str, Any]] = None, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -121,13 +138,22 @@ def execute( The ``Execution`` status is returned by the Assistant and indicates if the assistant yielded, finished the conversation. """ - return run_async_in_sync( - self.execute_async, execution_interrupts, method_name="execute_async" - ) + + async def _execute_async_wrapper() -> "ExecutionStatus": + return await self.execute_async( + execution_interrupts, + _final_checkpoint_id=_final_checkpoint_id, + _final_checkpoint_metadata=_final_checkpoint_metadata, + ) + + return run_async_in_sync(_execute_async_wrapper, method_name="execute_async") async def execute_async( self, execution_interrupts: Optional[Sequence["ExecutionInterrupt"]] = None, + *, + _final_checkpoint_id: Optional[str] = None, + _final_checkpoint_metadata: Optional[Dict[str, Any]] = None, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -138,11 +164,20 @@ async def execute_async( if self.status_handled is False: self._update_conversation_with_status() - with _register_conversation(self): - new_status = await self.component.runner.execute_async(self, execution_interrupts) + from wayflowcore.checkpointing.checkpointeventlistener import ( + get_conversation_checkpoint_execution_context, + ) - self.status = new_status - self.status_handled = False + with get_conversation_checkpoint_execution_context( + self, + is_outermost_execution=is_outermost_execution(), + final_checkpoint_id=_final_checkpoint_id, + final_checkpoint_metadata=_final_checkpoint_metadata, + ): + with _register_conversation(self): + new_status = await self.component.runner.execute_async(self, execution_interrupts) + self.status = new_status + self.status_handled = False return self.status @property diff --git a/wayflowcore/src/wayflowcore/conversationalcomponent.py b/wayflowcore/src/wayflowcore/conversationalcomponent.py index 2d15b20c2..ee4ae5105 100644 --- a/wayflowcore/src/wayflowcore/conversationalcomponent.py +++ b/wayflowcore/src/wayflowcore/conversationalcomponent.py @@ -7,11 +7,20 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Set, + Type, + TypeVar, + Union, +) from wayflowcore._metadata import MetadataType -from wayflowcore.checkpointing.runtime import _attach_checkpointer_to_conversation -from wayflowcore.checkpointing.serialization import _deserialize_conversation_checkpoint_state from wayflowcore.componentwithio import ComponentWithInputsOutputs from wayflowcore.idgeneration import IdGenerator from wayflowcore.property import Property @@ -21,6 +30,7 @@ if TYPE_CHECKING: from wayflowcore.checkpointing import Checkpointer + from wayflowcore.checkpointing.checkpointer import ConversationCheckpoint from wayflowcore.conversation import Conversation from wayflowcore.executors._executor import ConversationExecutor from wayflowcore.messagelist import Message, MessageList @@ -28,6 +38,7 @@ from wayflowcore.tools import Tool _HUMAN_ENTITY_ID = "human_user" +ConversationTypeT = TypeVar("ConversationTypeT", bound="Conversation") class ConversationalComponent(ComponentWithInputsOutputs, ABC): @@ -72,11 +83,36 @@ def start_conversation( messages: Union[None, str, "Message", List["Message"], "MessageList"] = None, conversation_id: Optional[str] = None, *, - root_conversation_id: Optional[str] = None, checkpointer: Optional["Checkpointer"] = None, checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "Conversation": - pass + """ + Start a conversation for this component. + + Parameters + ---------- + inputs: + Optional structured inputs used to initialize the conversation. + messages: + Optional initial message history. Concrete implementations normalize this into a + ``MessageList`` when needed. + conversation_id: + Optional identifier for the concrete conversation instance. + checkpointer: + Optional checkpoint backend used to restore and persist conversation state. + checkpoint_id: + Optional checkpoint identifier to restore. Requires ``checkpointer``. + _root_conversation_id: + Internal lineage identifier shared by nested conversations for usage accounting, + execution limits, and checkpoint lineage. + + Returns + ------- + Conversation + A new or restored conversation instance ready for execution. + """ @property def llms(self) -> List["LlmModel"]: @@ -137,33 +173,37 @@ def _messages_or_inputs_were_passed( return len(messages) > 0 if isinstance(messages, Message): return True - return len(message) > 0 + return len(messages) > 0 - def _restore_or_prepare_checkpoint_conversation( + def _prepare_conversation_start( self, *, inputs: Optional[Dict[str, Any]], messages: Union[None, str, "Message", List["Message"], "MessageList"], conversation_id: Optional[str], - root_conversation_id: Optional[str], + _root_conversation_id: Optional[str], checkpointer: Optional["Checkpointer"], checkpoint_id: Optional[str], - ) -> tuple[Optional["Conversation"], Optional[str]]: + expected_conversation_type: Type[ConversationTypeT], + attach_checkpointer: bool, + ) -> tuple[Optional[ConversationTypeT], str, str]: if checkpointer is None: if checkpoint_id is not None: raise ValueError("`checkpoint_id` requires a `checkpointer`.") - return None, conversation_id + + runtime_conversation_id = IdGenerator.get_or_generate_id(conversation_id) + return None, runtime_conversation_id, _root_conversation_id or runtime_conversation_id if ( - root_conversation_id is not None + _root_conversation_id is not None and conversation_id is not None - and root_conversation_id != conversation_id + and _root_conversation_id != conversation_id ): raise ValueError( "`root_conversation_id` and `conversation_id` cannot differ when checkpointing is enabled." ) - resolved_conversation_id = conversation_id or root_conversation_id + resolved_conversation_id = conversation_id or _root_conversation_id if resolved_conversation_id is None and checkpoint_id is not None: raise ValueError("`checkpoint_id` requires a `conversation_id`.") if resolved_conversation_id is None: @@ -175,7 +215,7 @@ def _restore_or_prepare_checkpoint_conversation( else checkpointer.load_latest(resolved_conversation_id) ) if checkpoint is None: - return None, resolved_conversation_id + return None, resolved_conversation_id, resolved_conversation_id if self._messages_or_inputs_were_passed(inputs=inputs, messages=messages): raise ValueError( @@ -183,41 +223,47 @@ def _restore_or_prepare_checkpoint_conversation( "Load the conversation first, then append new user input explicitly." ) + conversation = self._restore_checkpointed_conversation( + checkpoint=checkpoint, + checkpointer=checkpointer, + expected_conversation_type=expected_conversation_type, + attach_checkpointer=attach_checkpointer, + ) + return conversation, resolved_conversation_id, resolved_conversation_id + + def _restore_checkpointed_conversation( + self, + *, + checkpoint: "ConversationCheckpoint", + checkpointer: "Checkpointer", + expected_conversation_type: Type[ConversationTypeT], + attach_checkpointer: bool, + ) -> ConversationTypeT: + from wayflowcore.checkpointing.serialization import ( + _deserialize_conversation_checkpoint_state, + ) + + if checkpoint.component_id != self.id: + raise ValueError( + "Cannot restore this checkpoint because this conversation was started with another " + f"component. Checkpoint component id: `{checkpoint.component_id}`. Current component id: `{self.id}`." + ) + conversation = _deserialize_conversation_checkpoint_state( checkpoint.state, tool_registry={tool.name: tool for tool in self._referenced_tools()}, component=self, ) - return ( - _attach_checkpointer_to_conversation( - conversation, - checkpointer=checkpointer, - checkpoint_id=checkpoint.checkpoint_id, - ), - resolved_conversation_id, - ) - - @staticmethod - def _resolve_runtime_and_root_conversation_ids( - *, - conversation_id: Optional[str], - root_conversation_id: Optional[str], - checkpointer: Optional["Checkpointer"], - restored_conversation_id: Optional[str], - ) -> tuple[str, str]: - if checkpointer is not None: - runtime_conversation_id = restored_conversation_id or IdGenerator.get_or_generate_id( - conversation_id or root_conversation_id + if not isinstance(conversation, expected_conversation_type): + raise ValueError( + "Cannot restore this checkpoint because this conversation was started with another " + f"component. Expected `{expected_conversation_type.__name__}`, got `{type(conversation).__name__}`." ) - resolved_root_conversation_id = root_conversation_id or runtime_conversation_id - if resolved_root_conversation_id != runtime_conversation_id: - raise ValueError( - "`root_conversation_id` and `conversation_id` cannot differ when checkpointing is enabled." - ) - return runtime_conversation_id, runtime_conversation_id - - runtime_conversation_id = IdGenerator.get_or_generate_id(conversation_id) - return runtime_conversation_id, root_conversation_id or runtime_conversation_id + + if attach_checkpointer: + conversation.checkpointer = checkpointer + conversation.checkpoint_id = checkpoint.checkpoint_id + return conversation # Define a TypeVar that represents the component's type diff --git a/wayflowcore/src/wayflowcore/events/event.py b/wayflowcore/src/wayflowcore/events/event.py index 4c9c39f1c..c6428cbfa 100644 --- a/wayflowcore/src/wayflowcore/events/event.py +++ b/wayflowcore/src/wayflowcore/events/event.py @@ -767,7 +767,7 @@ class ConversationExecutionStartedEvent(StartSpanEvent["ConversationSpan"]): def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, Any]: return { **super().to_tracing_info(mask_sensitive_information=mask_sensitive_information), - "conversation.id": self.conversation.conversation_id, + "conversation.id": self.conversation.id, "conversation.name": self.conversation.name, } @@ -788,7 +788,7 @@ class ConversationExecutionFinishedEvent(EndSpanEvent["ConversationSpan"]): def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, Any]: return { **super().to_tracing_info(mask_sensitive_information=mask_sensitive_information), - "conversation.id": self.conversation.conversation_id, + "conversation.id": self.conversation.id, "conversation.name": self.conversation.name, "execution_status": self.execution_status.__class__.__name__, } diff --git a/wayflowcore/src/wayflowcore/executors/_agentexecutor.py b/wayflowcore/src/wayflowcore/executors/_agentexecutor.py index e82bce27e..da62f6a8f 100644 --- a/wayflowcore/src/wayflowcore/executors/_agentexecutor.py +++ b/wayflowcore/src/wayflowcore/executors/_agentexecutor.py @@ -464,7 +464,7 @@ def _get_or_create_expert_agent_subconversation( sub_agent_conversation = expert_agent.start_conversation( messages=init_messages, inputs=inputs, - root_conversation_id=caller_conv.root_conversation_id, + _root_conversation_id=caller_conv.root_conversation_id, ) return sub_agent_conversation @@ -534,7 +534,7 @@ async def _execute_flow( state.current_flow_conversation = flow.start_conversation( inputs=inputs, messages=messages, - root_conversation_id=root_conversation_id, + _root_conversation_id=root_conversation_id, ) messages.append_message( Message( diff --git a/wayflowcore/src/wayflowcore/executors/_flowexecutor.py b/wayflowcore/src/wayflowcore/executors/_flowexecutor.py index 8b57ba98a..1e5da0b93 100644 --- a/wayflowcore/src/wayflowcore/executors/_flowexecutor.py +++ b/wayflowcore/src/wayflowcore/executors/_flowexecutor.py @@ -288,7 +288,7 @@ def create_sub_conversation( sub_conversation = flow.start_conversation( inputs_not_from_context_providers, - root_conversation_id=conversation.root_conversation_id, + _root_conversation_id=conversation.root_conversation_id, messages=conversation.message_list, nesting_level=conversation.state.nesting_level + 1, context_providers_from_parent_flow=all_context_provider_keys, diff --git a/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py b/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py index a242caa65..e298be4a7 100644 --- a/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py +++ b/wayflowcore/src/wayflowcore/executors/_managerworkersconversation.py @@ -31,7 +31,7 @@ class ManagerWorkersConversationExecutionState(ConversationExecutionState): def _create_subconversation_for_agent( self, agent: Union[Agent, ManagerWorkers] ) -> Union["AgentConversation", "ManagerWorkersConversation"]: - subconv = agent.start_conversation(root_conversation_id=self.root_conversation_id or None) + subconv = agent.start_conversation(_root_conversation_id=self.root_conversation_id or None) self.subconversations[agent.name] = subconv return subconv diff --git a/wayflowcore/src/wayflowcore/executors/_swarmconversation.py b/wayflowcore/src/wayflowcore/executors/_swarmconversation.py index 1a169f271..234d2ebe1 100644 --- a/wayflowcore/src/wayflowcore/executors/_swarmconversation.py +++ b/wayflowcore/src/wayflowcore/executors/_swarmconversation.py @@ -55,6 +55,7 @@ class SwarmConversationExecutionState(ConversationExecutionState): main_thread: SwarmThread agents_and_threads: Dict[str, Dict[str, SwarmThread]] context_providers: List["ContextProvider"] + root_conversation_id: str = "" current_thread: Optional["SwarmThread"] = None thread_stack: List["SwarmThread"] = field(default_factory=list) @@ -89,6 +90,7 @@ def _create_subconversation_for_thread( conversation = thread.recipient_agent.start_conversation( inputs=inputs, messages=thread.message_list, + _root_conversation_id=self.root_conversation_id or None, ) self.thread_subconversations[thread_id] = conversation diff --git a/wayflowcore/src/wayflowcore/flow.py b/wayflowcore/src/wayflowcore/flow.py index 2c77ada85..9710ca7ce 100644 --- a/wayflowcore/src/wayflowcore/flow.py +++ b/wayflowcore/src/wayflowcore/flow.py @@ -45,6 +45,7 @@ from wayflowcore.tools import Tool if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.executors._flowconversation import FlowConversation from wayflowcore.executors._flowexecutor import _IoKeyType from wayflowcore.messagelist import Message @@ -1166,33 +1167,59 @@ def start_conversation( conversation_id: Optional[str] = None, nesting_level: int = 0, context_providers_from_parent_flow: Optional[Set[str]] = None, + *, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "FlowConversation": """ - Start the conversation. + Start a conversation for this flow. Parameters ---------- inputs: - Dictionary of inputs. Keys are the variable identifiers and - values are the actual inputs to start the conversation. - conversation_id: - Conversation id of the parent conversation. + Optional input values used to initialize flow execution. messages: - List of messages (``MessageList`` object) before starting the conversation. - context_providers_from_parent_flow: - Context provider that don't need to be checked when validating existing inputs. + Optional message history available to the flow at startup. + conversation_id: + Optional identifier for this flow conversation. nesting_level: - Nesting level of the conversation. + Nesting level of the flow execution. Nested subflows increase this value. + context_providers_from_parent_flow: + Names of inputs already provided by parent-flow context providers when validating + required inputs for nested execution. + checkpointer: + Optional checkpoint backend used to restore and persist this conversation. + checkpoint_id: + Optional checkpoint identifier to restore. Requires ``checkpointer``. + _root_conversation_id: + Internal lineage identifier shared with nested or parent conversations. Returns ------- - Conversation: - A Flow Conversation object. + FlowConversation + A new or restored flow conversation. """ from wayflowcore.events.event import ConversationCreatedEvent from wayflowcore.events.eventlistener import record_event from wayflowcore.executors._flowconversation import FlowConversation + restored_conversation, conversation_runtime_id, conversation_root_id = ( + self._prepare_conversation_start( + inputs=inputs, + messages=messages, + conversation_id=conversation_id, + checkpointer=checkpointer, + checkpoint_id=checkpoint_id, + _root_conversation_id=_root_conversation_id, + expected_conversation_type=FlowConversation, + attach_checkpointer=_attach_checkpointer, + ) + ) + if restored_conversation is not None: + return restored_conversation + context_providers_from_parent_flow = context_providers_from_parent_flow or set() if inputs is None: inputs = {} @@ -1242,7 +1269,7 @@ def start_conversation( conversational_component=self, inputs=inputs, messages=messages, - conversation_id=conversation_id, + conversation_id=conversation_runtime_id, nesting_level=nesting_level, ) ) @@ -1271,16 +1298,19 @@ def start_conversation( nesting_level=nesting_level, ) - return FlowConversation( + conversation = FlowConversation( component=self, inputs=inputs, - conversation_id=IdGenerator.get_or_generate_id(conversation_id), + id=conversation_runtime_id, + checkpointer=checkpointer, message_list=messages, __metadata_info__={}, status=None, name="flow_conversation", state=state, + root_conversation_id=conversation_root_id, ) + return conversation @property def llms(self) -> List["LlmModel"]: diff --git a/wayflowcore/src/wayflowcore/managerworkers.py b/wayflowcore/src/wayflowcore/managerworkers.py index f59ace30b..a64c47fd5 100644 --- a/wayflowcore/src/wayflowcore/managerworkers.py +++ b/wayflowcore/src/wayflowcore/managerworkers.py @@ -22,6 +22,7 @@ from wayflowcore.transforms import MessageTransform if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.executors._managerworkersconversation import ManagerWorkersConversation from wayflowcore.messagelist import Message @@ -222,24 +223,36 @@ def start_conversation( messages: Union[None, str, "Message", List["Message"], "MessageList"] = None, conversation_id: Optional[str] = None, conversation_name: Optional[str] = None, + *, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "ManagerWorkersConversation": """ - Initializes a conversation with the managerworkers. + Start a conversation for the manager-workers group. Parameters ---------- inputs: - Dictionary of inputs. Keys are the variable identifiers and - values are the actual inputs to start the main conversation. + Optional input values passed to the manager's main conversation. messages: - Message list of the manager agent and the end-user. + Optional shared message history between the user and the manager. conversation_id: - Conversation id of the main conversation. + Optional identifier for this manager-workers conversation. + conversation_name: + Optional display name used for the created conversation object. + checkpointer: + Optional checkpoint backend used to restore and persist this conversation. + checkpoint_id: + Optional checkpoint identifier to restore. Requires ``checkpointer``. + _root_conversation_id: + Internal lineage identifier shared with nested or parent conversations. Returns ------- - Conversation: - The conversation object of the managerworkers. + ManagerWorkersConversation + A new or restored manager-workers conversation. """ from wayflowcore.agentconversation import AgentConversation from wayflowcore.events.event import ConversationCreatedEvent @@ -249,18 +262,30 @@ def start_conversation( ManagerWorkersConversationExecutionState, ) + restored_conversation, conversation_runtime_id, conversation_root_id = ( + self._prepare_conversation_start( + inputs=inputs, + messages=messages, + conversation_id=conversation_id, + checkpointer=checkpointer, + checkpoint_id=checkpoint_id, + _root_conversation_id=_root_conversation_id, + expected_conversation_type=ManagerWorkersConversation, + attach_checkpointer=_attach_checkpointer, + ) + ) + if restored_conversation is not None: + return restored_conversation + if not isinstance(messages, MessageList): messages = MessageList.from_messages(messages=messages) - if conversation_id is None: - conversation_id = IdGenerator.get_or_generate_id(conversation_id) - record_event( ConversationCreatedEvent( conversational_component=self, inputs=inputs or {}, messages=messages, - conversation_id=conversation_id, + conversation_id=conversation_runtime_id, nesting_level=None, ) ) @@ -269,23 +294,28 @@ def start_conversation( subconversations[self.manager_agent.name] = self.manager_agent.start_conversation( inputs=inputs, messages=messages, + _root_conversation_id=conversation_root_id, ) state = ManagerWorkersConversationExecutionState( current_agent_name=self.manager_agent.name, subconversations=subconversations, + root_conversation_id=conversation_root_id, ) - return ManagerWorkersConversation( + conversation = ManagerWorkersConversation( component=self, inputs={}, message_list=messages, + id=conversation_runtime_id, name=conversation_name or "managerworkers_conversation", state=state, status=None, - conversation_id=conversation_id, + checkpointer=checkpointer, + root_conversation_id=conversation_root_id, __metadata_info__={}, ) + return conversation def _referenced_tools_dict_inner( self, recursive: bool, visited_set: Set[str] diff --git a/wayflowcore/src/wayflowcore/ociagent.py b/wayflowcore/src/wayflowcore/ociagent.py index 4bfa36e92..ad74efde0 100644 --- a/wayflowcore/src/wayflowcore/ociagent.py +++ b/wayflowcore/src/wayflowcore/ociagent.py @@ -19,6 +19,7 @@ from wayflowcore.tools import Tool if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.conversation import Conversation @@ -106,22 +107,37 @@ def start_conversation( self, inputs: Optional[Dict[str, Any]] = None, messages: Union[None, str, Message, List[Message], MessageList] = None, + conversation_id: Optional[str] = None, + *, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "Conversation": """ - Initializes a conversation with the agent. + Start a conversation with the OCI agent. Parameters ---------- inputs: - This argument is not used. - It is included for compatibility with the Flow class. + Optional structured inputs stored on the conversation for interface compatibility. messages: - Message list to which the agent will participate + Optional initial message history for the OCI agent session. + conversation_id: + Optional identifier for this OCI agent conversation. + checkpointer: + Optional checkpoint backend. ``OciAgent`` does not support checkpoint restore yet, so + passing this raises ``NotImplementedError``. + checkpoint_id: + Optional checkpoint identifier. ``OciAgent`` does not support checkpoint restore yet, + so passing this raises ``NotImplementedError``. + _root_conversation_id: + Internal lineage identifier shared with nested or parent conversations. Returns ------- - Conversation: - The conversation object of the agent. + Conversation + A new OCI agent conversation. """ from wayflowcore.executors._ociagentconversation import OciAgentConversation from wayflowcore.executors._ociagentexecutor import ( @@ -130,9 +146,25 @@ def start_conversation( _init_oci_agent_session, ) + if any(value is not None for value in (checkpointer, checkpoint_id)): + raise NotImplementedError("`OciAgent` checkpoint restore is not supported yet.") + if not isinstance(messages, MessageList): messages = MessageList.from_messages(messages=messages) + _restored_conversation, conversation_runtime_id, conversation_root_id = ( + self._prepare_conversation_start( + inputs=inputs, + messages=messages, + conversation_id=conversation_id, + checkpointer=None, + checkpoint_id=None, + _root_conversation_id=_root_conversation_id, + expected_conversation_type=OciAgentConversation, + attach_checkpointer=_attach_checkpointer, + ) + ) + _client = _init_oci_agent_client(self) return OciAgentConversation( @@ -145,8 +177,9 @@ def start_conversation( inputs=inputs or {}, message_list=messages, status=None, - conversation_id=IdGenerator.get_or_generate_id(None), + id=conversation_runtime_id, name="oci_conversation", + root_conversation_id=conversation_root_id, __metadata_info__={}, ) diff --git a/wayflowcore/src/wayflowcore/serialization/context.py b/wayflowcore/src/wayflowcore/serialization/context.py index 232215778..186b3406d 100644 --- a/wayflowcore/src/wayflowcore/serialization/context.py +++ b/wayflowcore/src/wayflowcore/serialization/context.py @@ -47,6 +47,7 @@ def __init__(self, root: Any, plugins: Optional[List["WayflowSerializationPlugin """ self.root = root self._serialized_objects: Dict[str, Any] = {} + self._external_references: set[str] = set() self._started_serialization: Dict[str, bool] = {} self.plugins = plugins or [] @@ -113,6 +114,16 @@ def record_obj_dict(self, obj: Any, obj_as_dict: Dict[Any, Any]) -> None: """ self._serialized_objects[self.get_reference(obj)] = obj_as_dict + def register_external_reference(self, obj: Any) -> None: + """ + Registers an object as provided externally to the serialized payload. + + The serializer will emit a ``$ref`` for this object, but it will not add the object to + the root ``_referenced_objects`` section because the deserialization context is expected + to already contain it. + """ + self._external_references.add(self.get_reference(obj)) + def check_obj_is_already_serialized(self, obj: Any) -> bool: """ Returns True if the object has already been serialized @@ -122,7 +133,11 @@ def check_obj_is_already_serialized(self, obj: Any) -> bool: obj: The original, non-serialized object """ - return self._serialized_objects.get(self.get_reference(obj)) is not None + obj_ref = self.get_reference(obj) + return ( + obj_ref in self._external_references + or self._serialized_objects.get(obj_ref) is not None + ) def get_reference_dict(self, obj: Any) -> Dict[str, str]: """ diff --git a/wayflowcore/src/wayflowcore/serialization/serializer.py b/wayflowcore/src/wayflowcore/serialization/serializer.py index 37ca23b45..e92a99023 100644 --- a/wayflowcore/src/wayflowcore/serialization/serializer.py +++ b/wayflowcore/src/wayflowcore/serialization/serializer.py @@ -140,13 +140,18 @@ class MyDataclass: type_3: "MySecondCustomAttr" <--- resolves the actual type of this kind of attribute """ dataclass_fields: Dict[str, Any] = { - param.name: param.type for param in fields(cls) if param.init + param.name: param.type for param in fields(cls) if _should_serialize_dataclass_field(param) } # we resolve the forwards references (e.g. dataclasses with type annotations specified "between quotes") if any(isinstance(t, str) for t in dataclass_fields.values()): try: - dataclass_fields = get_type_hints(cls) + resolved_type_hints = get_type_hints(cls) + dataclass_fields = { + field_name: resolved_type_hints[field_name] + for field_name in dataclass_fields + if field_name in resolved_type_hints + } except NameError as e: pass @@ -193,23 +198,32 @@ def _resolve_legacy_field_name(cls: type, field_name: str) -> str: } if cls in _CLS_TO_ATTRIBUTE_MAPPING: - return _CLS_TO_ATTRIBUTE_MAPPING[cls].get(field_name, field_name) + resolved_field_name = _CLS_TO_ATTRIBUTE_MAPPING[cls].get(field_name) + if resolved_field_name is not None: + return resolved_field_name + + if field_name == "root_conversation_id" and any( + base.__name__ == "Conversation" for base in cls.__mro__ + ): + return "conversation_id" return field_name +def _should_serialize_dataclass_field(dataclass_field: Any) -> bool: + return bool( + (not dataclass_field.name.startswith("_") or dataclass_field.name == "__metadata_info__") + and dataclass_field.init + and dataclass_field.metadata.get("serialize", True) + ) + + class SerializableDataclassMixin: def _serialize_to_dict(self, serialization_context: "SerializationContext") -> Dict[str, Any]: return { k.name: serialize_any_to_dict(getattr(self, k.name), serialization_context) for k in fields(self) # type: ignore - if ( - ( - not k.name.startswith("_") # we don't serialize private fields - or k.name == "__metadata_info__" # except the metadata - ) - and k.init # not part of the dataclass __init__ -> would fail at deserialization - ) + if _should_serialize_dataclass_field(k) } @classmethod diff --git a/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py b/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py index 25c4efe58..a26329035 100644 --- a/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py +++ b/wayflowcore/src/wayflowcore/steps/agentexecutionstep.py @@ -296,7 +296,7 @@ def _get_or_create_agent_subconversation( agent_sub_conversation = self.agent.start_conversation( inputs=inputs, messages=init_messages, - root_conversation_id=caller_conv.root_conversation_id, + _root_conversation_id=caller_conv.root_conversation_id, ) return agent_sub_conversation diff --git a/wayflowcore/src/wayflowcore/swarm.py b/wayflowcore/src/wayflowcore/swarm.py index b4852fc18..676fa98f3 100644 --- a/wayflowcore/src/wayflowcore/swarm.py +++ b/wayflowcore/src/wayflowcore/swarm.py @@ -23,6 +23,7 @@ from wayflowcore.transforms import MessageTransform if TYPE_CHECKING: + from wayflowcore.checkpointing import Checkpointer from wayflowcore.conversation import Conversation from wayflowcore.messagelist import Message @@ -312,7 +313,37 @@ def start_conversation( messages: Union[None, str, "Message", List["Message"], MessageList] = None, conversation_id: Optional[str] = None, conversation_name: Optional[str] = None, + *, + checkpointer: Optional["Checkpointer"] = None, + checkpoint_id: Optional[str] = None, + _root_conversation_id: Optional[str] = None, + _attach_checkpointer: bool = True, ) -> "Conversation": + """ + Start a conversation for the swarm. + + Parameters + ---------- + inputs: + Optional input values available to the swarm execution state. + messages: + Optional shared message history for the swarm conversation. + conversation_id: + Optional identifier for this swarm conversation. + conversation_name: + Optional display name used for the created conversation object. + checkpointer: + Optional checkpoint backend used to restore and persist this conversation. + checkpoint_id: + Optional checkpoint identifier to restore. Requires ``checkpointer``. + _root_conversation_id: + Internal lineage identifier shared with nested or parent conversations. + + Returns + ------- + Conversation + A new or restored swarm conversation. + """ from wayflowcore.executors._swarmconversation import ( SwarmConversation, SwarmConversationExecutionState, @@ -320,12 +351,24 @@ def start_conversation( SwarmUser, ) + restored_conversation, conversation_runtime_id, conversation_root_id = ( + self._prepare_conversation_start( + inputs=inputs, + messages=messages, + conversation_id=conversation_id, + checkpointer=checkpointer, + checkpoint_id=checkpoint_id, + _root_conversation_id=_root_conversation_id, + expected_conversation_type=SwarmConversation, + attach_checkpointer=_attach_checkpointer, + ) + ) + if restored_conversation is not None: + return restored_conversation + if not isinstance(messages, MessageList): messages = MessageList.from_messages(messages=messages) - if conversation_id is None: - conversation_id = IdGenerator.get_or_generate_id(conversation_id) - main_thread = SwarmThread( caller=SwarmUser(), recipient_agent=self.first_agent, is_main_thread=True ) @@ -344,17 +387,21 @@ def start_conversation( context_providers=[], inputs=inputs, messages=messages, + root_conversation_id=conversation_root_id, ) - return SwarmConversation( + conversation = SwarmConversation( component=self, inputs=inputs or {}, message_list=messages, + id=conversation_runtime_id, name=conversation_name or "swarm_conversation", state=state, status=None, - conversation_id=conversation_id, + checkpointer=checkpointer, + root_conversation_id=conversation_root_id, __metadata_info__={}, ) + return conversation def _referenced_tools_dict_inner( self, recursive: bool, visited_set: Set[str] diff --git a/wayflowcore/src/wayflowcore/tools/servertools.py b/wayflowcore/src/wayflowcore/tools/servertools.py index 3d6bd6daf..9c85792e3 100644 --- a/wayflowcore/src/wayflowcore/tools/servertools.py +++ b/wayflowcore/src/wayflowcore/tools/servertools.py @@ -639,7 +639,7 @@ async def __call__(self, **inputs: Any) -> Any: conversation = self.flow.start_conversation( inputs=inputs, messages=self._parent_conversation.message_list, - root_conversation_id=self._parent_conversation.root_conversation_id, + _root_conversation_id=self._parent_conversation.root_conversation_id, ) interrupts = self._parent_conversation._get_interrupts() diff --git a/wayflowcore/src/wayflowcore/tracing/span.py b/wayflowcore/src/wayflowcore/tracing/span.py index 07952157c..22dbfbbdc 100644 --- a/wayflowcore/src/wayflowcore/tracing/span.py +++ b/wayflowcore/src/wayflowcore/tracing/span.py @@ -550,7 +550,7 @@ class ConversationSpan(Span): def to_tracing_info(self, mask_sensitive_information: bool = True) -> Dict[str, Any]: return { **super().to_tracing_info(mask_sensitive_information=mask_sensitive_information), - "conversation.id": self.conversation.conversation_id, + "conversation.id": self.conversation.id, "conversation.name": self.conversation.name, "conversational_component.type": self.conversation.component.__class__.__name__, "conversational_component.id": self.conversation.component.id, diff --git a/wayflowcore/tests/integration/steps/test_prompt_execution_step.py b/wayflowcore/tests/integration/steps/test_prompt_execution_step.py index 7c85f1eca..fc6e291e7 100644 --- a/wayflowcore/tests/integration/steps/test_prompt_execution_step.py +++ b/wayflowcore/tests/integration/steps/test_prompt_execution_step.py @@ -451,7 +451,7 @@ def test_check_token_consumption(remotely_hosted_llm): assert isinstance(status, FinishedStatus) assert PromptExecutionStep.OUTPUT in status.output_values assert isinstance(status.output_values[PromptExecutionStep.OUTPUT], str) - token_consumption = remotely_hosted_llm.get_total_token_consumption(conv.conversation_id) + token_consumption = remotely_hosted_llm.get_total_token_consumption(conv.root_conversation_id) assert token_consumption.input_tokens == 39 assert token_consumption.output_tokens == 10 diff --git a/wayflowcore/tests/serialization/test_conversation_checkpointing.py b/wayflowcore/tests/serialization/test_conversation_checkpointing.py new file mode 100644 index 000000000..2d694a346 --- /dev/null +++ b/wayflowcore/tests/serialization/test_conversation_checkpointing.py @@ -0,0 +1,654 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import warnings +from types import SimpleNamespace +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock + +import pytest + +import wayflowcore.checkpointing.serialization as checkpoint_serialization +from wayflowcore.a2a.a2aagent import A2AAgent, A2AConnectionConfig +from wayflowcore.agent import Agent +from wayflowcore.checkpointing import ( + CheckpointingInterval, + ConversationCheckpoint, + InMemoryCheckpointer, +) +from wayflowcore.checkpointing.checkpointeventlistener import _save_conversation_checkpoint +from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus +from wayflowcore.flowhelpers import create_single_step_flow +from wayflowcore.managerworkers import ManagerWorkers +from wayflowcore.models.ociclientconfig import OCIClientConfigWithApiKey +from wayflowcore.ociagent import OciAgent +from wayflowcore.serialization.serializer import _resolve_legacy_field_name +from wayflowcore.steps import OutputMessageStep, PromptExecutionStep +from wayflowcore.swarm import Swarm + +from ..testhelpers.dummy import DummyModel +from ..testhelpers.testhelpers import retry_test +from .test_assistant_serialization import create_flow + + +class RecordingCheckpointer: + def __init__(self, *, fail_first_save: bool = False) -> None: + self.checkpointing_interval = CheckpointingInterval.CONVERSATION_TURNS + self.should_fail_next_save = fail_first_save + self.saved_checkpoints: Dict[tuple[str, str], Dict[str, Any]] = {} + + def save_conversation( + self, + conversation, + *, + checkpoint_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + if self.should_fail_next_save: + self.should_fail_next_save = False + raise RuntimeError("checkpoint save failed") + resolved_checkpoint_id = checkpoint_id or "generated-checkpoint-id" + self.saved_checkpoints[(conversation.id, resolved_checkpoint_id)] = dict(metadata or {}) + conversation.checkpoint_id = resolved_checkpoint_id + return SimpleNamespace( + checkpoint_id=resolved_checkpoint_id, + metadata=dict(metadata or {}), + ) + + def load(self, conversation_id: str, checkpoint_id: str) -> Any: + return SimpleNamespace(metadata=self.saved_checkpoints[(conversation_id, checkpoint_id)]) + + def load_latest(self, conversation_id: str) -> Any: + return None + + +class StaticLoadCheckpointer: + def __init__(self, checkpoint: ConversationCheckpoint) -> None: + self.checkpointing_interval = CheckpointingInterval.CONVERSATION_TURNS + self.checkpoint = checkpoint + + def load(self, conversation_id: str, checkpoint_id: str) -> ConversationCheckpoint: + return self.checkpoint + + def load_latest(self, conversation_id: str) -> Optional[ConversationCheckpoint]: + if conversation_id != self.checkpoint.conversation_id: + return None + return self.checkpoint + + +def _build_checkpointable_agent( + *, + name: str, + initial_message: str, +) -> tuple[Agent, DummyModel]: + llm = DummyModel() + agent = Agent( + llm=llm, + name=name, + description=f"{name} description", + custom_instruction="Be helpful.", + initial_message=initial_message, + ) + return agent, llm + + +def _build_checkpointable_swarm() -> tuple[Swarm, DummyModel]: + first_agent, first_agent_llm = _build_checkpointable_agent( + name="checkpoint_swarm_first_agent", + initial_message="Hello from the swarm.", + ) + second_agent = Agent( + llm=DummyModel(fails_if_not_set=False), + name="checkpoint_swarm_second_agent", + description="Swarm helper", + custom_instruction="Help with delegated tasks.", + ) + swarm = Swarm( + first_agent=first_agent, + relationships=[(first_agent, second_agent)], + name="checkpoint_swarm", + ) + return swarm, first_agent_llm + + +def _build_checkpointable_managerworkers() -> tuple[ManagerWorkers, DummyModel]: + manager_agent, manager_llm = _build_checkpointable_agent( + name="checkpoint_manager_agent", + initial_message="Hello from the manager.", + ) + worker_agent = Agent( + llm=DummyModel(fails_if_not_set=False), + name="checkpoint_worker_agent", + description="Worker agent", + custom_instruction="Help the manager.", + ) + managerworkers = ManagerWorkers( + group_manager=manager_agent, + workers=[worker_agent], + name="checkpoint_managerworkers", + ) + return managerworkers, manager_llm + + +@pytest.fixture(scope="session") +def connection_config_no_verify(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + return A2AConnectionConfig(verify=False) + + +@pytest.fixture +def a2a_agent(a2a_server, connection_config_no_verify): + return A2AAgent( + name="Checkpoint A2A Agent", + agent_url=a2a_server, + connection_config=connection_config_no_verify, + ) + + +def test_inmemory_checkpointer_can_save_load_list_and_delete_checkpoints() -> None: + checkpointer = InMemoryCheckpointer() + flow = create_single_step_flow(OutputMessageStep(message_template="Hello from checkpointing.")) + + conversation = flow.start_conversation( + conversation_id="checkpoint-lifecycle", checkpointer=checkpointer + ) + assert conversation.checkpointer is checkpointer + + status = conversation.execute() + + assert isinstance(status, FinishedStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + + checkpointer.save(conversation) + second_checkpoint_id = conversation.checkpoint_id + assert second_checkpoint_id is not None + assert second_checkpoint_id != first_checkpoint_id + + checkpoints = checkpointer.list_checkpoints("checkpoint-lifecycle") + assert [checkpoint.checkpoint_id for checkpoint in checkpoints] == [ + first_checkpoint_id, + second_checkpoint_id, + ] + assert checkpoints[-1].metadata["save_sequence"] == 2 + + latest_checkpoint = checkpointer.load_latest("checkpoint-lifecycle") + assert latest_checkpoint is not None + assert latest_checkpoint.checkpoint_id == second_checkpoint_id + + restored_conversation = flow.start_conversation( + conversation_id="checkpoint-lifecycle", + checkpoint_id=first_checkpoint_id, + checkpointer=checkpointer, + ) + assert restored_conversation.checkpointer is checkpointer + assert restored_conversation.checkpoint_id == first_checkpoint_id + assert restored_conversation.get_last_message().content == "Hello from checkpointing." + + checkpointer.delete("checkpoint-lifecycle", second_checkpoint_id) + promoted_checkpoint = checkpointer.load_latest("checkpoint-lifecycle") + assert promoted_checkpoint is not None + assert promoted_checkpoint.checkpoint_id == first_checkpoint_id + assert [ + checkpoint.checkpoint_id + for checkpoint in checkpointer.list_checkpoints("checkpoint-lifecycle") + ] == [first_checkpoint_id] + + +def test_conversation_turns_checkpoint_interval_saves_once_after_outer_execute() -> None: + checkpointer = InMemoryCheckpointer( + checkpointing_interval=CheckpointingInterval.CONVERSATION_TURNS + ) + flow = create_single_step_flow(OutputMessageStep(message_template="Hello once.")) + + status = flow.start_conversation( + conversation_id="conversation-turn", checkpointer=checkpointer + ).execute() + + assert isinstance(status, FinishedStatus) + checkpoints = checkpointer.list_checkpoints("conversation-turn") + assert len(checkpoints) == 1 + assert checkpoints[0].metadata["save_reason"] == "conversation_turn" + assert checkpoints[0].metadata["status_type"] == "FinishedStatus" + + +def test_all_internal_turns_checkpoint_interval_saves_before_each_flow_turn() -> None: + checkpointer = InMemoryCheckpointer( + checkpointing_interval=CheckpointingInterval.ALL_INTERNAL_TURNS + ) + flow = create_single_step_flow(OutputMessageStep(message_template="Hello internal turns.")) + + status = flow.start_conversation( + conversation_id="all-internal-turns", checkpointer=checkpointer + ).execute() + + assert isinstance(status, FinishedStatus) + checkpoints = checkpointer.list_checkpoints("all-internal-turns") + assert len(checkpoints) == 3 + assert [checkpoint.metadata["save_reason"] for checkpoint in checkpoints] == [ + "internal_turn_boundary", + "internal_turn_boundary", + "conversation_turn", + ] + assert [checkpoint.metadata.get("event_type") for checkpoint in checkpoints[:-1]] == [ + "FlowExecutionIterationStartedEvent", + "FlowExecutionIterationStartedEvent", + ] + assert checkpoints[-1].metadata["status_type"] == "FinishedStatus" + + +def test_llm_turns_checkpoint_interval_saves_only_after_llm_backed_turns() -> None: + checkpointer = InMemoryCheckpointer(checkpointing_interval=CheckpointingInterval.LLM_TURNS) + dummy_llm = DummyModel() + dummy_llm.set_next_output("Hello from the prompt step.") + flow = create_single_step_flow( + PromptExecutionStep( + llm=dummy_llm, + prompt_template="Say hello.", + ) + ) + + status = flow.start_conversation( + conversation_id="llm-turns", checkpointer=checkpointer + ).execute() + + assert isinstance(status, FinishedStatus) + checkpoints = checkpointer.list_checkpoints("llm-turns") + assert len(checkpoints) == 2 + assert checkpoints[0].metadata["save_reason"] == "internal_turn_boundary" + assert checkpoints[0].metadata["event_type"] == "FlowExecutionIterationStartedEvent" + assert checkpoints[0].metadata["llm_used_in_previous_turn"] is True + assert checkpoints[1].metadata["save_reason"] == "conversation_turn" + + +def test_checkpoint_serialization_context_registers_component_tree_as_external_refs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _FakeSerializationContext: + def __init__(self, root: Any) -> None: + self.root = root + self.external_refs: set[str] = set() + self.recorded_refs: Dict[str, Dict[str, Any]] = {} + + @staticmethod + def get_reference(obj: Any) -> str: + obj_id = getattr(obj, "id", id(obj)) + return f"{obj.__class__.__name__.lower()}/{obj_id}" + + def register_external_reference(self, obj: Any) -> None: + self.external_refs.add(self.get_reference(obj)) + + def record_obj_dict(self, obj: Any, obj_as_dict: Dict[str, Any]) -> None: + self.recorded_refs[self.get_reference(obj)] = obj_as_dict + + monkeypatch.setattr( + checkpoint_serialization, + "SerializationContext", + _FakeSerializationContext, + ) + + flow = create_single_step_flow(OutputMessageStep(message_template="Hello external refs.")) + conversation = flow.start_conversation(conversation_id="checkpoint-external-refs") + serialization_context = checkpoint_serialization._build_checkpoint_serialization_context( + conversation + ) + + expected_component_refs = { + _FakeSerializationContext.get_reference(component) + for component in checkpoint_serialization._iter_component_tree(conversation.component) + } + + assert serialization_context.external_refs == expected_component_refs + assert serialization_context.recorded_refs == {} + + +def test_explicit_final_checkpoint_parameters_can_be_retried_after_save_fails() -> None: + checkpointer = RecordingCheckpointer(fail_first_save=True) + flow = create_single_step_flow(OutputMessageStep(message_template="Hello overrides.")) + + conversation = flow.start_conversation( + conversation_id="checkpoint-final-overrides", checkpointer=checkpointer + ) + + with pytest.raises(RuntimeError, match="checkpoint save failed"): + _save_conversation_checkpoint( + conversation, + save_reason="conversation_turn", + checkpoint_id="final-checkpoint-id", + metadata={"response_id": "resp-123"}, + ) + + _save_conversation_checkpoint( + conversation, + save_reason="conversation_turn", + checkpoint_id="final-checkpoint-id", + metadata={"response_id": "resp-123"}, + ) + + assert conversation.checkpoint_id == "final-checkpoint-id" + checkpoint = checkpointer.load("checkpoint-final-overrides", "final-checkpoint-id") + assert checkpoint.metadata["response_id"] == "resp-123" + + +def test_execute_final_checkpoint_parameters_are_applied_to_final_save() -> None: + checkpointer = RecordingCheckpointer() + flow = create_single_step_flow(OutputMessageStep(message_template="Hello execute.")) + + conversation = flow.start_conversation( + conversation_id="checkpoint-final-execute", + checkpointer=checkpointer, + ) + status = conversation.execute( + _final_checkpoint_id="final-checkpoint-id", + _final_checkpoint_metadata={"response_id": "resp-123"}, + ) + + assert isinstance(status, FinishedStatus) + assert conversation.checkpoint_id == "final-checkpoint-id" + checkpoint = checkpointer.load("checkpoint-final-execute", "final-checkpoint-id") + assert checkpoint.metadata["response_id"] == "resp-123" + assert checkpoint.metadata["save_reason"] == "conversation_turn" + + +def test_execute_async_does_not_save_final_checkpoint_when_execution_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpointer = RecordingCheckpointer() + flow = create_single_step_flow(OutputMessageStep(message_template="Hello failure.")) + + conversation = flow.start_conversation( + conversation_id="checkpoint-final-exception", + checkpointer=checkpointer, + ) + monkeypatch.setattr( + conversation.component.runner, + "execute_async", + AsyncMock(side_effect=RuntimeError("runner failed")), + ) + + with pytest.raises(RuntimeError, match="runner failed"): + conversation.execute() + + assert conversation.checkpoint_id is None + assert checkpointer.saved_checkpoints == {} + + +def test_checkpoint_restore_rejects_conversations_from_other_components() -> None: + original_flow = create_single_step_flow(OutputMessageStep(message_template="Hello original.")) + other_flow = create_single_step_flow(OutputMessageStep(message_template="Hello other.")) + checkpointer = StaticLoadCheckpointer( + ConversationCheckpoint( + checkpoint_id="checkpoint-1", + conversation_id="checkpoint-other-component", + component_id=original_flow.id, + created_at=0, + state="unused-because-component-mismatch-is-checked-first", + metadata={}, + ) + ) + + with pytest.raises(ValueError, match="started with another component"): + other_flow.start_conversation( + conversation_id="checkpoint-other-component", + checkpointer=checkpointer, + ) + + +def test_restore_can_skip_attaching_live_checkpointer( + monkeypatch: pytest.MonkeyPatch, +) -> None: + flow = create_single_step_flow(OutputMessageStep(message_template="Hello restore.")) + checkpoint_id = "checkpoint-no-attach-id" + checkpointer = StaticLoadCheckpointer( + ConversationCheckpoint( + checkpoint_id=checkpoint_id, + conversation_id="checkpoint-no-attach", + component_id=flow.id, + created_at=0, + state="unused-because-deserialization-is-mocked", + metadata={}, + ) + ) + + monkeypatch.setattr( + checkpoint_serialization, + "_deserialize_conversation_checkpoint_state", + lambda *args, **kwargs: flow.start_conversation(conversation_id="checkpoint-no-attach"), + ) + + restored_conversation = flow.start_conversation( + conversation_id="checkpoint-no-attach", + checkpoint_id=checkpoint_id, + checkpointer=checkpointer, + _attach_checkpointer=False, + ) + + assert restored_conversation.checkpointer is None + assert restored_conversation.checkpoint_id == checkpoint_id + + +def test_legacy_serialized_conversation_id_restores_root_conversation_id() -> None: + agent, _ = _build_checkpointable_agent( + name="legacy_checkpoint_agent", + initial_message="Hello from the past.", + ) + conversation = agent.start_conversation(_root_conversation_id="legacy-root-conversation") + + assert conversation.root_conversation_id == "legacy-root-conversation" + assert ( + _resolve_legacy_field_name(type(conversation), "root_conversation_id") == "conversation_id" + ) + assert not hasattr(conversation, "conversation_id") + + +def test_flow_checkpointing_supports_resume_and_time_travel() -> None: + checkpointer = InMemoryCheckpointer() + flow = create_flow() + + conversation = flow.start_conversation( + conversation_id="flow-checkpoint", checkpointer=checkpointer + ) + first_status = conversation.execute() + + assert isinstance(first_status, UserMessageRequestStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + + restored_conversation = flow.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + ) + assert restored_conversation.checkpoint_id == first_checkpoint_id + assert isinstance(restored_conversation.status, UserMessageRequestStatus) + + restored_conversation.append_user_message("continue") + restored_status = restored_conversation.execute() + assert isinstance(restored_status, FinishedStatus) + + rewound_conversation = flow.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + checkpoint_id=first_checkpoint_id, + ) + assert len(rewound_conversation.get_messages()) < len(restored_conversation.get_messages()) + rewound_conversation.append_user_message("rewind") + rewound_status = rewound_conversation.execute() + assert isinstance(rewound_status, FinishedStatus) + + +def test_agent_checkpointing_supports_resume_and_time_travel() -> None: + checkpointer = InMemoryCheckpointer() + agent, llm = _build_checkpointable_agent( + name="checkpoint_agent", + initial_message="Hello from the agent.", + ) + + conversation = agent.start_conversation( + conversation_id="agent-checkpoint", checkpointer=checkpointer + ) + first_status = conversation.execute() + + assert isinstance(first_status, UserMessageRequestStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + + restored_conversation = agent.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + ) + llm.set_next_output("Agent resumed successfully.") + restored_conversation.append_user_message("Please continue.") + restored_status = restored_conversation.execute() + assert isinstance(restored_status, UserMessageRequestStatus) + assert restored_conversation.get_last_message().content == "Agent resumed successfully." + + rewound_conversation = agent.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + checkpoint_id=first_checkpoint_id, + ) + llm.set_next_output("Agent rewound successfully.") + rewound_conversation.append_user_message("Try again.") + rewound_status = rewound_conversation.execute() + assert isinstance(rewound_status, UserMessageRequestStatus) + assert rewound_conversation.get_last_message().content == "Agent rewound successfully." + + +def test_swarm_checkpointing_supports_resume_and_time_travel() -> None: + checkpointer = InMemoryCheckpointer() + swarm, llm = _build_checkpointable_swarm() + + conversation = swarm.start_conversation( + conversation_id="swarm-checkpoint", checkpointer=checkpointer + ) + first_status = conversation.execute() + + assert isinstance(first_status, UserMessageRequestStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + + restored_conversation = swarm.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + ) + llm.set_next_output("Swarm resumed successfully.") + restored_conversation.append_user_message("Continue the swarm conversation.") + restored_status = restored_conversation.execute() + assert isinstance(restored_status, UserMessageRequestStatus) + assert restored_conversation.get_last_message().content == "Swarm resumed successfully." + + rewound_conversation = swarm.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + checkpoint_id=first_checkpoint_id, + ) + llm.set_next_output("Swarm rewound successfully.") + rewound_conversation.append_user_message("Try the swarm again.") + rewound_status = rewound_conversation.execute() + assert isinstance(rewound_status, UserMessageRequestStatus) + assert rewound_conversation.get_last_message().content == "Swarm rewound successfully." + + +def test_managerworkers_checkpointing_supports_resume_and_time_travel() -> None: + checkpointer = InMemoryCheckpointer() + managerworkers, llm = _build_checkpointable_managerworkers() + + conversation = managerworkers.start_conversation( + conversation_id="managerworkers-checkpoint", + checkpointer=checkpointer, + ) + first_status = conversation.execute() + + assert isinstance(first_status, UserMessageRequestStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + + restored_conversation = managerworkers.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + ) + llm.set_next_output("Manager resumed successfully.") + restored_conversation.append_user_message("Continue the manager workflow.") + restored_status = restored_conversation.execute() + assert isinstance(restored_status, UserMessageRequestStatus) + assert restored_conversation.get_last_message().content == "Manager resumed successfully." + + rewound_conversation = managerworkers.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + checkpoint_id=first_checkpoint_id, + ) + llm.set_next_output("Manager rewound successfully.") + rewound_conversation.append_user_message("Try the manager workflow again.") + rewound_status = rewound_conversation.execute() + assert isinstance(rewound_status, UserMessageRequestStatus) + assert rewound_conversation.get_last_message().content == "Manager rewound successfully." + + +@retry_test(max_attempts=4) +def test_a2aagent_checkpointing_supports_resume_and_time_travel(a2a_agent: A2AAgent) -> None: + """ + Failure rate: 0 out of 20 + Observed on: 2026-03-23 + Average success time: 0.00 seconds per successful attempt + Average failure time: No time measurement + Max attempt: 4 + Justification: (0.05 ** 4) ~= 0.6 / 100'000 + """ + checkpointer = InMemoryCheckpointer() + + conversation = a2a_agent.start_conversation( + conversation_id="a2a-checkpoint", checkpointer=checkpointer + ) + conversation.append_user_message("What is 5+5? Just output the answer.") + first_status = conversation.execute() + + assert isinstance(first_status, UserMessageRequestStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + first_message_count = len(conversation.get_messages()) + + restored_conversation = a2a_agent.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + ) + assert restored_conversation.checkpoint_id == first_checkpoint_id + assert len(restored_conversation.get_messages()) == first_message_count + restored_conversation.append_user_message( + "What if you replace 5 by 10? Just output the answer." + ) + restored_status = restored_conversation.execute() + + assert isinstance(restored_status, UserMessageRequestStatus) + assert len(restored_conversation.get_messages()) > first_message_count + assert restored_conversation.get_last_message() is not None + assert checkpointer.load_latest(conversation.id) is not None + + rewound_conversation = a2a_agent.start_conversation( + conversation_id=conversation.id, + checkpointer=checkpointer, + checkpoint_id=first_checkpoint_id, + ) + assert len(rewound_conversation.get_messages()) == first_message_count + assert len(rewound_conversation.get_messages()) < len(restored_conversation.get_messages()) + rewound_conversation.append_user_message("What if you replace 5 by 7? Just output the answer.") + rewound_status = rewound_conversation.execute() + + assert isinstance(rewound_status, UserMessageRequestStatus) + assert len(rewound_conversation.get_messages()) > first_message_count + assert rewound_conversation.get_last_message() is not None + + +def test_ociagent_explicitly_rejects_checkpoint_restore_arguments() -> None: + oci_agent = OciAgent( + agent_endpoint_id="ocid1.test.oc1..example", + client_config=OCIClientConfigWithApiKey(service_endpoint="https://example.com"), + name="checkpoint_oci_agent", + ) + + with pytest.raises(NotImplementedError, match="checkpoint restore"): + oci_agent.start_conversation(checkpointer=InMemoryCheckpointer()) diff --git a/wayflowcore/tests/test_managerworkers.py b/wayflowcore/tests/test_managerworkers.py index 8141e4d5c..6b977cbb7 100644 --- a/wayflowcore/tests/test_managerworkers.py +++ b/wayflowcore/tests/test_managerworkers.py @@ -184,7 +184,7 @@ def test_managerworkers_can_execute_with_initial_params_passed_in_start_conversa conversation = group.start_conversation( messages=[Message(content="Please compute 3*4 + 2", message_type=MessageType.USER)], inputs={"USER": "Iris"}, - root_conversation_id="12345", + _root_conversation_id="12345", ) conversation.execute() diff --git a/wayflowcore/tests/test_swarm.py b/wayflowcore/tests/test_swarm.py index 330654872..193be9028 100644 --- a/wayflowcore/tests/test_swarm.py +++ b/wayflowcore/tests/test_swarm.py @@ -226,7 +226,7 @@ def test_can_execute_swarm_with_initial_params_passed_in_start_conversation( ) ], inputs={"USER": "Iris"}, - root_conversation_id="12345", + _root_conversation_id="12345", ) conversation.execute() diff --git a/wayflowcore/tests/tracing/spans/test_conversation_span.py b/wayflowcore/tests/tracing/spans/test_conversation_span.py index 24e6d7098..274d28efc 100644 --- a/wayflowcore/tests/tracing/spans/test_conversation_span.py +++ b/wayflowcore/tests/tracing/spans/test_conversation_span.py @@ -92,7 +92,7 @@ def test_span_serialization_format( assert serialized_span["span_type"] == str(span.__class__.__name__) for attribute_name in attributes_to_check: assert getattr(span, attribute_name) == serialized_span[attribute_name] - assert serialized_span["conversation.id"] == span.conversation.conversation_id + assert serialized_span["conversation.id"] == span.conversation.id assert serialized_span["conversation.name"] == span.conversation.name assert ( serialized_span["conversational_component.type"] == "Agent" From 5cb722d69e35a7c17df594ffef585dda94f0b8be Mon Sep 17 00:00:00 2001 From: jschweiz Date: Wed, 22 Apr 2026 11:12:56 +0200 Subject: [PATCH 3/4] [fix]: fix import --- wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py b/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py index 6a09bad18..a269babf9 100644 --- a/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py +++ b/wayflowcore/src/wayflowcore/agentserver/_storagehelpers.py @@ -7,7 +7,7 @@ from typing import Dict, Optional from wayflowcore.agentserver.serverstorageconfig import ServerStorageConfig -from wayflowcore.checkpointing.datastore import ( +from wayflowcore.checkpointing.datastorecheckpointer import ( _prepare_oracle_checkpoint_datastore, _prepare_postgres_checkpoint_datastore, ) From c2f878ee81034b7110e26290b23bf0e017010b89 Mon Sep 17 00:00:00 2001 From: Son Le Date: Fri, 15 May 2026 18:43:44 +0200 Subject: [PATCH 4/4] Harden checkpoint persistence and restore support --- docs/wayflowcore/requirements-docs.txt | 1 + .../source/core/api/conversation.rst | 15 + .../core/code_examples/howto_checkpointing.py | 2 +- .../howto_serialize_conversations.py | 8 +- .../core/howtoguides/howto_checkpointing.rst | 12 +- wayflowcore/requirements-dev.txt | 1 + wayflowcore/src/wayflowcore/a2a/a2aagent.py | 2 +- wayflowcore/src/wayflowcore/agent.py | 2 +- .../services/wayflowservice.py | 182 +++++++-- .../agentserver/serverstorageconfig.py | 4 + .../src/wayflowcore/checkpointing/__init__.py | 3 - .../checkpointing/_componentidentity.py | 259 ++++++++++++ .../wayflowcore/checkpointing/checkpointer.py | 12 +- .../checkpointing/checkpointeventlistener.py | 14 +- .../checkpointing/datastorecheckpointer.py | 223 ++++++---- .../checkpointing/serialization.py | 93 ++--- wayflowcore/src/wayflowcore/conversation.py | 4 + .../wayflowcore/conversationalcomponent.py | 19 +- .../src/wayflowcore/datastore/_relational.py | 10 + .../src/wayflowcore/datastore/oracle.py | 34 +- .../src/wayflowcore/datastore/postgres.py | 22 +- .../wayflowcore/executors/_agentexecutor.py | 2 + .../wayflowcore/executors/executionstatus.py | 34 +- wayflowcore/src/wayflowcore/flow.py | 2 +- wayflowcore/src/wayflowcore/managerworkers.py | 2 +- wayflowcore/src/wayflowcore/messagelist.py | 21 +- .../datastoresteps/datastorequerystep.py | 2 + .../src/wayflowcore/tools/servertools.py | 7 +- wayflowcore/tests/agentserver/conftest.py | 12 +- .../tests/agentserver/test_wayflow_server.py | 342 +++++++++++++++- wayflowcore/tests/datastores/conftest.py | 45 ++- .../tests/datastores/test_datastore.py | 18 +- wayflowcore/tests/search/conftest.py | 76 ++-- wayflowcore/tests/serialization/conftest.py | 2 + .../test_conversation_checkpointing.py | 382 ++++++++++++++++++ wayflowcore/tests/test_docstring.py | 52 +-- wayflowcore/tests/test_managerworkers.py | 14 +- wayflowcore/tests/transforms/conftest.py | 19 +- 38 files changed, 1639 insertions(+), 315 deletions(-) create mode 100644 wayflowcore/src/wayflowcore/checkpointing/_componentidentity.py diff --git a/docs/wayflowcore/requirements-docs.txt b/docs/wayflowcore/requirements-docs.txt index 7ca2c04ef..85d42a9a3 100644 --- a/docs/wayflowcore/requirements-docs.txt +++ b/docs/wayflowcore/requirements-docs.txt @@ -1,4 +1,5 @@ sphinx==8.1.3 +sphinx-autodoc-typehints==3.0.1 sphinx-substitution-extensions==2022.2.16 sphinx-tabs==3.4.5 sphinx-copybutton==0.5.2 diff --git a/docs/wayflowcore/source/core/api/conversation.rst b/docs/wayflowcore/source/core/api/conversation.rst index fc70aba07..308070f7c 100644 --- a/docs/wayflowcore/source/core/api/conversation.rst +++ b/docs/wayflowcore/source/core/api/conversation.rst @@ -57,6 +57,21 @@ Base class for conversations. Can manipulate a conversation object, and can be s .. _conversation: .. autoclass:: wayflowcore.conversation.Conversation +Checkpointing +------------- + +.. _conversationcheckpoint: +.. autoclass:: wayflowcore.checkpointing.checkpointer.ConversationCheckpoint + +.. _checkpointinginterval: +.. autoclass:: wayflowcore.checkpointing.checkpointer.CheckpointingInterval + +.. _storageconfig: +.. autoclass:: wayflowcore.checkpointing.checkpointer.StorageConfig + +.. _checkpointer: +.. autoclass:: wayflowcore.checkpointing.checkpointer.Checkpointer + Execution Plan -------------- diff --git a/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py b/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py index bcfc88cc4..795a50546 100644 --- a/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py +++ b/docs/wayflowcore/source/core/code_examples/howto_checkpointing.py @@ -25,7 +25,7 @@ from wayflowcore import Agent from wayflowcore.checkpointing import InMemoryCheckpointer -agent = Agent(llm=llm) +agent = Agent(llm=llm, agent_id="support-agent") checkpointer = InMemoryCheckpointer() conversation = agent.start_conversation( diff --git a/docs/wayflowcore/source/core/code_examples/howto_serialize_conversations.py b/docs/wayflowcore/source/core/code_examples/howto_serialize_conversations.py index 0dc08cff3..31b9a96c3 100644 --- a/docs/wayflowcore/source/core/code_examples/howto_serialize_conversations.py +++ b/docs/wayflowcore/source/core/code_examples/howto_serialize_conversations.py @@ -43,7 +43,7 @@ def store_conversation(path: str, conversation: Conversation) -> str: """Store the given conversation and return the conversation id.""" - conversation_id = conversation.conversation_id + conversation_id = conversation.id serialized_conversation = serialize(conversation) # Read existing data @@ -87,7 +87,7 @@ def load_conversation(path: str, conversation_id: str) -> Conversation: # .. start-##_Run_the_agent # Start a conversation conversation = assistant.start_conversation() -conversation_id = conversation.conversation_id +conversation_id = conversation.id print(f"1. Started conversation with ID: {conversation_id}") # Execute initial greeting @@ -155,7 +155,7 @@ def load_conversation(path: str, conversation_id: str) -> Conversation: # .. end-##_Creating_a_flow # .. start-##_Run_the_flow flow_conversation = simple_flow.start_conversation() -flow_id = flow_conversation.conversation_id +flow_id = flow_conversation.id print(f"1. Started flow conversation with ID: {flow_id}") # Execute until user input is needed @@ -208,7 +208,7 @@ def run_persistent_agent(assistant: Agent, store_path: str, conversation_id: str conversation = assistant.start_conversation() else: conversation = assistant.start_conversation() - print(f"Started new conversation {conversation.conversation_id}") + print(f"Started new conversation {conversation.id}") # Main conversation loop while True: diff --git a/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst b/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst index e22d9a64e..568c44c87 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_checkpointing.rst @@ -1,8 +1,8 @@ .. _top-howtocheckpointing: -========================================= +========================================== How to Checkpoint and Resume Conversations -========================================= +========================================== .. admonition:: Prerequisites @@ -38,7 +38,10 @@ Start a checkpointed conversation ================================= Attach a checkpointer when you start the conversation. ``conversation_id`` becomes the durable key -used to look up the conversation later. +used to look up the conversation later. For persistent checkpointers such as PostgreSQL or Oracle +Database, construct the agent, flow, or other top-level component with a stable ``id`` or +component-specific id alias, such as ``agent_id`` for ``Agent``. The same component id must be used +after a process restart to restore the checkpoint safely. .. literalinclude:: ../code_examples/howto_checkpointing.py :language: python @@ -46,8 +49,7 @@ used to look up the conversation later. :end-before: .. end-##_Start_a_checkpointed_conversation Once checkpointing is enabled, WayFlow saves the root conversation automatically at the configured -checkpoint boundaries. For nested execution lineage without checkpoint restore, pass -``root_conversation_id`` explicitly. +checkpoint boundaries. Resume the latest checkpoint diff --git a/wayflowcore/requirements-dev.txt b/wayflowcore/requirements-dev.txt index 13af2b9c9..8660c9bef 100644 --- a/wayflowcore/requirements-dev.txt +++ b/wayflowcore/requirements-dev.txt @@ -1,5 +1,6 @@ # For docs sphinx==8.1.3 +sphinx-autodoc-typehints==3.0.1 sphinx-substitution-extensions==2022.2.16 sphinx-tabs==3.4.5 sphinx-copybutton==0.5.2 diff --git a/wayflowcore/src/wayflowcore/a2a/a2aagent.py b/wayflowcore/src/wayflowcore/a2a/a2aagent.py index 7a654ad3f..3bd410f55 100644 --- a/wayflowcore/src/wayflowcore/a2a/a2aagent.py +++ b/wayflowcore/src/wayflowcore/a2a/a2aagent.py @@ -277,7 +277,7 @@ def start_conversation( Returns ------- - A2AAgentConversation + Conversation A new or restored A2A agent conversation. """ from wayflowcore.executors._a2aagentconversation import A2AAgentConversation diff --git a/wayflowcore/src/wayflowcore/agent.py b/wayflowcore/src/wayflowcore/agent.py index a74d66fe0..bebe9f6c7 100644 --- a/wayflowcore/src/wayflowcore/agent.py +++ b/wayflowcore/src/wayflowcore/agent.py @@ -413,7 +413,7 @@ def start_conversation( Returns ------- - AgentConversation + Conversation A new or restored agent conversation. """ from wayflowcore.events.event import ConversationCreatedEvent diff --git a/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py b/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py index 58a4d13ea..0e7805f05 100644 --- a/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py +++ b/wayflowcore/src/wayflowcore/agentserver/openairesponses/services/wayflowservice.py @@ -15,6 +15,7 @@ from wayflowcore.agentserver.serverstorageconfig import ServerStorageConfig from wayflowcore.checkpointing import ConversationCheckpoint, DatastoreCheckpointer +from wayflowcore.checkpointing.checkpointeventlistener import _build_checkpoint_metadata from wayflowcore.conversation import Conversation from wayflowcore.conversationalcomponent import ConversationalComponent from wayflowcore.datastore import Datastore, InMemoryDatastore @@ -51,6 +52,10 @@ logger = logging.getLogger(__name__) +_OPENAI_RESPONSES_CHECKPOINT_FORMAT_VERSION = 2 +_OPENAI_RESPONSES_CHECKPOINT_FORMAT_VERSION_KEY = "wayflow_openai_responses_format_version" +_OPENAI_RESPONSES_SERVER_MODEL_ID_KEY = "server_model_id" + class WayFlowOpenAIResponsesService(OpenAIResponsesService): def __init__( @@ -127,13 +132,54 @@ async def get_response( raise HTTPException( status_code=http_status_code.HTTP_404_NOT_FOUND, detail="Response not found" ) - response_as_txt = checkpoint.metadata.get("response") - if not isinstance(response_as_txt, str): + response_as_txt = self._get_checkpoint_response_text(checkpoint) + if response_as_txt is None: raise HTTPException( status_code=http_status_code.HTTP_404_NOT_FOUND, detail="Response not found" ) return Response.model_validate_json(response_as_txt) + def _get_checkpoint_response_text(self, checkpoint: ConversationCheckpoint) -> Optional[str]: + response_as_txt = checkpoint.metadata.get("response") + return response_as_txt if isinstance(response_as_txt, str) else None + + def _get_checkpoint_response_model(self, checkpoint: ConversationCheckpoint) -> Optional[str]: + response_as_txt = self._get_checkpoint_response_text(checkpoint) + if response_as_txt is None: + return None + try: + return Response.model_validate_json(response_as_txt).model + except ValueError: + return None + + def _get_checkpoint_model_identity(self, checkpoint: ConversationCheckpoint) -> Optional[str]: + checkpoint_model = self._get_checkpoint_response_model(checkpoint) + if checkpoint_model is not None: + return checkpoint_model + metadata_model = checkpoint.metadata.get(_OPENAI_RESPONSES_SERVER_MODEL_ID_KEY) + return metadata_model if isinstance(metadata_model, str) else None + + def _validate_checkpoint_response_model( + self, + *, + checkpoint: ConversationCheckpoint, + agent_id: str, + resume_parameter_name: str, + resume_parameter_value: str, + checkpoint_description: str, + ) -> None: + checkpoint_model = self._get_checkpoint_model_identity(checkpoint) + if checkpoint_model is not None and checkpoint_model != agent_id: + raise HTTPException( + status_code=http_status_code.HTTP_400_BAD_REQUEST, + detail=( + f"Cannot use {resume_parameter_name} `{resume_parameter_value}` with " + f"model `{agent_id}` because the {checkpoint_description} was created " + f"with model `{checkpoint_model}`. Use the same model as the " + f"{checkpoint_description}." + ), + ) + async def delete_response(self, response_id: str) -> Optional[ResponseError]: checkpoint = self._lookup_checkpoint_by_response_id(response_id) if checkpoint is not None: @@ -257,7 +303,7 @@ async def runner(conversation: Conversation) -> None: try: with register_event_listeners([token_usage_listener, yielding_listener]): status = await conversation.execute_async( - _final_checkpoint_id=response_id if should_store_response else None, + _save_final_checkpoint=not should_store_response, ) except Exception as e: nonlocal raised_exception @@ -266,12 +312,13 @@ async def runner(conversation: Conversation) -> None: # close the send side so the receiver side's async for terminates await send_stream.aclose() - async with anyio.create_task_group() as tg: - tg.start_soon(runner, state) + async with receive_stream: + async with anyio.create_task_group() as tg: + tg.start_soon(runner, state) - async for ev in receive_stream: - # These events come from the synchronous callback - yield ev + async for ev in receive_stream: + # These events come from the synchronous callback + yield ev if raised_exception: if "not a multimodal model" in str(raised_exception): @@ -306,7 +353,18 @@ async def runner(conversation: Conversation) -> None: self.checkpointer.save_conversation( state, checkpoint_id=current_response.id, - metadata={"response": current_response.model_dump_json()}, + component_id=current_response.model, + metadata=_build_checkpoint_metadata( + state, + save_reason="conversation_turn", + metadata={ + _OPENAI_RESPONSES_CHECKPOINT_FORMAT_VERSION_KEY: ( + _OPENAI_RESPONSES_CHECKPOINT_FORMAT_VERSION + ), + _OPENAI_RESPONSES_SERVER_MODEL_ID_KEY: current_response.model, + "response": current_response.model_dump_json(), + }, + ), ) self._response_conversation_ids[current_response.id] = state.id @@ -371,53 +429,103 @@ def _load_state( status_code=http_status_code.HTTP_404_NOT_FOUND, detail=f"No previous response with id `{previous_response_id}` was found", ) + self._validate_checkpoint_response_model( + checkpoint=checkpoint, + agent_id=agent_id, + resume_parameter_name="previous_response_id", + resume_parameter_value=previous_response_id, + checkpoint_description="previous response", + ) self._response_conversation_ids[checkpoint.checkpoint_id] = checkpoint.conversation_id - try: - return self.agents[agent_id].start_conversation( - conversation_id=checkpoint.conversation_id, - checkpoint_id=checkpoint.checkpoint_id, - checkpointer=self.checkpointer, - _attach_checkpointer=attach_checkpointer, - ) - except (TypeError, ValueError) as e: - raise HTTPException( - status_code=http_status_code.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Conversation state is corrupted, it cannot be de-serialized: {e}", - ) from e + return self._restore_checkpoint_state( + checkpoint=checkpoint, + agent_id=agent_id, + attach_checkpointer=attach_checkpointer, + ) elif conversation_id: - checkpoint = self.checkpointer.load_latest(conversation_id) + checkpoint = self._lookup_latest_response_checkpoint_by_conversation_id(conversation_id) if checkpoint is None: raise HTTPException( status_code=http_status_code.HTTP_404_NOT_FOUND, detail=f"No conversation with id `{conversation_id}` was found", ) + self._validate_checkpoint_response_model( + checkpoint=checkpoint, + agent_id=agent_id, + resume_parameter_name="conversation", + resume_parameter_value=conversation_id, + checkpoint_description="latest response in that conversation", + ) self._response_conversation_ids[checkpoint.checkpoint_id] = checkpoint.conversation_id - try: - return self.agents[agent_id].start_conversation( - conversation_id=conversation_id, - checkpointer=self.checkpointer, - _attach_checkpointer=attach_checkpointer, - ) - except (TypeError, ValueError) as e: - raise HTTPException( - status_code=http_status_code.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Conversation state is corrupted, it cannot be de-serialized: {e}", - ) from e + return self._restore_checkpoint_state( + checkpoint=checkpoint, + agent_id=agent_id, + attach_checkpointer=attach_checkpointer, + ) else: return None + def _restore_checkpoint_state( + self, + *, + checkpoint: ConversationCheckpoint, + agent_id: str, + attach_checkpointer: bool, + ) -> Conversation: + agent = self.agents[agent_id] + accepted_checkpoint_component_ids = [] + if self._get_checkpoint_model_identity(checkpoint) == agent_id: + accepted_checkpoint_component_ids.append(checkpoint.component_id) + + try: + return cast( + Conversation, + agent._restore_checkpointed_conversation( + checkpoint=checkpoint, + checkpointer=self.checkpointer, + expected_conversation_type=agent.conversation_class, + attach_checkpointer=attach_checkpointer, + accepted_checkpoint_component_ids=accepted_checkpoint_component_ids, + ), + ) + except (TypeError, ValueError) as e: + raise HTTPException( + status_code=http_status_code.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Conversation state is corrupted, it cannot be de-serialized: {e}", + ) from e + def _lookup_checkpoint_by_response_id( self, response_id: str ) -> Optional[ConversationCheckpoint]: conversation_id = self._response_conversation_ids.get(response_id) if conversation_id is not None: try: - return self.checkpointer.load(conversation_id, response_id) + checkpoint = self.checkpointer.load(conversation_id, response_id) except ValueError: self._response_conversation_ids.pop(response_id, None) - checkpoint = self.checkpointer._find_checkpoint_by_id(response_id) - if checkpoint is not None: - self._response_conversation_ids[response_id] = checkpoint.conversation_id + return None + if self._get_checkpoint_response_text(checkpoint) is not None: + return checkpoint + self._response_conversation_ids.pop(response_id, None) + return None + checkpoint_by_id = self.checkpointer._find_checkpoint_by_id(response_id) + if checkpoint_by_id is None or self._get_checkpoint_response_text(checkpoint_by_id) is None: + return None + self._response_conversation_ids[response_id] = checkpoint_by_id.conversation_id + return checkpoint_by_id + + def _lookup_latest_response_checkpoint_by_conversation_id( + self, conversation_id: str + ) -> Optional[ConversationCheckpoint]: + response_checkpoints = [ + checkpoint + for checkpoint in self.checkpointer.list_checkpoints(conversation_id, limit=None) + if self._get_checkpoint_response_text(checkpoint) is not None + ] + if not response_checkpoints: + return None + checkpoint = response_checkpoints[-1] + self._response_conversation_ids[checkpoint.checkpoint_id] = checkpoint.conversation_id return checkpoint async def _create_state( diff --git a/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py b/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py index f2701cacb..0764c2e0a 100644 --- a/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py +++ b/wayflowcore/src/wayflowcore/agentserver/serverstorageconfig.py @@ -5,10 +5,14 @@ # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. from dataclasses import dataclass +from typing import Optional from wayflowcore.checkpointing import StorageConfig +from wayflowcore.datastore.datastore import Datastore @dataclass class ServerStorageConfig(StorageConfig): """Configuration for agent-server conversation storage.""" + + datastore: Optional[Datastore] = None diff --git a/wayflowcore/src/wayflowcore/checkpointing/__init__.py b/wayflowcore/src/wayflowcore/checkpointing/__init__.py index 7db08002e..2da450b0c 100644 --- a/wayflowcore/src/wayflowcore/checkpointing/__init__.py +++ b/wayflowcore/src/wayflowcore/checkpointing/__init__.py @@ -4,9 +4,6 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from importlib import import_module -from typing import Any - from .checkpointer import Checkpointer, CheckpointingInterval, ConversationCheckpoint, StorageConfig from .datastorecheckpointer import ( DatastoreCheckpointer, diff --git a/wayflowcore/src/wayflowcore/checkpointing/_componentidentity.py b/wayflowcore/src/wayflowcore/checkpointing/_componentidentity.py new file mode 100644 index 000000000..7b886a6f6 --- /dev/null +++ b/wayflowcore/src/wayflowcore/checkpointing/_componentidentity.py @@ -0,0 +1,259 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence + +from wayflowcore.idgeneration import IdGenerator +from wayflowcore.serialization.context import DeserializationContext, SerializationContext + +if TYPE_CHECKING: + from wayflowcore.component import Component + from wayflowcore.conversation import Conversation + + +CHECKPOINT_COMPONENT_REFERENCES_KEY = "component_references" + +_COMPONENT_TYPE_KEY = "component_type" +_COMPONENT_NAME_KEY = "name" +_COMPONENT_STABLE_IDS_KEY = "stable_ids" + + +def iter_checkpoint_conversation_graph( + root_conversation: "Conversation", +) -> Sequence["Conversation"]: + visited_conversation_ids: set[str] = set() + queue: List["Conversation"] = [root_conversation] + ordered_conversations: List["Conversation"] = [] + + while queue: + conversation = queue.pop() + if conversation.id in visited_conversation_ids: + continue + visited_conversation_ids.add(conversation.id) + ordered_conversations.append(conversation) + queue.extend(conversation._get_all_sub_conversations()) + + return ordered_conversations + + +def _iter_checkpoint_child_components(component: "Component") -> Iterator["Component"]: + from wayflowcore.agent import Agent + from wayflowcore.component import Component + from wayflowcore.flow import Flow + from wayflowcore.managerworkers import ManagerWorkers + from wayflowcore.steps.agentexecutionstep import AgentExecutionStep + from wayflowcore.swarm import Swarm + + if isinstance(component, Agent): + for agent_child_component in [*component.agents, *component.flows]: + if isinstance(agent_child_component, Component): + yield agent_child_component + return + elif isinstance(component, Flow): + for step in component.steps.values(): + if isinstance(step, AgentExecutionStep) and isinstance(step.agent, Component): + yield step.agent + for sub_flow in step.sub_flows() or []: + if isinstance(sub_flow, Component): + yield sub_flow + return + elif isinstance(component, ManagerWorkers): + for manager_child_component in [component.manager_agent, *component.workers]: + if isinstance(manager_child_component, Component): + yield manager_child_component + return + elif isinstance(component, Swarm): + for swarm_child_component in component._agent_by_name.values(): + if isinstance(swarm_child_component, Component): + yield swarm_child_component + + +def iter_checkpoint_component_tree(component: "Component") -> Sequence["Component"]: + visited_component_refs: set[str] = set() + ordered_components: List["Component"] = [] + queue: List["Component"] = [component] + + while queue: + current_component = queue.pop() + current_component_ref = SerializationContext.get_reference(current_component) + if current_component_ref in visited_component_refs: + continue + visited_component_refs.add(current_component_ref) + ordered_components.append(current_component) + queue.extend(_iter_checkpoint_child_components(current_component)) + + return ordered_components + + +def _stable_component_name(component: "Component") -> Optional[str]: + component_name = getattr(component, "name", None) + if not isinstance(component_name, str) or IdGenerator.is_auto_generated(component_name): + return None + return component_name + + +def build_checkpoint_component_references( + root_component: "Component", + *, + root_component_id: Optional[str] = None, +) -> Dict[str, Dict[str, Any]]: + component_references: Dict[str, Dict[str, Any]] = {} + for component in iter_checkpoint_component_tree(root_component): + stable_ids = [component.id] + if component is root_component and root_component_id is not None: + stable_ids.append(root_component_id) + + descriptor: Dict[str, Any] = { + _COMPONENT_TYPE_KEY: component.__class__.__name__, + _COMPONENT_STABLE_IDS_KEY: sorted(set(stable_ids)), + } + component_name = _stable_component_name(component) + if component_name is not None: + descriptor[_COMPONENT_NAME_KEY] = component_name + component_references[SerializationContext.get_reference(component)] = descriptor + return component_references + + +def _component_reference_for_id(component: "Component", component_id: str) -> str: + return f"{component.__class__.__name__.lower()}/{component_id}" + + +def _add_unique_index_value( + index: Dict[tuple[str, str], Optional["Component"]], + key: tuple[str, str], + component: "Component", +) -> None: + existing_component = index.get(key) + if existing_component is component: + return + if key in index: + index[key] = None + else: + index[key] = component + + +def _build_current_component_index( + current_components: Sequence["Component"], + *, + root_component: "Component", + root_component_id_aliases: Iterable[str], +) -> Dict[tuple[str, str], Optional["Component"]]: + current_components_by_identity: Dict[tuple[str, str], Optional["Component"]] = {} + for current_component in current_components: + component_type = current_component.__class__.__name__ + _add_unique_index_value( + current_components_by_identity, + (component_type, current_component.id), + current_component, + ) + + component_name = _stable_component_name(current_component) + if component_name is not None: + _add_unique_index_value( + current_components_by_identity, + (component_type, component_name), + current_component, + ) + + root_component_type = root_component.__class__.__name__ + for root_component_id_alias in root_component_id_aliases: + current_components_by_identity[(root_component_type, root_component_id_alias)] = ( + root_component + ) + + return current_components_by_identity + + +def register_checkpoint_component_references( + *, + deserialization_context: DeserializationContext, + root_component: Optional["Component"], + component_references: Any, + root_component_id_aliases: Optional[Sequence[str]] = None, +) -> None: + if root_component is None: + return + + root_component_id_aliases = root_component_id_aliases or [] + current_components = iter_checkpoint_component_tree(root_component) + for current_tree_component in current_components: + deserialization_context.recorddeserialized_object( + SerializationContext.get_reference(current_tree_component), + current_tree_component, + ) + for root_component_id_alias in root_component_id_aliases: + deserialization_context.recorddeserialized_object( + _component_reference_for_id(root_component, root_component_id_alias), + root_component, + ) + + if not isinstance(component_references, dict): + return + + current_components_by_identity = _build_current_component_index( + current_components, + root_component=root_component, + root_component_id_aliases=root_component_id_aliases, + ) + + for serialized_reference, descriptor in component_references.items(): + if not isinstance(serialized_reference, str) or not isinstance(descriptor, dict): + continue + + component_type = descriptor.get(_COMPONENT_TYPE_KEY) + if not isinstance(component_type, str): + continue + + matched_component: Optional["Component"] = None + stable_ids = descriptor.get(_COMPONENT_STABLE_IDS_KEY) + if isinstance(stable_ids, list): + for stable_id in stable_ids: + if not isinstance(stable_id, str): + continue + matched_component = current_components_by_identity.get((component_type, stable_id)) + if matched_component is not None: + break + + if matched_component is None: + component_name = descriptor.get(_COMPONENT_NAME_KEY) + if isinstance(component_name, str): + matched_component = current_components_by_identity.get( + (component_type, component_name) + ) + + if matched_component is not None: + deserialization_context.recorddeserialized_object( + serialized_reference, + matched_component, + ) + + +def normalize_restored_component_keyed_state(conversation: "Conversation") -> None: + for sub_conversation in iter_checkpoint_conversation_graph(conversation): + state = getattr(sub_conversation, "state", None) + if state is None: + continue + sub_component_conversations = getattr( + state, + "current_sub_component_conversations", + None, + ) + if not isinstance(sub_component_conversations, dict): + continue + + rekeyed_sub_component_conversations: Dict[str, Any] = {} + for child_conversation in sub_component_conversations.values(): + child_component = getattr(child_conversation, "component", None) + child_component_id = getattr(child_component, "id", None) + if not isinstance(child_component_id, str): + break + rekeyed_sub_component_conversations[child_component_id] = child_conversation + else: + setattr( + state, + "current_sub_component_conversations", + rekeyed_sub_component_conversations, + ) diff --git a/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py b/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py index 0ecafb48f..c9190cf34 100644 --- a/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py +++ b/wayflowcore/src/wayflowcore/checkpointing/checkpointer.py @@ -109,7 +109,8 @@ def save(self, checkpoint: Any) -> None: return if not isinstance(checkpoint, ConversationCheckpoint): raise TypeError( - f"Expected a Conversation or ConversationCheckpoint, got {type(checkpoint).__name__}." + "Expected a Conversation or ConversationCheckpoint, got " + f"{type(checkpoint).__name__}." ) self._save_checkpoint(checkpoint) @@ -121,19 +122,24 @@ def save_conversation( conversation: "Conversation", *, checkpoint_id: Optional[str] = None, + component_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> ConversationCheckpoint: next_save_sequence = self._save_sequence_by_conversation.get(conversation.id, 0) + 1 self._save_sequence_by_conversation[conversation.id] = next_save_sequence + checkpoint_component_id = component_id or conversation.component.id checkpoint_metadata = {"save_sequence": next_save_sequence} if metadata: checkpoint_metadata.update(metadata) checkpoint = ConversationCheckpoint( checkpoint_id=checkpoint_id or IdGenerator.get_or_generate_id(), conversation_id=conversation.id, - component_id=conversation.component.id, + component_id=checkpoint_component_id, created_at=int(time.time()), - state=_serialize_conversation_checkpoint_state(conversation), + state=_serialize_conversation_checkpoint_state( + conversation, + root_component_id=checkpoint_component_id, + ), metadata=checkpoint_metadata, ) self._save_checkpoint(checkpoint) diff --git a/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py b/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py index af750a3b2..c507509b2 100644 --- a/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py +++ b/wayflowcore/src/wayflowcore/checkpointing/checkpointeventlistener.py @@ -154,6 +154,7 @@ def get_conversation_checkpoint_execution_context( is_outermost_execution: bool, final_checkpoint_id: Optional[str] = None, final_checkpoint_metadata: Optional[Dict[str, Any]] = None, + save_final_checkpoint: bool = True, ) -> Iterator[None]: if conversation.checkpointer is None or not is_outermost_execution: with nullcontext(): @@ -170,9 +171,10 @@ def get_conversation_checkpoint_execution_context( raise else: listener.flush_pending_checkpoint() - _save_conversation_checkpoint( - conversation, - save_reason="conversation_turn", - checkpoint_id=final_checkpoint_id, - metadata=final_checkpoint_metadata, - ) + if save_final_checkpoint: + _save_conversation_checkpoint( + conversation, + save_reason="conversation_turn", + checkpoint_id=final_checkpoint_id, + metadata=final_checkpoint_metadata, + ) diff --git a/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py b/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py index 27bb5a2bb..9be842eba 100644 --- a/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py +++ b/wayflowcore/src/wayflowcore/checkpointing/datastorecheckpointer.py @@ -6,7 +6,6 @@ import json import warnings -from textwrap import dedent from typing import Any, Dict, List, Optional, Sequence from wayflowcore.datastore import ( @@ -47,26 +46,61 @@ def _build_checkpoint_create_table_columns( return columns +def _latest_checkpoint_index_name(storage_config: StorageConfig) -> str: + return f"{storage_config.table_name}_last_turn_idx" + + +def _build_postgres_latest_checkpoint_index_query(storage_config: StorageConfig) -> str: + return "\n".join( + [ + f"CREATE UNIQUE INDEX {_latest_checkpoint_index_name(storage_config)}", + f"ON {storage_config.table_name} ({storage_config.conversation_id_column_name})", + f"WHERE {storage_config.is_last_turn_column_name} = 1", + ] + ) + + +def _build_oracle_latest_checkpoint_index_query(storage_config: StorageConfig) -> str: + case_expression = ( + f"CASE WHEN {storage_config.is_last_turn_column_name} = 1 " + f"THEN {storage_config.conversation_id_column_name} END" + ) + return "\n".join( + [ + f"CREATE UNIQUE INDEX {_latest_checkpoint_index_name(storage_config)}", + f"ON {storage_config.table_name} ({case_expression})", + ] + ) + + +def _datastore_already_setup_error(storage_config: StorageConfig) -> ValueError: + return ValueError( + "The datastore is already setup. Either delete the existing " + f'"{storage_config.table_name}" table or start the server with ' + "'--setup-datastore=no'." + ) + + def _prepare_postgres_checkpoint_datastore( connection_config: PostgresDatabaseConnectionConfig, storage_config: StorageConfig, ) -> None: from sqlalchemy.exc import ProgrammingError - create_table_query = dedent( - f""" - CREATE TABLE {storage_config.table_name} ( - {", ".join(_build_checkpoint_create_table_columns(storage_config, is_oracle=False))} - ); - """ - ) + columns = ", ".join(_build_checkpoint_create_table_columns(storage_config, is_oracle=False)) + create_table_query = f"CREATE TABLE {storage_config.table_name} ({columns})" try: _execute_query_on_postgres_db(connection_config, create_table_query) + _execute_query_on_postgres_db( + connection_config, + _build_postgres_latest_checkpoint_index_query(storage_config), + ) except ProgrammingError as e: - if f'relation "{storage_config.table_name}" already exists' in str(e): - raise ValueError( - f'The datastore is already setup. Either delete the existing "{storage_config.table_name}" table or start the server with `--setup-datastore=no`.' - ) from e + latest_index_name = _latest_checkpoint_index_name(storage_config) + if f'relation "{storage_config.table_name}" already exists' in str( + e + ) or f'relation "{latest_index_name}" already exists' in str(e): + raise _datastore_already_setup_error(storage_config) from e raise @@ -74,20 +108,17 @@ def _prepare_oracle_checkpoint_datastore( connection_config: OracleDatabaseConnectionConfig, storage_config: StorageConfig, ) -> None: - create_table_query = dedent( - f""" - CREATE TABLE {storage_config.table_name} ( - {", ".join(_build_checkpoint_create_table_columns(storage_config, is_oracle=True))} - ); - """ - ) + columns = ", ".join(_build_checkpoint_create_table_columns(storage_config, is_oracle=True)) + create_table_query = f"CREATE TABLE {storage_config.table_name} ({columns})" try: _execute_query_on_oracle_db(connection_config, query=create_table_query) + _execute_query_on_oracle_db( + connection_config, + query=_build_oracle_latest_checkpoint_index_query(storage_config), + ) except Exception as e: if "already exists" in str(e): - raise ValueError( - f'The datastore is already setup. Either delete the existing "{storage_config.table_name}" table or start the server with `--setup-datastore=no`.' - ) from e + raise _datastore_already_setup_error(storage_config) from e raise @@ -194,11 +225,31 @@ def load(self, conversation_id: str, checkpoint_id: str) -> ConversationCheckpoi ) if checkpoint is None: raise ValueError( - f"Checkpoint `{checkpoint_id}` was not found for conversation `{conversation_id}`." + f"Checkpoint '{checkpoint_id}' was not found for conversation " + f"'{conversation_id}'." ) return checkpoint def _save_checkpoint(self, checkpoint: ConversationCheckpoint) -> None: + if isinstance(self.datastore, RelationalDatastore): + self._save_checkpoint_relational(checkpoint) + return + + self._save_checkpoint_non_relational(checkpoint) + + def _save_checkpoint_relational(self, checkpoint: ConversationCheckpoint) -> None: + from sqlalchemy.exc import IntegrityError + + max_attempts = 3 + for attempt in range(max_attempts): + try: + self._save_checkpoint_relational_once(checkpoint) + return + except IntegrityError: + if attempt == max_attempts - 1: + raise + + def _save_checkpoint_relational_once(self, checkpoint: ConversationCheckpoint) -> None: existing_checkpoint = self._find_checkpoint( conversation_id=checkpoint.conversation_id, checkpoint_id=checkpoint.checkpoint_id, @@ -220,55 +271,26 @@ def _save_checkpoint(self, checkpoint: ConversationCheckpoint) -> None: update_latest_values = {self.storage_config.is_last_turn_column_name: 0} entity = self._checkpoint_to_entity(checkpoint) - if isinstance(self.datastore, RelationalDatastore): - data_table = self.datastore.data_tables[self.storage_config.table_name] - with data_table.engine.connect() as connection: - connection.execute( - data_table._update_query( - where=update_latest_where, - update=update_latest_values, - ) + datastore = self.datastore + if not isinstance(datastore, RelationalDatastore): + raise TypeError("Relational checkpoint save requires a relational datastore.") + + data_table = datastore.data_tables[self.storage_config.table_name] + with data_table.engine.begin() as connection: + connection.execute( + data_table._update_query( + where=update_latest_where, + update=update_latest_values, ) - if existing_checkpoint is None: - sql_create_stmt, new_entities = data_table._create_query([entity]) - connection.execute(sql_create_stmt, new_entities) - else: - update_checkpoint_where = { - self.storage_config.conversation_id_column_name: checkpoint.conversation_id, - self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, - } - update_checkpoint_values = { - self.storage_config.agent_id_column_name: checkpoint.component_id, - self.storage_config.created_at_column_name: checkpoint.created_at, - self.storage_config.conversation_turn_state_column_name: checkpoint.state, - self.storage_config.is_last_turn_column_name: 1, - self.storage_config.extra_metadata_column_name: json.dumps( - checkpoint.metadata - ), - } - if self.storage_config.max_retention is not None: - update_checkpoint_values[self.storage_config.remove_by_column_name] = ( - checkpoint.created_at + self.storage_config.max_retention - ) - connection.execute( - data_table._update_query( - where=update_checkpoint_where, - update=update_checkpoint_values, - ) - ) - connection.commit() - else: - self.datastore.update( - collection_name=self.storage_config.table_name, - where=update_latest_where, - update=update_latest_values, ) if existing_checkpoint is None: - self.datastore.create( - collection_name=self.storage_config.table_name, - entities=[entity], - ) + sql_create_stmt, new_entities = data_table._create_query([entity]) + connection.execute(sql_create_stmt, new_entities) else: + update_checkpoint_where = { + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, + } update_checkpoint_values = { self.storage_config.agent_id_column_name: checkpoint.component_id, self.storage_config.created_at_column_name: checkpoint.created_at, @@ -280,14 +302,65 @@ def _save_checkpoint(self, checkpoint: ConversationCheckpoint) -> None: update_checkpoint_values[self.storage_config.remove_by_column_name] = ( checkpoint.created_at + self.storage_config.max_retention ) - self.datastore.update( - collection_name=self.storage_config.table_name, - where={ - self.storage_config.conversation_id_column_name: checkpoint.conversation_id, - self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, - }, - update=update_checkpoint_values, + connection.execute( + data_table._update_query( + where=update_checkpoint_where, + update=update_checkpoint_values, + ) + ) + + def _save_checkpoint_non_relational(self, checkpoint: ConversationCheckpoint) -> None: + existing_checkpoint = self._find_checkpoint( + conversation_id=checkpoint.conversation_id, + checkpoint_id=checkpoint.checkpoint_id, + ) + if existing_checkpoint is not None: + checkpoint = ConversationCheckpoint( + checkpoint_id=checkpoint.checkpoint_id, + conversation_id=checkpoint.conversation_id, + component_id=checkpoint.component_id, + created_at=checkpoint.created_at, + state=checkpoint.state, + metadata=existing_checkpoint.metadata | checkpoint.metadata, + ) + + update_latest_where = { + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.is_last_turn_column_name: 1, + } + update_latest_values = {self.storage_config.is_last_turn_column_name: 0} + entity = self._checkpoint_to_entity(checkpoint) + + self.datastore.update( + collection_name=self.storage_config.table_name, + where=update_latest_where, + update=update_latest_values, + ) + if existing_checkpoint is None: + self.datastore.create( + collection_name=self.storage_config.table_name, + entities=[entity], + ) + else: + update_checkpoint_values = { + self.storage_config.agent_id_column_name: checkpoint.component_id, + self.storage_config.created_at_column_name: checkpoint.created_at, + self.storage_config.conversation_turn_state_column_name: checkpoint.state, + self.storage_config.is_last_turn_column_name: 1, + self.storage_config.extra_metadata_column_name: json.dumps(checkpoint.metadata), + } + if self.storage_config.max_retention is not None: + update_checkpoint_values[self.storage_config.remove_by_column_name] = ( + checkpoint.created_at + self.storage_config.max_retention ) + self.datastore.update( + collection_name=self.storage_config.table_name, + where={ + self.storage_config.conversation_id_column_name: checkpoint.conversation_id, + self.storage_config.turn_id_column_name: checkpoint.checkpoint_id, + }, + update=update_checkpoint_values, + ) def list_checkpoints( self, conversation_id: str, limit: Optional[int] = 50 diff --git a/wayflowcore/src/wayflowcore/checkpointing/serialization.py b/wayflowcore/src/wayflowcore/checkpointing/serialization.py index d757a3bc0..133bc5250 100644 --- a/wayflowcore/src/wayflowcore/checkpointing/serialization.py +++ b/wayflowcore/src/wayflowcore/checkpointing/serialization.py @@ -4,7 +4,7 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, cast import yaml @@ -12,6 +12,15 @@ from wayflowcore.serialization.context import DeserializationContext, SerializationContext from wayflowcore.serialization.serializer import autodeserialize_from_dict +from ._componentidentity import ( + CHECKPOINT_COMPONENT_REFERENCES_KEY, + build_checkpoint_component_references, + iter_checkpoint_component_tree, + iter_checkpoint_conversation_graph, + normalize_restored_component_keyed_state, + register_checkpoint_component_references, +) + if TYPE_CHECKING: from wayflowcore.component import Component from wayflowcore.conversation import Conversation @@ -22,19 +31,7 @@ def _iter_conversation_graph(root_conversation: "Conversation") -> Sequence["Conversation"]: - visited_conversation_ids: set[str] = set() - queue: List["Conversation"] = [root_conversation] - ordered_conversations: List["Conversation"] = [] - - while queue: - conversation = queue.pop() - if conversation.id in visited_conversation_ids: - continue - visited_conversation_ids.add(conversation.id) - ordered_conversations.append(conversation) - queue.extend(conversation._get_all_sub_conversations()) - - return ordered_conversations + return iter_checkpoint_conversation_graph(root_conversation) def _ensure_checkpointing_supported(conversation: "Conversation") -> None: @@ -48,44 +45,7 @@ def _ensure_checkpointing_supported(conversation: "Conversation") -> None: def _iter_component_tree(component: "Component") -> Sequence["Component"]: - from wayflowcore.component import Component - - def _iter_nested_components(value: Any) -> List["Component"]: - if isinstance(value, Component): - return [value] - if isinstance(value, dict): - nested_components: List["Component"] = [] - for nested_value in value.values(): - nested_components.extend(_iter_nested_components(nested_value)) - return nested_components - if isinstance(value, (list, tuple, set)): - nested_components = [] - for nested_value in value: - nested_components.extend(_iter_nested_components(nested_value)) - return nested_components - return [] - - visited_component_ids: set[str] = set() - ordered_components: List["Component"] = [] - queue: List["Component"] = [component] - - while queue: - current_component = queue.pop() - current_component_ref = SerializationContext.get_reference(current_component) - if current_component_ref in visited_component_ids: - continue - visited_component_ids.add(current_component_ref) - ordered_components.append(current_component) - - all_public_attrs = { - name: value - for name, value in vars(current_component).items() - if not name.startswith("_") - } - for attr in all_public_attrs.values(): - queue.extend(_iter_nested_components(attr)) - - return ordered_components + return iter_checkpoint_component_tree(component) def _build_checkpoint_serialization_context(conversation: "Conversation") -> SerializationContext: @@ -95,7 +55,11 @@ def _build_checkpoint_serialization_context(conversation: "Conversation") -> Ser return serialization_context -def _serialize_conversation_checkpoint_state(conversation: "Conversation") -> str: +def _serialize_conversation_checkpoint_state( + conversation: "Conversation", + *, + root_component_id: Optional[str] = None, +) -> str: _ensure_checkpointing_supported(conversation) serialized_conversation = serialize_to_dict( @@ -107,6 +71,10 @@ def _serialize_conversation_checkpoint_state(conversation: "Conversation") -> st "checkpoint_format": _CHECKPOINT_ENVELOPE_FORMAT, "version": _CHECKPOINT_ENVELOPE_VERSION, "conversation": serialized_conversation, + CHECKPOINT_COMPONENT_REFERENCES_KEY: build_checkpoint_component_references( + conversation.component, + root_component_id=root_component_id, + ), } return yaml.safe_dump(envelope) @@ -116,14 +84,23 @@ def _deserialize_conversation_checkpoint_state( *, tool_registry: Optional[Dict[str, Any]] = None, component: Optional["Component"] = None, + root_component_id_aliases: Optional[Sequence[str]] = None, ) -> "Conversation": deserialization_context = DeserializationContext() deserialization_context.registered_tools = tool_registry.copy() if tool_registry else {} - if component is not None: - deserialization_context._add_component_to_context(component) - state_payload = yaml.safe_load(serialized_state) + component_references = ( + state_payload.get(CHECKPOINT_COMPONENT_REFERENCES_KEY) + if isinstance(state_payload, dict) + else None + ) + register_checkpoint_component_references( + deserialization_context=deserialization_context, + root_component=component, + component_references=component_references, + root_component_id_aliases=root_component_id_aliases, + ) if ( isinstance(state_payload, dict) and state_payload.get("checkpoint_format") == _CHECKPOINT_ENVELOPE_FORMAT @@ -140,4 +117,6 @@ def _deserialize_conversation_checkpoint_state( deserialization_context=deserialization_context, ) - return cast("Conversation", conversation) + restored_conversation = cast("Conversation", conversation) + normalize_restored_component_keyed_state(restored_conversation) + return restored_conversation diff --git a/wayflowcore/src/wayflowcore/conversation.py b/wayflowcore/src/wayflowcore/conversation.py index c2c2dadd2..8a58fa9f8 100644 --- a/wayflowcore/src/wayflowcore/conversation.py +++ b/wayflowcore/src/wayflowcore/conversation.py @@ -131,6 +131,7 @@ def execute( *, _final_checkpoint_id: Optional[str] = None, _final_checkpoint_metadata: Optional[Dict[str, Any]] = None, + _save_final_checkpoint: bool = True, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -144,6 +145,7 @@ async def _execute_async_wrapper() -> "ExecutionStatus": execution_interrupts, _final_checkpoint_id=_final_checkpoint_id, _final_checkpoint_metadata=_final_checkpoint_metadata, + _save_final_checkpoint=_save_final_checkpoint, ) return run_async_in_sync(_execute_async_wrapper, method_name="execute_async") @@ -154,6 +156,7 @@ async def execute_async( *, _final_checkpoint_id: Optional[str] = None, _final_checkpoint_metadata: Optional[Dict[str, Any]] = None, + _save_final_checkpoint: bool = True, ) -> "ExecutionStatus": """ Execute the conversation and get its ``ExecutionStatus`` based on the outcome. @@ -173,6 +176,7 @@ async def execute_async( is_outermost_execution=is_outermost_execution(), final_checkpoint_id=_final_checkpoint_id, final_checkpoint_metadata=_final_checkpoint_metadata, + save_final_checkpoint=_save_final_checkpoint, ): with _register_conversation(self): new_status = await self.component.runner.execute_async(self, execution_interrupts) diff --git a/wayflowcore/src/wayflowcore/conversationalcomponent.py b/wayflowcore/src/wayflowcore/conversationalcomponent.py index ee4ae5105..79cd35288 100644 --- a/wayflowcore/src/wayflowcore/conversationalcomponent.py +++ b/wayflowcore/src/wayflowcore/conversationalcomponent.py @@ -14,6 +14,7 @@ Generic, List, Optional, + Sequence, Set, Type, TypeVar, @@ -23,6 +24,7 @@ from wayflowcore._metadata import MetadataType from wayflowcore.componentwithio import ComponentWithInputsOutputs from wayflowcore.idgeneration import IdGenerator +from wayflowcore.messagelist import Message from wayflowcore.property import Property logger = logging.getLogger(__name__) @@ -33,7 +35,7 @@ from wayflowcore.checkpointing.checkpointer import ConversationCheckpoint from wayflowcore.conversation import Conversation from wayflowcore.executors._executor import ConversationExecutor - from wayflowcore.messagelist import Message, MessageList + from wayflowcore.messagelist import MessageList from wayflowcore.models.llmmodel import LlmModel from wayflowcore.tools import Tool @@ -238,21 +240,32 @@ def _restore_checkpointed_conversation( checkpointer: "Checkpointer", expected_conversation_type: Type[ConversationTypeT], attach_checkpointer: bool, + accepted_checkpoint_component_ids: Optional[Sequence[str]] = None, ) -> ConversationTypeT: from wayflowcore.checkpointing.serialization import ( _deserialize_conversation_checkpoint_state, ) - if checkpoint.component_id != self.id: + accepted_checkpoint_component_ids = accepted_checkpoint_component_ids or [] + if ( + checkpoint.component_id != self.id + and checkpoint.component_id not in accepted_checkpoint_component_ids + ): raise ValueError( "Cannot restore this checkpoint because this conversation was started with another " - f"component. Checkpoint component id: `{checkpoint.component_id}`. Current component id: `{self.id}`." + f"component. Checkpoint component id: `{checkpoint.component_id}`. Current component id: `{self.id}`. " + "For persistent restore across process restarts, construct the component with a stable `id` " + "or component-specific id alias, such as `agent_id` for `Agent`." ) + root_component_id_aliases = ( + [checkpoint.component_id] if checkpoint.component_id != self.id else [] + ) conversation = _deserialize_conversation_checkpoint_state( checkpoint.state, tool_registry={tool.name: tool for tool in self._referenced_tools()}, component=self, + root_component_id_aliases=root_component_id_aliases, ) if not isinstance(conversation, expected_conversation_type): raise ValueError( diff --git a/wayflowcore/src/wayflowcore/datastore/_relational.py b/wayflowcore/src/wayflowcore/datastore/_relational.py index a973127cd..dcf1a49d6 100644 --- a/wayflowcore/src/wayflowcore/datastore/_relational.py +++ b/wayflowcore/src/wayflowcore/datastore/_relational.py @@ -405,6 +405,16 @@ def __init__( description=description, ) + def close(self) -> None: + """Close pooled database connections held by this datastore.""" + self.engine.dispose() + + def __del__(self) -> None: + try: + self.close() + except Exception: + return + def _create_data_tables_from_entities(self) -> Dict[str, _RelationalDatatable]: metadata = sqlalchemy.MetaData() diff --git a/wayflowcore/src/wayflowcore/datastore/oracle.py b/wayflowcore/src/wayflowcore/datastore/oracle.py index 91f1616ed..618e26025 100644 --- a/wayflowcore/src/wayflowcore/datastore/oracle.py +++ b/wayflowcore/src/wayflowcore/datastore/oracle.py @@ -205,18 +205,22 @@ def __init__( ) self.engine = engine - super().__init__( - schema=schema, - engine=engine, - search_configs=search_configs, - vector_configs=vector_configs, - id=id, - name=IdGenerator.get_or_generate_name(name, prefix="oracle_datastore", length=8), - description=description, - __metadata_info__=__metadata_info__, - ) - - SerializableObject.__init__(self, None) + try: + super().__init__( + schema=schema, + engine=engine, + search_configs=search_configs, + vector_configs=vector_configs, + id=id, + name=IdGenerator.get_or_generate_name(name, prefix="oracle_datastore", length=8), + description=description, + __metadata_info__=__metadata_info__, + ) + + SerializableObject.__init__(self, None) + except Exception: + engine.dispose() + raise def _serialize_to_dict(self, serialization_context: SerializationContext) -> Dict[str, Any]: result: Dict[str, Any] = { @@ -344,4 +348,8 @@ def _execute_query_on_oracle_db( connection_config: OracleDatabaseConnectionConfig, query: str ) -> None: with connection_config.get_connection() as conn: - conn.cursor().execute(query) + cursor = conn.cursor() + try: + cursor.execute(query) + finally: + cursor.close() diff --git a/wayflowcore/src/wayflowcore/datastore/postgres.py b/wayflowcore/src/wayflowcore/datastore/postgres.py index 899985631..2a318e6b0 100644 --- a/wayflowcore/src/wayflowcore/datastore/postgres.py +++ b/wayflowcore/src/wayflowcore/datastore/postgres.py @@ -153,15 +153,19 @@ def __init__( """ self.connection_config = connection_config engine = connection_config.get_connection() - super().__init__( - schema, - engine, - name=name, - description=description, - id=id, - __metadata_info__=__metadata_info__, - ) - SerializableObject.__init__(self) + try: + super().__init__( + schema, + engine, + name=name, + description=description, + id=id, + __metadata_info__=__metadata_info__, + ) + SerializableObject.__init__(self) + except Exception: + engine.dispose() + raise def _serialize_to_dict(self, serialization_context: SerializationContext) -> Dict[str, Any]: return { diff --git a/wayflowcore/src/wayflowcore/executors/_agentexecutor.py b/wayflowcore/src/wayflowcore/executors/_agentexecutor.py index da62f6a8f..5bfbe3f85 100644 --- a/wayflowcore/src/wayflowcore/executors/_agentexecutor.py +++ b/wayflowcore/src/wayflowcore/executors/_agentexecutor.py @@ -648,6 +648,8 @@ def _handle_unknown_call( @staticmethod def _get_tool_response_message(content: Any, tool_request_id: str, agent_id: str) -> Message: + if isinstance(content, Exception): + content = str(content) return Message( tool_result=ToolResult( content=content, diff --git a/wayflowcore/src/wayflowcore/executors/executionstatus.py b/wayflowcore/src/wayflowcore/executors/executionstatus.py index a314e1f19..1b76e0afa 100644 --- a/wayflowcore/src/wayflowcore/executors/executionstatus.py +++ b/wayflowcore/src/wayflowcore/executors/executionstatus.py @@ -220,11 +220,41 @@ def _requires_yielding(self) -> bool: return True def _serialize_to_dict(self, serialization_context: "SerializationContext") -> Dict[str, Any]: + from wayflowcore.serialization.serializer import serialize_any_to_dict_or_stringify + return { - "tool_requests": [asdict(tool) for tool in self.tool_requests], + "tool_requests": [ + { + "name": tool.name, + "args": serialize_any_to_dict_or_stringify( + tool.args, + serialization_context, + ), + "tool_request_id": tool.tool_request_id, + "_extra_content": serialize_any_to_dict_or_stringify( + tool._extra_content, + serialization_context, + ), + "_requires_confirmation": tool._requires_confirmation, + "_tool_execution_confirmed": tool._tool_execution_confirmed, + "_tool_rejection_reason": tool._tool_rejection_reason, + } + for tool in self.tool_requests + ], "_conversation_id": self._conversation_id, "_tool_results": ( - [asdict(t) for t in self._tool_results] if self._tool_results is not None else None + [ + { + "content": serialize_any_to_dict_or_stringify( + tool_result.content, + serialization_context, + ), + "tool_request_id": tool_result.tool_request_id, + } + for tool_result in self._tool_results + ] + if self._tool_results is not None + else None ), "id": self.id, } diff --git a/wayflowcore/src/wayflowcore/flow.py b/wayflowcore/src/wayflowcore/flow.py index 9710ca7ce..921b775ea 100644 --- a/wayflowcore/src/wayflowcore/flow.py +++ b/wayflowcore/src/wayflowcore/flow.py @@ -1198,7 +1198,7 @@ def start_conversation( Returns ------- - FlowConversation + Conversation A new or restored flow conversation. """ from wayflowcore.events.event import ConversationCreatedEvent diff --git a/wayflowcore/src/wayflowcore/managerworkers.py b/wayflowcore/src/wayflowcore/managerworkers.py index a64c47fd5..8f82888a5 100644 --- a/wayflowcore/src/wayflowcore/managerworkers.py +++ b/wayflowcore/src/wayflowcore/managerworkers.py @@ -251,7 +251,7 @@ def start_conversation( Returns ------- - ManagerWorkersConversation + Conversation A new or restored manager-workers conversation. """ from wayflowcore.agentconversation import AgentConversation diff --git a/wayflowcore/src/wayflowcore/messagelist.py b/wayflowcore/src/wayflowcore/messagelist.py index e13ad91f1..2e664f51c 100644 --- a/wayflowcore/src/wayflowcore/messagelist.py +++ b/wayflowcore/src/wayflowcore/messagelist.py @@ -438,7 +438,10 @@ def content(self) -> str: return "" def _serialize_to_dict(self, serialization_context: "SerializationContext") -> Dict[str, Any]: - from wayflowcore.serialization.serializer import serialize_to_dict + from wayflowcore.serialization.serializer import ( + serialize_any_to_dict_or_stringify, + serialize_to_dict, + ) # NOTE: Manual deserialization is required because of the tool request and tool result objects @@ -456,7 +459,11 @@ def _serialize_to_dict(self, serialization_context: "SerializationContext") -> D "__metadata_info__": self.__metadata_info__, "tool_requests": ( [ - {"name": t.name, "args": t.args, "tool_request_id": t.tool_request_id} + { + "name": t.name, + "args": serialize_any_to_dict_or_stringify(t.args, serialization_context), + "tool_request_id": t.tool_request_id, + } for t in self.tool_requests ] if self.tool_requests is not None @@ -465,12 +472,18 @@ def _serialize_to_dict(self, serialization_context: "SerializationContext") -> D "tool_result": ( { "tool_request_id": self.tool_result.tool_request_id, - "content": self.tool_result.content, + "content": serialize_any_to_dict_or_stringify( + self.tool_result.content, + serialization_context, + ), } if self.tool_result is not None else None ), - "_extra_content": self._extra_content, + "_extra_content": serialize_any_to_dict_or_stringify( + self._extra_content, + serialization_context, + ), } @property diff --git a/wayflowcore/src/wayflowcore/steps/datastoresteps/datastorequerystep.py b/wayflowcore/src/wayflowcore/steps/datastoresteps/datastorequerystep.py index 974a048e4..bf19e5400 100644 --- a/wayflowcore/src/wayflowcore/steps/datastoresteps/datastorequerystep.py +++ b/wayflowcore/src/wayflowcore/steps/datastoresteps/datastorequerystep.py @@ -177,6 +177,8 @@ def __init__( ... TypeError: The input passed: `{'salary': '1', 'depname': 'sales'}` of type `dict` is not of the expected type ... + >>> datastore.close() + """ super().__init__( step_static_configuration=dict(datastore=datastore, query=query), diff --git a/wayflowcore/src/wayflowcore/tools/servertools.py b/wayflowcore/src/wayflowcore/tools/servertools.py index 9c85792e3..b6f1a91dd 100644 --- a/wayflowcore/src/wayflowcore/tools/servertools.py +++ b/wayflowcore/src/wayflowcore/tools/servertools.py @@ -339,8 +339,9 @@ async def _run( raise_exceptions=raise_exceptions, ) - tool_result = ToolResult( - content=output, tool_request_id=tool_request.tool_request_id + message_output = str(output) if isinstance(output, Exception) else output + message_tool_result = ToolResult( + content=message_output, tool_request_id=tool_request.tool_request_id ) sender = None recipients = None @@ -350,7 +351,7 @@ async def _run( conversation.state.current_tool_request = None conversation.message_list.append_message( Message( - tool_result=tool_result, + tool_result=message_tool_result, message_type=MessageType.TOOL_RESULT, sender=sender, recipients=recipients, diff --git a/wayflowcore/tests/agentserver/conftest.py b/wayflowcore/tests/agentserver/conftest.py index 336a54379..756f4bd55 100644 --- a/wayflowcore/tests/agentserver/conftest.py +++ b/wayflowcore/tests/agentserver/conftest.py @@ -9,7 +9,7 @@ import time from dataclasses import asdict from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import httpx import pytest @@ -30,6 +30,15 @@ from ..utils import LogTee, _terminate_process_tree, get_available_port from .datastore_agent_server import ORACLE_DB_CREATE_DDL, ORACLE_DB_DELETE_DDL +_SERVER_LOG_TEES: Dict[str, LogTee] = {} + + +def get_recent_server_logs(base_url: str) -> str: + tee = _SERVER_LOG_TEES.get(base_url.rstrip("/")) + if tee is None: + return "" + return tee.dump() + def _wait_for_http_ready(url: str, timeout: float) -> None: deadline = time.time() + timeout @@ -106,6 +115,7 @@ def _run_server( raise RuntimeError("Failed to capture server stdout") tee = LogTee(process.stdout, prefix="[uvicorn] ") tee.start() + _SERVER_LOG_TEES[url] = tee # Poll for readiness or early exit start = time.time() diff --git a/wayflowcore/tests/agentserver/test_wayflow_server.py b/wayflowcore/tests/agentserver/test_wayflow_server.py index 1a2c0f9fd..d8df45fad 100644 --- a/wayflowcore/tests/agentserver/test_wayflow_server.py +++ b/wayflowcore/tests/agentserver/test_wayflow_server.py @@ -6,29 +6,97 @@ import base64 import json from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union, cast import httpx import pytest - +import yaml +from fastapi.testclient import TestClient + +from wayflowcore._utils.async_helpers import run_async_in_sync +from wayflowcore.agentserver.openairesponses.models.openairesponsespydanticmodels import ( + Conversation2, + CreateResponse, + Response, +) +from wayflowcore.agentserver.openairesponses.services.wayflowservice import ( + WayFlowOpenAIResponsesService, +) +from wayflowcore.agentserver.server import OpenAIResponsesServer +from wayflowcore.flowhelpers import create_single_step_flow +from wayflowcore.idgeneration import IdGenerator from wayflowcore.messagelist import Message, MessageType from wayflowcore.models import OpenAIAPIType, OpenAICompatibleModel, StreamChunkType from wayflowcore.models.llmmodel import LlmCompletion, Prompt +from wayflowcore.serialization import serialize +from wayflowcore.steps import OutputMessageStep from wayflowcore.tools import ToolResult from ..testhelpers.testhelpers import retry_test -from .conftest import _get_api_key_headers, get_all_server_fixtures_name +from .conftest import _get_api_key_headers, get_all_server_fixtures_name, get_recent_server_logs all_available_servers = pytest.mark.parametrize( "server_fixture_name", get_all_server_fixtures_name() ) +_MAX_FAILURE_CONTEXT_CHARS = 12000 + + +def _truncate_for_failure_log(value: str, max_chars: int = _MAX_FAILURE_CONTEXT_CHARS) -> str: + if len(value) <= max_chars: + return value + return f"{value[:max_chars]}\n... truncated {len(value) - max_chars} chars ..." + + +def _format_json_for_failure_log(value: Any) -> str: + try: + return json.dumps(value, indent=2, sort_keys=True, default=str) + except TypeError: + return repr(value) + + +def _format_response_body_for_failure_log(response: httpx.Response) -> str: + try: + return _format_json_for_failure_log(response.json()) + except ValueError: + return response.text + + +def _raise_for_status_with_context( + response: httpx.Response, + *, + base_url: str, + payload: Dict[str, Any], +) -> None: + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + recent_logs = get_recent_server_logs(base_url) + message_parts = [ + str(exc), + "", + "Request payload:", + _truncate_for_failure_log(_format_json_for_failure_log(payload)), + "", + "Response body:", + _truncate_for_failure_log(_format_response_body_for_failure_log(response)), + ] + if recent_logs: + message_parts.extend( + [ + "", + "Recent server logs:", + _truncate_for_failure_log(recent_logs), + ] + ) + raise AssertionError("\n".join(message_parts)) from exc + def _create_response( base_url: str, input_value: Any, model: str = "hr-assistant", - headers: Dict[str, str] = None, + headers: Optional[Dict[str, str]] = None, **payload_fields: Any, ) -> Dict[str, Any]: payload: Dict[str, Any] = { @@ -37,10 +105,226 @@ def _create_response( **payload_fields, } response = httpx.post(f"{base_url}/v1/responses", json=payload, timeout=120.0, headers=headers) - response.raise_for_status() + _raise_for_status_with_context(response, base_url=base_url, payload=payload) return response.json() +def _seed_legacy_openai_response_row( + service: WayFlowOpenAIResponsesService, + *, + model_id: str, + conversation_id: str, + response_id: str, +) -> None: + """Seed storage with the row shape written by the pre-checkpointing server.""" + conversation = service.agents[model_id].start_conversation(conversation_id=conversation_id) + conversation.execute() + + serialized_conversation = yaml.safe_load(serialize(conversation)) + serialized_conversation.pop("root_conversation_id", None) + serialized_conversation["conversation_id"] = "legacy-deprecated-conversation-id" + + response = Response( + id=response_id, + created_at=1, + error=None, + incomplete_details=None, + instructions=None, + model=model_id, + object="response", + output=[], + parallel_tool_calls=True, + conversation=Conversation2(id=conversation_id), + status="completed", + ) + + service.storage.create( + collection_name=service.storage_config.table_name, + entities=[ + { + service.storage_config.agent_id_column_name: model_id, + service.storage_config.conversation_id_column_name: conversation_id, + service.storage_config.turn_id_column_name: response_id, + service.storage_config.created_at_column_name: 1, + service.storage_config.conversation_turn_state_column_name: yaml.safe_dump( + serialized_conversation + ), + service.storage_config.is_last_turn_column_name: 1, + service.storage_config.extra_metadata_column_name: json.dumps( + {"response": response.model_dump_json()} + ), + } + ], + ) + + +def _create_response_with_service( + service: WayFlowOpenAIResponsesService, + body: CreateResponse, +) -> Response: + async def collect_response() -> Response: + response = None + async for event in service.create_response(body): + if hasattr(event, "response"): + response = event.response + assert response is not None + return response + + return run_async_in_sync(collect_response, method_name="create_response") + + +@pytest.mark.filterwarnings("ignore:InMemoryDatastore is for DEVELOPMENT:UserWarning") +def test_openai_responses_restores_checkpoint_after_component_restart_with_same_model() -> None: + model_id = "restart-stable-model" + original_flow = create_single_step_flow(OutputMessageStep(message_template="original response")) + original_service = WayFlowOpenAIResponsesService(agents={model_id: original_flow}) + response = _create_response_with_service( + original_service, + CreateResponse(model=model_id, input="start"), + ) + checkpoint = original_service._lookup_checkpoint_by_response_id(response.id) + assert checkpoint is not None + assert checkpoint.component_id == model_id + + restarted_flow = create_single_step_flow( + OutputMessageStep(message_template="restarted response") + ) + restarted_service = WayFlowOpenAIResponsesService( + agents={model_id: restarted_flow}, + storage=original_service.storage, + storage_config=original_service.storage_config, + ) + + restored = restarted_service._load_state( + previous_response_id=response.id, + conversation_id=None, + agent_id=model_id, + ) + + assert restored is not None + assert restored.component is restarted_flow + assert restored.checkpoint_id == response.id + + +@pytest.mark.parametrize( + "load_kwargs", + [ + {"previous_response_id": "legacy-response-id", "conversation_id": None}, + {"previous_response_id": None, "conversation_id": "legacy-conversation-id"}, + ], + ids=["previous_response_id", "conversation_id"], +) +@pytest.mark.filterwarnings("ignore:InMemoryDatastore is for DEVELOPMENT:UserWarning") +def test_openai_responses_restores_legacy_rows_with_model_alias_as_agent_id( + load_kwargs: Dict[str, Optional[str]], +) -> None: + model_id = "legacy-flow-model" + conversation_id = "legacy-conversation-id" + response_id = "legacy-response-id" + flow = create_single_step_flow(OutputMessageStep(message_template="legacy response")) + service = WayFlowOpenAIResponsesService(agents={model_id: flow}) + _seed_legacy_openai_response_row( + service, + model_id=model_id, + conversation_id=conversation_id, + response_id=response_id, + ) + + restored = service._load_state( + agent_id=model_id, + attach_checkpointer=True, + **load_kwargs, + ) + + assert restored is not None + assert restored.id == conversation_id + assert restored.root_conversation_id == "legacy-deprecated-conversation-id" + assert restored.component is flow + assert restored.checkpointer is service.checkpointer + + +@pytest.mark.filterwarnings("ignore:InMemoryDatastore is for DEVELOPMENT:UserWarning") +def test_openai_responses_does_not_leave_response_checkpoint_when_final_save_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + model_id = "partial-response-flow" + response_id = "partial-response-id" + flow = create_single_step_flow(OutputMessageStep(message_template="partial response")) + server = OpenAIResponsesServer(agents={model_id: flow}) + service = cast(WayFlowOpenAIResponsesService, server.agent_service) + + original_get_or_generate_id = IdGenerator.get_or_generate_id + deterministic_ids = [response_id] + + def deterministic_response_id(id: Optional[str] = None) -> str: + if id is not None: + return id + if deterministic_ids: + return deterministic_ids.pop(0) + return original_get_or_generate_id() + + monkeypatch.setattr( + IdGenerator, + "get_or_generate_id", + staticmethod(deterministic_response_id), + ) + + original_save_conversation = service.checkpointer.save_conversation + response_metadata_save_attempts = 0 + failed_response_conversation_id: Optional[str] = None + + def fail_response_metadata_save(*args: Any, **kwargs: Any) -> Any: + nonlocal failed_response_conversation_id, response_metadata_save_attempts + metadata = kwargs.get("metadata") + if isinstance(metadata, dict) and "response" in metadata: + response_metadata_save_attempts += 1 + failed_response_conversation_id = args[0].id + if response_metadata_save_attempts == 1: + raise RuntimeError("response metadata save failed") + return original_save_conversation(*args, **kwargs) + + monkeypatch.setattr( + service.checkpointer, + "save_conversation", + fail_response_metadata_save, + ) + + with TestClient(server.get_app(), raise_server_exceptions=False) as client: + response = client.post( + "/v1/responses", + json={"model": model_id, "input": "hello", "store": True}, + ) + assert response.status_code == 500 + assert response_metadata_save_attempts == 1 + assert failed_response_conversation_id is not None + + missing_response = client.get(f"/v1/responses/{response_id}") + assert missing_response.status_code == 404 + + follow_up_response = client.post( + "/v1/responses", + json={ + "model": model_id, + "input": "continue", + "previous_response_id": response_id, + }, + ) + conversation_follow_up_response = client.post( + "/v1/responses", + json={ + "model": model_id, + "input": "continue", + "conversation": {"id": failed_response_conversation_id}, + }, + ) + + assert follow_up_response.status_code == 404 + assert conversation_follow_up_response.status_code == 404 + assert failed_response_conversation_id is not None + assert service.checkpointer.list_checkpoints(failed_response_conversation_id, limit=None) == [] + assert service._lookup_checkpoint_by_response_id(response_id) is None + + @pytest.fixture def official_openai_client(server_url): try: @@ -274,6 +558,52 @@ def test_create_response_unknown_response(server_url) -> None: assert "previous response" in detail.lower() +@all_available_servers +def test_create_response_previous_response_id_rejects_different_model(server_url) -> None: + created = _create_response( + base_url=server_url, + model="simple-flow", + input_value="whatever", + ) + + resp = httpx.post( + f"{server_url}/v1/responses", + json={ + "model": "hr-assistant", + "input": "Continue the conversation.", + "previous_response_id": created["id"], + }, + timeout=30.0, + ) + + assert resp.status_code == 400 + detail = resp.json().get("detail") + assert "previous response was created with model `simple-flow`" in detail + + +@all_available_servers +def test_create_response_conversation_rejects_different_model(server_url) -> None: + created = _create_response( + base_url=server_url, + model="simple-flow", + input_value="whatever", + ) + + resp = httpx.post( + f"{server_url}/v1/responses", + json={ + "model": "hr-assistant", + "input": "Continue the conversation.", + "conversation": created["conversation"]["id"], + }, + timeout=30.0, + ) + + assert resp.status_code == 400 + detail = resp.json().get("detail") + assert "latest response in that conversation was created with model `simple-flow`" in detail + + @all_available_servers def test_create_response_with_instructions_when_agent_does_not_support_it(server_url) -> None: @@ -725,7 +1055,7 @@ def test_agent_with_datastore_is_supported(datastore_agent_inmemory_server): follow_up = _create_response( base_url=datastore_agent_inmemory_server, input_value="and the biggest city?", - model="datastore-assistant", + model="datastore-swarm", previous_response_id=response["id"], ) output = follow_up["output"][0]["content"][0]["text"] diff --git a/wayflowcore/tests/datastores/conftest.py b/wayflowcore/tests/datastores/conftest.py index 755534788..09db5065d 100644 --- a/wayflowcore/tests/datastores/conftest.py +++ b/wayflowcore/tests/datastores/conftest.py @@ -174,10 +174,7 @@ def get_tls_postgres_connection_config(): def get_oracle_datastore_with_schema(ddl: List[str], entities: Dict[str, Entity]): connection_config = get_oracle_connection_config() - conn = connection_config.get_connection() - for stmt in ddl: - conn.cursor().execute(stmt) - conn.close() + _execute_oracle_ddl(ddl) return OracleDatabaseDatastore(entities, connection_config=connection_config) @@ -207,11 +204,21 @@ def get_tls_postgres_datastore_with_schema(ddl: List[str], entities: Dict[str, E def cleanup_oracle_datastore(ddl: Optional[List[str]] = None): stmts = ddl if ddl is not None else ORACLE_DB_DDL[:2] + _execute_oracle_ddl(stmts) + + +def _execute_oracle_ddl(ddl: List[str]) -> None: connection_config = get_oracle_connection_config() conn = connection_config.get_connection() - for stmt in stmts: - conn.cursor().execute(stmt) - conn.close() + try: + cursor = conn.cursor() + try: + for stmt in ddl: + cursor.execute(stmt) + finally: + cursor.close() + finally: + conn.close() def cleanup_postgres_datastore(ddl: Optional[List[str]] = None): @@ -233,20 +240,32 @@ def cleanup_datastore_content(datastore: Datastore): @pytest.fixture(scope="session") def oracle_datastore(): - yield get_oracle_datastore_with_schema(ORACLE_DB_DDL, get_basic_office_entities()) - cleanup_oracle_datastore() + datastore = get_oracle_datastore_with_schema(ORACLE_DB_DDL, get_basic_office_entities()) + try: + yield datastore + finally: + datastore.close() + cleanup_oracle_datastore() @pytest.fixture(scope="session") def postgres_datastore(): - yield get_postgres_datastore_with_schema(POSTGRES_DDL, get_basic_office_entities()) - cleanup_postgres_datastore() + datastore = get_postgres_datastore_with_schema(POSTGRES_DDL, get_basic_office_entities()) + try: + yield datastore + finally: + datastore.close() + cleanup_postgres_datastore() @pytest.fixture(scope="session") def tls_postgres_datastore(): - yield get_postgres_datastore_with_schema(POSTGRES_DDL, get_basic_office_entities()) - cleanup_postgres_datastore() + datastore = get_postgres_datastore_with_schema(POSTGRES_DDL, get_basic_office_entities()) + try: + yield datastore + finally: + datastore.close() + cleanup_postgres_datastore() @pytest.fixture(scope="function") diff --git a/wayflowcore/tests/datastores/test_datastore.py b/wayflowcore/tests/datastores/test_datastore.py index c0c1a1494..4325d66e6 100644 --- a/wayflowcore/tests/datastores/test_datastore.py +++ b/wayflowcore/tests/datastores/test_datastore.py @@ -404,10 +404,13 @@ def test_oracle_column_type_mapping(): "category": nullable(StringProperty()), }, ) + datastore = None try: datastore = get_oracle_datastore_with_schema(ddl, {"products": product}) assert datastore.list("products") == [] finally: + if datastore is not None: + datastore.close() cleanup_oracle_datastore(ddl=["DROP TABLE IF EXISTS PRODUCTS cascade constraints"]) @@ -453,6 +456,7 @@ def test_oracle_json_columns(caplog): "details": StringProperty(description="A JSON with all the product details"), }, ) + datastore = None try: with pytest.raises( DatastoreTypeError, @@ -474,12 +478,14 @@ def test_oracle_json_columns(caplog): ) caplog.clear() # If the datastore schema doesn't reference any JSON column, then this should work without warnings - _ = get_oracle_datastore_with_schema(ddl, {"products": product_slim}) + datastore = get_oracle_datastore_with_schema(ddl, {"products": product_slim}) assert ( "Suppressed warning during database inspection: Did not recognize type 'JSON' of column 'details'" in caplog.text ) finally: + if datastore is not None: + datastore.close() cleanup_oracle_datastore(ddl=["DROP TABLE IF EXISTS PRODUCTS cascade constraints"]) @@ -509,8 +515,12 @@ def test_schema_inspection_on_dropped_tables(): connection_config = get_oracle_connection_config() with connection_config.get_connection() as conn: - for stmt in ddl: - conn.cursor().execute(stmt) + cursor = conn.cursor() + try: + for stmt in ddl: + cursor.execute(stmt) + finally: + cursor.close() product = Entity( properties={ @@ -556,6 +566,8 @@ def drop_table(): assert datastore is not None assert datastore.list("products") == [] finally: + if datastore is not None: + datastore.close() cleanup_oracle_datastore( ddl=["DROP TABLE IF EXISTS PRODUCTS cascade constraints"] + drop_ddl ) diff --git a/wayflowcore/tests/search/conftest.py b/wayflowcore/tests/search/conftest.py index c4dc3322e..edd28d417 100644 --- a/wayflowcore/tests/search/conftest.py +++ b/wayflowcore/tests/search/conftest.py @@ -206,10 +206,7 @@ def get_empty_table_entities(): def get_oracle_datastore_with_schema(ddl: List[str], entities: Dict[str, Entity], embedding_model): connection_config = get_oracle_connection_config() - conn = connection_config.get_connection() - for stmt in ddl: - conn.cursor().execute(stmt) - conn.close() + _execute_oracle_ddl(ddl) if "motorcycles" in entities: search_config = SearchConfig( name="search_motorcycles", @@ -237,10 +234,7 @@ def get_oracle_datastore_with_multiple_search_configs_and_schema( ddl: List[str], entities: Dict[str, Entity], embedding_model ): connection_config = get_oracle_connection_config() - conn = connection_config.get_connection() - for stmt in ddl: - conn.cursor().execute(stmt) - conn.close() + _execute_oracle_ddl(ddl) return OracleDatabaseDatastore( entities, @@ -253,10 +247,7 @@ def get_oracle_datastore_with_multiple_search_and_vector_configs( ddl: List[str], entities: Dict[str, Entity], embedding_model ): connection_config = get_oracle_connection_config() - conn = connection_config.get_connection() - for stmt in ddl: - conn.cursor().execute(stmt) - conn.close() + _execute_oracle_ddl(ddl) search_configs = get_search_configs(embedding_model) vector_config1 = VectorConfig(model=embedding_model, collection_name="motorcycles") @@ -308,10 +299,7 @@ def create_oracle_datastore_with_vector_config( ddl: List[str], entities: Dict[str, Entity], embedding_model ): connection_config = get_oracle_connection_config() - conn = connection_config.get_connection() - for stmt in ddl: - conn.cursor().execute(stmt) - conn.close() + _execute_oracle_ddl(ddl) vector_config = VectorConfig( name="vector_config1", collection_name="motorcycles", vector_property="embeddings" @@ -336,11 +324,21 @@ def create_oracle_datastore_with_vector_config( def cleanup_oracle_datastore(ddl: Optional[List[str]] = None): stmts = ddl if ddl is not None else ORACLE_DB_DDL[:3] + _execute_oracle_ddl(stmts) + + +def _execute_oracle_ddl(ddl: List[str]) -> None: connection_config = get_oracle_connection_config() conn = connection_config.get_connection() - for stmt in stmts: - conn.cursor().execute(stmt) - conn.close() + try: + cursor = conn.cursor() + try: + for stmt in ddl: + cursor.execute(stmt) + finally: + cursor.close() + finally: + conn.close() def cleanup_datastore_content(oracle_datastore: OracleDatabaseDatastore): @@ -351,28 +349,40 @@ def cleanup_datastore_content(oracle_datastore: OracleDatabaseDatastore): @pytest.fixture(scope="function") def oracle_vehicle_datastore(embedding_model): cleanup_oracle_datastore() - yield get_oracle_datastore_with_schema( + datastore = get_oracle_datastore_with_schema( ORACLE_DB_DDL, get_basic_vehicle_entities(), embedding_model ) - cleanup_oracle_datastore() + try: + yield datastore + finally: + datastore.close() + cleanup_oracle_datastore() @pytest.fixture(scope="function") def oracle_empty_table_datastore(embedding_model): cleanup_oracle_datastore() - yield get_oracle_datastore_with_schema( + datastore = get_oracle_datastore_with_schema( ORACLE_DB_DDL, get_empty_table_entities(), embedding_model ) - cleanup_oracle_datastore() + try: + yield datastore + finally: + datastore.close() + cleanup_oracle_datastore() @pytest.fixture(scope="function") def oracle_vehicle_multi_search_config_datastore(embedding_model): cleanup_oracle_datastore() - yield get_oracle_datastore_with_multiple_search_configs_and_schema( + datastore = get_oracle_datastore_with_multiple_search_configs_and_schema( ORACLE_DB_DDL, get_basic_vehicle_entities(), embedding_model ) - cleanup_oracle_datastore() + try: + yield datastore + finally: + datastore.close() + cleanup_oracle_datastore() @pytest.fixture(scope="function") @@ -395,19 +405,27 @@ def oracle_vehicle_multi_search_and_vector_config_datastore(embedding_model): ) new_entities[collection_name] = new_entity - yield get_oracle_datastore_with_multiple_search_and_vector_configs( + datastore = get_oracle_datastore_with_multiple_search_and_vector_configs( ORACLE_DB_DDL_2, new_entities, embedding_model ) - cleanup_oracle_datastore() + try: + yield datastore + finally: + datastore.close() + cleanup_oracle_datastore() @pytest.fixture(scope="function") def oracle_vehicle_vector_config_datastore(embedding_model): cleanup_oracle_datastore() - yield create_oracle_datastore_with_vector_config( + datastore = create_oracle_datastore_with_vector_config( ORACLE_DB_DDL, get_basic_vehicle_entities(), embedding_model ) - cleanup_oracle_datastore() + try: + yield datastore + finally: + datastore.close() + cleanup_oracle_datastore() @pytest.fixture(scope="function") diff --git a/wayflowcore/tests/serialization/conftest.py b/wayflowcore/tests/serialization/conftest.py index dc23392cf..38692b2d8 100644 --- a/wayflowcore/tests/serialization/conftest.py +++ b/wayflowcore/tests/serialization/conftest.py @@ -28,6 +28,8 @@ def _start_a2a_server( host, "--port", str(port), + "--agent", + "flow_that_yields_once", ] url = f"http://{host}:{port}" diff --git a/wayflowcore/tests/serialization/test_conversation_checkpointing.py b/wayflowcore/tests/serialization/test_conversation_checkpointing.py index 2d694a346..f9980a257 100644 --- a/wayflowcore/tests/serialization/test_conversation_checkpointing.py +++ b/wayflowcore/tests/serialization/test_conversation_checkpointing.py @@ -11,23 +11,31 @@ import pytest +import wayflowcore.checkpointing.datastorecheckpointer as checkpoint_datastore import wayflowcore.checkpointing.serialization as checkpoint_serialization from wayflowcore.a2a.a2aagent import A2AAgent, A2AConnectionConfig from wayflowcore.agent import Agent from wayflowcore.checkpointing import ( CheckpointingInterval, ConversationCheckpoint, + DatastoreCheckpointer, InMemoryCheckpointer, + StorageConfig, ) from wayflowcore.checkpointing.checkpointeventlistener import _save_conversation_checkpoint +from wayflowcore.datastore._relational import RelationalDatastore +from wayflowcore.exceptions import DatastoreEntityError +from wayflowcore.executors._agentexecutor import AgentConversationExecutor from wayflowcore.executors.executionstatus import FinishedStatus, UserMessageRequestStatus from wayflowcore.flowhelpers import create_single_step_flow from wayflowcore.managerworkers import ManagerWorkers +from wayflowcore.messagelist import Message, MessageType from wayflowcore.models.ociclientconfig import OCIClientConfigWithApiKey from wayflowcore.ociagent import OciAgent from wayflowcore.serialization.serializer import _resolve_legacy_field_name from wayflowcore.steps import OutputMessageStep, PromptExecutionStep from wayflowcore.swarm import Swarm +from wayflowcore.tools import ToolResult from ..testhelpers.dummy import DummyModel from ..testhelpers.testhelpers import retry_test @@ -45,6 +53,7 @@ def save_conversation( conversation, *, checkpoint_id: Optional[str] = None, + component_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ): if self.should_fail_next_save: @@ -79,6 +88,36 @@ def load_latest(self, conversation_id: str) -> Optional[ConversationCheckpoint]: return self.checkpoint +class FakeRelationalDatastore(RelationalDatastore): + def __init__(self) -> None: + pass + + def _serialize_to_dict(self, serialization_context: Any) -> Dict[str, Any]: + raise NotImplementedError() + + @classmethod + def _deserialize_from_dict( + cls, input_dict: Dict[str, Any], deserialization_context: Any + ) -> Any: + raise NotImplementedError() + + +class IntegrityRetryCheckpointer(DatastoreCheckpointer): + def __init__(self) -> None: + super().__init__( + datastore=FakeRelationalDatastore(), + storage_config=StorageConfig(), + ) + self.relational_save_attempts = 0 + + def _save_checkpoint_relational_once(self, checkpoint: ConversationCheckpoint) -> None: + from sqlalchemy.exc import IntegrityError + + self.relational_save_attempts += 1 + if self.relational_save_attempts == 1: + raise IntegrityError("statement", {}, Exception("duplicate latest row")) + + def _build_checkpointable_agent( *, name: str, @@ -149,6 +188,113 @@ def a2a_agent(a2a_server, connection_config_no_verify): ) +def test_postgres_checkpoint_setup_creates_latest_turn_unique_index( + monkeypatch: pytest.MonkeyPatch, +) -> None: + queries: list[str] = [] + + def record_query(connection_config: Any, query: str) -> None: + queries.append(query) + + monkeypatch.setattr( + checkpoint_datastore, + "_execute_query_on_postgres_db", + record_query, + ) + + connection_config: Any = object() + checkpoint_datastore._prepare_postgres_checkpoint_datastore(connection_config, StorageConfig()) + + assert len(queries) == 2 + assert "CREATE TABLE conversations" in queries[0] + assert "CREATE UNIQUE INDEX conversations_last_turn_idx" in queries[1] + assert "ON conversations (conversation_id)" in queries[1] + assert "WHERE is_last_turn = 1" in queries[1] + + +def test_oracle_checkpoint_setup_creates_latest_turn_unique_index( + monkeypatch: pytest.MonkeyPatch, +) -> None: + queries: list[str] = [] + + def record_query(connection_config: Any, query: str) -> None: + queries.append(query) + + monkeypatch.setattr( + checkpoint_datastore, + "_execute_query_on_oracle_db", + record_query, + ) + + connection_config: Any = object() + checkpoint_datastore._prepare_oracle_checkpoint_datastore(connection_config, StorageConfig()) + + assert len(queries) == 2 + assert "CREATE TABLE conversations" in queries[0] + assert "CREATE UNIQUE INDEX conversations_last_turn_idx" in queries[1] + assert "CASE WHEN is_last_turn = 1" in queries[1] + assert "THEN conversation_id END" in queries[1] + + +def test_relational_checkpoint_save_retries_integrity_errors() -> None: + checkpointer = IntegrityRetryCheckpointer() + + checkpointer.save( + ConversationCheckpoint( + checkpoint_id="checkpoint-1", + conversation_id="conversation-1", + component_id="component-1", + created_at=0, + state="state", + metadata={}, + ) + ) + + assert checkpointer.relational_save_attempts == 2 + + +def test_checkpoint_serializes_exception_tool_result_content_as_text() -> None: + agent = Agent( + llm=DummyModel(), + agent_id="checkpoint-exception-agent", + name="checkpoint_exception_agent", + description="Checkpoint exception agent.", + ) + conversation = agent.start_conversation(conversation_id="checkpoint-exception-tool-result") + conversation.message_list.append_message( + Message( + tool_result=ToolResult( + content=DatastoreEntityError("datastore create failed"), + tool_request_id="tool-call-1", + ), + message_type=MessageType.TOOL_RESULT, + ) + ) + + serialized_state = checkpoint_serialization._serialize_conversation_checkpoint_state( + conversation + ) + restored_conversation = checkpoint_serialization._deserialize_conversation_checkpoint_state( + serialized_state, + component=agent, + ) + + restored_tool_result = restored_conversation.get_last_message().tool_result + assert restored_tool_result is not None + assert restored_tool_result.content == "datastore create failed" + + +def test_agent_tool_response_messages_stringify_exception_content() -> None: + message = AgentConversationExecutor._get_tool_response_message( + content=DatastoreEntityError("datastore create failed"), + tool_request_id="tool-call-1", + agent_id="agent-1", + ) + + assert message.tool_result is not None + assert message.tool_result.content == "datastore create failed" + + def test_inmemory_checkpointer_can_save_load_list_and_delete_checkpoints() -> None: checkpointer = InMemoryCheckpointer() flow = create_single_step_flow(OutputMessageStep(message_template="Hello from checkpointing.")) @@ -199,6 +345,34 @@ def test_inmemory_checkpointer_can_save_load_list_and_delete_checkpoints() -> No ] == [first_checkpoint_id] +def test_checkpoint_restore_handles_messages_validation_without_runtime_name_error() -> None: + checkpointer = InMemoryCheckpointer() + flow = create_single_step_flow(OutputMessageStep(message_template="Hello from checkpointing.")) + + conversation = flow.start_conversation( + conversation_id="checkpoint-messages-validation", + checkpointer=checkpointer, + ) + status = conversation.execute() + assert isinstance(status, FinishedStatus) + checkpoint_id = conversation.checkpoint_id + assert checkpoint_id is not None + + restored_conversation = flow.start_conversation( + conversation_id="checkpoint-messages-validation", + checkpointer=checkpointer, + messages=[], + ) + assert restored_conversation.checkpoint_id == checkpoint_id + + with pytest.raises(ValueError, match="Cannot restore a checkpoint"): + flow.start_conversation( + conversation_id="checkpoint-messages-validation", + checkpointer=checkpointer, + messages=Message("new input"), + ) + + def test_conversation_turns_checkpoint_interval_saves_once_after_outer_execute() -> None: checkpointer = InMemoryCheckpointer( checkpointing_interval=CheckpointingInterval.CONVERSATION_TURNS @@ -241,6 +415,26 @@ def test_all_internal_turns_checkpoint_interval_saves_before_each_flow_turn() -> assert checkpoints[-1].metadata["status_type"] == "FinishedStatus" +def test_execute_can_skip_final_checkpoint_while_preserving_internal_checkpoints() -> None: + checkpointer = InMemoryCheckpointer( + checkpointing_interval=CheckpointingInterval.ALL_INTERNAL_TURNS + ) + flow = create_single_step_flow(OutputMessageStep(message_template="Hello internal only.")) + + status = flow.start_conversation( + conversation_id="skip-final-checkpoint", + checkpointer=checkpointer, + ).execute(_save_final_checkpoint=False) + + assert isinstance(status, FinishedStatus) + checkpoints = checkpointer.list_checkpoints("skip-final-checkpoint") + assert len(checkpoints) == 2 + assert [checkpoint.metadata["save_reason"] for checkpoint in checkpoints] == [ + "internal_turn_boundary", + "internal_turn_boundary", + ] + + def test_llm_turns_checkpoint_interval_saves_only_after_llm_backed_turns() -> None: checkpointer = InMemoryCheckpointer(checkpointing_interval=CheckpointingInterval.LLM_TURNS) dummy_llm = DummyModel() @@ -398,6 +592,194 @@ def test_checkpoint_restore_rejects_conversations_from_other_components() -> Non ) +def test_checkpoint_restore_supports_fresh_component_with_same_stable_id() -> None: + checkpointer = InMemoryCheckpointer() + original_agent = Agent( + llm=DummyModel(), + agent_id="stable-checkpoint-agent", + name="stable_checkpoint_agent", + description="Stable checkpoint agent.", + initial_message="Hello from the original agent.", + ) + + conversation = original_agent.start_conversation( + conversation_id="checkpoint-stable-component", + checkpointer=checkpointer, + ) + first_status = conversation.execute() + + assert isinstance(first_status, UserMessageRequestStatus) + first_checkpoint_id = conversation.checkpoint_id + assert first_checkpoint_id is not None + + restarted_agent = Agent( + llm=DummyModel(), + agent_id="stable-checkpoint-agent", + name="stable_checkpoint_agent", + description="Stable checkpoint agent.", + initial_message="Hello from the restarted agent.", + ) + restored_conversation = restarted_agent.start_conversation( + conversation_id="checkpoint-stable-component", + checkpointer=checkpointer, + ) + + assert restored_conversation.component is restarted_agent + assert restored_conversation.checkpoint_id == first_checkpoint_id + assert isinstance(restored_conversation.status, UserMessageRequestStatus) + + +def test_checkpoint_restore_remaps_nested_agent_refs_after_restart() -> None: + def build_parent_agent() -> tuple[Agent, Agent]: + sub_agent = Agent( + llm=DummyModel(fails_if_not_set=False), + name="checkpoint_sub_agent", + description="Checkpoint sub-agent.", + initial_message="Hello from sub-agent.", + ) + parent_agent = Agent( + llm=DummyModel(fails_if_not_set=False), + agent_id="stable-parent-agent", + name="checkpoint_parent_agent", + description="Checkpoint parent agent.", + agents=[sub_agent], + initial_message="Hello from parent.", + ) + return parent_agent, sub_agent + + checkpointer = InMemoryCheckpointer() + original_parent, original_sub_agent = build_parent_agent() + original_conversation = original_parent.start_conversation( + conversation_id="agent-nested-restart", + checkpointer=checkpointer, + ) + original_sub_conversation = original_sub_agent.start_conversation( + conversation_id="agent-nested-restart-sub", + _root_conversation_id=original_conversation.root_conversation_id, + ) + original_conversation.state.current_sub_component_conversations[original_sub_agent.id] = ( + original_sub_conversation + ) + checkpointer.save_conversation(original_conversation) + + restarted_parent, restarted_sub_agent = build_parent_agent() + restored_conversation = restarted_parent.start_conversation( + conversation_id="agent-nested-restart", + checkpointer=checkpointer, + ) + + restored_sub_conversation = restored_conversation._get_sub_component_conversation( + restarted_sub_agent + ) + assert restored_sub_conversation is not None + assert restored_sub_conversation.component is restarted_sub_agent + assert list(restored_conversation.state.current_sub_component_conversations) == [ + restarted_sub_agent.id + ] + + +def test_checkpoint_restore_remaps_managerworkers_nested_refs_after_restart() -> None: + def build_managerworkers() -> tuple[ManagerWorkers, DummyModel]: + manager_llm = DummyModel() + manager_agent = Agent( + llm=manager_llm, + name="checkpoint_restart_manager_agent", + description="Checkpoint restart manager.", + initial_message="Hello from the manager.", + ) + worker_agent = Agent( + llm=DummyModel(fails_if_not_set=False), + name="checkpoint_restart_worker_agent", + description="Checkpoint restart worker.", + custom_instruction="Help the manager.", + ) + managerworkers = ManagerWorkers( + group_manager=manager_agent, + workers=[worker_agent], + name="checkpoint_restart_managerworkers", + id="stable-managerworkers", + ) + return managerworkers, manager_llm + + checkpointer = InMemoryCheckpointer() + original_managerworkers, _ = build_managerworkers() + original_conversation = original_managerworkers.start_conversation( + conversation_id="managerworkers-nested-restart", + checkpointer=checkpointer, + ) + first_status = original_conversation.execute() + assert isinstance(first_status, UserMessageRequestStatus) + + restarted_managerworkers, restarted_manager_llm = build_managerworkers() + restored_conversation = restarted_managerworkers.start_conversation( + conversation_id="managerworkers-nested-restart", + checkpointer=checkpointer, + ) + + main_subconversation = restored_conversation._get_main_subconversation() + assert restored_conversation.component is restarted_managerworkers + assert main_subconversation.component is restarted_managerworkers.manager_agent + + restarted_manager_llm.set_next_output("ManagerWorkers resumed successfully.") + restored_conversation.append_user_message("Please continue.") + restored_status = restored_conversation.execute() + + assert isinstance(restored_status, UserMessageRequestStatus) + assert restored_conversation.get_last_message().content == ( + "ManagerWorkers resumed successfully." + ) + + +def test_checkpoint_restore_remaps_swarm_nested_refs_after_restart() -> None: + def build_swarm() -> tuple[Swarm, DummyModel]: + first_llm = DummyModel() + first_agent = Agent( + llm=first_llm, + name="checkpoint_restart_swarm_first_agent", + description="Checkpoint restart first swarm agent.", + initial_message="Hello from the swarm.", + ) + second_agent = Agent( + llm=DummyModel(fails_if_not_set=False), + name="checkpoint_restart_swarm_second_agent", + description="Checkpoint restart second swarm agent.", + custom_instruction="Help with delegated tasks.", + ) + swarm = Swarm( + first_agent=first_agent, + relationships=[(first_agent, second_agent)], + name="checkpoint_restart_swarm", + id="stable-swarm", + ) + return swarm, first_llm + + checkpointer = InMemoryCheckpointer() + original_swarm, _ = build_swarm() + original_conversation = original_swarm.start_conversation( + conversation_id="swarm-nested-restart", + checkpointer=checkpointer, + ) + first_status = original_conversation.execute() + assert isinstance(first_status, UserMessageRequestStatus) + + restarted_swarm, restarted_first_llm = build_swarm() + restored_conversation = restarted_swarm.start_conversation( + conversation_id="swarm-nested-restart", + checkpointer=checkpointer, + ) + + main_thread_conversation = restored_conversation._get_main_thread_conversation() + assert restored_conversation.component is restarted_swarm + assert main_thread_conversation.component is restarted_swarm.first_agent + + restarted_first_llm.set_next_output("Swarm resumed successfully.") + restored_conversation.append_user_message("Please continue.") + restored_status = restored_conversation.execute() + + assert isinstance(restored_status, UserMessageRequestStatus) + assert restored_conversation.get_last_message().content == "Swarm resumed successfully." + + def test_restore_can_skip_attaching_live_checkpointer( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/wayflowcore/tests/test_docstring.py b/wayflowcore/tests/test_docstring.py index 92720cee4..181b49f37 100644 --- a/wayflowcore/tests/test_docstring.py +++ b/wayflowcore/tests/test_docstring.py @@ -21,6 +21,7 @@ from .datastores.conftest import ( # noqa ORACLE_DB_DDL, + cleanup_oracle_datastore, get_basic_office_entities, get_oracle_datastore_with_schema, populate_with_basic_entities, @@ -72,26 +73,31 @@ def test_examples_in_docstrings_can_be_successfully_ran( # Check the docs at https://docs.python.org/3/library/doctest.html#doctest.testfile # if you want to understand how this test works. assistant = create_assistant() - doctest.testfile( - filename=file_path, - module_relative=False, - globs={ - "llm": remotely_hosted_llm, - "assistant": assistant, - "config_file_path": CONFIGS_DIR / "tests/configs/docstring_assistant.yaml", - "serialized_assistant_as_str": serialize(assistant), - "LLAMA70B_API_ENDPOINT": LLAMA70BV33_API_URL, - # Note: the docstring of the datastore query step instantiates a new datastore, - # but we use this so that we create a new datastore in the test that connects to the - # existing database with the data already there - "testing_oracle_data_store_with_data": testing_oracle_data_store_with_data, - "database_connection_config": ( - testing_oracle_data_store_with_data.connection_config - if testing_oracle_data_store_with_data - else None - ), - "multimodal_llm": remote_gemma_llm, - }, - raise_on_error=True, - verbose=True, - ) + try: + doctest.testfile( + filename=file_path, + module_relative=False, + globs={ + "llm": remotely_hosted_llm, + "assistant": assistant, + "config_file_path": CONFIGS_DIR / "tests/configs/docstring_assistant.yaml", + "serialized_assistant_as_str": serialize(assistant), + "LLAMA70B_API_ENDPOINT": LLAMA70BV33_API_URL, + # Note: the docstring of the datastore query step instantiates a new datastore, + # but we use this so that we create a new datastore in the test that connects to the + # existing database with the data already there + "testing_oracle_data_store_with_data": testing_oracle_data_store_with_data, + "database_connection_config": ( + testing_oracle_data_store_with_data.connection_config + if testing_oracle_data_store_with_data + else None + ), + "multimodal_llm": remote_gemma_llm, + }, + raise_on_error=True, + verbose=True, + ) + finally: + if testing_oracle_data_store_with_data is not None: + testing_oracle_data_store_with_data.close() + cleanup_oracle_datastore() diff --git a/wayflowcore/tests/test_managerworkers.py b/wayflowcore/tests/test_managerworkers.py index 6b977cbb7..b5de87206 100644 --- a/wayflowcore/tests/test_managerworkers.py +++ b/wayflowcore/tests/test_managerworkers.py @@ -1480,15 +1480,15 @@ def test_multi_managers_with_mock_outputs(vllm_responses_llm): assert conv.get_last_message().content == "first-level manager answers to user" -@retry_test(max_attempts=5) +@retry_test(max_attempts=14) def test_multi_managers_with_llms(vllm_responses_llm): """ - Failure rate: 2 out of 20 - Observed on: 2026-01-28 - Average success time: 14.97 seconds per successful attempt - Average failure time: 20.80 seconds per failed attempt - Max attempt: 5 - Justification: (0.14 ** 5) ~= 4.7 / 100'000 + Failure rate: 25 out of 50 + Observed on: 2026-05-13 + Average success time: 20.35 seconds per successful attempt + Average failure time: 11.01 seconds per failed attempt + Max attempt: 14 + Justification: (0.50 ** 14) ~= 6.1 / 100'000 """ llm = vllm_responses_llm diff --git a/wayflowcore/tests/transforms/conftest.py b/wayflowcore/tests/transforms/conftest.py index f816d46d3..6a33f9b3e 100644 --- a/wayflowcore/tests/transforms/conftest.py +++ b/wayflowcore/tests/transforms/conftest.py @@ -51,14 +51,22 @@ def oracle_database_connection(): def create_entities_inside_oracle_database(oracle_database_connection, schema): ddl = _get_dll_for_creation_of_one_entity_schema(schema) - for stmt in ddl: - oracle_database_connection.cursor().execute(stmt) + cursor = oracle_database_connection.cursor() + try: + for stmt in ddl: + cursor.execute(stmt) + finally: + cursor.close() def delete_entities_inside_oracle_database(oracle_database_connection, schema): ddl = _get_dll_for_deletion_of_one_entity_schema(schema) - for stmt in ddl: - oracle_database_connection.cursor().execute(stmt) + cursor = oracle_database_connection.cursor() + try: + for stmt in ddl: + cursor.execute(stmt) + finally: + cursor.close() def find_datastore_by_schema(pool, schema): @@ -74,6 +82,7 @@ def find_datastore_by_schema(pool, schema): @pytest.fixture(scope="session") def oracle_database_datastores_pool(): + testing_schemas = [] pool = [] try: testing_schemas = get_testing_schemas() @@ -84,6 +93,8 @@ def oracle_database_datastores_pool(): pool.append((schema, datastore)) yield pool finally: + for _, datastore in pool: + datastore.close() for schema in testing_schemas: cleanup_oracle_datastore(_get_dll_for_deletion_of_one_entity_schema(schema))