Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 96 additions & 38 deletions wayflowcore/src/wayflowcore/tools/toolhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Union,
get_args,
get_origin,
overload,
)

from wayflowcore.property import JsonSchemaParam, Property
Expand Down Expand Up @@ -79,7 +80,9 @@ def _get_partial_schema_from_annotation(arg_type: Type[Any]) -> JsonSchemaParam:
# Handle Dict[str, X]
if origin is dict or origin is Dict:
if len(args) != 2:
raise TypeError("Dict must have exactly two type arguments, e.g., Dict[str, int]")
raise TypeError(
"Dict must have exactly two type arguments, e.g., Dict[str, int]"
)
key_type, value_type = args
if key_type is not str:
raise TypeError("JSON object keys must be strings")
Expand All @@ -104,7 +107,10 @@ def _get_partial_schema_from_annotation(arg_type: Type[Any]) -> JsonSchemaParam:
if len(unique_json_types) > 1:
return {
"anyOf": [
{"type": t, "enum": [v for v in args if PRIMITIVE_TYPE_MAP[type(v)] == t]}
{
"type": t,
"enum": [v for v in args if PRIMITIVE_TYPE_MAP[type(v)] == t],
}
for t in unique_json_types
]
}
Expand Down Expand Up @@ -141,7 +147,7 @@ def _get_tool_schema_no_parsing(

if "self" in tool_signature.parameters.keys():
raise TypeError(
f"The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead"
"The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead"
)

# Determining the schema of input parameters
Expand Down Expand Up @@ -172,7 +178,7 @@ def _get_tool_schema_no_parsing(
raise TypeError(f"Return annotation is not specified for tool {tool_name}")
if _is_annotated_type(output_annotation):
raise TypeError(
f"Annotated types are not permitted when using the description mode `only_docstring`. "
"Annotated types are not permitted when using the description mode `only_docstring`. "
f"Return annotation of tool {tool_name} has type `{output_annotation}`"
)
output_schema = _get_partial_schema_from_annotation(output_annotation)
Expand All @@ -187,7 +193,7 @@ def _get_tool_schema_from_parsed_signature(

if "self" in tool_signature.parameters.keys():
raise TypeError(
f"The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead"
"The tool decorator cannot be used directly on a class method, use `tool(my_object.my_method)` instead"
)

# Determining the schema of input parameters
Expand All @@ -205,7 +211,9 @@ def _get_tool_schema_from_parsed_signature(
f"Description mode is `infer_from_signature` but parameter {param_name} of tool {tool_name} is not Annotated "
f"(has type {annotated_param.annotation}). Either annotate the parameter or use the `only_docstring` description mode."
)
param_annotation, param_description = _unpack_annotated_types(annotated_param.annotation)
param_annotation, param_description = _unpack_annotated_types(
annotated_param.annotation
)
param_schema = _get_partial_schema_from_annotation(param_annotation)

param_schema["description"] = param_description
Expand All @@ -220,7 +228,9 @@ def _get_tool_schema_from_parsed_signature(
raise TypeError(f"Return annotation is not specified for tool {tool_name}")
output_annotation: Any
if _is_annotated_type(annotated_output_type):
output_annotation, output_description = _unpack_annotated_types(annotated_output_type)
output_annotation, output_description = _unpack_annotated_types(
annotated_output_type
)
else:
output_annotation, output_description = annotated_output_type, ""

Expand All @@ -231,8 +241,55 @@ def _get_tool_schema_from_parsed_signature(
return args_schema, output_schema


@overload
def tool(
func_or_name: str,
func: Callable[..., Any],
/,
description_mode: Literal[
DescriptionMode.INFER_FROM_SIGNATURE,
DescriptionMode.ONLY_DOCSTRING,
DescriptionMode.EXTRACT_FROM_DOCSTRING,
] = DescriptionMode.INFER_FROM_SIGNATURE,
output_descriptors: Optional[List[Property]] = None,
requires_confirmation: bool = False,
) -> ServerTool: ...


@overload
def tool(
func_or_name: Callable[..., Any],
func: None = None,
/,
description_mode: Literal[
DescriptionMode.INFER_FROM_SIGNATURE,
DescriptionMode.ONLY_DOCSTRING,
DescriptionMode.EXTRACT_FROM_DOCSTRING,
] = DescriptionMode.INFER_FROM_SIGNATURE,
output_descriptors: Optional[List[Property]] = None,
requires_confirmation: bool = False,
) -> ServerTool: ...


@overload
def tool(
func_or_name: str,
func: None = None,
/,
description_mode: Literal[
DescriptionMode.INFER_FROM_SIGNATURE,
DescriptionMode.ONLY_DOCSTRING,
DescriptionMode.EXTRACT_FROM_DOCSTRING,
] = DescriptionMode.INFER_FROM_SIGNATURE,
output_descriptors: Optional[List[Property]] = None,
requires_confirmation: bool = False,
) -> Callable[[Callable[..., Any]], ServerTool]: ...


def tool(
*args: Union[str, Callable[..., Any]],
func_or_name: Callable[..., Any] | str,
func: Callable[..., Any] | None = None,
/,
description_mode: Literal[
DescriptionMode.INFER_FROM_SIGNATURE,
DescriptionMode.ONLY_DOCSTRING,
Expand Down Expand Up @@ -390,61 +447,58 @@ def _make_tool(
requires_confirmation=requires_confirmation,
)

# When used as a wrapper, `args` can be [tool_name, callable] or [callable]
# When used as a decorator, `args` can be [tool_name, callable] or [callable]
if len(args) == 2 and (isinstance(args[0], str) and callable(args[1])):
# When used as a wrapper, the function arguments can be [tool_name, callable] or [callable]
# When used as a decorator, the decorator arguments can be [tool_name, callable] or [callable]
if func is not None and isinstance(func_or_name, str) and callable(func):
# Example case: wrapper with custom tool name
# def my_callable():
# pass
# my_tool = tool("my_callable1", my_callable)
# here args[0] is the tool name, and args[1] the callable
# here func_or_name is the tool name, and func the callable
# we simply return the newly created ServerTool
tool_name = args[0]
tool_name = func_or_name
return _make_tool(
args[1], tool_name, description_mode, output_descriptors, requires_confirmation
func,
tool_name,
description_mode,
output_descriptors,
requires_confirmation,
)
elif len(args) == 1 and isinstance(args[0], str):
elif isinstance(func_or_name, str):
# Example case: decorator with custom tool name
# @tool("my_callable1")
# def my_callable():
# pass
# here args[0] is the tool name
# here func_or_name is the tool name
# Upon instantiation, first the `tool` function is called, directly followed
# by the `_partial_with_name` function being called, thus converting the
# callable to a ServerTool
tool_name = args[0]
tool_name = func_or_name

def _partial_with_name(func: Callable[..., Any]) -> ServerTool:
return _make_tool(
func, tool_name, description_mode, output_descriptors, requires_confirmation
func,
tool_name,
description_mode,
output_descriptors,
requires_confirmation,
)

return _partial_with_name
elif len(args) == 1 and callable(args[0]):
elif callable(func_or_name):
# Example case: wrapper
# def my_callable():
# pass
# my_tool = tool(my_callable)
# here args[0] is the callable
# here func_or_name is the callable
# we simply return the newly created ServerTool
return _make_tool(
args[0], None, description_mode, output_descriptors, requires_confirmation
func_or_name,
None,
description_mode,
output_descriptors,
requires_confirmation,
)
elif len(args) == 0:
# Example case: decorator with user-specified description_mode
# @tool(description_mode='only_docstring')
# def my_callable(param1: int = 2) -> int:
# """Callable description"""
# return 0
# Upon instantiation, first the `tool` function is called, directly followed
# by the `_partial_no_name` function being called, thus converting the
# callable to a ServerTool
def _partial_no_name(func: Callable[..., Any]) -> ServerTool:
return _make_tool(
func, None, description_mode, output_descriptors, requires_confirmation
)

return _partial_no_name
else:
raise ValueError("Invalid usage of the `tool` helper")

Expand Down Expand Up @@ -477,15 +531,19 @@ def _to_react_template_dict(tool: Tool) -> Dict[str, str]:
f"- {parameter_name}: {_find_json_schema_full_type(parameter_info)}"
)
if "default" in parameter_info:
param_description += f" (Optional, default={parameter_info['default']})"
param_description += (
f" (Optional, default={parameter_info['default']})"
)
else:
param_description += " (Required)"
if "description" in parameter_info:
param_description += f" {parameter_info['description']}"
parameter_descriptions.append(param_description)

formatted_parameter_descriptions = "\n".join(parameter_descriptions)
description = tool_as_str + f"Parameters:\n{formatted_parameter_descriptions}"
description = (
tool_as_str + f"Parameters:\n{formatted_parameter_descriptions}"
)
return {
"name": tool.name,
"description": description,
Expand Down
2 changes: 1 addition & 1 deletion wayflowcore/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@

compartment_id = os.environ.get("COMPARTMENT_ID")
if not compartment_id:
raise Exception("compartment_id is not set in the environment")
raise Exception("COMPARTMENT_ID is not set in the environment")

oracle_http_proxy = os.environ.get("ORACLE_HTTP_PROXY")

Expand Down
53 changes: 53 additions & 0 deletions wayflowcore/tests/tools/test_tool_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Callable
from typing_extensions import assert_type
from wayflowcore.tools.servertools import ServerTool
from wayflowcore.tools.toolhelpers import DescriptionMode, tool


def test_tool_decorator_result_type_is_correct() -> None:
# Wrapper with different name
tool_one = tool("tool_one")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, I am not sure to understand this use case.
What is supposed to happen in this case?
The tool is not invocable really is it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm preserving whatever semantics already existed, the general use-case is to use it to rename a tool. I added more comments to the tests to clarify things.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

assert_type(tool_one, Callable[[Callable[..., Any]], ServerTool])
assert isinstance(tool_one, Callable)

def actual_func() -> None:
"""Actual func"""

func_tool = tool_one(actual_func)
assert_type(func_tool, ServerTool)
assert isinstance(func_tool, ServerTool)
assert func_tool.name == "tool_one"

# Decorator with different tool name

@tool("real_function_name")
def another_func() -> None:
"""Another func"""

assert_type(another_func, ServerTool)
assert isinstance(another_func, ServerTool)
assert another_func.name == "real_function_name"

# Wrapper with name and function passed as arguments
def func_two() -> None:
"""Just a func"""

tool_two = tool("tool_two", func_two)
assert_type(tool_two, ServerTool)
assert isinstance(tool_two, ServerTool)

# Decorator with description mode
@tool("tool_three", description_mode=DescriptionMode.ONLY_DOCSTRING)
def tool_three() -> None:
"""tool_three function"""

assert_type(tool_three, ServerTool)
assert isinstance(tool_three, ServerTool)

# Decorator with no arguments passed
@tool
def tool_four() -> None:
"""tool_four function"""

assert_type(tool_four, ServerTool)
assert isinstance(tool_four, ServerTool)