From c82f100235704a9821348603fba23021181a28ed Mon Sep 17 00:00:00 2001 From: Paul Cayet Date: Fri, 10 Apr 2026 16:18:34 +0200 Subject: [PATCH] fix: align Agent Spec tracing and transform serialization compatibility --- .../wayflowcore/source/core/api/agentspec.rst | 2 + docs/wayflowcore/source/core/changelog.rst | 6 + .../core/code_examples/howto_tracing.py | 12 ++ .../source/core/howtoguides/howto_tracing.rst | 9 ++ .../src/wayflowcore/agentspec/_legacy.py | 15 ++ .../agentspec/components/__init__.py | 4 +- .../agentspec/components/transforms.py | 9 +- .../wayflowcore/agentspec/runtimeloader.py | 3 + .../src/wayflowcore/agentspec/tracing.py | 58 +++++++- .../wayflowcore/executors/_flowexecutor.py | 3 + .../wayflowcore/executors/executionstatus.py | 8 +- .../serialization/_builtins_components.py | 2 +- .../_builtins_deserialization_plugin.py | 10 +- .../_builtins_serialization_plugin.py | 12 +- .../wayflowcore/serialization/serializer.py | 11 ++ .../templates/_managerworkerstemplate.py | 89 +---------- .../wayflowcore/templates/_swarmtemplate.py | 101 +------------ .../templates/agenticpatterntemplate.py | 97 ++++++++++++ .../test_agentspec_conversion_coverage.py | 1 - wayflowcore/tests/agentspec/test_tracing.py | 138 +++++++++++++++++- .../tests/agentspec/test_transforms.py | 55 +++++++ .../test_managerworkers_serialization.py | 24 +++ .../serialization/test_serializableobject.py | 14 +- .../serialization/test_swarm_serialization.py | 12 ++ 24 files changed, 474 insertions(+), 221 deletions(-) create mode 100644 wayflowcore/src/wayflowcore/agentspec/_legacy.py create mode 100644 wayflowcore/src/wayflowcore/templates/agenticpatterntemplate.py diff --git a/docs/wayflowcore/source/core/api/agentspec.rst b/docs/wayflowcore/source/core/api/agentspec.rst index 6bc75ae78..7c5c8c7dd 100644 --- a/docs/wayflowcore/source/core/api/agentspec.rst +++ b/docs/wayflowcore/source/core/api/agentspec.rst @@ -40,6 +40,8 @@ This event listener makes WayFlow components emit traces according to the Agent .. _agentspeceventlistener: .. autoclass:: wayflowcore.agentspec.tracing.AgentSpecEventListener +.. autofunction:: wayflowcore.agentspec.tracing.dump_tracing_model + Custom Components ================= diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 134725431..da9bec2b3 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -51,6 +51,12 @@ New features Improvements ^^^^^^^^^^^^ +* **Improved Agent Spec tracing compatibility** + + Agent Spec tracing exports now serialize WayFlow Agent Spec plugin components with the + proper plugin context and report a valid flow end branch when a flow finishes through a + transition to ``None``. + * **Scoped opt-in for authless MCP clients** Added ``authless_mcp_enabled()`` as a scoped context manager for local or test MCP clients diff --git a/docs/wayflowcore/source/core/code_examples/howto_tracing.py b/docs/wayflowcore/source/core/code_examples/howto_tracing.py index 40f029f8a..ac8526f08 100644 --- a/docs/wayflowcore/source/core/code_examples/howto_tracing.py +++ b/docs/wayflowcore/source/core/code_examples/howto_tracing.py @@ -144,3 +144,15 @@ def subtract(a: float, b: float) -> float: conversation.append_user_message("Compute 2+3") status = conversation.execute() # .. end-##_Enable_Agent_Spec_Tracing + +# .. start-##_Dump_Agent_Spec_Tracing +from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart +from wayflowcore.agentspec.tracing import dump_tracing_model + +# Use this helper when a custom Agent Spec span processor/exporter needs to serialize +# Agent Spec tracing events or spans that may contain WayFlow extension/plugin components. +agentspec_agent = AgentSpecExporter().to_component(agent) +serialized_event = dump_tracing_model( + AgentSpecAgentExecutionStart(agent=agentspec_agent, inputs={}) +) +# .. end-##_Dump_Agent_Spec_Tracing diff --git a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst index 8a1de2286..7464734c7 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_tracing.rst @@ -152,6 +152,15 @@ Here's an example of how to use it in your code. :start-after: .. start-##_Enable_Agent_Spec_Tracing :end-before: .. end-##_Enable_Agent_Spec_Tracing +If you implement a custom Agent Spec span processor or exporter and need to serialize Agent Spec +tracing events/spans emitted by WayFlow, use ``dump_tracing_model`` so WayFlow's Agent Spec plugin +components are dumped with the correct serialization context. + +.. literalinclude:: ../code_examples/howto_tracing.py + :language: python + :start-after: .. start-##_Dump_Agent_Spec_Tracing + :end-before: .. end-##_Dump_Agent_Spec_Tracing + Agent Spec Exporting/Loading ============================ diff --git a/wayflowcore/src/wayflowcore/agentspec/_legacy.py b/wayflowcore/src/wayflowcore/agentspec/_legacy.py new file mode 100644 index 000000000..0029ac80f --- /dev/null +++ b/wayflowcore/src/wayflowcore/agentspec/_legacy.py @@ -0,0 +1,15 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + + +def _resolve_legacy_configurations(serialized_config: str) -> str: + """ + Normalize legacy Agent Spec component names before deserialization. + """ + return serialized_config.replace( + "PluginSwarmToolRequestAndCallsTransform", + "PluginToolRequestAndCallsTransform", + ) diff --git a/wayflowcore/src/wayflowcore/agentspec/components/__init__.py b/wayflowcore/src/wayflowcore/agentspec/components/__init__.py index dfa345296..2b397e2fd 100644 --- a/wayflowcore/src/wayflowcore/agentspec/components/__init__.py +++ b/wayflowcore/src/wayflowcore/agentspec/components/__init__.py @@ -117,7 +117,7 @@ PluginCoalesceSystemMessagesTransform, PluginReactMergeToolRequestAndCallsTransform, PluginRemoveEmptyNonUserMessageTransform, - PluginSwarmToolRequestAndCallsTransform, + PluginToolRequestAndCallsTransform, messagetransform_deserialization_plugin, messagetransform_serialization_plugin, ) @@ -225,9 +225,9 @@ "contextprovider_deserialization_plugin", "PluginAppendTrailingSystemMessageToUserMessageTransform", "PluginCoalesceSystemMessagesTransform", + "PluginToolRequestAndCallsTransform", "PluginRemoveEmptyNonUserMessageTransform", "PluginReactMergeToolRequestAndCallsTransform", - "PluginSwarmToolRequestAndCallsTransform", "messagetransform_serialization_plugin", "messagetransform_deserialization_plugin", "PluginPromptTemplate", diff --git a/wayflowcore/src/wayflowcore/agentspec/components/transforms.py b/wayflowcore/src/wayflowcore/agentspec/components/transforms.py index 2e2bc0bc3..16215f0bf 100644 --- a/wayflowcore/src/wayflowcore/agentspec/components/transforms.py +++ b/wayflowcore/src/wayflowcore/agentspec/components/transforms.py @@ -53,9 +53,8 @@ class PluginReactMergeToolRequestAndCallsTransform(MessageTransform): """Simple message processor that joins tool requests and calls into a python-like message""" -class PluginSwarmToolRequestAndCallsTransform(MessageTransform): - """Format Tool requests as Agent messages and Tool results as User messages to have a simple User/Agent - sequence of messages.""" +class PluginToolRequestAndCallsTransform(MessageTransform): + """Format tool requests as agent messages and tool results as user messages.""" class PluginCanonicalizationMessageTransform(MessageTransform): @@ -98,7 +97,7 @@ class PluginSplitPromptOnMarkerMessageTransform(MessageTransform): PluginAppendTrailingSystemMessageToUserMessageTransform.__name__: PluginAppendTrailingSystemMessageToUserMessageTransform, PluginLlamaMergeToolRequestAndCallsTransform.__name__: PluginLlamaMergeToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform.__name__: PluginReactMergeToolRequestAndCallsTransform, - PluginSwarmToolRequestAndCallsTransform.__name__: PluginSwarmToolRequestAndCallsTransform, + PluginToolRequestAndCallsTransform.__name__: PluginToolRequestAndCallsTransform, PluginCanonicalizationMessageTransform.__name__: PluginCanonicalizationMessageTransform, PluginSplitPromptOnMarkerMessageTransform.__name__: PluginSplitPromptOnMarkerMessageTransform, }, @@ -111,7 +110,7 @@ class PluginSplitPromptOnMarkerMessageTransform(MessageTransform): PluginAppendTrailingSystemMessageToUserMessageTransform.__name__: PluginAppendTrailingSystemMessageToUserMessageTransform, PluginLlamaMergeToolRequestAndCallsTransform.__name__: PluginLlamaMergeToolRequestAndCallsTransform, PluginReactMergeToolRequestAndCallsTransform.__name__: PluginReactMergeToolRequestAndCallsTransform, - PluginSwarmToolRequestAndCallsTransform.__name__: PluginSwarmToolRequestAndCallsTransform, + PluginToolRequestAndCallsTransform.__name__: PluginToolRequestAndCallsTransform, PluginCanonicalizationMessageTransform.__name__: PluginCanonicalizationMessageTransform, PluginSplitPromptOnMarkerMessageTransform.__name__: PluginSplitPromptOnMarkerMessageTransform, }, diff --git a/wayflowcore/src/wayflowcore/agentspec/runtimeloader.py b/wayflowcore/src/wayflowcore/agentspec/runtimeloader.py index 1446f23cb..6d98f72b4 100644 --- a/wayflowcore/src/wayflowcore/agentspec/runtimeloader.py +++ b/wayflowcore/src/wayflowcore/agentspec/runtimeloader.py @@ -20,6 +20,7 @@ from pyagentspec.serialization.types import ComponentsRegistryT as AgentSpecComponentsRegistryT from typing_extensions import TypeAlias +from wayflowcore.agentspec._legacy import _resolve_legacy_configurations from wayflowcore.agentspec.components.mcp import PluginStdioTransport from wayflowcore.component import Component as RuntimeComponent from wayflowcore.serialization.plugins import WayflowDeserializationPlugin @@ -280,6 +281,7 @@ def load_json( ... ) """ + serialized_assistant = _resolve_legacy_configurations(serialized_assistant) deserializer = AgentSpecDeserializer( plugins=self._get_all_agentspec_plugins(), allowed_components=self.allowed_components, @@ -459,6 +461,7 @@ def load_yaml( ... ) """ + serialized_assistant = _resolve_legacy_configurations(serialized_assistant) deserializer = AgentSpecDeserializer( plugins=self._get_all_agentspec_plugins(), allowed_components=self.allowed_components, diff --git a/wayflowcore/src/wayflowcore/agentspec/tracing.py b/wayflowcore/src/wayflowcore/agentspec/tracing.py index e42147179..ed7e59b22 100644 --- a/wayflowcore/src/wayflowcore/agentspec/tracing.py +++ b/wayflowcore/src/wayflowcore/agentspec/tracing.py @@ -4,7 +4,7 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import json -from typing import Dict, Optional, Union, cast +from typing import Dict, List, Optional, Union, cast from pyagentspec import Component as AgentSpecComponent from pyagentspec.agent import Agent as AgentSpecAgent @@ -12,9 +12,12 @@ from pyagentspec.flows.node import Node as AgentSpecNode from pyagentspec.llms import LlmConfig as AgentSpecLlmConfig from pyagentspec.llms import LlmGenerationConfig +from pyagentspec.serialization import ComponentSerializationPlugin from pyagentspec.tools import Tool as AgentSpecTool +from pyagentspec.tracing._basemodel import _TracingSerializationContextImpl from pyagentspec.tracing.events import AgentExecutionEnd as AgentSpecAgentExecutionEnd from pyagentspec.tracing.events import AgentExecutionStart as AgentSpecAgentExecutionStart +from pyagentspec.tracing.events import Event as AgentSpecEvent from pyagentspec.tracing.events import ExceptionRaised as AgentSpecExceptionRaised from pyagentspec.tracing.events import FlowExecutionEnd as AgentSpecFlowExecutionEnd from pyagentspec.tracing.events import FlowExecutionStart as AgentSpecFlowExecutionStart @@ -56,18 +59,51 @@ ) from wayflowcore.events.eventlistener import EventListener from wayflowcore.executors.executionstatus import FinishedStatus +from wayflowcore.serialization.plugins import WayflowSerializationPlugin from wayflowcore.tracing.span import LlmGenerationSpan, get_active_span_stack, get_current_span +def _create_tracing_serialization_context( + plugins: Optional[List[Union[ComponentSerializationPlugin, WayflowSerializationPlugin]]] = None, +) -> _TracingSerializationContextImpl: + """ + Build a tracing serialization context that knows how to dump WayFlow Agent Spec extension components. + """ + exporter = AgentSpecExporter(plugins=plugins) + return _TracingSerializationContextImpl(plugins=exporter._get_all_agentspec_plugins()) + + +def dump_tracing_model( + model: Union[AgentSpecEvent, AgentSpecSpan], + mask_sensitive_information: bool = True, + plugins: Optional[List[Union[ComponentSerializationPlugin, WayflowSerializationPlugin]]] = None, +) -> Dict[str, object]: + """ + Serialize an Agent Spec tracing span/event using the WayFlow Agent Spec serialization plugins. + """ + if not isinstance(model, (AgentSpecEvent, AgentSpecSpan)): + raise TypeError("dump_tracing_model only supports Agent Spec tracing events and spans.") + return model.model_dump( + mask_sensitive_information=mask_sensitive_information, + context=_create_tracing_serialization_context(plugins=plugins), + ) + + class AgentSpecEventListener(EventListener): """Event listener that emits traces according to the Open Agent Spec Tracing standard""" - def __init__(self) -> None: + def __init__( + self, + plugins: Optional[ + List[Union[ComponentSerializationPlugin, WayflowSerializationPlugin]] + ] = None, + ) -> None: super().__init__() + self.plugins = plugins # We keep track of the mapping between the wayflow span (id) and the corresponding agent spec span self.agentspec_spans_registry: Dict[str, AgentSpecSpan] = {} # As we need to store agent spec objects in the agent spec spans and events, we need to perform conversions - self.agentspec_exporter: AgentSpecExporter = AgentSpecExporter() + self.agentspec_exporter: AgentSpecExporter = AgentSpecExporter(plugins=plugins) # We keep a registry of conversions, so that we do not repeat the conversion for the same object twice self.agentspec_components_registry: Dict[str, AgentSpecComponent] = {} # Track last assistant message id and a robust mapping tool_request_id -> assistant message id. @@ -76,6 +112,16 @@ def __init__(self) -> None: self._last_assistant_message_id: Union[str, None] = None self._tool_to_message: Dict[str, Optional[str]] = {} + def dump_tracing_model( + self, model: Union[AgentSpecEvent, AgentSpecSpan], mask_sensitive_information: bool = True + ) -> Dict[str, object]: + """Serialize an Agent Spec tracing span/event with the listener's plugin set.""" + return dump_tracing_model( + model=model, + mask_sensitive_information=mask_sensitive_information, + plugins=self.plugins, + ) + def _convert_to_agentspec(self, component: Component) -> AgentSpecComponent: if component.id not in self.agentspec_components_registry: self.agentspec_components_registry[component.id] = self.agentspec_exporter.to_component( @@ -340,7 +386,11 @@ def __call__(self, event: Event) -> None: AgentSpecFlow, self._convert_to_agentspec(event.conversational_component) ) if isinstance(event.execution_status, FinishedStatus): - branch_selected = event.execution_status.complete_step_name + branch_selected = ( + event.execution_status.complete_step_name + or event.execution_status._final_step_name + or "" + ) outputs = event.execution_status.output_values else: branch_selected = "" diff --git a/wayflowcore/src/wayflowcore/executors/_flowexecutor.py b/wayflowcore/src/wayflowcore/executors/_flowexecutor.py index 8f164b8ea..eb572ff22 100644 --- a/wayflowcore/src/wayflowcore/executors/_flowexecutor.py +++ b/wayflowcore/src/wayflowcore/executors/_flowexecutor.py @@ -737,6 +737,7 @@ async def _execute_flow( logger.debug("Interrupts received: %s", execution_interrupts) flow_state = conversation.state last_complete_step_name_executed = None + last_step_name_executed = None try: @@ -756,6 +757,7 @@ async def _execute_flow( ) current_step = flow_state.flow.steps[flow_state.current_step_name] + last_step_name_executed = flow_state.current_step_name if isinstance(current_step, CompleteStep): last_complete_step_name_executed = flow_state.current_step_name @@ -918,6 +920,7 @@ async def _execute_flow( return FinishedStatus( output_values=outputs, complete_step_name=last_complete_step_name_executed, + _final_step_name=last_step_name_executed, _conversation_id=conversation.id, ) diff --git a/wayflowcore/src/wayflowcore/executors/executionstatus.py b/wayflowcore/src/wayflowcore/executors/executionstatus.py index a314e1f19..e4975b309 100644 --- a/wayflowcore/src/wayflowcore/executors/executionstatus.py +++ b/wayflowcore/src/wayflowcore/executors/executionstatus.py @@ -45,8 +45,10 @@ class FinishedStatus(ExecutionStatus): output_values: Dict[str, Any] """The outputs produced by the agent or flow returning this execution status.""" complete_step_name: Optional[str] = None - """The name of the last step reached if the flow returning this execution status transitioned \ + """The name of the last step reached if the flow returning this execution status transitioned to a ``CompleteStep``, otherwise ``None``.""" + _final_step_name: Optional[str] = None + """The name of the last executed step, including flows that end via a transition to ``None``.""" @property def _requires_yielding(self) -> bool: @@ -56,6 +58,7 @@ def _serialize_to_dict(self, serialization_context: "SerializationContext") -> D return { "output_values": self.output_values, "complete_step_name": self.complete_step_name, + "_final_step_name": self._final_step_name, "_conversation_id": self._conversation_id, "id": self.id, } @@ -66,7 +69,8 @@ def _deserialize_from_dict( ) -> "SerializableObject": return FinishedStatus( output_values=input_dict["output_values"], - complete_step_name=input_dict["complete_step_name"], + complete_step_name=input_dict.get("complete_step_name"), + _final_step_name=input_dict.get("_final_step_name", input_dict.get("final_step_name")), _conversation_id=input_dict.get("_conversation_id", None), id=input_dict.get("id") or IdGenerator.get_or_generate_id(), ) diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_components.py b/wayflowcore/src/wayflowcore/serialization/_builtins_components.py index 04a04818f..456d01512 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_components.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_components.py @@ -203,5 +203,5 @@ "_PythonMergeToolRequestAndCallsTransform", "_ReactMergeToolRequestAndCallsTransform", "_TokenConsumptionEvent", - "_ToolRequestAndCallsTransform", + "ToolRequestAndCallsTransform", } diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py index 1ffdb3c37..d562ba96d 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py @@ -305,7 +305,7 @@ PluginSplitPromptOnMarkerMessageTransform as AgentSpecPluginSplitPromptOnMarkerMessageTransform, ) from wayflowcore.agentspec.components.transforms import ( - PluginSwarmToolRequestAndCallsTransform as AgentSpecPluginSwarmToolRequestAndCallsTransform, + PluginToolRequestAndCallsTransform as AgentSpecPluginToolRequestAndCallsTransform, ) from wayflowcore.contextproviders.constantcontextprovider import ( ConstantContextProvider as RuntimeConstantContextProvider, @@ -445,8 +445,8 @@ from wayflowcore.swarm import HandoffMode as RuntimeHandoffMode from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.templates import PromptTemplate as RuntimePromptTemplate -from wayflowcore.templates._swarmtemplate import ( - _ToolRequestAndCallsTransform as RuntimeSwarmToolRequestAndCallsTransform, +from wayflowcore.templates.agenticpatterntemplate import ( + ToolRequestAndCallsTransform as RuntimeToolRequestAndCallsTransform, ) from wayflowcore.templates.llamatemplates import ( _LlamaMergeToolRequestAndCallsTransform as RuntimeLlamaMergeToolRequestAndCallsTransform, @@ -1921,8 +1921,8 @@ class SupportsTimeoutKwargs(TypedDict, total=False): return RuntimeReactMergeToolRequestAndCallsTransform( **self._get_component_arguments(agentspec_component) ) - elif isinstance(agentspec_component, AgentSpecPluginSwarmToolRequestAndCallsTransform): - return RuntimeSwarmToolRequestAndCallsTransform( + elif isinstance(agentspec_component, AgentSpecPluginToolRequestAndCallsTransform): + return RuntimeToolRequestAndCallsTransform( **self._get_component_arguments(agentspec_component) ) elif isinstance(agentspec_component, AgentSpecPluginCanonicalizationMessageTransform): diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py index 698034358..47bbf26d3 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py @@ -296,7 +296,7 @@ PluginSplitPromptOnMarkerMessageTransform as AgentSpecPluginSplitPromptOnMarkerMessageTransform, ) from wayflowcore.agentspec.components.transforms import ( - PluginSwarmToolRequestAndCallsTransform as AgentSpecPluginSwarmToolRequestAndCallsTransform, + PluginToolRequestAndCallsTransform as AgentSpecPluginToolRequestAndCallsTransform, ) from wayflowcore.contextproviders import ContextProvider as RuntimeContextProvider from wayflowcore.contextproviders.constantcontextprovider import ( @@ -447,8 +447,8 @@ ) from wayflowcore.swarm import Swarm as RuntimeSwarm from wayflowcore.templates import PromptTemplate as RuntimePromptTemplate -from wayflowcore.templates._swarmtemplate import ( - _ToolRequestAndCallsTransform as RuntimeSwarmToolRequestAndCallsTransform, +from wayflowcore.templates.agenticpatterntemplate import ( + ToolRequestAndCallsTransform as RuntimeToolRequestAndCallsTransform, ) from wayflowcore.templates.llamatemplates import ( _LlamaMergeToolRequestAndCallsTransform as RuntimeLlamaMergeToolRequestAndCallsTransform, @@ -1741,9 +1741,9 @@ def _messagetransform_convert_to_agentspec( runtime_messagetransform ), ) - elif isinstance(runtime_messagetransform, RuntimeSwarmToolRequestAndCallsTransform): - return AgentSpecPluginSwarmToolRequestAndCallsTransform( - name="swarmtoolrequestandcalls_messagetransform", + elif isinstance(runtime_messagetransform, RuntimeToolRequestAndCallsTransform): + return AgentSpecPluginToolRequestAndCallsTransform( + name="toolrequestandcalls_messagetransform", metadata=_create_agentspec_metadata_from_runtime_component( runtime_messagetransform ), diff --git a/wayflowcore/src/wayflowcore/serialization/serializer.py b/wayflowcore/src/wayflowcore/serialization/serializer.py index 37ca23b45..fcb770db3 100644 --- a/wayflowcore/src/wayflowcore/serialization/serializer.py +++ b/wayflowcore/src/wayflowcore/serialization/serializer.py @@ -198,6 +198,15 @@ def _resolve_legacy_field_name(cls: type, field_name: str) -> str: return field_name +def _resolve_legacy_configurations(serialized_config: str) -> str: + """ + Normalize legacy WayFlow component names before native deserialization. + """ + return serialized_config.replace( + "_ToolRequestAndCallsTransform", "ToolRequestAndCallsTransform" + ) + + class SerializableDataclassMixin: def _serialize_to_dict(self, serialization_context: "SerializationContext") -> Dict[str, Any]: return { @@ -681,6 +690,7 @@ def deserialize( UserWarning, ) + obj = _resolve_legacy_configurations(obj) obj_as_dict: Dict[str, Any] = yaml.safe_load(obj) component_type: str = obj_as_dict["_component_type"] @@ -733,6 +743,7 @@ def autodeserialize( UserWarning, ) + obj = _resolve_legacy_configurations(obj) obj_as_dict: Dict[str, Any] = yaml.safe_load(obj) return autodeserialize_from_dict(obj_as_dict, deserialization_context) diff --git a/wayflowcore/src/wayflowcore/templates/_managerworkerstemplate.py b/wayflowcore/src/wayflowcore/templates/_managerworkerstemplate.py index 8cc43c109..cd86f9522 100644 --- a/wayflowcore/src/wayflowcore/templates/_managerworkerstemplate.py +++ b/wayflowcore/src/wayflowcore/templates/_managerworkerstemplate.py @@ -4,20 +4,16 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -import json -from textwrap import shorten -from typing import List, Tuple +from typing import Tuple -from wayflowcore._utils.formatting import format_tool_output_for_llm -from wayflowcore.messagelist import Message, TextContent from wayflowcore.outputparser import JsonToolOutputParser from wayflowcore.serialization.serializer import SerializableObject from wayflowcore.templates import PromptTemplate +from wayflowcore.templates.agenticpatterntemplate import ToolRequestAndCallsTransform from wayflowcore.templates.template import _TOOL_OUTPUT_SYSTEM_RULE from wayflowcore.transforms import ( AppendTrailingSystemMessageToUserMessageTransform, CoalesceSystemMessagesTransform, - MessageTransform, ) _DEFAULT_MANAGERWORKERS_SYSTEM_PROMPT = ( @@ -100,85 +96,6 @@ """ + _TOOL_OUTPUT_SYSTEM_RULE).strip() -_MAX_CHAR_TOOL_RESULT_HEADER = 140 -"""Max number of characters in the message header when formatting a Tool Result""" - - -class _ToolRequestAndCallsTransform(MessageTransform): - def __call__(self, messages: List["Message"]) -> List["Message"]: - """ - Format Tool requests as Agent messages and Tool results as clearly-labelled - tool-result User messages to have a simple User/Agent sequence of - messages. - """ - from wayflowcore import Message, MessageType - - tool_request_by_id = { # Mapping for fast lookup - tool_request.tool_request_id: tool_request - for msg in messages - if msg.message_type == MessageType.TOOL_REQUEST and msg.tool_requests - for tool_request in msg.tool_requests - } - - formatted_messages = [] - for message in messages: - if message.message_type == MessageType.TOOL_RESULT: - # Find corresponding ToolRequest by tool_request_id - if not message.tool_result: - raise ValueError(f"TOOL_RESULT message must contain tool_result: {message}") - tool_request_id = message.tool_result.tool_request_id - tool_request = tool_request_by_id.get(tool_request_id) - if not tool_request: - raise ValueError( - f"Could not find matching ToolRequest for TOOL_RESULT with id: {tool_request_id}" - ) - - message_header_tool_info = shorten( - f"name={tool_request.name}, parameters={tool_request.args}", - width=_MAX_CHAR_TOOL_RESULT_HEADER, - placeholder=" ...}", - ) - formatted_messages.append( - Message( - content=( - f"--- TOOL RESULT: {message_header_tool_info} ---\n" - f"{format_tool_output_for_llm(message.tool_result.content)}" - ), - message_type=MessageType.USER, - ) - ) - - elif message.message_type is MessageType.TOOL_REQUEST: - if not message.tool_requests: - raise ValueError( - "Message is of type TOOL_REQUEST but has no tool_requests. This should be reported." - ) - - formatted_tool_calls = "\n".join( - json.dumps({"name": tool_request.name, "parameters": tool_request.args}) - for tool_request in message.tool_requests - ) - for tool_request in message.tool_requests: - formatted_messages.append( - Message( - content=( - f"--- MESSAGE: From: {message.sender} ---\n" - f"{message.content}\n" - f"{formatted_tool_calls}" - ), - message_type=MessageType.AGENT, - ) - ) - elif message.message_type == MessageType.SYSTEM: - formatted_messages.append(message) - else: - message_copy = message.copy() - message_copy.contents.insert( - 0, TextContent(f"--- MESSAGE: From: {message_copy.sender} ---\n") - ) - formatted_messages.append(message_copy) - return formatted_messages - class ManagerWorkersJsonToolOutputParser(JsonToolOutputParser, SerializableObject): def parse_thoughts_and_calls(self, raw_txt: str) -> Tuple[str, str]: @@ -197,7 +114,7 @@ def parse_thoughts_and_calls(self, raw_txt: str) -> Tuple[str, str]: ], native_tool_calling=False, post_rendering_transforms=[ - _ToolRequestAndCallsTransform(), + ToolRequestAndCallsTransform(), CoalesceSystemMessagesTransform(), AppendTrailingSystemMessageToUserMessageTransform(), ], diff --git a/wayflowcore/src/wayflowcore/templates/_swarmtemplate.py b/wayflowcore/src/wayflowcore/templates/_swarmtemplate.py index f67a5b791..80c2fedf9 100644 --- a/wayflowcore/src/wayflowcore/templates/_swarmtemplate.py +++ b/wayflowcore/src/wayflowcore/templates/_swarmtemplate.py @@ -4,20 +4,17 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. -import json -from textwrap import shorten -from typing import List, Tuple +from typing import Tuple -from wayflowcore._utils.formatting import format_tool_output_for_llm -from wayflowcore.messagelist import Message, MessageType, TextContent +from wayflowcore.messagelist import Message, MessageType from wayflowcore.outputparser import JsonToolOutputParser from wayflowcore.serialization.serializer import SerializableObject from wayflowcore.templates import PromptTemplate +from wayflowcore.templates.agenticpatterntemplate import ToolRequestAndCallsTransform from wayflowcore.templates.template import _TOOL_OUTPUT_SYSTEM_RULE from wayflowcore.transforms import ( AppendTrailingSystemMessageToUserMessageTransform, CoalesceSystemMessagesTransform, - MessageTransform, ) _DEFAULT_SWARM_SYSTEM_PROMPT = ( @@ -119,96 +116,6 @@ def _is_system_reminder(message: "Message") -> bool: Simply continue the conversation from where the previous agent left off. """.strip() -_MAX_CHAR_TOOL_RESULT_HEADER = 140 -"""Max number of characters in the message header when formatting a Tool Result""" - - -class _ToolRequestAndCallsTransform(MessageTransform): - def __call__(self, messages: List["Message"]) -> List["Message"]: - """ - Format Tool requests as Agent messages and Tool results as clearly-labelled - tool-result User messages to have a simple User/Agent sequence of - messages. - """ - from wayflowcore import Message, MessageType - - tool_request_by_id = { # Mapping for fast lookup - tool_request.tool_request_id: tool_request - for msg in messages - if msg.message_type is MessageType.TOOL_REQUEST and msg.tool_requests - for tool_request in msg.tool_requests - } - - formatted_messages = [] - for message in messages: - if message.message_type == MessageType.TOOL_RESULT: - # Find corresponding ToolRequest by tool_request_id - if not message.tool_result: - raise ValueError(f"TOOL_RESULT message must contain tool_result: {message}") - tool_request_id = message.tool_result.tool_request_id - tool_request = tool_request_by_id.get(tool_request_id) - if not tool_request: - raise ValueError( - f"Could not find matching ToolRequest for TOOL_RESULT with id: {tool_request_id}" - ) - - message_header_tool_info = shorten( - f"name={tool_request.name}, parameters={tool_request.args}", - width=_MAX_CHAR_TOOL_RESULT_HEADER, - placeholder=" ...}", - ) - formatted_messages.append( - Message( - content=( - f"--- TOOL RESULT: {message_header_tool_info} ---\n" - f"{format_tool_output_for_llm(message.tool_result.content)}" - ), - message_type=MessageType.USER, - ) - ) - - elif message.message_type == MessageType.TOOL_REQUEST: - if not message.tool_requests: - raise ValueError( - "Message is of type TOOL_REQUEST but has no tool_requests. This should be reported." - ) - - formatted_tool_calls = "\n".join( - json.dumps({"name": tool_request.name, "parameters": tool_request.args}) - for tool_request in message.tool_requests - ) - - header = f"--- MESSAGE: From: {message.sender} ---\n" - content = ( - message.content # sometimes the llm outputs this header automatically -> no need to add it. - if message.content.startswith(header) - else f"{header}{message.content}" - ) - - formatted_messages.append( - Message( - content=( - f"{content}\n{formatted_tool_calls}" - if formatted_tool_calls not in content - else f"{content}" - ), - message_type=MessageType.AGENT, - ) - ) - elif message.message_type == MessageType.SYSTEM: - formatted_messages.append(message) - else: - message_copy = message.copy() - if message_copy.role == "user" and not message_copy.sender: - # If the message's sender is None, it is from the HUMAN USER - message_copy.sender = "HUMAN USER" - - message_copy.contents.insert( - 0, TextContent(f"--- MESSAGE: From: {message_copy.sender} ---\n") - ) - formatted_messages.append(message_copy) - return formatted_messages - class SwarmJsonToolOutputParser(JsonToolOutputParser, SerializableObject): def parse_thoughts_and_calls(self, raw_txt: str) -> Tuple[str, str]: @@ -227,7 +134,7 @@ def parse_thoughts_and_calls(self, raw_txt: str) -> Tuple[str, str]: ], native_tool_calling=False, post_rendering_transforms=[ - _ToolRequestAndCallsTransform(), + ToolRequestAndCallsTransform(), CoalesceSystemMessagesTransform(), AppendTrailingSystemMessageToUserMessageTransform(), ], diff --git a/wayflowcore/src/wayflowcore/templates/agenticpatterntemplate.py b/wayflowcore/src/wayflowcore/templates/agenticpatterntemplate.py new file mode 100644 index 000000000..81683b02f --- /dev/null +++ b/wayflowcore/src/wayflowcore/templates/agenticpatterntemplate.py @@ -0,0 +1,97 @@ +# Copyright © 2025, 2026 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. + +import json +from textwrap import shorten +from typing import List + +from wayflowcore._utils.formatting import format_tool_output_for_llm +from wayflowcore.messagelist import Message, TextContent +from wayflowcore.transforms import MessageTransform + +_MAX_CHAR_TOOL_RESULT_HEADER = 140 +"""Max number of characters in the message header when formatting a Tool Result""" + + +class ToolRequestAndCallsTransform(MessageTransform): + def __call__(self, messages: List["Message"]) -> List["Message"]: + """ + Format tool requests as agent messages and tool results as user messages so the + conversation remains a simple user/agent sequence. + """ + from wayflowcore import Message, MessageType + + tool_request_by_id = { + tool_request.tool_request_id: tool_request + for msg in messages + if msg.message_type is MessageType.TOOL_REQUEST and msg.tool_requests + for tool_request in msg.tool_requests + } + + formatted_messages = [] + for message in messages: + if message.message_type == MessageType.TOOL_RESULT: + if not message.tool_result: + raise ValueError(f"TOOL_RESULT message must contain tool_result: {message}") + tool_request_id = message.tool_result.tool_request_id + tool_request = tool_request_by_id.get(tool_request_id) + if not tool_request: + raise ValueError( + f"Could not find matching ToolRequest for TOOL_RESULT with id: {tool_request_id}" + ) + + message_header_tool_info = shorten( + f"name={tool_request.name}, parameters={tool_request.args}", + width=_MAX_CHAR_TOOL_RESULT_HEADER, + placeholder=" ...}", + ) + formatted_messages.append( + Message( + content=( + f"--- TOOL RESULT: {message_header_tool_info} ---\n" + f"{format_tool_output_for_llm(message.tool_result.content)}" + ), + message_type=MessageType.USER, + ) + ) + elif message.message_type == MessageType.TOOL_REQUEST: + if not message.tool_requests: + raise ValueError( + "Message is of type TOOL_REQUEST but has no tool_requests. This should be reported." + ) + + formatted_tool_calls = "\n".join( + json.dumps({"name": tool_request.name, "parameters": tool_request.args}) + for tool_request in message.tool_requests + ) + + header = f"--- MESSAGE: From: {message.sender} ---\n" + content = ( + message.content + if message.content.startswith(header) + else f"{header}{message.content}" + ) + formatted_messages.append( + Message( + content=( + f"{content}\n{formatted_tool_calls}" + if formatted_tool_calls not in content + else f"{content}" + ), + message_type=MessageType.AGENT, + ) + ) + elif message.message_type == MessageType.SYSTEM: + formatted_messages.append(message) + else: + message_copy = message.copy() + if message_copy.role == "user" and not message_copy.sender: + message_copy.sender = "HUMAN USER" + message_copy.contents.insert( + 0, TextContent(f"--- MESSAGE: From: {message_copy.sender} ---\n") + ) + formatted_messages.append(message_copy) + return formatted_messages diff --git a/wayflowcore/tests/agentspec/test_agentspec_conversion_coverage.py b/wayflowcore/tests/agentspec/test_agentspec_conversion_coverage.py index fe5934610..15e1eac3c 100644 --- a/wayflowcore/tests/agentspec/test_agentspec_conversion_coverage.py +++ b/wayflowcore/tests/agentspec/test_agentspec_conversion_coverage.py @@ -119,7 +119,6 @@ "DoNothingStep", # internal transforms not supported in AgentSpec "_PythonMergeToolRequestAndCallsTransform", - "_ToolRequestAndCallsTransform", # TODO: Support these in the future "ToolFromToolBox", # Requires search config to be set up along with an embedding model "DatastoreQueryStep", # requires a relational datastore, we only have oracledb and it requires a connection to create the object diff --git a/wayflowcore/tests/agentspec/test_tracing.py b/wayflowcore/tests/agentspec/test_tracing.py index f265b7bae..cec20c222 100644 --- a/wayflowcore/tests/agentspec/test_tracing.py +++ b/wayflowcore/tests/agentspec/test_tracing.py @@ -4,7 +4,7 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import asyncio -from typing import List, Tuple, cast +from typing import Annotated, List, Tuple, cast import pytest from pyagentspec.agent import Agent as AgentSpecAgent @@ -38,15 +38,22 @@ from pyagentspec.tracing.trace import Trace as AgentSpecTrace from wayflowcore import Agent, Flow -from wayflowcore.agentspec import AgentSpecLoader -from wayflowcore.agentspec.tracing import AgentSpecEventListener +from wayflowcore._utils._templating_helpers import render_template +from wayflowcore.agentspec import AgentSpecExporter, AgentSpecLoader +from wayflowcore.agentspec.tracing import AgentSpecEventListener, dump_tracing_model from wayflowcore.events.eventlistener import register_event_listeners from wayflowcore.executors.executionstatus import ( ExecutionStatus, FinishedStatus, UserMessageRequestStatus, ) +from wayflowcore.managerworkers import ManagerWorkers +from wayflowcore.messagelist import Message, MessageType +from wayflowcore.serialization import deserialize, serialize +from wayflowcore.steps import CompleteStep, ConstantValuesStep, InputMessageStep, StartStep +from wayflowcore.tools import tool +from ..conftest import mock_llm from ..testhelpers.patching import patch_llm @@ -283,6 +290,131 @@ def test_agentspec_flow_raises_correct_events(wayflow_flow: Flow): assert all(e.request_id == response_event.request_id for e in llm_events) +def test_agentspec_flow_end_uses_last_executed_step_when_flow_ends_on_none_transition(): + flow = Flow.from_steps( + [ + StartStep(name="start"), + ConstantValuesStep(name="finish", constant_values={}), + ] + ) + + listener = AgentSpecEventListener() + span_processor = DummyAgentSpecSpanProcessor() + + with AgentSpecTrace(span_processors=[span_processor]): + with register_event_listeners([listener]): + status = flow.start_conversation().execute() + + assert isinstance(status, FinishedStatus) + assert status.complete_step_name is None + assert status._final_step_name == "finish" + + flow_span = next(s for s in span_processor.starts if isinstance(s, AgentSpecFlowExecutionSpan)) + end_event = next(e for e in flow_span.events if isinstance(e, AgentSpecFlowExecutionEnd)) + assert end_event.branch_selected == "finish" + + +def test_finished_status_serialization_round_trip_keeps_final_step_name(): + status = FinishedStatus( + output_values={"value": "ok"}, + complete_step_name=None, + _final_step_name="finish", + _conversation_id="conv-id", + ) + + round_tripped_status = deserialize(FinishedStatus, serialize(status)) + + assert isinstance(round_tripped_status, FinishedStatus) + assert round_tripped_status.complete_step_name is None + assert round_tripped_status._final_step_name == "finish" + assert round_tripped_status.output_values == {"value": "ok"} + + +def test_finished_status_deserialization_supports_legacy_final_step_name(): + status = FinishedStatus( + output_values={"value": "ok"}, + complete_step_name=None, + _final_step_name="finish", + _conversation_id="conv-id", + ) + + legacy_serialized_status = serialize(status).replace("_final_step_name", "final_step_name") + + round_tripped_status = deserialize(FinishedStatus, legacy_serialized_status) + + assert isinstance(round_tripped_status, FinishedStatus) + assert round_tripped_status._final_step_name == "finish" + + +def test_dump_tracing_model_supports_plugin_components(): + flow = Flow.from_steps( + [ + StartStep(name="start"), + InputMessageStep(name="ask_name", message_template="What is your name?"), + CompleteStep(name="end"), + ] + ) + + event = AgentSpecFlowExecutionStart(flow=AgentSpecExporter().to_component(flow), inputs={}) + + with pytest.raises(Exception, match="PluginInputMessageNode"): + event.model_dump() + + dumped_event = dump_tracing_model(event) + assert dumped_event["flow"]["nodes"][1]["component_type"] == "PluginInputMessageNode" + + listener = AgentSpecEventListener() + dumped_with_listener = listener.dump_tracing_model(event) + assert dumped_with_listener["flow"]["nodes"][1]["component_type"] == "PluginInputMessageNode" + + +def test_agentspec_tracing_supports_managerworkers_template_transform(): + llm = mock_llm() + + @tool + def say_hello(user_name: Annotated[str, "Name of the user"]) -> str: + """Return a greeting.""" + return f"Hello {user_name}!" + + manager_agent = Agent(name="manager", description="manager agent", tools=[say_hello], llm=llm) + worker = Agent(name="worker", description="worker", llm=llm) + group = ManagerWorkers(workers=[worker], group_manager=manager_agent) + + listener = AgentSpecEventListener() + span_processor = DummyAgentSpecSpanProcessor() + conversation = group.start_conversation() + conversation.append_user_message("Dummy") + + with patch_llm( + llm, + outputs=[ + Message( + render_template( + """ +{{thoughts}} + +{"name": {{tool_name}}, "parameters": {{tool_params}}} +""".strip(), + inputs=dict( + thoughts="", + tool_name="say_hello", + tool_params={"user_name": "Iris"}, + ), + ), + message_type=MessageType.AGENT, + ), + "Dummy", + ], + patch_internal=True, + ): + with AgentSpecTrace(span_processors=[span_processor]): + with register_event_listeners([listener]): + status = conversation.execute() + + assert isinstance(status, UserMessageRequestStatus) + assert any(isinstance(span, AgentSpecAgentExecutionSpan) for span in span_processor.starts) + + def test_agentspec_flow_async_raises_correct_events(wayflow_flow: Flow): # Retrieve the wayflow LLM step to patch its LLM calls diff --git a/wayflowcore/tests/agentspec/test_transforms.py b/wayflowcore/tests/agentspec/test_transforms.py index e2d0c7382..f72c5fd73 100644 --- a/wayflowcore/tests/agentspec/test_transforms.py +++ b/wayflowcore/tests/agentspec/test_transforms.py @@ -13,6 +13,7 @@ import pytest from pyagentspec.datastores.datastore import InMemoryCollectionDatastore from pyagentspec.llms import VllmConfig +from pyagentspec.serialization import AgentSpecSerializer from pyagentspec.transforms import ( ConversationSummarizationTransform as AgentSpecConversationSummarizationTransform, ) @@ -21,9 +22,18 @@ ) from wayflowcore.agentspec.agentspecexporter import AgentSpecExporter +from wayflowcore.agentspec.components.transforms import ( + PluginToolRequestAndCallsTransform as AgentSpecToolRequestAndCallsTransform, +) +from wayflowcore.agentspec.components.transforms import ( + messagetransform_serialization_plugin, +) from wayflowcore.agentspec.runtimeloader import AgentSpecLoader from wayflowcore.datastore import InMemoryDatastore from wayflowcore.datastore.inmemory import _INMEMORY_USER_WARNING +from wayflowcore.templates.agenticpatterntemplate import ( + ToolRequestAndCallsTransform as WayflowToolRequestAndCallsTransform, +) from wayflowcore.transforms import ( ConversationSummarizationTransform as WayflowConversationSummarizationTransform, ) @@ -218,6 +228,51 @@ def test_agentspec_summarization_conversation_transform_can_be_converted_to_wayf ) +def test_wayflow_tool_request_and_calls_transform_can_be_converted_to_agentspec(): + converted_transform = AgentSpecExporter().to_component(WayflowToolRequestAndCallsTransform()) + + assert type(converted_transform) is AgentSpecToolRequestAndCallsTransform + assert converted_transform.name == "toolrequestandcalls_messagetransform" + + +def test_agentspec_tool_request_and_calls_transform_can_be_converted_to_wayflow(): + converted_transform = AgentSpecLoader().load_component( + AgentSpecToolRequestAndCallsTransform(name="toolrequestandcalls_messagetransform") + ) + + assert type(converted_transform) is WayflowToolRequestAndCallsTransform + + +def test_legacy_agentspec_swarm_tool_request_and_calls_transform_is_upgraded(): + serialized_transform = AgentSpecSerializer( + plugins=[messagetransform_serialization_plugin] + ).to_json(AgentSpecToolRequestAndCallsTransform(name="toolrequestandcalls_messagetransform")) + serialized_transform = serialized_transform.replace( + "PluginToolRequestAndCallsTransform", + "PluginSwarmToolRequestAndCallsTransform", + 1, + ) + + converted_transform = AgentSpecLoader().load_json(serialized_transform) + + assert type(converted_transform) is WayflowToolRequestAndCallsTransform + + +def test_legacy_agentspec_swarm_tool_request_and_calls_transform_is_upgraded_from_yaml(): + serialized_transform = AgentSpecSerializer( + plugins=[messagetransform_serialization_plugin] + ).to_yaml(AgentSpecToolRequestAndCallsTransform(name="toolrequestandcalls_messagetransform")) + serialized_transform = serialized_transform.replace( + "PluginToolRequestAndCallsTransform", + "PluginSwarmToolRequestAndCallsTransform", + 1, + ) + + converted_transform = AgentSpecLoader().load_yaml(serialized_transform) + + assert type(converted_transform) is WayflowToolRequestAndCallsTransform + + def assert_message_summarization_transforms_are_equal(converted_transform, expected_transform): converted_datastore = ( converted_transform.datastore diff --git a/wayflowcore/tests/serialization/test_managerworkers_serialization.py b/wayflowcore/tests/serialization/test_managerworkers_serialization.py index de5e20938..323f0c4c3 100644 --- a/wayflowcore/tests/serialization/test_managerworkers_serialization.py +++ b/wayflowcore/tests/serialization/test_managerworkers_serialization.py @@ -7,6 +7,7 @@ from copy import deepcopy import pytest +import yaml from wayflowcore.agent import Agent from wayflowcore.executors._flowconversation import FlowConversation @@ -21,6 +22,7 @@ from wayflowcore.models.llmmodelfactory import LlmModelFactory from wayflowcore.serialization import deserialize, serialize, serialize_to_dict from wayflowcore.steps.agentexecutionstep import AgentExecutionStep +from wayflowcore.templates.agenticpatterntemplate import ToolRequestAndCallsTransform from wayflowcore.tools import ToolRequest from wayflowcore.transforms import RemoveEmptyNonUserMessageTransform @@ -224,6 +226,28 @@ def test_can_continue_a_deserialized_conversation(simple_managerworkers: Manager deser_conv.execute() +def test_legacy_managerworkers_transform_is_upgraded_on_deserialization( + simple_managerworkers: ManagerWorkers, +): + serialized_managerworkers = serialize_to_dict(simple_managerworkers) + managerworkers_template_ref = serialized_managerworkers["managerworkers_template"]["$ref"] + serialized_managerworkers["_referenced_objects"][managerworkers_template_ref][ + "post_rendering_transforms" + ][0]["_component_type"] = "_ToolRequestAndCallsTransform" + + deserialized_managerworkers = deserialize( + ManagerWorkers, yaml.safe_dump(serialized_managerworkers, sort_keys=False) + ) + + assert isinstance( + deserialized_managerworkers.managerworkers_template.post_rendering_transforms[0], + ToolRequestAndCallsTransform, + ) + reserialized_managerworkers = serialize(deserialized_managerworkers) + assert "_component_type: ToolRequestAndCallsTransform" in reserialized_managerworkers + assert "_component_type: _ToolRequestAndCallsTransform" not in reserialized_managerworkers + + def test_deserialized_conversation_does_not_duplicate_internal_tool_results() -> None: manager_llm = LlmModelFactory.from_config(deepcopy(VLLM_MODEL_CONFIG)) addition_llm = LlmModelFactory.from_config(deepcopy(GEMMA_CONFIG)) diff --git a/wayflowcore/tests/serialization/test_serializableobject.py b/wayflowcore/tests/serialization/test_serializableobject.py index bc5e2d60c..9f48ebba5 100644 --- a/wayflowcore/tests/serialization/test_serializableobject.py +++ b/wayflowcore/tests/serialization/test_serializableobject.py @@ -191,6 +191,7 @@ "Tool", "ToolBox", "ToolContextProvider", + "ToolRequestAndCallsTransform", "ToolExecutionConfirmationStatus", "ToolExecutionStep", "ToolOutputParser", @@ -213,15 +214,13 @@ "_TokenConsumptionEvent", "VectorConfig", "VectorRetrieverConfig", - "_ToolRequestAndCallsTransform", } def test_componentregistry_is_complete(tmp_path): # We need to run this in a separate script to avoid that creating classes in tests poison the registry of components all_classes_str = "{" + ", ".join(f'"{c}"' for c in ALL_SERIALIZABLE_CLASSES) + "}" - script = dedent( - f""" + script = dedent(f""" from wayflowcore.serialization.serializer import SerializableObject, _import_all_submodules _import_all_submodules("wayflowcore") @@ -242,8 +241,7 @@ def test_componentregistry_is_complete(tmp_path): lines.append(f" Extra ({{len(extra)}}):") lines.extend(f" + {{name}}" for name in extra) raise AssertionError("\\n".join(lines)) - """ - ) + """) testfile = tmp_path / "_temp_test_componentregistry_is_complete.py" with open(testfile, "w") as f: f.write(script) @@ -255,8 +253,7 @@ def test_componentregistry_is_complete(tmp_path): def test_all_components_are_builtin_components(tmp_path): # We need to run this in a separate script to avoid that creating classes in tests poison the registry of components - script = dedent( - """ + script = dedent(""" from wayflowcore.serialization._builtins_components import _BUILTIN_COMPONENTS from wayflowcore.serialization.serializer import SerializableObject, _import_all_submodules @@ -264,8 +261,7 @@ def test_all_components_are_builtin_components(tmp_path): component_registry = SerializableObject._COMPONENT_REGISTRY assert set(_BUILTIN_COMPONENTS) == set(component_registry) - """ - ) + """) testfile = tmp_path / "_temp_test_all_components_are_builtin_components.py" with open(testfile, "w") as f: f.write(script) diff --git a/wayflowcore/tests/serialization/test_swarm_serialization.py b/wayflowcore/tests/serialization/test_swarm_serialization.py index 54b511cae..dc1b050fc 100644 --- a/wayflowcore/tests/serialization/test_swarm_serialization.py +++ b/wayflowcore/tests/serialization/test_swarm_serialization.py @@ -17,6 +17,8 @@ ) from wayflowcore.serialization import deserialize, serialize, serialize_to_dict from wayflowcore.swarm import Swarm +from wayflowcore.templates._swarmtemplate import _DEFAULT_SWARM_CHAT_TEMPLATE +from wayflowcore.templates.agenticpatterntemplate import ToolRequestAndCallsTransform from wayflowcore.transforms import RemoveEmptyNonUserMessageTransform from ..conftest import _assert_config_are_equal @@ -245,3 +247,13 @@ def test_can_continue_a_deserialized_swarm_conversation(simple_swarm: Swarm) -> assert len(deser_conv.get_messages()) == conv_length_before_serialization deser_conv.append_user_message("Actually it's better now") deser_conv.execute() + + +def test_legacy_swarm_transform_keeps_swarm_deserialization_behavior(): + deserialized_template = deserialize( + type(_DEFAULT_SWARM_CHAT_TEMPLATE), serialize(_DEFAULT_SWARM_CHAT_TEMPLATE) + ) + + assert isinstance( + deserialized_template.post_rendering_transforms[0], ToolRequestAndCallsTransform + )