diff --git a/wayflowcore/src/wayflowcore/tools/toolhelpers.py b/wayflowcore/src/wayflowcore/tools/toolhelpers.py index 2591f98ab..54558c4d9 100644 --- a/wayflowcore/src/wayflowcore/tools/toolhelpers.py +++ b/wayflowcore/src/wayflowcore/tools/toolhelpers.py @@ -22,6 +22,7 @@ Union, get_args, get_origin, + overload, ) from wayflowcore.property import JsonSchemaParam, Property @@ -79,7 +80,9 @@ def _get_partial_schema_from_annotation(arg_type: Type[Any]) -> JsonSchemaParam: # Handle Dict[str, X] if origin is dict or origin is Dict: if len(args) != 2: - raise TypeError("Dict must have exactly two type arguments, e.g., Dict[str, int]") + raise TypeError( + "Dict must have exactly two type arguments, e.g., Dict[str, int]" + ) key_type, value_type = args if key_type is not str: raise TypeError("JSON object keys must be strings") @@ -104,7 +107,10 @@ def _get_partial_schema_from_annotation(arg_type: Type[Any]) -> JsonSchemaParam: if len(unique_json_types) > 1: return { "anyOf": [ - {"type": t, "enum": [v for v in args if PRIMITIVE_TYPE_MAP[type(v)] == t]} + { + "type": t, + "enum": [v for v in args if PRIMITIVE_TYPE_MAP[type(v)] == t], + } for t in unique_json_types ] } @@ -141,7 +147,7 @@ def _get_tool_schema_no_parsing( if "self" in tool_signature.parameters.keys(): raise TypeError( - f"The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead" + "The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead" ) # Determining the schema of input parameters @@ -172,7 +178,7 @@ def _get_tool_schema_no_parsing( raise TypeError(f"Return annotation is not specified for tool {tool_name}") if _is_annotated_type(output_annotation): raise TypeError( - f"Annotated types are not permitted when using the description mode `only_docstring`. " + "Annotated types are not permitted when using the description mode `only_docstring`. " f"Return annotation of tool {tool_name} has type `{output_annotation}`" ) output_schema = _get_partial_schema_from_annotation(output_annotation) @@ -187,7 +193,7 @@ def _get_tool_schema_from_parsed_signature( if "self" in tool_signature.parameters.keys(): raise TypeError( - f"The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead" + "The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead" ) # Determining the schema of input parameters @@ -205,7 +211,9 @@ def _get_tool_schema_from_parsed_signature( f"Description mode is `infer_from_signature` but parameter {param_name} of tool {tool_name} is not Annotated " f"(has type {annotated_param.annotation}). Either annotate the parameter or use the `only_docstring` description mode." ) - param_annotation, param_description = _unpack_annotated_types(annotated_param.annotation) + param_annotation, param_description = _unpack_annotated_types( + annotated_param.annotation + ) param_schema = _get_partial_schema_from_annotation(param_annotation) param_schema["description"] = param_description @@ -220,7 +228,9 @@ def _get_tool_schema_from_parsed_signature( raise TypeError(f"Return annotation is not specified for tool {tool_name}") output_annotation: Any if _is_annotated_type(annotated_output_type): - output_annotation, output_description = _unpack_annotated_types(annotated_output_type) + output_annotation, output_description = _unpack_annotated_types( + annotated_output_type + ) else: output_annotation, output_description = annotated_output_type, "" @@ -231,8 +241,55 @@ def _get_tool_schema_from_parsed_signature( return args_schema, output_schema +@overload +def tool( + func_or_name: str, + func: Callable[..., Any], + /, + description_mode: Literal[ + DescriptionMode.INFER_FROM_SIGNATURE, + DescriptionMode.ONLY_DOCSTRING, + DescriptionMode.EXTRACT_FROM_DOCSTRING, + ] = DescriptionMode.INFER_FROM_SIGNATURE, + output_descriptors: Optional[List[Property]] = None, + requires_confirmation: bool = False, +) -> ServerTool: ... + + +@overload +def tool( + func_or_name: Callable[..., Any], + func: None = None, + /, + description_mode: Literal[ + DescriptionMode.INFER_FROM_SIGNATURE, + DescriptionMode.ONLY_DOCSTRING, + DescriptionMode.EXTRACT_FROM_DOCSTRING, + ] = DescriptionMode.INFER_FROM_SIGNATURE, + output_descriptors: Optional[List[Property]] = None, + requires_confirmation: bool = False, +) -> ServerTool: ... + + +@overload +def tool( + func_or_name: str, + func: None = None, + /, + description_mode: Literal[ + DescriptionMode.INFER_FROM_SIGNATURE, + DescriptionMode.ONLY_DOCSTRING, + DescriptionMode.EXTRACT_FROM_DOCSTRING, + ] = DescriptionMode.INFER_FROM_SIGNATURE, + output_descriptors: Optional[List[Property]] = None, + requires_confirmation: bool = False, +) -> Callable[[Callable[..., Any]], ServerTool]: ... + + def tool( - *args: Union[str, Callable[..., Any]], + func_or_name: Callable[..., Any] | str, + func: Callable[..., Any] | None = None, + /, description_mode: Literal[ DescriptionMode.INFER_FROM_SIGNATURE, DescriptionMode.ONLY_DOCSTRING, @@ -390,61 +447,58 @@ def _make_tool( requires_confirmation=requires_confirmation, ) - # When used as a wrapper, `args` can be [tool_name, callable] or [callable] - # When used as a decorator, `args` can be [tool_name, callable] or [callable] - if len(args) == 2 and (isinstance(args[0], str) and callable(args[1])): + # When used as a wrapper, the function arguments can be [tool_name, callable] or [callable] + # When used as a decorator, the decorator arguments can be [tool_name, callable] or [callable] + if func is not None and isinstance(func_or_name, str) and callable(func): # Example case: wrapper with custom tool name # def my_callable(): # pass # my_tool = tool("my_callable1", my_callable) - # here args[0] is the tool name, and args[1] the callable + # here func_or_name is the tool name, and func the callable # we simply return the newly created ServerTool - tool_name = args[0] + tool_name = func_or_name return _make_tool( - args[1], tool_name, description_mode, output_descriptors, requires_confirmation + func, + tool_name, + description_mode, + output_descriptors, + requires_confirmation, ) - elif len(args) == 1 and isinstance(args[0], str): + elif isinstance(func_or_name, str): # Example case: decorator with custom tool name # @tool("my_callable1") # def my_callable(): # pass - # here args[0] is the tool name + # here func_or_name is the tool name # Upon instantiation, first the `tool` function is called, directly followed # by the `_partial_with_name` function being called, thus converting the # callable to a ServerTool - tool_name = args[0] + tool_name = func_or_name def _partial_with_name(func: Callable[..., Any]) -> ServerTool: return _make_tool( - func, tool_name, description_mode, output_descriptors, requires_confirmation + func, + tool_name, + description_mode, + output_descriptors, + requires_confirmation, ) return _partial_with_name - elif len(args) == 1 and callable(args[0]): + elif callable(func_or_name): # Example case: wrapper # def my_callable(): # pass # my_tool = tool(my_callable) - # here args[0] is the callable + # here func_or_name is the callable # we simply return the newly created ServerTool return _make_tool( - args[0], None, description_mode, output_descriptors, requires_confirmation + func_or_name, + None, + description_mode, + output_descriptors, + requires_confirmation, ) - elif len(args) == 0: - # Example case: decorator with user-specified description_mode - # @tool(description_mode='only_docstring') - # def my_callable(param1: int = 2) -> int: - # """Callable description""" - # return 0 - # Upon instantiation, first the `tool` function is called, directly followed - # by the `_partial_no_name` function being called, thus converting the - # callable to a ServerTool - def _partial_no_name(func: Callable[..., Any]) -> ServerTool: - return _make_tool( - func, None, description_mode, output_descriptors, requires_confirmation - ) - - return _partial_no_name else: raise ValueError("Invalid usage of the `tool` helper") @@ -477,7 +531,9 @@ def _to_react_template_dict(tool: Tool) -> Dict[str, str]: f"- {parameter_name}: {_find_json_schema_full_type(parameter_info)}" ) if "default" in parameter_info: - param_description += f" (Optional, default={parameter_info['default']})" + param_description += ( + f" (Optional, default={parameter_info['default']})" + ) else: param_description += " (Required)" if "description" in parameter_info: @@ -485,7 +541,9 @@ def _to_react_template_dict(tool: Tool) -> Dict[str, str]: parameter_descriptions.append(param_description) formatted_parameter_descriptions = "\n".join(parameter_descriptions) - description = tool_as_str + f"Parameters:\n{formatted_parameter_descriptions}" + description = ( + tool_as_str + f"Parameters:\n{formatted_parameter_descriptions}" + ) return { "name": tool.name, "description": description, diff --git a/wayflowcore/tests/conftest.py b/wayflowcore/tests/conftest.py index 9550f355d..8cf7ab51a 100644 --- a/wayflowcore/tests/conftest.py +++ b/wayflowcore/tests/conftest.py @@ -98,7 +98,7 @@ compartment_id = os.environ.get("COMPARTMENT_ID") if not compartment_id: - raise Exception("compartment_id is not set in the environment") + raise Exception("COMPARTMENT_ID is not set in the environment") oracle_http_proxy = os.environ.get("ORACLE_HTTP_PROXY") diff --git a/wayflowcore/tests/tools/test_tool_decorator.py b/wayflowcore/tests/tools/test_tool_decorator.py new file mode 100644 index 000000000..f8c76de0f --- /dev/null +++ b/wayflowcore/tests/tools/test_tool_decorator.py @@ -0,0 +1,53 @@ +from typing import Any, Callable +from typing_extensions import assert_type +from wayflowcore.tools.servertools import ServerTool +from wayflowcore.tools.toolhelpers import DescriptionMode, tool + + +def test_tool_decorator_result_type_is_correct() -> None: + # Wrapper with different name + tool_one = tool("tool_one") + assert_type(tool_one, Callable[[Callable[..., Any]], ServerTool]) + assert isinstance(tool_one, Callable) + + def actual_func() -> None: + """Actual func""" + + func_tool = tool_one(actual_func) + assert_type(func_tool, ServerTool) + assert isinstance(func_tool, ServerTool) + assert func_tool.name == "tool_one" + + # Decorator with different tool name + + @tool("real_function_name") + def another_func() -> None: + """Another func""" + + assert_type(another_func, ServerTool) + assert isinstance(another_func, ServerTool) + assert another_func.name == "real_function_name" + + # Wrapper with name and function passed as arguments + def func_two() -> None: + """Just a func""" + + tool_two = tool("tool_two", func_two) + assert_type(tool_two, ServerTool) + assert isinstance(tool_two, ServerTool) + + # Decorator with description mode + @tool("tool_three", description_mode=DescriptionMode.ONLY_DOCSTRING) + def tool_three() -> None: + """tool_three function""" + + assert_type(tool_three, ServerTool) + assert isinstance(tool_three, ServerTool) + + # Decorator with no arguments passed + @tool + def tool_four() -> None: + """tool_four function""" + + assert_type(tool_four, ServerTool) + assert isinstance(tool_four, ServerTool)