Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 111 additions & 3 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import base64
import copy
import hashlib
import inspect

import json
import random
import re
Expand Down Expand Up @@ -562,6 +564,96 @@ def _is_empty(content: Any) -> bool:

payloads["messages"] = cleaned

@staticmethod
def _shorten_tool_call_id(raw_id: str | None) -> str | None:
"""Deterministically shorten an oversized tool_call ID.

Non-cryptographic by design; MUST NOT be used for any
security-sensitive purpose. Its only job is to normalize IDs
that exceed the 64-character limit enforced by the OpenAI API
spec into a stable, compact form so the same ID collapses to
the same short form across retries of the same request.

Short IDs (or empty/None) are returned unchanged.
"""
if not raw_id or len(raw_id) <= 64:
return raw_id
# MD5 is used purely for deterministic compact hashing, not security.
return "call_" + hashlib.md5(raw_id.encode("utf-8")).hexdigest()

@staticmethod
def _normalize_tool_call_ids(payloads: dict) -> None:
"""Normalize oversized tool_call IDs in outgoing payloads.

Some OpenAI-compatible relay services return tool_call IDs that
far exceed the 64-character limit enforced by the OpenAI API spec
(observed lengths of 660 / 1650+ chars in the wild). Round-tripping
those IDs into the next request's ``messages[].tool_calls[].id`` or
``tool_call_id`` fields triggers HTTP 400 ``string_above_max_length``
from the upstream. Some relays internally translate Chat Completions
payloads into the Responses API format, which renames
``tool_call_id`` to ``call_id`` — but the root cause is the same.

A shared map keeps assistant ``tool_calls[].id`` and its matching
tool ``tool_call_id`` in sync after normalization. The conversation
history is mutated in place.
"""
messages = payloads.get("messages")
if not isinstance(messages, list):
return

id_map: dict[str, str] = {}

def _register(tid: str | None) -> None:
if not tid or tid in id_map or len(tid) <= 64:
return
shortened = ProviderOpenAIOfficial._shorten_tool_call_id(tid)
if shortened is not None and shortened != tid:
id_map[tid] = shortened

# First pass: collect every oversized ID.
for msg in messages:
if not isinstance(msg, dict):
continue
role = msg.get("role")

if role == "assistant":
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if isinstance(tc, dict):
_register(tc.get("id"))
elif role == "tool":
_register(msg.get("tool_call_id"))

if not id_map:
return

logger.warning(
"Normalized %d oversized tool_call ID(s) before sending request.",
len(id_map),
)

# Second pass: apply the rewrite map.
for msg in messages:
if not isinstance(msg, dict):
continue
role = msg.get("role")

if role == "assistant":
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if not isinstance(tc, dict):
continue
tid = tc.get("id")
if tid in id_map:
tc["id"] = id_map[tid]
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and tid in id_map:
msg["tool_call_id"] = id_map[tid]

async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
Expand Down Expand Up @@ -592,6 +684,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
model = payloads.get("model", "").lower()

self._sanitize_assistant_messages(payloads)
self._normalize_tool_call_ids(payloads)

completion = await self.client.chat.completions.create(
**payloads,
Expand Down Expand Up @@ -644,6 +737,7 @@ async def _query_stream(
self._apply_provider_specific_extra_body_overrides(extra_body)

self._sanitize_assistant_messages(payloads)
self._normalize_tool_call_ids(payloads)

stream = await self.client.chat.completions.create(
**payloads,
Expand Down Expand Up @@ -903,13 +997,27 @@ async def _parse_openai_completion(
args = tool_call.function.arguments
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)

raw_id = tool_call.id
safe_id = self._shorten_tool_call_id(raw_id)
if raw_id and safe_id != raw_id:
# Log only the length and the normalized short ID —
# the raw ID is opaque and may be provider-specific,
# so we avoid leaking its prefix into logs.
logger.warning(
"tool_call.id exceeded 64 chars (length=%d); "
"normalized to %s",
len(raw_id),
safe_id,
)

tool_call_ids.append(safe_id)

# gemini-2.5 / gemini-3 series extra_content handling
extra_content = getattr(tool_call, "extra_content", None)
if extra_content is not None:
tool_call_extra_content_dict[tool_call.id] = extra_content

tool_call_extra_content_dict[safe_id] = extra_content
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
Expand Down