Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option.
import json
import logging
import re
from typing import Any, AsyncIterable, Callable, Dict, List, Optional, TypedDict

from wayflowcore._utils.formatting import format_tool_output_for_llm
Expand All @@ -27,6 +28,10 @@

logger = logging.getLogger(__name__)

_LM_STUDIO_TEXT_TOOL_CALL_OPEN_MARKER = "<|tool_call>"
_LM_STUDIO_TEXT_TOOL_CALL_CLOSE_MARKER = "<tool_call|>"
_LM_STUDIO_TEXT_TOOL_CALL_NAME_RE = re.compile(r"[A-Za-z0-9_.-]+")


class OpenAIToolRequestAsDictT(TypedDict, total=True):
tool_request_id: str
Expand All @@ -36,6 +41,71 @@ class OpenAIToolRequestAsDictT(TypedDict, total=True):


class _ChatCompletionsAPIProcessor(_APIProcessor):
@staticmethod
def _parse_text_encoded_tool_calls_fallback(
content: Any, extra_content: Optional[ExtraContentT] = None
) -> Optional[List[ToolRequest]]:
try:
return _ChatCompletionsAPIProcessor._parse_lm_studio_text_encoded_tool_calls_fallback(
content, extra_content=extra_content
)
except (AttributeError, TypeError, ValueError, json.JSONDecodeError):
logger.debug("Failed to parse text-encoded tool call fallback.", exc_info=True)
return None

@staticmethod
def _parse_lm_studio_text_encoded_tool_calls_fallback(
content: Any, extra_content: Optional[ExtraContentT] = None
) -> Optional[List[ToolRequest]]:
if not isinstance(content, str) or _LM_STUDIO_TEXT_TOOL_CALL_OPEN_MARKER not in content:
return None

remaining_content = content.strip()
parsed_tool_requests: List[ToolRequest] = []
json_decoder = json.JSONDecoder()
open_marker_len = len(_LM_STUDIO_TEXT_TOOL_CALL_OPEN_MARKER)
close_marker_len = len(_LM_STUDIO_TEXT_TOOL_CALL_CLOSE_MARKER)

while remaining_content:
# Each LM Studio text tool call starts with its sentinel marker.
if not remaining_content.startswith(_LM_STUDIO_TEXT_TOOL_CALL_OPEN_MARKER):
return None

remaining_content = remaining_content[open_marker_len:].lstrip()
# LM Studio emits the tool name after a literal "call:" prefix.
if not remaining_content.startswith("call:"):
return None

remaining_content = remaining_content[len("call:") :].lstrip()
name_match = _LM_STUDIO_TEXT_TOOL_CALL_NAME_RE.match(remaining_content)
if name_match is None:
return None

tool_name = name_match.group(0)
remaining_content = remaining_content[name_match.end() :].lstrip()
# Arguments must be a JSON object immediately following the tool name.
if not remaining_content.startswith("{"):
return None

try:
# raw_decode handles nested objects and braces inside JSON strings.
tool_args, consumed_chars = json_decoder.raw_decode(remaining_content)
except json.JSONDecodeError:
return None

if not isinstance(tool_args, dict):
return None

parsed_tool_requests.append(
ToolRequest(name=tool_name, args=tool_args, _extra_content=extra_content)
)
remaining_content = remaining_content[consumed_chars:].lstrip()

# Some templates add a closing marker; others stop after the JSON object.
if remaining_content.startswith(_LM_STUDIO_TEXT_TOOL_CALL_CLOSE_MARKER):
remaining_content = remaining_content[close_marker_len:].lstrip()

return parsed_tool_requests or None

@staticmethod
def _convert_openai_logprobs_into_text_logprobs(logprobs: Any) -> List[TextTokenLogProb]:
Expand Down Expand Up @@ -218,6 +288,17 @@ def _convert_openai_response_into_message(self, response: Any) -> "Message":
# content might be empty when certain models (like gemini) decide
# to finish the conversation
content = extracted_message.get("content", "")
# Compatibility fix: some OpenAI-compatible providers encode tool calls
# as assistant text content instead of structured message.tool_calls.
text_tool_calls = self._parse_text_encoded_tool_calls_fallback(
content, extra_content=extracted_message.get("extra_content")
)
if text_tool_calls:
return Message(
tool_requests=text_tool_calls,
role="assistant",
_extra_content=extracted_message.get("extra_content"),
)

logprobs = None
choice_logprobs = response["choices"][0].get("logprobs")
Expand Down Expand Up @@ -315,14 +396,18 @@ async def _tagged_chunk_iterator_from_stream_of_openai_compatible_json(
content=text_delta, message_type=MessageType.AGENT
), None

tool_calls: Optional[List[ToolRequest]]
if len(tool_deltas) > 0:
message_type = MessageType.TOOL_REQUEST
tool_calls = self._convert_tool_deltas_into_tool_requests(tool_deltas)
else:
message_type = MessageType.AGENT
tool_calls = None
tool_calls = self._parse_text_encoded_tool_calls_fallback(text)
message_type = MessageType.TOOL_REQUEST if tool_calls else MessageType.AGENT

message = Message(content=text, message_type=message_type, tool_requests=tool_calls)
message_content = "" if tool_calls else text
message = Message(
content=message_content, message_type=message_type, tool_requests=tool_calls
)
if post_processing is not None:
message = post_processing(message)
yield StreamChunkType.END_CHUNK, message, token_usage
Expand Down
131 changes: 131 additions & 0 deletions wayflowcore/tests/models/test_openaicompatiblemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,79 @@ def test_chat_completions_processor_formats_tool_result_as_tool_data():
]


def test_chat_completions_processor_converts_text_encoded_tool_call_content():
processor = _ChatCompletionsAPIProcessor(
model_id="test-model",
base_url="http://example.test",
api_type=OpenAIAPIType.CHAT_COMPLETIONS,
)
response = {
"choices": [
{
"message": {
"role": "assistant",
"content": (
'<|tool_call>call:send_message{"message":"please {inspect} '
'the bug","recipient":"worker-1"}<tool_call|>'
),
}
}
]
}

message = processor._convert_openai_response_into_message(response)

assert message.message_type == MessageType.TOOL_REQUEST
assert message.tool_requests is not None
assert len(message.tool_requests) == 1
assert message.tool_requests[0].name == "send_message"
assert message.tool_requests[0].args == {
"message": "please {inspect} the bug",
"recipient": "worker-1",
}


def test_chat_completions_processor_converts_text_encoded_tool_call_without_closer():
processor = _ChatCompletionsAPIProcessor(
model_id="test-model",
base_url="http://example.test",
api_type=OpenAIAPIType.CHAT_COMPLETIONS,
)
response = {
"choices": [
{
"message": {
"role": "assistant",
"content": '<|tool_call>call:send_message{"message":"hi","recipient":"worker-1"}',
}
}
]
}

message = processor._convert_openai_response_into_message(response)

assert message.message_type == MessageType.TOOL_REQUEST
assert message.tool_requests is not None
assert message.tool_requests[0].name == "send_message"
assert message.tool_requests[0].args == {"message": "hi", "recipient": "worker-1"}


def test_chat_completions_processor_ignores_malformed_text_encoded_tool_call_content():
processor = _ChatCompletionsAPIProcessor(
model_id="test-model",
base_url="http://example.test",
api_type=OpenAIAPIType.CHAT_COMPLETIONS,
)
raw_content = '<|tool_call>call:send_message{"message":"hi","recipient":'
response = {"choices": [{"message": {"role": "assistant", "content": raw_content}}]}

message = processor._convert_openai_response_into_message(response)

assert message.message_type == MessageType.AGENT
assert message.content == raw_content
assert message.tool_requests is None


def test_responses_processor_formats_tool_result_as_tool_data():
processor = _ResponsesAPIProcessor(
model_id="test-model",
Expand Down Expand Up @@ -435,6 +508,64 @@ async def test_chat_completions_streaming_preserves_terminal_usage():
assert chunks[-1][2].total_tokens == 12


@pytest.mark.anyio
async def test_chat_completions_streaming_converts_text_encoded_tool_call_content():
processor = _ChatCompletionsAPIProcessor(
model_id="test-model",
base_url="http://example.test",
api_type=OpenAIAPIType.CHAT_COMPLETIONS,
)

chunks = []
async for (
tagged_chunk
) in processor._tagged_chunk_iterator_from_stream_of_openai_compatible_json(
_yield_json_objects(
{"choices": [{"delta": {"content": "<|tool_call>"}}]},
{"choices": [{"delta": {"content": 'call:send_message{"message":"please '}}]},
{"choices": [{"delta": {"content": '{inspect} the bug","recipient":"worker-1"}'}}]},
{"choices": [{"delta": {"content": "<tool_call|>"}}]},
)
):
chunks.append(tagged_chunk)

final_message = chunks[-1][1]
assert final_message is not None
assert final_message.message_type == MessageType.TOOL_REQUEST
assert final_message.content == ""
assert final_message.tool_requests is not None
assert len(final_message.tool_requests) == 1
assert final_message.tool_requests[0].name == "send_message"
assert final_message.tool_requests[0].args == {
"message": "please {inspect} the bug",
"recipient": "worker-1",
}


@pytest.mark.anyio
async def test_chat_completions_streaming_ignores_malformed_text_encoded_tool_call_content():
processor = _ChatCompletionsAPIProcessor(
model_id="test-model",
base_url="http://example.test",
api_type=OpenAIAPIType.CHAT_COMPLETIONS,
)
raw_content = '<|tool_call>call:send_message{"message":"hi","recipient":'

chunks = []
async for (
tagged_chunk
) in processor._tagged_chunk_iterator_from_stream_of_openai_compatible_json(
_yield_json_objects({"choices": [{"delta": {"content": raw_content}}]})
):
chunks.append(tagged_chunk)

final_message = chunks[-1][1]
assert final_message is not None
assert final_message.message_type == MessageType.AGENT
assert final_message.content == raw_content
assert final_message.tool_requests is None


@pytest.mark.anyio
async def test_responses_streaming_preserves_terminal_usage():
processor = _ResponsesAPIProcessor(
Expand Down