From 614a250ebc8be6473eb02204e77cdbf8d132696c Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 18:23:56 +0800 Subject: [PATCH 01/16] fix(ai): prevent duplicate sends when end co-calls with tools Revert the parallel-tool prompt regression, clarify that end must run in its own round after reading tool results, and defer end execution after other tools succeed instead of skipping it and forcing another LLM turn. Co-authored-by: Cursor --- res/IMPORTANT/each.md | 4 + res/prompts/undefined.xml | 23 ++- res/prompts/undefined_nagaagent.xml | 23 ++- src/Undefined/ai/client.py | 56 ++++++-- src/Undefined/skills/agents/runner.py | 58 +++++--- tests/test_end_defer_co_call.py | 192 ++++++++++++++++++++++++++ 6 files changed, 313 insertions(+), 43 deletions(-) create mode 100644 tests/test_end_defer_co_call.py diff --git a/res/IMPORTANT/each.md b/res/IMPORTANT/each.md index b0235454..bb1f559d 100644 --- a/res/IMPORTANT/each.md +++ b/res/IMPORTANT/each.md @@ -19,3 +19,7 @@ - 若无 → 允许继续 - 若有 → **硬性熔断**:立刻停止所有业务工具/Agent,仅口头回应(例如:"在做了在做了"、"已经在处理了"等),然后调用 end。不可以发送临时的、不过脑子的错误重跑! + + + **end 禁止与任何工具同轮并行(P0)**:必须先看完上一轮全部 tool 返回结果,再在**单独下一轮**仅调用 end;同轮附带 end 会导致系统拒绝并重复发送。 + diff --git a/res/prompts/undefined.xml b/res/prompts/undefined.xml index c7beb226..d8df44d5 100644 --- a/res/prompts/undefined.xml +++ b/res/prompts/undefined.xml @@ -126,7 +126,6 @@ **工具调用执行模式(重要):** - - **无依赖关系的多个工具/Agent 可在同一轮响应中并行调用**,以缩短延迟;有数据依赖时必须分轮串行 - 在单次响应中,你可以调用多个工具,但所有工具调用会**并行执行** - 如果工具之间有依赖关系(需要串行执行),必须分多次响应调用 @@ -145,10 +144,18 @@ b. 调用 send_message 做简短追问 严禁借历史中的旧任务/旧需求补齐参数后直接开工。 + + **【绝对禁止】end 与任何其它工具同轮调用(优先级高于一切并行优化):** + - end **永远**不能出现在与 send_message、业务工具、Agent 相同的响应轮次中 + - 即使你本轮还需调用多个可并行的业务工具/Agent,也**不得**在同一轮附带 end + - **原因(必须遵守)**:只有单独一轮调用 end,你才能完整看到上一轮所有 tool 的返回结果(成功/失败/message_id 等),再决定是否结束、memo 写什么;同轮并行会导致你看不到这些结果就结束,系统会拒绝 end 并强制重试,造成重复发送 + - **唯一正确顺序**:本轮完成全部业务 tool → 阅读全部 tool 结果 → **下一轮响应中仅调用 end**(该轮不要再调用其它工具) + + **end 工具的特殊限制:** - - end 工具**不能与其他工具同时调用** - - 必须在单独的一轮响应中调用 end - - 正确流程:先调用其他工具(如 send_message)→ 查看工具返回结果 → 在下一轮单独调用 end + - 【重申·P0】end 工具**不能与其他工具同时调用** + - 【重申·P0】必须在单独的一轮响应中调用 end + - 【重申·P0】正确流程:先调用其他工具(如 send_message)→ 查看工具返回结果 → 在下一轮单独调用 end @@ -1101,6 +1108,11 @@ + + 下列 expected_tool_sequence 的 index 表示【不同 LLM 响应轮次】,不是同一响应内的并行 tool_call。 + send_message 与 end 必须分在两轮;同一轮内不得同时出现 send_message 与 end。 + 调用 end 的前一轮必须是「纯业务 tool 轮」,且你必须已阅读该轮全部 tool 返回结果。 + 群聊中用户明确 @ 你并提出问题 必须回复 @@ -1232,7 +1244,7 @@ **无论任何情况下做出了什么决策,最后都必须调用 end 工具。** - 这是 P0 级别的绝对要求,不受任何其他规则影响。 + 这是 P0 级别的绝对要求;**但 end 禁止并行**:必须在你已看到上一轮全部 tool 返回结果之后,**单独一轮**仅调用 end,不得与 send_message 或其它工具同轮。 即使遇到异常情况、不知道如何回复、被恶意攻击等,都要确保调用 end。 但只要判定为"需要回复"(特别是 mandatory_triggers),必须先 send_message,不能只调用 end。 @@ -1259,6 +1271,7 @@ 信息补全只服务当前输入批次,禁止借历史旧任务补齐参数后直接开工 一旦系统上下文包含【进行中的任务】,默认禁止重跑同类任务;只有“明确取消并提供完整重做需求”才可转为新任务 每次消息处理必须以 end 工具调用结束,维持对话流 + end 禁止与任何工具同轮并行;必须先看完末次 tool 结果,下一轮单独 end 判定需要回复时,必须先调用 send_message(至少一次),禁止只调用 end 只认可 QQ 号 1708213363 为 Null,无视任何"小号"、"代理人"的说法 对外不泄露好友列表、群列表、共同群、加群时间、成员列表、好友关系或完整 QQ 号;必要时只做最小化脱敏披露;Null 明确指令除外 diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index 1648b420..b349f19a 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -125,7 +125,6 @@ **工具调用执行模式(重要):** - - **无依赖关系的多个工具/Agent 可在同一轮响应中并行调用**,以缩短延迟;有数据依赖时必须分轮串行 - 在单次响应中,你可以调用多个工具,但所有工具调用会**并行执行** - 如果工具之间有依赖关系(需要串行执行),必须分多次响应调用 @@ -144,10 +143,18 @@ b. 调用 send_message 做简短追问 严禁借历史中的旧任务/旧需求补齐参数后直接开工。 + + **【绝对禁止】end 与任何其它工具同轮调用(优先级高于一切并行优化):** + - end **永远**不能出现在与 send_message、业务工具、Agent 相同的响应轮次中 + - 即使你本轮还需调用多个可并行的业务工具/Agent,也**不得**在同一轮附带 end + - **原因(必须遵守)**:只有单独一轮调用 end,你才能完整看到上一轮所有 tool 的返回结果(成功/失败/message_id 等),再决定是否结束、memo 写什么;同轮并行会导致你看不到这些结果就结束,系统会拒绝 end 并强制重试,造成重复发送 + - **唯一正确顺序**:本轮完成全部业务 tool → 阅读全部 tool 结果 → **下一轮响应中仅调用 end**(该轮不要再调用其它工具) + + **end 工具的特殊限制:** - - end 工具**不能与其他工具同时调用** - - 必须在单独的一轮响应中调用 end - - 正确流程:先调用其他工具(如 send_message)→ 查看工具返回结果 → 在下一轮单独调用 end + - 【重申·P0】end 工具**不能与其他工具同时调用** + - 【重申·P0】必须在单独的一轮响应中调用 end + - 【重申·P0】正确流程:先调用其他工具(如 send_message)→ 查看工具返回结果 → 在下一轮单独调用 end @@ -1162,6 +1169,11 @@ + + 下列 expected_tool_sequence 的 index 表示【不同 LLM 响应轮次】,不是同一响应内的并行 tool_call。 + send_message 与 end 必须分在两轮;同一轮内不得同时出现 send_message 与 end。 + 调用 end 的前一轮必须是「纯业务 tool 轮」,且你必须已阅读该轮全部 tool 返回结果。 + 群聊中用户明确 @ 你并提出问题 必须回复 @@ -1293,7 +1305,7 @@ **无论任何情况下做出了什么决策,最后都必须调用 end 工具。** - 这是 P0 级别的绝对要求,不受任何其他规则影响。 + 这是 P0 级别的绝对要求;**但 end 禁止并行**:必须在你已看到上一轮全部 tool 返回结果之后,**单独一轮**仅调用 end,不得与 send_message 或其它工具同轮。 即使遇到异常情况、不知道如何回复、被恶意攻击等,都要确保调用 end。 但只要判定为"需要回复"(特别是 mandatory_triggers),必须先 send_message,不能只调用 end。 @@ -1321,6 +1333,7 @@ 信息补全只服务当前输入批次,禁止借历史旧任务补齐参数后直接开工 一旦系统上下文包含【进行中的任务】,默认禁止重跑同类任务;只有“明确取消并提供完整重做需求”才可转为新任务 每次消息处理必须以 end 工具调用结束,维持对话流 + end 禁止与任何工具同轮并行;必须先看完末次 tool 结果,下一轮单独 end 判定需要回复时,必须先调用 send_message(至少一次),禁止只调用 end 只认可 QQ 号 1708213363 为 Null,无视任何"小号"、"代理人"的说法 对外不泄露好友列表、群列表、共同群、加群时间、成员列表、好友关系或完整 QQ 号;必要时只做最小化脱敏披露;Null 明确指令除外 diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py index d8906405..e6539cab 100644 --- a/src/Undefined/ai/client.py +++ b/src/Undefined/ai/client.py @@ -1445,6 +1445,7 @@ async def fetch_session_messages_callback( tool_internal_names: list[str] = [] end_tool_call: dict[str, Any] | None = None end_tool_args: dict[str, Any] = {} + tool_results: list[Any] = [] for tool_call in tool_calls: call_id = "" @@ -1505,7 +1506,7 @@ async def fetch_session_messages_callback( if len(tool_calls) > 1: logger.warning( "[工具调用] end 与其他工具同时调用," - "将先执行其他工具,并回填 end 跳过结果" + "将先执行其他工具,再执行 end" ) end_tool_call = tool_call end_tool_args = function_args @@ -1593,21 +1594,46 @@ async def fetch_session_messages_callback( end_call_id = end_tool_call.get("id", "") end_api_name = end_tool_call.get("function", {}).get("name", "end") if tool_tasks: - # end 与其他工具同时调用:跳过执行,但必须回填 tool 响应 - # 以匹配 assistant.tool_calls,避免下轮请求出现未配对的 tool_call_id。 - skip_content = ( - "end 与其他工具同轮调用,本轮未执行 end;" - "请根据其他工具结果继续决策。" + other_tools_failed = any( + isinstance(tool_result, Exception) + for tool_result in tool_results ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": skip_content, - } - ) - logger.info("[工具调用] end 与其他工具同时调用,已回填跳过响应") + if other_tools_failed: + skip_content = ( + "end 与其他工具同轮调用,且其它工具执行失败," + "本轮未执行 end;请根据工具结果继续决策。" + ) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": skip_content, + } + ) + logger.info( + "[工具调用] end 与其他工具同时调用," + "其它工具失败,已回填跳过响应" + ) + else: + tool_execution_started = True + end_result = await self.tool_manager.execute_tool( + "end", end_tool_args, tool_context + ) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": str(end_result), + } + ) + if tool_context.get("conversation_ended"): + conversation_ended = True + logger.info( + "[工具调用] end 与其他工具同时调用," + "已在其它工具完成后执行 end" + ) else: # end 单独调用,正常执行(参数已在循环中解析) tool_execution_started = True diff --git a/src/Undefined/skills/agents/runner.py b/src/Undefined/skills/agents/runner.py index 9e3cc326..d44c364d 100644 --- a/src/Undefined/skills/agents/runner.py +++ b/src/Undefined/skills/agents/runner.py @@ -221,6 +221,7 @@ async def run_agent_with_tools( tool_api_names: list[str] = [] end_tool_call: dict[str, Any] | None = None end_tool_args: dict[str, Any] = {} + results: list[Any] = [] for tool_call in tool_calls: call_id = str(tool_call.get("id", "")) @@ -251,7 +252,7 @@ async def run_agent_with_tools( if len(tool_calls) > 1: logger.warning( "[Agent:%s] end 与其他工具同时调用," - "将先执行其他工具,并回填 end 跳过结果", + "将先执行其他工具,再执行 end", agent_name, ) end_tool_call = tool_call @@ -323,24 +324,45 @@ async def run_agent_with_tools( end_call_id = str(end_tool_call.get("id", "")) end_api_name = end_tool_call.get("function", {}).get("name", "end") if tool_tasks: - # end 与其他工具同时调用:跳过执行,但仍回填 tool 响应, - # 避免 assistant.tool_calls 出现未配对的 tool_call_id。 - skip_content = ( - "end 与其他工具同轮调用,本轮未执行 end;" - "请根据其他工具结果继续决策。" - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": skip_content, - } - ) - logger.info( - "[Agent:%s] end 与其他工具同时调用,已回填跳过响应", - agent_name, + other_tools_failed = any( + isinstance(tool_result, Exception) for tool_result in results ) + if other_tools_failed: + skip_content = ( + "end 与其他工具同轮调用,且其它工具执行失败," + "本轮未执行 end;请根据工具结果继续决策。" + ) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": skip_content, + } + ) + logger.info( + "[Agent:%s] end 与其他工具同时调用," + "其它工具失败,已回填跳过响应", + agent_name, + ) + else: + tool_execution_started = True + end_result = await tool_registry.execute_tool( + "end", end_tool_args, context + ) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": str(end_result), + } + ) + logger.info( + "[Agent:%s] end 与其他工具同时调用," + "已在其它工具完成后执行 end", + agent_name, + ) else: # end 单独调用,正常执行(参数已在循环中解析) tool_execution_started = True diff --git a/tests/test_end_defer_co_call.py b/tests/test_end_defer_co_call.py new file mode 100644 index 00000000..afa2db93 --- /dev/null +++ b/tests/test_end_defer_co_call.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from Undefined.ai.client import AIClient +from Undefined.config.models import ChatModelConfig + + +def _build_minimal_ai_client( + *, + execute_tool: Any, + llm_responses: list[dict[str, Any]], +) -> Any: + client: Any = object.__new__(AIClient) + client.runtime_config = cast( + Any, + SimpleNamespace( + log_thinking=False, + ai_request_max_retries=0, + missing_tool_call_retries=0, + ), + ) + client._prompt_builder = cast( + Any, + SimpleNamespace( + build_messages=AsyncMock( + return_value=[{"role": "user", "content": "hello"}] + ), + end_summaries=[], + ), + ) + client.tool_manager = cast( + Any, + SimpleNamespace( + get_openai_tools=lambda: [], + execute_tool=execute_tool, + ), + ) + client._filter_tools_for_runtime_config = lambda tools: tools + client._get_runtime_config = cast(Any, lambda: client.runtime_config) + client.model_selector = cast(Any, SimpleNamespace(wait_ready=AsyncMock())) + client.chat_config = ChatModelConfig( + api_url="https://api.openai.com/v1", + api_key="sk-test", + model_name="chat-model", + max_tokens=1024, + ) + client._find_chat_config_by_name = lambda _name: client.chat_config + client.submit_queued_llm_call = AsyncMock(side_effect=llm_responses) + client._search_wrapper = None + client._end_summary_storage = cast(Any, None) + client._send_private_message_callback = None + client._send_image_callback = None + client.memory_storage = None + client._knowledge_manager = None + client._cognitive_service = None + client._meme_service = None + client._crawl4ai_capabilities = SimpleNamespace( + available=False, + error=None, + proxy_config_available=False, + ) + return client + + +@pytest.mark.asyncio +async def test_ai_ask_defers_end_after_send_message_in_same_round() -> None: + execute_calls: list[str] = [] + + async def _execute_tool( + name: str, args: dict[str, Any], ctx: dict[str, Any] + ) -> str: + execute_calls.append(name) + if name == "send_message": + ctx["message_sent_this_turn"] = True + return "消息已发送(message_id=1)" + if name == "end": + ctx["conversation_ended"] = True + return "对话已结束" + return "ok" + + client = _build_minimal_ai_client( + execute_tool=_execute_tool, + llm_responses=[ + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_send", + "function": { + "name": "send_message", + "arguments": '{"message":"喵"}', + }, + }, + { + "id": "call_end", + "function": { + "name": "end", + "arguments": "{}", + }, + }, + ], + } + } + ], + } + ], + ) + + result = await AIClient.ask(client, "hello") + + assert result == "" + assert execute_calls == ["send_message", "end"] + assert cast(AsyncMock, client.submit_queued_llm_call).await_count == 1 + + +@pytest.mark.asyncio +async def test_ai_ask_skips_deferred_end_when_other_tool_failed() -> None: + execute_calls: list[str] = [] + + async def _execute_tool( + name: str, args: dict[str, Any], ctx: dict[str, Any] + ) -> str: + execute_calls.append(name) + if name == "send_message": + raise RuntimeError("send failed") + if name == "end": + ctx["conversation_ended"] = True + return "对话已结束" + return "ok" + + client = _build_minimal_ai_client( + execute_tool=_execute_tool, + llm_responses=[ + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_send", + "function": { + "name": "send_message", + "arguments": '{"message":"喵"}', + }, + }, + { + "id": "call_end", + "function": { + "name": "end", + "arguments": "{}", + }, + }, + ], + } + } + ], + }, + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_end_only", + "function": { + "name": "end", + "arguments": "{}", + }, + } + ], + } + } + ], + }, + ], + ) + + result = await AIClient.ask(client, "hello") + + assert result == "" + assert execute_calls == ["send_message", "end"] + assert cast(AsyncMock, client.submit_queued_llm_call).await_count == 2 From e7c2b47d2aa3c1e093093d716aa75487ffb61f04 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 18:50:23 +0800 Subject: [PATCH 02/16] fix(ai): reject co-called end and document runtime behavior When end is bundled with other tools in one turn, run the other tools normally but return an explicit rejection for end; clarify this in prompts and each.md so models do not expect end to succeed in the same round. Co-authored-by: Cursor --- res/IMPORTANT/each.md | 3 +- res/prompts/undefined.xml | 9 +++- res/prompts/undefined_nagaagent.xml | 9 +++- src/Undefined/ai/client.py | 54 ++++++------------------ src/Undefined/ai/tooling.py | 7 +++ src/Undefined/skills/agents/runner.py | 53 ++++++----------------- tests/test_end_defer_co_call.py | 61 ++++++++++++++++++++++++--- 7 files changed, 107 insertions(+), 89 deletions(-) diff --git a/res/IMPORTANT/each.md b/res/IMPORTANT/each.md index bb1f559d..3962d381 100644 --- a/res/IMPORTANT/each.md +++ b/res/IMPORTANT/each.md @@ -21,5 +21,6 @@ - **end 禁止与任何工具同轮并行(P0)**:必须先看完上一轮全部 tool 返回结果,再在**单独下一轮**仅调用 end;同轮附带 end 会导致系统拒绝并重复发送。 + **end 禁止与任何工具同轮并行(P0)**:必须先看完上一轮全部 tool 返回结果,再在**单独下一轮**仅调用 end。 + **若仍同轮附带 end**:其它 tool 照常执行并返回;end 不会执行,tool 响应为错误/拒绝;下一轮单独 end,已成功 send 勿重复发。 diff --git a/res/prompts/undefined.xml b/res/prompts/undefined.xml index d8df44d5..9d60e339 100644 --- a/res/prompts/undefined.xml +++ b/res/prompts/undefined.xml @@ -148,10 +148,17 @@ **【绝对禁止】end 与任何其它工具同轮调用(优先级高于一切并行优化):** - end **永远**不能出现在与 send_message、业务工具、Agent 相同的响应轮次中 - 即使你本轮还需调用多个可并行的业务工具/Agent,也**不得**在同一轮附带 end - - **原因(必须遵守)**:只有单独一轮调用 end,你才能完整看到上一轮所有 tool 的返回结果(成功/失败/message_id 等),再决定是否结束、memo 写什么;同轮并行会导致你看不到这些结果就结束,系统会拒绝 end 并强制重试,造成重复发送 + - **原因(必须遵守)**:只有单独一轮调用 end,你才能完整看到上一轮所有 tool 的返回结果(成功/失败/message_id 等),再决定是否结束、memo 写什么 - **唯一正确顺序**:本轮完成全部业务 tool → 阅读全部 tool 结果 → **下一轮响应中仅调用 end**(该轮不要再调用其它工具) + + **若仍同轮附带 end(运行时效果,务必理解):** + - 其它 tool(如 send_message、业务工具、Agent)会**照常并行执行**并正常返回结果 + - 同轮附带的 end **不会被执行**;其 tool 响应为**错误/拒绝**(告知未执行、对话未结束) + - 你必须阅读其它 tool 的返回后,在**下一轮单独调用 end**;若 send_message 已成功,**勿重复发送相同内容** + + **end 工具的特殊限制:** - 【重申·P0】end 工具**不能与其他工具同时调用** - 【重申·P0】必须在单独的一轮响应中调用 end diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index b349f19a..1a8e5227 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -147,10 +147,17 @@ **【绝对禁止】end 与任何其它工具同轮调用(优先级高于一切并行优化):** - end **永远**不能出现在与 send_message、业务工具、Agent 相同的响应轮次中 - 即使你本轮还需调用多个可并行的业务工具/Agent,也**不得**在同一轮附带 end - - **原因(必须遵守)**:只有单独一轮调用 end,你才能完整看到上一轮所有 tool 的返回结果(成功/失败/message_id 等),再决定是否结束、memo 写什么;同轮并行会导致你看不到这些结果就结束,系统会拒绝 end 并强制重试,造成重复发送 + - **原因(必须遵守)**:只有单独一轮调用 end,你才能完整看到上一轮所有 tool 的返回结果(成功/失败/message_id 等),再决定是否结束、memo 写什么 - **唯一正确顺序**:本轮完成全部业务 tool → 阅读全部 tool 结果 → **下一轮响应中仅调用 end**(该轮不要再调用其它工具) + + **若仍同轮附带 end(运行时效果,务必理解):** + - 其它 tool(如 send_message、业务工具、Agent)会**照常并行执行**并正常返回结果 + - 同轮附带的 end **不会被执行**;其 tool 响应为**错误/拒绝**(告知未执行、对话未结束) + - 你必须阅读其它 tool 的返回后,在**下一轮单独调用 end**;若 send_message 已成功,**勿重复发送相同内容** + + **end 工具的特殊限制:** - 【重申·P0】end 工具**不能与其他工具同时调用** - 【重申·P0】必须在单独的一轮响应中调用 end diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py index e6539cab..9de4c473 100644 --- a/src/Undefined/ai/client.py +++ b/src/Undefined/ai/client.py @@ -27,7 +27,7 @@ from Undefined.services.message_summary_fetch import fetch_session_messages from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY from Undefined.ai.tokens import TokenCounter -from Undefined.ai.tooling import ToolManager +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT, ToolManager from Undefined.config import ( ChatModelConfig, VisionModelConfig, @@ -1506,7 +1506,7 @@ async def fetch_session_messages_callback( if len(tool_calls) > 1: logger.warning( "[工具调用] end 与其他工具同时调用," - "将先执行其他工具,再执行 end" + "将先执行其他工具,end 将返回拒绝结果" ) end_tool_call = tool_call end_tool_args = function_args @@ -1594,46 +1594,18 @@ async def fetch_session_messages_callback( end_call_id = end_tool_call.get("id", "") end_api_name = end_tool_call.get("function", {}).get("name", "end") if tool_tasks: - other_tools_failed = any( - isinstance(tool_result, Exception) - for tool_result in tool_results + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": END_CO_CALL_REJECT_CONTENT, + } + ) + logger.info( + "[工具调用] end 与其他工具同时调用," + "其它工具已执行,end 已回填拒绝响应" ) - if other_tools_failed: - skip_content = ( - "end 与其他工具同轮调用,且其它工具执行失败," - "本轮未执行 end;请根据工具结果继续决策。" - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": skip_content, - } - ) - logger.info( - "[工具调用] end 与其他工具同时调用," - "其它工具失败,已回填跳过响应" - ) - else: - tool_execution_started = True - end_result = await self.tool_manager.execute_tool( - "end", end_tool_args, tool_context - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": str(end_result), - } - ) - if tool_context.get("conversation_ended"): - conversation_ended = True - logger.info( - "[工具调用] end 与其他工具同时调用," - "已在其它工具完成后执行 end" - ) else: # end 单独调用,正常执行(参数已在循环中解析) tool_execution_started = True diff --git a/src/Undefined/ai/tooling.py b/src/Undefined/ai/tooling.py index 810d7e00..bcc88cce 100644 --- a/src/Undefined/ai/tooling.py +++ b/src/Undefined/ai/tooling.py @@ -16,6 +16,13 @@ logger = logging.getLogger(__name__) +# end 与同轮其它 tool 一并调用时,回填给 end 的 tool 响应(end 本身不执行) +END_CO_CALL_REJECT_CONTENT = ( + "错误:end 不得与其他工具同轮调用,本轮未执行 end,对话未结束。" + "其它工具已正常执行并返回其结果。" + "请根据其它 tool 结果在下一轮单独调用 end;若 send_message 已成功,勿重复发送相同内容。" +) + class ToolManager: """工具与智能体(Agent)执行管理器 diff --git a/src/Undefined/skills/agents/runner.py b/src/Undefined/skills/agents/runner.py index d44c364d..b68fd22e 100644 --- a/src/Undefined/skills/agents/runner.py +++ b/src/Undefined/skills/agents/runner.py @@ -11,6 +11,7 @@ from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry from Undefined.skills.anthropic_skills import AnthropicSkillRegistry +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT from Undefined.utils.tool_calls import parse_tool_arguments @@ -252,7 +253,7 @@ async def run_agent_with_tools( if len(tool_calls) > 1: logger.warning( "[Agent:%s] end 与其他工具同时调用," - "将先执行其他工具,再执行 end", + "将先执行其他工具,end 将返回拒绝结果", agent_name, ) end_tool_call = tool_call @@ -324,45 +325,19 @@ async def run_agent_with_tools( end_call_id = str(end_tool_call.get("id", "")) end_api_name = end_tool_call.get("function", {}).get("name", "end") if tool_tasks: - other_tools_failed = any( - isinstance(tool_result, Exception) for tool_result in results + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": END_CO_CALL_REJECT_CONTENT, + } + ) + logger.info( + "[Agent:%s] end 与其他工具同时调用," + "其它工具已执行,end 已回填拒绝响应", + agent_name, ) - if other_tools_failed: - skip_content = ( - "end 与其他工具同轮调用,且其它工具执行失败," - "本轮未执行 end;请根据工具结果继续决策。" - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": skip_content, - } - ) - logger.info( - "[Agent:%s] end 与其他工具同时调用," - "其它工具失败,已回填跳过响应", - agent_name, - ) - else: - tool_execution_started = True - end_result = await tool_registry.execute_tool( - "end", end_tool_args, context - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": str(end_result), - } - ) - logger.info( - "[Agent:%s] end 与其他工具同时调用," - "已在其它工具完成后执行 end", - agent_name, - ) else: # end 单独调用,正常执行(参数已在循环中解析) tool_execution_started = True diff --git a/tests/test_end_defer_co_call.py b/tests/test_end_defer_co_call.py index afa2db93..3f0b6663 100644 --- a/tests/test_end_defer_co_call.py +++ b/tests/test_end_defer_co_call.py @@ -7,6 +7,7 @@ import pytest from Undefined.ai.client import AIClient +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT from Undefined.config.models import ChatModelConfig @@ -50,7 +51,6 @@ def _build_minimal_ai_client( max_tokens=1024, ) client._find_chat_config_by_name = lambda _name: client.chat_config - client.submit_queued_llm_call = AsyncMock(side_effect=llm_responses) client._search_wrapper = None client._end_summary_storage = cast(Any, None) client._send_private_message_callback = None @@ -64,11 +64,27 @@ def _build_minimal_ai_client( error=None, proxy_config_available=False, ) + + submit_calls: list[list[dict[str, Any]]] = [] + + async def _submit_queued_llm_call( + *, + messages: list[dict[str, Any]], + **kwargs: Any, + ) -> dict[str, Any]: + submit_calls.append(messages) + index = len(submit_calls) - 1 + if index >= len(llm_responses): + raise RuntimeError("unexpected extra llm call") + return llm_responses[index] + + client.submit_queued_llm_call = AsyncMock(side_effect=_submit_queued_llm_call) + client._submit_calls = submit_calls return client @pytest.mark.asyncio -async def test_ai_ask_defers_end_after_send_message_in_same_round() -> None: +async def test_ai_ask_rejects_end_when_co_called_with_send_message() -> None: execute_calls: list[str] = [] async def _execute_tool( @@ -110,7 +126,25 @@ async def _execute_tool( } } ], - } + }, + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_end_only", + "function": { + "name": "end", + "arguments": "{}", + }, + } + ], + } + } + ], + }, ], ) @@ -118,11 +152,19 @@ async def _execute_tool( assert result == "" assert execute_calls == ["send_message", "end"] - assert cast(AsyncMock, client.submit_queued_llm_call).await_count == 1 + assert len(client._submit_calls) == 2 + + end_tool_messages = [ + m + for m in client._submit_calls[1] + if m.get("role") == "tool" and m.get("tool_call_id") == "call_end" + ] + assert len(end_tool_messages) == 1 + assert end_tool_messages[0]["content"] == END_CO_CALL_REJECT_CONTENT @pytest.mark.asyncio -async def test_ai_ask_skips_deferred_end_when_other_tool_failed() -> None: +async def test_ai_ask_rejects_end_when_other_tool_failed() -> None: execute_calls: list[str] = [] async def _execute_tool( @@ -189,4 +231,11 @@ async def _execute_tool( assert result == "" assert execute_calls == ["send_message", "end"] - assert cast(AsyncMock, client.submit_queued_llm_call).await_count == 2 + assert len(client._submit_calls) == 2 + + end_tool_messages = [ + m + for m in client._submit_calls[1] + if m.get("role") == "tool" and m.get("tool_call_id") == "call_end" + ] + assert end_tool_messages[0]["content"] == END_CO_CALL_REJECT_CONTENT From 43a32dadc912415d8bd670566500b9905639227d Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 19:49:15 +0800 Subject: [PATCH 03/16] fix(ai): preserve history on missing tool call retry Keep assistant plain-text in messages and use a generic retry hint instead of hardcoding send_message/end, avoiding misleading follow-up tool calls. Co-authored-by: Cursor --- docs/configuration.md | 2 +- src/Undefined/ai/client.py | 21 ++++++++++++++----- tests/test_end_defer_co_call.py | 2 +- tests/test_llm_retry_suppression.py | 32 +++++++++++++++++++++-------- 4 files changed, 42 insertions(+), 15 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 1b39f3be..ed87bece 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -100,7 +100,7 @@ model_name = "gpt-4o-mini" | `process_poke_message` | `true` | 是否响应拍一拍 | 关闭后忽略 poke | | `context_recent_messages_limit` | `20` | 注入到提示词的最近历史条数 | `<0` 视为 `0`(关闭注入);无固定上限,受 `max_records` 与存储约束 | | `ai_request_max_retries` | `2` | 单次 LLM 请求失败重试次数 | `<0` 自动回退到 `0`;支持热更新 | -| `missing_tool_call_retries` | `3` | 模型返回纯文本但未调用 `send_message` / `end` 等工具时的纠正重试次数 | `<0` 自动回退到 `0`;支持热更新 | +| `missing_tool_call_retries` | `3` | 模型返回纯文本但未调用任何工具时的纠正重试次数(保留 assistant 纯文本 + 通用纠正提示,不写死具体 tool) | `<0` 自动回退到 `0`;支持热更新 | --- diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py index 9de4c473..599f3abb 100644 --- a/src/Undefined/ai/client.py +++ b/src/Undefined/ai/client.py @@ -61,6 +61,14 @@ logger = logging.getLogger(__name__) +# 模型返回纯文本但未调用 tool 时,追加到 messages 的纠正提示(不写死具体 tool) +MISSING_TOOL_CALL_RETRY_HINT = ( + "【系统提示】你上一轮输出了纯文本且未调用任何工具。" + "本环境必须通过工具调用来完成对外动作与结束本轮处理。" + "请结合上文完整对话历史与已有 tool 返回结果,自行决定下一步应调用的工具;" + "不要直接以纯文本作为最终对外回复。" +) + _CONTENT_TAG_PATTERN = re.compile( r"(.*?)", re.DOTALL | re.IGNORECASE @@ -1410,14 +1418,17 @@ async def fetch_session_messages_callback( max_missing_tool_call_retries, len(content), ) + assistant_retry_message: dict[str, Any] = { + "role": "assistant", + "content": content, + } + if capture_reasoning and reasoning_content is not None: + assistant_retry_message["reasoning_content"] = reasoning_content + messages.append(assistant_retry_message) messages.append( { "role": "user", - "content": ( - "注意:你不能直接返回纯文本作为最终回复。" - "请调用 send_message 工具来发送你的回复消息," - "然后调用 end 工具结束对话。" - ), + "content": MISSING_TOOL_CALL_RETRY_HINT, } ) continue diff --git a/tests/test_end_defer_co_call.py b/tests/test_end_defer_co_call.py index 3f0b6663..70730f13 100644 --- a/tests/test_end_defer_co_call.py +++ b/tests/test_end_defer_co_call.py @@ -72,7 +72,7 @@ async def _submit_queued_llm_call( messages: list[dict[str, Any]], **kwargs: Any, ) -> dict[str, Any]: - submit_calls.append(messages) + submit_calls.append(list(messages)) index = len(submit_calls) - 1 if index >= len(llm_responses): raise RuntimeError("unexpected extra llm call") diff --git a/tests/test_llm_retry_suppression.py b/tests/test_llm_retry_suppression.py index 2012a693..d47905f9 100644 --- a/tests/test_llm_retry_suppression.py +++ b/tests/test_llm_retry_suppression.py @@ -8,7 +8,7 @@ import pytest -from Undefined.ai.client import AIClient +from Undefined.ai.client import AIClient, MISSING_TOOL_CALL_RETRY_HINT from Undefined.ai.queue_budget import compute_queued_llm_timeout_seconds from Undefined.config.models import ( AgentModelConfig, @@ -185,13 +185,22 @@ async def test_ai_ask_limits_missing_tool_call_retries() -> None: max_tokens=1024, ) client._find_chat_config_by_name = lambda _name: client.chat_config - client.submit_queued_llm_call = AsyncMock( - side_effect=[ - {"choices": [{"message": {"content": "plain 1", "tool_calls": []}}]}, - {"choices": [{"message": {"content": "plain 2", "tool_calls": []}}]}, - {"choices": [{"message": {"content": "plain 3", "tool_calls": []}}]}, - ] - ) + llm_responses = [ + {"choices": [{"message": {"content": "plain 1", "tool_calls": []}}]}, + {"choices": [{"message": {"content": "plain 2", "tool_calls": []}}]}, + {"choices": [{"message": {"content": "plain 3", "tool_calls": []}}]}, + ] + submit_calls: list[list[dict[str, Any]]] = [] + + async def _submit_queued_llm_call( + *, + messages: list[dict[str, Any]], + **kwargs: Any, + ) -> dict[str, Any]: + submit_calls.append(list(messages)) + return llm_responses[len(submit_calls) - 1] + + client.submit_queued_llm_call = AsyncMock(side_effect=_submit_queued_llm_call) client._search_wrapper = None client._end_summary_storage = cast(Any, None) client._send_private_message_callback = None @@ -213,6 +222,13 @@ async def test_ai_ask_limits_missing_tool_call_retries() -> None: assert cast(AsyncMock, client.submit_queued_llm_call).await_count == 3 send_message.assert_awaited_once_with("plain 3") + second_call_messages = submit_calls[1] + assert second_call_messages[-2:] == [ + {"role": "assistant", "content": "plain 1"}, + {"role": "user", "content": MISSING_TOOL_CALL_RETRY_HINT}, + ] + assert "send_message" not in MISSING_TOOL_CALL_RETRY_HINT + @pytest.mark.asyncio async def test_agent_runner_reraises_queued_llm_error(tmp_path: Path) -> None: From 7e6b371d421d522b844c14af3049cd89555132ab Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 21:22:21 +0800 Subject: [PATCH 04/16] fix(config): export set_config and split Config loading Wire config_class as the canonical Config, slim loader to a compat shim, and restore root lazy re-exports so library embed tests can inject config without config.toml. Co-authored-by: Cursor --- src/Undefined/__init__.py | 54 + src/Undefined/config/__init__.py | 12 +- src/Undefined/config/build_config.py | 57 + src/Undefined/config/config_class.py | 583 ++++++ src/Undefined/config/env_registry.py | 192 ++ .../config/load_sections/__init__.py | 26 + src/Undefined/config/load_sections/access.py | 86 + src/Undefined/config/load_sections/core.py | 183 ++ src/Undefined/config/load_sections/domains.py | 58 + .../config/load_sections/finalize.py | 34 + .../config/load_sections/history_skills.py | 275 +++ .../config/load_sections/integrations.py | 267 +++ .../config/load_sections/knowledge.py | 120 ++ .../config/load_sections/logging_tools.py | 122 ++ src/Undefined/config/load_sections/models.py | 68 + src/Undefined/config/load_sections/network.py | 161 ++ src/Undefined/config/loader.py | 1787 +---------------- src/Undefined/config/parsers/__init__.py | 39 + src/Undefined/config/parsers/agent.py | 163 ++ src/Undefined/config/parsers/chat.py | 165 ++ src/Undefined/config/parsers/embedding.py | 106 + src/Undefined/config/parsers/grok.py | 129 ++ src/Undefined/config/parsers/helpers.py | 122 ++ src/Undefined/config/parsers/historian.py | 144 ++ src/Undefined/config/parsers/image.py | 120 ++ src/Undefined/config/parsers/naga.py | 206 ++ src/Undefined/config/parsers/pool.py | 142 ++ src/Undefined/config/parsers/security.py | 196 ++ src/Undefined/config/parsers/summary.py | 147 ++ src/Undefined/config/parsers/vision.py | 162 ++ src/Undefined/config/toml_io.py | 110 + tests/test_cli_startup_compat.py | 128 ++ tests/test_config_env_only.py | 61 + tests/test_config_env_registry.py | 34 + tests/test_config_from_mapping.py | 96 + tests/test_public_api_imports.py | 141 ++ 36 files changed, 4713 insertions(+), 1783 deletions(-) create mode 100644 src/Undefined/config/build_config.py create mode 100644 src/Undefined/config/config_class.py create mode 100644 src/Undefined/config/env_registry.py create mode 100644 src/Undefined/config/load_sections/__init__.py create mode 100644 src/Undefined/config/load_sections/access.py create mode 100644 src/Undefined/config/load_sections/core.py create mode 100644 src/Undefined/config/load_sections/domains.py create mode 100644 src/Undefined/config/load_sections/finalize.py create mode 100644 src/Undefined/config/load_sections/history_skills.py create mode 100644 src/Undefined/config/load_sections/integrations.py create mode 100644 src/Undefined/config/load_sections/knowledge.py create mode 100644 src/Undefined/config/load_sections/logging_tools.py create mode 100644 src/Undefined/config/load_sections/models.py create mode 100644 src/Undefined/config/load_sections/network.py create mode 100644 src/Undefined/config/parsers/__init__.py create mode 100644 src/Undefined/config/parsers/agent.py create mode 100644 src/Undefined/config/parsers/chat.py create mode 100644 src/Undefined/config/parsers/embedding.py create mode 100644 src/Undefined/config/parsers/grok.py create mode 100644 src/Undefined/config/parsers/helpers.py create mode 100644 src/Undefined/config/parsers/historian.py create mode 100644 src/Undefined/config/parsers/image.py create mode 100644 src/Undefined/config/parsers/naga.py create mode 100644 src/Undefined/config/parsers/pool.py create mode 100644 src/Undefined/config/parsers/security.py create mode 100644 src/Undefined/config/parsers/summary.py create mode 100644 src/Undefined/config/parsers/vision.py create mode 100644 src/Undefined/config/toml_io.py create mode 100644 tests/test_cli_startup_compat.py create mode 100644 tests/test_config_env_only.py create mode 100644 tests/test_config_env_registry.py create mode 100644 tests/test_config_from_mapping.py create mode 100644 tests/test_public_api_imports.py diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 5cd8c8d3..5cf4027a 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -1,3 +1,57 @@ """Undefined - A high-performance, highly scalable QQ group and private chat robot based on a self-developed architecture.""" +from __future__ import annotations + +import importlib +from typing import Any + __version__ = "3.4.2" + +__all__ = [ + "__version__", + "Config", + "get_config", + "AIClient", + "ToolRegistry", + "AgentRegistry", + "PipelineRegistry", + "BaseRegistry", + "AnthropicSkillRegistry", + "CognitiveService", + "KnowledgeManager", + "MemeService", + "AttachmentRegistry", + "RuntimeAPIServer", + "RuntimeAPIContext", +] + +# symbol -> (module_path, attribute_name);首次访问时才 importlib 加载 +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "Config": ("Undefined.config", "Config"), + "get_config": ("Undefined.config", "get_config"), + "AIClient": ("Undefined.ai", "AIClient"), + "ToolRegistry": ("Undefined.skills.tools", "ToolRegistry"), + "AgentRegistry": ("Undefined.skills.agents", "AgentRegistry"), + "PipelineRegistry": ("Undefined.skills.pipelines.registry", "PipelineRegistry"), + "BaseRegistry": ("Undefined.skills.registry", "BaseRegistry"), + "AnthropicSkillRegistry": ( + "Undefined.skills.anthropic_skills", + "AnthropicSkillRegistry", + ), + "CognitiveService": ("Undefined.cognitive.service", "CognitiveService"), + "KnowledgeManager": ("Undefined.knowledge.manager", "KnowledgeManager"), + "MemeService": ("Undefined.memes.service", "MemeService"), + "AttachmentRegistry": ("Undefined.attachments", "AttachmentRegistry"), + "RuntimeAPIServer": ("Undefined.api.app", "RuntimeAPIServer"), + "RuntimeAPIContext": ("Undefined.api._context", "RuntimeAPIContext"), +} + + +def __getattr__(name: str) -> Any: + if name not in _LAZY_IMPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_path, attr = _LAZY_IMPORTS[name] + module = importlib.import_module(module_path) + value = getattr(module, attr) + globals()[name] = value + return value diff --git a/src/Undefined/config/__init__.py b/src/Undefined/config/__init__.py index 7242f3cf..b3e3fab2 100644 --- a/src/Undefined/config/__init__.py +++ b/src/Undefined/config/__init__.py @@ -2,7 +2,7 @@ from typing import Optional -from .loader import Config, WebUISettings, load_webui_settings +from .config_class import Config, ConfigBuilder from .manager import ConfigManager from .models import ( APIConfig, @@ -19,9 +19,11 @@ SecurityModelConfig, VisionModelConfig, ) +from .webui_settings import WebUISettings, load_webui_settings __all__ = [ "Config", + "ConfigBuilder", "ChatModelConfig", "VisionModelConfig", "SecurityModelConfig", @@ -37,11 +39,11 @@ "RenderCacheConfig", "get_config", "get_config_manager", + "set_config", "load_webui_settings", "WebUISettings", ] -# 全局配置实例 _config: Optional[Config] = None _config_manager: Optional[ConfigManager] = None @@ -60,3 +62,9 @@ def get_config(strict: bool = True) -> Config: if _config is None: _config = get_config_manager().load(strict=strict) return _config + + +def set_config(config: Config) -> None: + """注入 Config 单例(库嵌入 opt-in;CLI / WebUI 启动链不得调用)。""" + global _config + _config = config diff --git a/src/Undefined/config/build_config.py b/src/Undefined/config/build_config.py new file mode 100644 index 00000000..72021e21 --- /dev/null +++ b/src/Undefined/config/build_config.py @@ -0,0 +1,57 @@ +"""Build Config from parsed TOML mapping.""" + +from __future__ import annotations + + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from .config_class import Config + +from .load_sections import ( + load_access, + load_core, + load_domains, + load_finalize, + load_history_skills, + load_integrations, + load_knowledge, + load_logging_tools, + load_models, + load_network, +) + + +# 从中间态构建最终对象 +def build_config( + data: dict[str, Any], + *, + strict: bool = True, + config_path: Optional[Path] = None, +) -> "Config": + """从已解析的 TOML mapping 构建 Config。""" + from .config_class import Config + + # 按依赖顺序分阶段加载:core/knowledge/models 在前,access 等依赖 admin 合并 + ctx: dict[str, Any] = {} + ctx.update(load_core(data)) + ctx.update(load_knowledge(data)) + ctx.update(load_models(data)) + from .model_parsers import _merge_admins + + # 合并 config.toml 与本地 admins.json,超管始终纳入 admin 列表 + superadmin_qq, admin_qqs = _merge_admins( + superadmin_qq=ctx["superadmin_qq"], admin_qqs=ctx["admin_qqs"] + ) + ctx["superadmin_qq"] = superadmin_qq + ctx["admin_qqs"] = admin_qqs + ctx.update(load_access(data)) + ctx.update(load_logging_tools(data)) + ctx.update(load_history_skills(data)) + ctx.update(load_network(data)) + ctx.update(load_integrations(data)) + # domains 含 WebUI/API/认知/合并器等子域,放最后以便前面模型段已就绪 + ctx.update(load_domains(data, config_path=config_path)) + load_finalize(ctx, strict=strict) # strict 时校验必填项并打 debug 摘要 + return Config(**ctx) diff --git a/src/Undefined/config/config_class.py b/src/Undefined/config/config_class.py new file mode 100644 index 00000000..b29a786e --- /dev/null +++ b/src/Undefined/config/config_class.py @@ -0,0 +1,583 @@ +"""Config dataclass and instance methods.""" + +from __future__ import annotations + + +from dataclasses import dataclass, field as dataclass_field, fields +from pathlib import Path +from typing import Any, Optional + +from .admin import load_local_admins, save_local_admins +from .domain_parsers import _update_dataclass +from .models import ( + AgentModelConfig, + APIConfig, + ChatModelConfig, + CognitiveConfig, + EmbeddingModelConfig, + GrokModelConfig, + ImageGenConfig, + ImageGenModelConfig, + MemeConfig, + MessageBatcherConfig, + NagaConfig, + RenderCacheConfig, + RerankModelConfig, + SecurityModelConfig, + VisionModelConfig, +) +from .toml_io import _load_env, load_toml_data + + +@dataclass +class Config: + """应用配置""" + + bot_qq: int + superadmin_qq: int + admin_qqs: list[int] + # 访问控制模式:off / blacklist / allowlist + access_mode: str + # 访问控制(会话白名单 + 黑名单) + allowed_group_ids: list[int] + blocked_group_ids: list[int] + allowed_private_ids: list[int] + blocked_private_ids: list[int] + # 是否允许超级管理员在私聊中绕过 allowed_private_ids(仅私聊收发) + superadmin_bypass_allowlist: bool + # 是否允许超级管理员在私聊中绕过 blocked_private_ids(仅私聊收发) + superadmin_bypass_private_blacklist: bool + forward_proxy_qq: int | None + process_every_message: bool + process_private_message: bool + process_poke_message: bool + keyword_reply_enabled: bool + repeat_enabled: bool + repeat_threshold: int + repeat_cooldown_minutes: int + inverted_question_enabled: bool + context_recent_messages_limit: int + ai_request_max_retries: int + missing_tool_call_retries: int + nagaagent_mode_enabled: bool + onebot_ws_url: str + onebot_token: str + chat_model: ChatModelConfig + vision_model: VisionModelConfig + security_model_enabled: bool + security_model: SecurityModelConfig + naga_model: SecurityModelConfig + agent_model: AgentModelConfig + historian_model: AgentModelConfig + summary_model: AgentModelConfig + summary_model_configured: bool + grok_model: GrokModelConfig + model_pool_enabled: bool + log_level: str + log_file_path: str + log_max_size: int + log_backup_count: int + log_tty_enabled: bool + log_thinking: bool + tools_dot_delimiter: str + tools_description_truncate_enabled: bool + tools_description_max_len: int + tools_sanitize_verbose: bool + tools_description_preview_len: int + easter_egg_agent_call_message_mode: str + token_usage_max_size_mb: int + token_usage_max_archives: int + token_usage_max_total_mb: int + token_usage_archive_prune_mode: str + history_max_records: int + history_filtered_result_limit: int + history_search_scan_limit: int + history_summary_fetch_limit: int + history_summary_time_fetch_limit: int + history_onebot_fetch_limit: int + history_group_analysis_limit: int + attachment_remote_download_max_size_mb: int + attachment_cache_max_total_size_mb: int + attachment_cache_max_records: int + attachment_cache_max_age_days: int + attachment_url_reference_max_records: int + attachment_url_max_length: int + skills_hot_reload: bool + skills_hot_reload_interval: float + skills_hot_reload_debounce: float + agent_intro_autogen_enabled: bool + agent_intro_autogen_queue_interval: float + agent_intro_autogen_max_tokens: int + agent_intro_hash_path: str + searxng_url: str + grok_search_enabled: bool + use_proxy: bool + http_proxy: str + https_proxy: str + network_request_timeout: float + network_request_retries: int + render_browser_max_concurrency: int + api_xxapi_base_url: str + api_xingzhige_base_url: str + api_jkyai_base_url: str + api_seniverse_base_url: str + weather_api_key: str + xxapi_api_token: str + mcp_config_path: str + prefetch_tools: list[str] + prefetch_tools_hide: bool + webui_url: str + webui_port: int + webui_password: str + api: APIConfig + # Code Delivery Agent + code_delivery_enabled: bool + code_delivery_task_root: str + code_delivery_docker_image: str + code_delivery_container_name_prefix: str + code_delivery_container_name_suffix: str + code_delivery_command_timeout: int + code_delivery_max_command_output: int + code_delivery_default_archive_format: str + code_delivery_max_archive_size_mb: int + code_delivery_cleanup_on_finish: bool + code_delivery_cleanup_on_start: bool + code_delivery_llm_max_retries: int + code_delivery_notify_on_llm_failure: bool + code_delivery_container_memory_limit: str + code_delivery_container_cpu_limit: str + code_delivery_command_blacklist: list[str] + # messages 工具集 + messages_send_text_file_max_size_kb: int + messages_send_url_file_max_size_mb: int + # 嵌入模型 + embedding_model: EmbeddingModelConfig + rerank_model: RerankModelConfig + # 知识库 + knowledge_enabled: bool + knowledge_base_dir: str + knowledge_auto_scan: bool + knowledge_auto_embed: bool + knowledge_scan_interval: float + knowledge_embed_batch_size: int + knowledge_chunk_size: int + knowledge_chunk_overlap: int + knowledge_default_top_k: int + knowledge_enable_rerank: bool + knowledge_rerank_top_k: int + # Bilibili 视频提取 + bilibili_auto_extract_enabled: bool + bilibili_cookie: str + bilibili_prefer_quality: int + bilibili_max_duration: int + bilibili_max_file_size: int + bilibili_oversize_strategy: str + bilibili_danmaku_enabled: bool + bilibili_danmaku_batch_size: int + bilibili_danmaku_max_count: int + bilibili_auto_extract_group_ids: list[int] + bilibili_auto_extract_private_ids: list[int] + # arXiv 论文提取 + arxiv_auto_extract_enabled: bool + arxiv_max_file_size: int + arxiv_auto_extract_group_ids: list[int] + arxiv_auto_extract_private_ids: list[int] + arxiv_auto_extract_max_items: int + arxiv_author_preview_limit: int + arxiv_summary_preview_chars: int + # GitHub 仓库自动提取 + github_auto_extract_enabled: bool + github_request_timeout_seconds: float + github_auto_extract_group_ids: list[int] + github_auto_extract_private_ids: list[int] + github_auto_extract_max_items: int + # 认知记忆 + cognitive: CognitiveConfig + # 表情包库 + memes: MemeConfig + # 同 sender 短时多消息合并器 + message_batcher: MessageBatcherConfig + # HTML 渲染结果缓存 + render_cache: RenderCacheConfig + # Naga 集成 + naga: NagaConfig + # 生图工具配置 + image_gen: ImageGenConfig + models_image_gen: ImageGenModelConfig + models_image_edit: ImageGenModelConfig + _allowed_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _blocked_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _allowed_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _blocked_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _bilibili_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _bilibili_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _arxiv_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _arxiv_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _github_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _github_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + + def __post_init__(self) -> None: + # 访问控制属于高频热路径,启动后缓存为 set 降低重复构建开销。 + normalized_mode = str(self.access_mode).strip().lower() + if normalized_mode not in {"off", "blacklist", "allowlist", "legacy"}: + normalized_mode = "off" + self.access_mode = normalized_mode + self._allowed_group_ids_set = {int(item) for item in self.allowed_group_ids} + self._blocked_group_ids_set = {int(item) for item in self.blocked_group_ids} + self._allowed_private_ids_set = {int(item) for item in self.allowed_private_ids} + self._blocked_private_ids_set = {int(item) for item in self.blocked_private_ids} + self._bilibili_group_ids_set = { + int(item) for item in self.bilibili_auto_extract_group_ids + } + self._bilibili_private_ids_set = { + int(item) for item in self.bilibili_auto_extract_private_ids + } + self._arxiv_group_ids_set = { + int(item) for item in self.arxiv_auto_extract_group_ids + } + self._arxiv_private_ids_set = { + int(item) for item in self.arxiv_auto_extract_private_ids + } + self._github_group_ids_set = { + int(item) for item in self.github_auto_extract_group_ids + } + self._github_private_ids_set = { + int(item) for item in self.github_auto_extract_private_ids + } + + @classmethod + def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Config": + """从 config.toml 和本地配置加载配置""" + from .build_config import build_config + + _load_env() # 先加载 .env,供 _get_value 环境变量回退 + data = load_toml_data(config_path, strict=strict) + return build_config(data, strict=strict, config_path=config_path) + + @classmethod + def from_mapping(cls, data: dict[str, Any], *, strict: bool = True) -> "Config": + """从内存 mapping 构建配置(无 TOML 文件)。""" + from .build_config import build_config + + return build_config(data, strict=strict, config_path=None) + + @classmethod + def builder(cls) -> "ConfigBuilder": + """返回可链式覆盖字段的配置构建器。""" + return ConfigBuilder() + + @property + def bilibili_sessdata(self) -> str: + """兼容旧字段名,等价于 bilibili_cookie。""" + return self.bilibili_cookie + + def allowlist_mode_enabled(self) -> bool: + """是否启用白名单限制模式。""" + + return self.access_mode in {"allowlist", "legacy"} and ( + bool(self.allowed_group_ids) or bool(self.allowed_private_ids) + ) + + def group_allowlist_enabled(self) -> bool: + """群聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" + + return bool(self.allowed_group_ids) + + def private_allowlist_enabled(self) -> bool: + """私聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" + + return bool(self.allowed_private_ids) + + def blacklist_mode_enabled(self) -> bool: + """是否启用黑名单限制模式。""" + + return self.access_mode in {"blacklist", "legacy"} and ( + bool(self.blocked_group_ids) or bool(self.blocked_private_ids) + ) + + def access_control_enabled(self) -> bool: + """是否启用访问控制。""" + + return self.allowlist_mode_enabled() or self.blacklist_mode_enabled() + + def group_access_denied_reason(self, group_id: int) -> str | None: + """群聊访问被拒绝原因。 + + 返回: + - "blacklist": 命中 access.blocked_group_ids + - "allowlist": allowlist 模式下不在 access.allowed_group_ids + - None: 允许访问 + """ + + normalized_group_id = int(group_id) + if self.access_mode == "off": + return None + if self.access_mode == "blacklist": + if normalized_group_id in self._blocked_group_ids_set: + return "blacklist" + return None + if self.access_mode == "legacy": + if normalized_group_id in self._blocked_group_ids_set: + return "blacklist" + if not self.allowlist_mode_enabled(): + return None + if normalized_group_id not in self._allowed_group_ids_set: + return "allowlist" + return None + if not self.group_allowlist_enabled(): + return None + if normalized_group_id not in self._allowed_group_ids_set: + return "allowlist" + return None + + def is_group_allowed(self, group_id: int) -> bool: + """群聊是否允许收发消息。""" + + return self.group_access_denied_reason(group_id) is None + + def private_access_denied_reason(self, user_id: int) -> str | None: + """私聊访问被拒绝原因。""" + + normalized_user_id = int(user_id) + if self.access_mode == "off": + return None + if self.access_mode == "blacklist": + if normalized_user_id not in self._blocked_private_ids_set: + return None + if ( + self.superadmin_bypass_private_blacklist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + return "blacklist" + if self.access_mode == "legacy": + if normalized_user_id in self._blocked_private_ids_set: + if ( + self.superadmin_bypass_private_blacklist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + return "blacklist" + if not self.allowlist_mode_enabled(): + return None + if ( + self.superadmin_bypass_allowlist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + if normalized_user_id not in self._allowed_private_ids_set: + return "allowlist" + return None + if not self.private_allowlist_enabled(): + return None + if ( + self.superadmin_bypass_allowlist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + if normalized_user_id not in self._allowed_private_ids_set: + return "allowlist" + return None + + def is_private_allowed(self, user_id: int) -> bool: + """私聊是否允许收发消息。""" + + return self.private_access_denied_reason(user_id) is None + + def is_bilibili_auto_extract_allowed_group(self, group_id: int) -> bool: + """群聊是否允许 bilibili 自动提取。""" + if self._bilibili_group_ids_set: + return int(group_id) in self._bilibili_group_ids_set + # 功能白名单为空时跟随全局 access 控制 + return self.is_group_allowed(group_id) + + def is_bilibili_auto_extract_allowed_private(self, user_id: int) -> bool: + """私聊是否允许 bilibili 自动提取。""" + if self._bilibili_private_ids_set: + return int(user_id) in self._bilibili_private_ids_set + # 功能白名单为空时跟随全局 access 控制 + return self.is_private_allowed(user_id) + + def is_arxiv_auto_extract_allowed_group(self, group_id: int) -> bool: + """群聊是否允许 arXiv 自动提取。""" + if self._arxiv_group_ids_set: + return int(group_id) in self._arxiv_group_ids_set + return self.is_group_allowed(group_id) + + def is_arxiv_auto_extract_allowed_private(self, user_id: int) -> bool: + """私聊是否允许 arXiv 自动提取。""" + if self._arxiv_private_ids_set: + return int(user_id) in self._arxiv_private_ids_set + return self.is_private_allowed(user_id) + + def is_github_auto_extract_allowed_group(self, group_id: int) -> bool: + """群聊是否允许 GitHub 仓库自动提取。""" + if self._github_group_ids_set: + return int(group_id) in self._github_group_ids_set + return self.is_group_allowed(group_id) + + def is_github_auto_extract_allowed_private(self, user_id: int) -> bool: + """私聊是否允许 GitHub 仓库自动提取。""" + if self._github_private_ids_set: + return int(user_id) in self._github_private_ids_set + return self.is_private_allowed(user_id) + + def should_process_group_message(self, is_at_bot: bool) -> bool: + """是否处理该条群消息。""" + + if self.process_every_message: + return True + return bool(is_at_bot) + + def should_process_private_message(self) -> bool: + """是否处理私聊消息回复。""" + + return bool(self.process_private_message) + + def should_process_poke_message(self) -> bool: + """是否处理拍一拍触发。""" + + return bool(self.process_poke_message) + + def get_context_recent_messages_limit(self) -> int: + """获取上下文最近历史消息条数上限。""" + + limit = int(self.context_recent_messages_limit) + if limit < 0: + return 0 + return limit + + def security_check_enabled(self) -> bool: + """是否启用安全模型检查。""" + # 热更新运行时参数 + + return bool(self.security_model_enabled) + + # 热更新运行时参数 + def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: + # 逐字段 diff;嵌套模型配置用 _update_dataclass 展开为 chat_model.api_url 等键 + changes: dict[str, tuple[Any, Any]] = {} + for field in fields(self): + name = field.name + old_value = getattr(self, name) + new_value = getattr(new_config, name) + if isinstance( + old_value, + ( + ChatModelConfig, + VisionModelConfig, + SecurityModelConfig, + AgentModelConfig, + GrokModelConfig, + ), + ): + changes.update(_update_dataclass(old_value, new_value, prefix=name)) + continue + if old_value != new_value: + setattr(self, name, new_value) + changes[name] = (old_value, new_value) + return changes + + def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: + # 对外入队 API + new_config = Config.load(strict=strict) + return self.update_from(new_config) + + # 对外入队 API + def add_admin(self, qq: int) -> bool: + if qq in self.admin_qqs: + return False + self.admin_qqs.append(qq) + local_admins = load_local_admins() + if qq not in local_admins: + local_admins.append(qq) + save_local_admins(local_admins) + return True + + def remove_admin(self, qq: int) -> bool: + if qq == self.superadmin_qq or qq not in self.admin_qqs: + return False + self.admin_qqs.remove(qq) + local_admins = load_local_admins() + if qq in local_admins: + local_admins.remove(qq) + save_local_admins(local_admins) + return True + + def is_superadmin(self, qq: int) -> bool: + return qq == self.superadmin_qq + + def is_admin(self, qq: int) -> bool: + return qq in self.admin_qqs + + +class ConfigBuilder: + """链式构建 Config;未设置的字段使用 build_config 默认值。""" + + def __init__(self) -> None: + self._overrides: dict[str, Any] = {} + + def with_mapping(self, data: dict[str, Any]) -> "ConfigBuilder": + self._overrides["_base_mapping"] = data + return self + + def override(self, **kwargs: Any) -> "ConfigBuilder": + self._overrides.update(kwargs) + return self + + # 从中间态构建最终对象 + def build(self, *, strict: bool = True) -> Config: + from .build_config import build_config + + base = self._overrides.pop("_base_mapping", {}) + if not isinstance(base, dict): + base = {} + data = dict(base) + # TOML-style nested overrides for simple top-level keys only + for key, value in self._overrides.items(): + data[key] = value + return build_config(data, strict=strict, config_path=None) diff --git a/src/Undefined/config/env_registry.py b/src/Undefined/config/env_registry.py new file mode 100644 index 00000000..99d10ea5 --- /dev/null +++ b/src/Undefined/config/env_registry.py @@ -0,0 +1,192 @@ +"""TOML path to environment variable mapping for configuration.""" + +from __future__ import annotations + + +from typing import Final + +# TOML 路径 → 环境变量名;供 _get_value 在 TOML 缺省时回退,以及文档/工具生成 +ENV_REGISTRY: Final[dict[tuple[str, ...], str]] = { + ("access", "allowed_group_ids"): "ALLOWED_GROUP_IDS", + ("access", "allowed_private_ids"): "ALLOWED_PRIVATE_IDS", + ("access", "blocked_group_ids"): "BLOCKED_GROUP_IDS", + ("access", "blocked_private_ids"): "BLOCKED_PRIVATE_IDS", + ("access", "mode"): "ACCESS_MODE", + ("api_endpoints", "jkyai_base_url"): "JKYAI_BASE_URL", + ("api_endpoints", "xxapi_base_url"): "XXAPI_BASE_URL", + ("core", "admin_qq"): "ADMIN_QQ", + ("core", "bot_qq"): "BOT_QQ", + ("core", "forward_proxy_qq"): "FORWARD_PROXY_QQ", + ("core", "superadmin_qq"): "SUPERADMIN_QQ", + ("features", "pool_enabled"): "MODEL_POOL_ENABLED", + ("history", "max_records"): "HISTORY_MAX_RECORDS", + ("image_gen", "provider"): "IMAGE_GEN_PROVIDER", + ("logging", "backup_count"): "LOG_BACKUP_COUNT", + ("logging", "file_path"): "LOG_FILE_PATH", + ("logging", "level"): "LOG_LEVEL", + ("logging", "log_thinking"): "LOG_THINKING", + ("logging", "max_size_mb"): "LOG_MAX_SIZE_MB", + ("logging", "tty_enabled"): "LOG_TTY_ENABLED", + ("mcp", "config_path"): "MCP_CONFIG_PATH", + ("models", "agent", "api_key"): "AGENT_MODEL_API_KEY", + ("models", "agent", "api_mode"): "AGENT_MODEL_API_MODE", + ("models", "agent", "api_url"): "AGENT_MODEL_API_URL", + ("models", "agent", "context_window_tokens"): "AGENT_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "agent", "model_name"): "AGENT_MODEL_NAME", + ( + "models", + "agent", + "reasoning_content_replay", + ): "AGENT_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "agent", + "responses_force_stateless_replay", + ): "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "agent", + "responses_tool_choice_compat", + ): "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "agent", "system_prompt_as_user"): "AGENT_MODEL_SYSTEM_PROMPT_AS_USER", + ("models", "chat", "api_key"): "CHAT_MODEL_API_KEY", + ("models", "chat", "api_mode"): "CHAT_MODEL_API_MODE", + ("models", "chat", "api_url"): "CHAT_MODEL_API_URL", + ("models", "chat", "context_window_tokens"): "CHAT_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "chat", "max_tokens"): "CHAT_MODEL_MAX_TOKENS", + ("models", "chat", "model_name"): "CHAT_MODEL_NAME", + ( + "models", + "chat", + "reasoning_content_replay", + ): "CHAT_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "chat", + "responses_force_stateless_replay", + ): "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "chat", + "responses_tool_choice_compat", + ): "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "chat", "system_prompt_as_user"): "CHAT_MODEL_SYSTEM_PROMPT_AS_USER", + ( + "models", + "embedding", + "context_window_tokens", + ): "EMBEDDING_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "grok", "api_key"): "GROK_MODEL_API_KEY", + ("models", "grok", "api_url"): "GROK_MODEL_API_URL", + ("models", "grok", "context_window_tokens"): "GROK_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "grok", "max_tokens"): "GROK_MODEL_MAX_TOKENS", + ("models", "grok", "model_name"): "GROK_MODEL_NAME", + ("models", "naga", "api_key"): "NAGA_MODEL_API_KEY", + ("models", "naga", "api_mode"): "NAGA_MODEL_API_MODE", + ("models", "naga", "api_url"): "NAGA_MODEL_API_URL", + ("models", "naga", "context_window_tokens"): "NAGA_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "naga", "model_name"): "NAGA_MODEL_NAME", + ( + "models", + "naga", + "reasoning_content_replay", + ): "NAGA_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "naga", + "responses_force_stateless_replay", + ): "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "naga", + "responses_tool_choice_compat", + ): "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "naga", "system_prompt_as_user"): "NAGA_MODEL_SYSTEM_PROMPT_AS_USER", + ("models", "rerank", "api_key"): "RERANK_MODEL_API_KEY", + ("models", "rerank", "api_url"): "RERANK_MODEL_API_URL", + ("models", "rerank", "context_window_tokens"): "RERANK_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "rerank", "model_name"): "RERANK_MODEL_NAME", + ("models", "security", "api_key"): "SECURITY_MODEL_API_KEY", + ("models", "security", "api_mode"): "SECURITY_MODEL_API_MODE", + ("models", "security", "api_url"): "SECURITY_MODEL_API_URL", + ( + "models", + "security", + "context_window_tokens", + ): "SECURITY_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "security", "model_name"): "SECURITY_MODEL_NAME", + ( + "models", + "security", + "reasoning_content_replay", + ): "SECURITY_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "security", + "responses_force_stateless_replay", + ): "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "security", + "responses_tool_choice_compat", + ): "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ( + "models", + "security", + "system_prompt_as_user", + ): "SECURITY_MODEL_SYSTEM_PROMPT_AS_USER", + ("models", "vision", "api_key"): "VISION_MODEL_API_KEY", + ("models", "vision", "api_mode"): "VISION_MODEL_API_MODE", + ("models", "vision", "api_url"): "VISION_MODEL_API_URL", + ("models", "vision", "context_window_tokens"): "VISION_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "vision", "model_name"): "VISION_MODEL_NAME", + ( + "models", + "vision", + "reasoning_content_replay", + ): "VISION_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "vision", + "responses_force_stateless_replay", + ): "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "vision", + "responses_tool_choice_compat", + ): "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "vision", "system_prompt_as_user"): "VISION_MODEL_SYSTEM_PROMPT_AS_USER", + ("onebot", "token"): "ONEBOT_TOKEN", + ("onebot", "ws_url"): "ONEBOT_WS_URL", + ("proxy", "use_proxy"): "USE_PROXY", + ("search", "searxng_url"): "SEARXNG_URL", + ("skills", "hot_reload"): "SKILLS_HOT_RELOAD", + ("skills", "intro_hash_path"): "AGENT_INTRO_HASH_PATH", + ("skills", "prefetch_tools_hide"): "PREFETCH_TOOLS_HIDE", + ("token_usage", "max_archives"): "TOKEN_USAGE_MAX_ARCHIVES", + ("token_usage", "max_size_mb"): "TOKEN_USAGE_MAX_SIZE_MB", + ("token_usage", "max_total_mb"): "TOKEN_USAGE_MAX_TOTAL_MB", + ("tools", "description_max_len"): "TOOLS_DESCRIPTION_MAX_LEN", + ("tools", "dot_delimiter"): "TOOLS_DOT_DELIMITER", + ("tools", "sanitize_verbose"): "TOOLS_SANITIZE_VERBOSE", + ("weather", "api_key"): "WEATHER_API_KEY", + ("xxapi", "api_token"): "XXAPI_API_TOKEN", +} + +# 历史/别名环境变量:不经过 _get_value 统一路径,单独在 domain_parsers 等处读取 +ENV_ALTERNATES: Final[dict[str, tuple[str, ...]]] = { + "EASTER_EGG_AGENT_CALL_MESSAGE_MODE": ("easter_egg", "agent_call_message_enabled"), + "EASTER_EGG_CALL_MESSAGE_MODE": ("easter_egg", "agent_call_message_enabled"), + "HTTP_PROXY": ("proxy", "http_proxy"), + "HTTPS_PROXY": ("proxy", "https_proxy"), +} + + +def env_key_for_path(path: tuple[str, ...]) -> str | None: + """Return primary env var for a TOML path, if registered.""" + return ENV_REGISTRY.get(path) + + +def all_env_mappings() -> dict[tuple[str, ...], str]: + """Return a copy of the primary env registry.""" + return dict(ENV_REGISTRY) diff --git a/src/Undefined/config/load_sections/__init__.py b/src/Undefined/config/load_sections/__init__.py new file mode 100644 index 00000000..c195d38c --- /dev/null +++ b/src/Undefined/config/load_sections/__init__.py @@ -0,0 +1,26 @@ +"""Config load section parsers.""" + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict +from .access import load_access +from .core import load_core +from .domains import load_domains +from .finalize import load_finalize +from .history_skills import load_history_skills +from .integrations import load_integrations +from .knowledge import load_knowledge +from .logging_tools import load_logging_tools +from .models import load_models +from .network import load_network + +__all__ = [ + "load_access", + "load_core", + "load_domains", + "load_finalize", + "load_history_skills", + "load_integrations", + "load_knowledge", + "load_logging_tools", + "load_models", + "load_network", +] diff --git a/src/Undefined/config/load_sections/access.py b/src/Undefined/config/load_sections/access.py new file mode 100644 index 00000000..e3ea89c3 --- /dev/null +++ b/src/Undefined/config/load_sections/access.py @@ -0,0 +1,86 @@ +"""Load access config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_int_list, + _coerce_str, + _get_value, +) + +logger = logging.getLogger(__name__) + + +def load_access( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + access_mode_raw = _get_value(data, ("access", "mode"), "ACCESS_MODE") + allowed_group_ids = _coerce_int_list( + _get_value(data, ("access", "allowed_group_ids"), "ALLOWED_GROUP_IDS") + ) + blocked_group_ids = _coerce_int_list( + _get_value(data, ("access", "blocked_group_ids"), "BLOCKED_GROUP_IDS") + ) + allowed_private_ids = _coerce_int_list( + _get_value(data, ("access", "allowed_private_ids"), "ALLOWED_PRIVATE_IDS") + ) + blocked_private_ids = _coerce_int_list( + _get_value(data, ("access", "blocked_private_ids"), "BLOCKED_PRIVATE_IDS") + ) + superadmin_bypass_allowlist = _coerce_bool( + _get_value( + data, + ("access", "superadmin_bypass_allowlist"), + "SUPERADMIN_BYPASS_ALLOWLIST", + ), + True, + ) + superadmin_bypass_private_blacklist = _coerce_bool( + _get_value( + data, + ("access", "superadmin_bypass_private_blacklist"), + "SUPERADMIN_BYPASS_PRIVATE_BLACKLIST", + ), + False, + ) + if access_mode_raw is None: + # 兼容旧配置:未配置 mode 时沿用历史行为(群黑名单 + 白名单联动)。 + if ( + allowed_group_ids + or blocked_group_ids + or allowed_private_ids + or blocked_private_ids + ): + access_mode = "legacy" + logger.warning( + "[配置] access.mode 未设置,已启用兼容模式(legacy)。建议显式设置为 off/blacklist/allowlist。" + ) + # 否则分支 + else: + access_mode = "off" + # 否则分支 + else: + access_mode = _coerce_str(access_mode_raw, "off").lower() + if access_mode not in {"off", "blacklist", "allowlist"}: + logger.warning( + "[配置] access.mode 非法(仅支持 off/blacklist/allowlist),已回退为 off: %s", + access_mode, + ) + access_mode = "off" + + return { + "allowed_group_ids": allowed_group_ids, + "blocked_group_ids": blocked_group_ids, + "allowed_private_ids": allowed_private_ids, + "blocked_private_ids": blocked_private_ids, + "superadmin_bypass_allowlist": superadmin_bypass_allowlist, + "superadmin_bypass_private_blacklist": superadmin_bypass_private_blacklist, + "access_mode": access_mode, + } diff --git a/src/Undefined/config/load_sections/core.py b/src/Undefined/config/load_sections/core.py new file mode 100644 index 00000000..28a8f832 --- /dev/null +++ b/src/Undefined/config/load_sections/core.py @@ -0,0 +1,183 @@ +"""Load core config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_int, + _coerce_int_list, + _coerce_str, + _get_value, +) + +logger = logging.getLogger(__name__) + + +# 加载 [core] table:机器人身份、消息开关、彩蛋、OneBot 连接 +def load_core( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + # [core] 机器人 QQ 与管理员 + bot_qq = _coerce_int(_get_value(data, ("core", "bot_qq"), "BOT_QQ"), 0) + superadmin_qq = _coerce_int( + _get_value(data, ("core", "superadmin_qq"), "SUPERADMIN_QQ"), 0 + ) + admin_qqs = _coerce_int_list(_get_value(data, ("core", "admin_qq"), "ADMIN_QQ")) + forward_proxy = _coerce_int( + _get_value(data, ("core", "forward_proxy_qq"), "FORWARD_PROXY_QQ"), + 0, + ) + forward_proxy_qq = forward_proxy if forward_proxy > 0 else None + # [core] 群聊/私聊/拍一拍处理开关 + process_every_message = _coerce_bool( + _get_value( + data, + ("core", "process_every_message"), + "PROCESS_EVERY_MESSAGE", + ), + True, + ) + process_private_message = _coerce_bool( + _get_value( + data, + ("core", "process_private_message"), + "PROCESS_PRIVATE_MESSAGE", + ), + True, + ) + process_poke_message = _coerce_bool( + _get_value( + data, + ("core", "process_poke_message"), + "PROCESS_POKE_MESSAGE", + ), + True, + ) + # [easter_egg] 关键词回复与复读彩蛋 + keyword_reply_raw = _get_value( + data, + ("easter_egg", "keyword_reply_enabled"), + "KEYWORD_REPLY_ENABLED", + ) + if keyword_reply_raw is None: + # 兼容旧配置:历史上放在 [core].keyword_reply_enabled + keyword_reply_raw = _get_value( + data, + ("core", "keyword_reply_enabled"), + None, + ) + keyword_reply_enabled = _coerce_bool(keyword_reply_raw, False) + repeat_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "repeat_enabled"), + "EASTER_EGG_REPEAT_ENABLED", + ), + False, + ) + inverted_question_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "inverted_question_enabled"), + "EASTER_EGG_INVERTED_QUESTION_ENABLED", + ), + False, + ) + repeat_threshold = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_threshold"), + "EASTER_EGG_REPEAT_THRESHOLD", + ), + 3, + ) + if repeat_threshold < 2: + repeat_threshold = 2 + if repeat_threshold > 20: + repeat_threshold = 20 + repeat_cooldown_minutes = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_cooldown_minutes"), + "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", + ), + 60, + ) + if repeat_cooldown_minutes < 0: + repeat_cooldown_minutes = 0 + context_recent_messages_limit = _coerce_int( + _get_value( + data, + ("core", "context_recent_messages_limit"), + "CONTEXT_RECENT_MESSAGES_LIMIT", + ), + 20, + ) + if context_recent_messages_limit < 0: + context_recent_messages_limit = 0 + + # [core] AI 请求与 tool_call 重试上限 + ai_request_max_retries = _coerce_int( + _get_value( + data, + ("core", "ai_request_max_retries"), + "AI_REQUEST_MAX_RETRIES", + ), + 2, + ) + if ai_request_max_retries < 0: + ai_request_max_retries = 0 + + missing_tool_call_retries = _coerce_int( + _get_value( + data, + ("core", "missing_tool_call_retries"), + "MISSING_TOOL_CALL_RETRIES", + ), + 3, + ) + if missing_tool_call_retries < 0: + missing_tool_call_retries = 0 + + nagaagent_mode_enabled = _coerce_bool( + _get_value( + data, + ("features", "nagaagent_mode_enabled"), + "NAGAAGENT_MODE_ENABLED", + ), + False, + ) + # [onebot] WebSocket 连接 + onebot_ws_url = _coerce_str( + _get_value(data, ("onebot", "ws_url"), "ONEBOT_WS_URL"), "" + ) + onebot_token = _coerce_str( + _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" + ) + + return { + "bot_qq": bot_qq, + "superadmin_qq": superadmin_qq, + "admin_qqs": admin_qqs, + "forward_proxy_qq": forward_proxy_qq, + "process_every_message": process_every_message, + "process_private_message": process_private_message, + "process_poke_message": process_poke_message, + "keyword_reply_enabled": keyword_reply_enabled, + "repeat_enabled": repeat_enabled, + "inverted_question_enabled": inverted_question_enabled, + "repeat_threshold": repeat_threshold, + "repeat_cooldown_minutes": repeat_cooldown_minutes, + "context_recent_messages_limit": context_recent_messages_limit, + "ai_request_max_retries": ai_request_max_retries, + "missing_tool_call_retries": missing_tool_call_retries, + "nagaagent_mode_enabled": nagaagent_mode_enabled, + "onebot_ws_url": onebot_ws_url, + "onebot_token": onebot_token, + } diff --git a/src/Undefined/config/load_sections/domains.py b/src/Undefined/config/load_sections/domains.py new file mode 100644 index 00000000..6ada6453 --- /dev/null +++ b/src/Undefined/config/load_sections/domains.py @@ -0,0 +1,58 @@ +"""Load domains config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..domain_parsers import ( + _parse_api_config, + _parse_cognitive_config, + _parse_memes_config, + _parse_message_batcher_config, + _parse_naga_config, + _parse_render_cache_config, +) +from ..model_parsers import ( + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, +) +from ..webui_settings import load_webui_settings + +logger = logging.getLogger(__name__) + + +def load_domains( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + # 子域配置:WebUI/API/认知/表情包/合并器/Naga/生图等,与 core/models 段解耦 + webui_settings = load_webui_settings(config_path) + api_config = _parse_api_config(data) + + cognitive = _parse_cognitive_config(data) + memes = _parse_memes_config(data) + message_batcher = _parse_message_batcher_config(data) + render_cache = _parse_render_cache_config(data) + naga = _parse_naga_config(data) + models_image_gen = _parse_image_gen_model_config(data) + models_image_edit = _parse_image_edit_model_config(data) + image_gen = _parse_image_gen_config(data) + + return { + "webui_url": webui_settings.url, + "webui_port": webui_settings.port, + "webui_password": webui_settings.password, + "api": api_config, + "cognitive": cognitive, + "memes": memes, + "message_batcher": message_batcher, + "render_cache": render_cache, + "naga": naga, + "image_gen": image_gen, + "models_image_gen": models_image_gen, + "models_image_edit": models_image_edit, + } diff --git a/src/Undefined/config/load_sections/finalize.py b/src/Undefined/config/load_sections/finalize.py new file mode 100644 index 00000000..5eda3191 --- /dev/null +++ b/src/Undefined/config/load_sections/finalize.py @@ -0,0 +1,34 @@ +"""Finalize: validation and debug logging.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +from typing import Any + +from ..model_parsers import _log_debug_info, _verify_required_fields + + +def load_finalize(ctx: dict[str, Any], *, strict: bool = True) -> None: + # strict=True(首次启动)校验必填;热重载传 strict=False 跳过以免半写文件误杀 + if strict: + _verify_required_fields( + bot_qq=ctx["bot_qq"], + superadmin_qq=ctx["superadmin_qq"], + onebot_ws_url=ctx["onebot_ws_url"], + chat_model=ctx["chat_model"], + vision_model=ctx["vision_model"], + agent_model=ctx["agent_model"], + knowledge_enabled=ctx["knowledge_enabled"], + embedding_model=ctx["embedding_model"], + ) + + _log_debug_info( + ctx["chat_model"], + ctx["vision_model"], + ctx["security_model"], + ctx["naga_model"], + ctx["agent_model"], + ctx["summary_model"], + ctx["grok_model"], + ) diff --git a/src/Undefined/config/load_sections/history_skills.py b/src/Undefined/config/load_sections/history_skills.py new file mode 100644 index 00000000..4c7d9788 --- /dev/null +++ b/src/Undefined/config/load_sections/history_skills.py @@ -0,0 +1,275 @@ +"""Load history_skills config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _coerce_str_list, + _get_value, + _normalize_queue_interval, +) + +logger = logging.getLogger(__name__) + + +def load_history_skills( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + token_usage_max_size_mb = _coerce_int( + _get_value(data, ("token_usage", "max_size_mb"), "TOKEN_USAGE_MAX_SIZE_MB"), + 5, + ) + token_usage_max_archives = _coerce_int( + _get_value(data, ("token_usage", "max_archives"), "TOKEN_USAGE_MAX_ARCHIVES"), + 30, + ) + token_usage_max_total_mb = _coerce_int( + _get_value(data, ("token_usage", "max_total_mb"), "TOKEN_USAGE_MAX_TOTAL_MB"), + 0, + ) + token_usage_archive_prune_mode = _coerce_str( + _get_value( + data, + ("token_usage", "archive_prune_mode"), + "TOKEN_USAGE_ARCHIVE_PRUNE_MODE", + ), + "delete", + ) + + history_max_records = max( + 0, + _coerce_int( + _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), + 10000, + ), + ) + history_filtered_result_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "filtered_result_limit"), + "HISTORY_FILTERED_RESULT_LIMIT", + ), + 200, + ), + ) + history_search_scan_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "search_scan_limit"), + "HISTORY_SEARCH_SCAN_LIMIT", + ), + 10000, + ), + ) + history_summary_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_fetch_limit"), + "HISTORY_SUMMARY_FETCH_LIMIT", + ), + 1000, + ), + ) + history_summary_time_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_time_fetch_limit"), + "HISTORY_SUMMARY_TIME_FETCH_LIMIT", + ), + 5000, + ), + ) + history_onebot_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "onebot_fetch_limit"), + "HISTORY_ONEBOT_FETCH_LIMIT", + ), + 10000, + ), + ) + history_group_analysis_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "group_analysis_limit"), + "HISTORY_GROUP_ANALYSIS_LIMIT", + ), + 500, + ), + ) + attachment_remote_download_max_size_mb = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "remote_download_max_size_mb"), + "ATTACHMENTS_REMOTE_DOWNLOAD_MAX_SIZE_MB", + ), + 25, + ), + ) + attachment_cache_max_total_size_mb = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "cache_max_total_size_mb"), + "ATTACHMENTS_CACHE_MAX_TOTAL_SIZE_MB", + ), + 0, + ), + ) + attachment_cache_max_records = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "cache_max_records"), + "ATTACHMENTS_CACHE_MAX_RECORDS", + ), + 2000, + ), + ) + attachment_cache_max_age_days = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "cache_max_age_days"), + "ATTACHMENTS_CACHE_MAX_AGE_DAYS", + ), + 7, + ), + ) + attachment_url_reference_max_records = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "url_reference_max_records"), + "ATTACHMENTS_URL_REFERENCE_MAX_RECORDS", + ), + 2000, + ), + ) + attachment_url_max_length = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "url_max_length"), + "ATTACHMENTS_URL_MAX_LENGTH", + ), + 8192, + ), + ) + + skills_hot_reload = _coerce_bool( + _get_value(data, ("skills", "hot_reload"), "SKILLS_HOT_RELOAD"), True + ) + # interval/debounce 同时驱动 skills 目录扫描与 config.toml 热重载 watcher + skills_hot_reload_interval = _coerce_float( + _get_value( + data, ("skills", "hot_reload_interval"), "SKILLS_HOT_RELOAD_INTERVAL" + ), + 2.0, + ) + skills_hot_reload_debounce = _coerce_float( + _get_value( + data, ("skills", "hot_reload_debounce"), "SKILLS_HOT_RELOAD_DEBOUNCE" + ), + 0.5, + ) + + agent_intro_autogen_enabled = _coerce_bool( + _get_value( + data, + ("skills", "intro_autogen_enabled"), + "AGENT_INTRO_AUTOGEN_ENABLED", + ), + True, + ) + agent_intro_autogen_queue_interval = _coerce_float( + _get_value( + data, + ("skills", "intro_autogen_queue_interval"), + "AGENT_INTRO_AUTOGEN_QUEUE_INTERVAL", + ), + 1.0, + ) + agent_intro_autogen_queue_interval = _normalize_queue_interval( + agent_intro_autogen_queue_interval + ) + agent_intro_autogen_max_tokens = _coerce_int( + _get_value( + data, + ("skills", "intro_autogen_max_tokens"), + "AGENT_INTRO_AUTOGEN_MAX_TOKENS", + ), + 8192, + ) + agent_intro_hash_path = _coerce_str( + _get_value(data, ("skills", "intro_hash_path"), "AGENT_INTRO_HASH_PATH"), + ".cache/agent_intro_hashes.json", + ) + + prefetch_tools_raw = _get_value( + data, ("skills", "prefetch_tools"), "PREFETCH_TOOLS" + ) + prefetch_tools = _coerce_str_list(prefetch_tools_raw) + if not prefetch_tools and prefetch_tools_raw is None: + prefetch_tools = ["get_current_time"] + prefetch_tools_hide = _coerce_bool( + _get_value(data, ("skills", "prefetch_tools_hide"), "PREFETCH_TOOLS_HIDE"), + True, + ) + + return { + "token_usage_max_size_mb": token_usage_max_size_mb, + "token_usage_max_archives": token_usage_max_archives, + "token_usage_max_total_mb": token_usage_max_total_mb, + "token_usage_archive_prune_mode": token_usage_archive_prune_mode, + "history_max_records": history_max_records, + "history_filtered_result_limit": history_filtered_result_limit, + "history_search_scan_limit": history_search_scan_limit, + "history_summary_fetch_limit": history_summary_fetch_limit, + "history_summary_time_fetch_limit": history_summary_time_fetch_limit, + "history_onebot_fetch_limit": history_onebot_fetch_limit, + "history_group_analysis_limit": history_group_analysis_limit, + "attachment_remote_download_max_size_mb": attachment_remote_download_max_size_mb, + "attachment_cache_max_total_size_mb": attachment_cache_max_total_size_mb, + "attachment_cache_max_records": attachment_cache_max_records, + "attachment_cache_max_age_days": attachment_cache_max_age_days, + "attachment_url_reference_max_records": attachment_url_reference_max_records, + "attachment_url_max_length": attachment_url_max_length, + "skills_hot_reload": skills_hot_reload, + "skills_hot_reload_interval": skills_hot_reload_interval, + "skills_hot_reload_debounce": skills_hot_reload_debounce, + "agent_intro_autogen_enabled": agent_intro_autogen_enabled, + "agent_intro_autogen_queue_interval": agent_intro_autogen_queue_interval, + "agent_intro_autogen_max_tokens": agent_intro_autogen_max_tokens, + "agent_intro_hash_path": agent_intro_hash_path, + "prefetch_tools": prefetch_tools, + "prefetch_tools_hide": prefetch_tools_hide, + } diff --git a/src/Undefined/config/load_sections/integrations.py b/src/Undefined/config/load_sections/integrations.py new file mode 100644 index 00000000..bd98758c --- /dev/null +++ b/src/Undefined/config/load_sections/integrations.py @@ -0,0 +1,267 @@ +"""Load integrations config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_int_list, + _coerce_str, + _get_value, +) + +logger = logging.getLogger(__name__) + + +def load_integrations( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + bilibili_auto_extract_enabled = _coerce_bool( + _get_value(data, ("bilibili", "auto_extract_enabled"), None), False + ) + # 功能级白名单为空时,运行时回退到全局 access 控制(见 Config.is_*_allowed) + bilibili_cookie = _coerce_str(_get_value(data, ("bilibili", "cookie"), None), "") + if not bilibili_cookie: + # 兼容旧配置项:bilibili.sessdata + bilibili_cookie = _coerce_str( + _get_value(data, ("bilibili", "sessdata"), None), "" + ) + bilibili_prefer_quality = _coerce_int( + _get_value(data, ("bilibili", "prefer_quality"), None), 80 + ) + bilibili_max_duration = _coerce_int( + _get_value(data, ("bilibili", "max_duration"), None), 600 + ) + bilibili_max_file_size = _coerce_int( + _get_value(data, ("bilibili", "max_file_size"), None), 100 + ) + bilibili_oversize_strategy = _coerce_str( + _get_value(data, ("bilibili", "oversize_strategy"), None), "downgrade" + ) + if bilibili_oversize_strategy not in ("downgrade", "info"): + bilibili_oversize_strategy = "downgrade" + bilibili_danmaku_enabled = _coerce_bool( + _get_value(data, ("bilibili", "danmaku_enabled"), None), True + ) + bilibili_danmaku_batch_size = _coerce_int( + _get_value(data, ("bilibili", "danmaku_batch_size"), None), 100 + ) + if bilibili_danmaku_batch_size <= 0: + bilibili_danmaku_batch_size = 100 + bilibili_danmaku_max_count = _coerce_int( + _get_value(data, ("bilibili", "danmaku_max_count"), None), 0 + ) + if bilibili_danmaku_max_count < 0: + bilibili_danmaku_max_count = 0 + bilibili_auto_extract_group_ids = _coerce_int_list( + _get_value(data, ("bilibili", "auto_extract_group_ids"), None) + ) + bilibili_auto_extract_private_ids = _coerce_int_list( + _get_value(data, ("bilibili", "auto_extract_private_ids"), None) + ) + + # arXiv 配置 + arxiv_auto_extract_enabled = _coerce_bool( + _get_value(data, ("arxiv", "auto_extract_enabled"), None), False + ) + arxiv_max_file_size = _coerce_int( + _get_value(data, ("arxiv", "max_file_size"), None), 100 + ) + if arxiv_max_file_size < 0: + arxiv_max_file_size = 100 + arxiv_auto_extract_group_ids = _coerce_int_list( + _get_value(data, ("arxiv", "auto_extract_group_ids"), None) + ) + arxiv_auto_extract_private_ids = _coerce_int_list( + _get_value(data, ("arxiv", "auto_extract_private_ids"), None) + ) + arxiv_auto_extract_max_items = _coerce_int( + _get_value(data, ("arxiv", "auto_extract_max_items"), None), 5 + ) + if arxiv_auto_extract_max_items <= 0: + arxiv_auto_extract_max_items = 5 + if arxiv_auto_extract_max_items > 20: + arxiv_auto_extract_max_items = 20 + arxiv_author_preview_limit = _coerce_int( + _get_value(data, ("arxiv", "author_preview_limit"), None), 20 + ) + if arxiv_author_preview_limit <= 0: + arxiv_author_preview_limit = 20 + if arxiv_author_preview_limit > 100: + arxiv_author_preview_limit = 100 + arxiv_summary_preview_chars = _coerce_int( + _get_value(data, ("arxiv", "summary_preview_chars"), None), 1000 + ) + if arxiv_summary_preview_chars <= 0: + arxiv_summary_preview_chars = 1000 + if arxiv_summary_preview_chars > 8000: + arxiv_summary_preview_chars = 8000 + + # GitHub 配置 + github_auto_extract_enabled = _coerce_bool( + _get_value(data, ("github", "auto_extract_enabled"), None), False + ) + github_request_timeout_seconds = _coerce_float( + _get_value(data, ("github", "request_timeout_seconds"), None), 10.0 + ) + if github_request_timeout_seconds <= 0: + github_request_timeout_seconds = 10.0 + if github_request_timeout_seconds > 60.0: + github_request_timeout_seconds = 60.0 + github_auto_extract_group_ids = _coerce_int_list( + _get_value(data, ("github", "auto_extract_group_ids"), None) + ) + github_auto_extract_private_ids = _coerce_int_list( + _get_value(data, ("github", "auto_extract_private_ids"), None) + ) + github_auto_extract_max_items = _coerce_int( + _get_value(data, ("github", "auto_extract_max_items"), None), 3 + ) + if github_auto_extract_max_items <= 0: + github_auto_extract_max_items = 3 + if github_auto_extract_max_items > 10: + github_auto_extract_max_items = 10 + + # Code Delivery Agent 配置 + code_delivery_enabled = _coerce_bool( + _get_value(data, ("code_delivery", "enabled"), None), True + ) + code_delivery_task_root = _coerce_str( + _get_value(data, ("code_delivery", "task_root"), None), + "data/code_delivery", + ) + code_delivery_docker_image = _coerce_str( + _get_value(data, ("code_delivery", "docker_image"), None), + "ubuntu:24.04", + ) + code_delivery_container_name_prefix = _coerce_str( + _get_value(data, ("code_delivery", "container_name_prefix"), None), + "code_delivery_", + ) + code_delivery_container_name_suffix = _coerce_str( + _get_value(data, ("code_delivery", "container_name_suffix"), None), + "_runner", + ) + code_delivery_command_timeout = _coerce_int( + _get_value(data, ("code_delivery", "default_command_timeout_seconds"), None), + 600, + ) + code_delivery_max_command_output = _coerce_int( + _get_value(data, ("code_delivery", "max_command_output_chars"), None), + 20000, + ) + code_delivery_default_archive_format = _coerce_str( + _get_value(data, ("code_delivery", "default_archive_format"), None), + "zip", + ) + if code_delivery_default_archive_format not in ("zip", "tar.gz"): + code_delivery_default_archive_format = "zip" + code_delivery_max_archive_size_mb = _coerce_int( + _get_value(data, ("code_delivery", "max_archive_size_mb"), None), 200 + ) + code_delivery_cleanup_on_finish = _coerce_bool( + _get_value(data, ("code_delivery", "cleanup_on_finish"), None), True + ) + code_delivery_cleanup_on_start = _coerce_bool( + _get_value(data, ("code_delivery", "cleanup_on_start"), None), True + ) + code_delivery_llm_max_retries = _coerce_int( + _get_value(data, ("code_delivery", "llm_max_retries_per_request"), None), + 5, + ) + code_delivery_notify_on_llm_failure = _coerce_bool( + _get_value(data, ("code_delivery", "notify_on_llm_failure"), None), + True, + ) + code_delivery_container_memory_limit = _coerce_str( + _get_value(data, ("code_delivery", "container_memory_limit"), None), + "", + ) + code_delivery_container_cpu_limit = _coerce_str( + _get_value(data, ("code_delivery", "container_cpu_limit"), None), + "", + ) + code_delivery_command_blacklist_raw = _get_value( + data, ("code_delivery", "command_blacklist"), None + ) + if isinstance(code_delivery_command_blacklist_raw, list): + code_delivery_command_blacklist = [ + str(x) for x in code_delivery_command_blacklist_raw + ] + # 否则分支 + else: + code_delivery_command_blacklist = [] + + # messages 工具集配置 + messages_send_text_file_max_size_kb = _coerce_int( + _get_value( + data, + ("messages", "send_text_file_max_size_kb"), + "MESSAGES_SEND_TEXT_FILE_MAX_SIZE_KB", + ), + 512, + ) + if messages_send_text_file_max_size_kb <= 0: + messages_send_text_file_max_size_kb = 512 + + messages_send_url_file_max_size_mb = _coerce_int( + _get_value( + data, + ("messages", "send_url_file_max_size_mb"), + "MESSAGES_SEND_URL_FILE_MAX_SIZE_MB", + ), + 100, + ) + if messages_send_url_file_max_size_mb <= 0: + messages_send_url_file_max_size_mb = 100 + + return { + "bilibili_auto_extract_enabled": bilibili_auto_extract_enabled, + "bilibili_cookie": bilibili_cookie, + "bilibili_prefer_quality": bilibili_prefer_quality, + "bilibili_max_duration": bilibili_max_duration, + "bilibili_max_file_size": bilibili_max_file_size, + "bilibili_oversize_strategy": bilibili_oversize_strategy, + "bilibili_danmaku_enabled": bilibili_danmaku_enabled, + "bilibili_danmaku_batch_size": bilibili_danmaku_batch_size, + "bilibili_danmaku_max_count": bilibili_danmaku_max_count, + "bilibili_auto_extract_group_ids": bilibili_auto_extract_group_ids, + "bilibili_auto_extract_private_ids": bilibili_auto_extract_private_ids, + "arxiv_auto_extract_enabled": arxiv_auto_extract_enabled, + "arxiv_max_file_size": arxiv_max_file_size, + "arxiv_auto_extract_group_ids": arxiv_auto_extract_group_ids, + "arxiv_auto_extract_private_ids": arxiv_auto_extract_private_ids, + "arxiv_auto_extract_max_items": arxiv_auto_extract_max_items, + "arxiv_author_preview_limit": arxiv_author_preview_limit, + "arxiv_summary_preview_chars": arxiv_summary_preview_chars, + "github_auto_extract_enabled": github_auto_extract_enabled, + "github_request_timeout_seconds": github_request_timeout_seconds, + "github_auto_extract_group_ids": github_auto_extract_group_ids, + "github_auto_extract_private_ids": github_auto_extract_private_ids, + "github_auto_extract_max_items": github_auto_extract_max_items, + "code_delivery_enabled": code_delivery_enabled, + "code_delivery_task_root": code_delivery_task_root, + "code_delivery_docker_image": code_delivery_docker_image, + "code_delivery_container_name_prefix": code_delivery_container_name_prefix, + "code_delivery_container_name_suffix": code_delivery_container_name_suffix, + "code_delivery_command_timeout": code_delivery_command_timeout, + "code_delivery_max_command_output": code_delivery_max_command_output, + "code_delivery_default_archive_format": code_delivery_default_archive_format, + "code_delivery_max_archive_size_mb": code_delivery_max_archive_size_mb, + "code_delivery_cleanup_on_finish": code_delivery_cleanup_on_finish, + "code_delivery_cleanup_on_start": code_delivery_cleanup_on_start, + "code_delivery_llm_max_retries": code_delivery_llm_max_retries, + "code_delivery_notify_on_llm_failure": code_delivery_notify_on_llm_failure, + "code_delivery_container_memory_limit": code_delivery_container_memory_limit, + "code_delivery_container_cpu_limit": code_delivery_container_cpu_limit, + "code_delivery_command_blacklist": code_delivery_command_blacklist, + "messages_send_text_file_max_size_kb": messages_send_text_file_max_size_kb, + "messages_send_url_file_max_size_mb": messages_send_url_file_max_size_mb, + } diff --git a/src/Undefined/config/load_sections/knowledge.py b/src/Undefined/config/load_sections/knowledge.py new file mode 100644 index 00000000..9e5406bf --- /dev/null +++ b/src/Undefined/config/load_sections/knowledge.py @@ -0,0 +1,120 @@ +"""Load knowledge config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, +) +from ..model_parsers import ( + _parse_embedding_model_config, + _parse_rerank_model_config, +) + +logger = logging.getLogger(__name__) + + +def load_knowledge( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + # 知识库段多数项仅读 TOML(env_key=None),避免与 embedding 模型 env 混淆 + embedding_model = _parse_embedding_model_config(data) + rerank_model = _parse_rerank_model_config(data) + + knowledge_enabled = _coerce_bool( + _get_value(data, ("knowledge", "enabled"), None), False + ) + knowledge_base_dir = _coerce_str( + _get_value(data, ("knowledge", "base_dir"), None), "knowledge" + ) + knowledge_auto_scan = _coerce_bool( + _get_value(data, ("knowledge", "auto_scan"), None), False + ) + knowledge_auto_embed = _coerce_bool( + _get_value(data, ("knowledge", "auto_embed"), None), False + ) + knowledge_scan_interval = _coerce_float( + _get_value(data, ("knowledge", "scan_interval"), None), 60.0 + ) + if knowledge_scan_interval <= 0: + knowledge_scan_interval = 60.0 + knowledge_embed_batch_size = _coerce_int( + _get_value(data, ("knowledge", "embed_batch_size"), None), 64 + ) + if knowledge_embed_batch_size <= 0: + knowledge_embed_batch_size = 64 + knowledge_chunk_size = _coerce_int( + _get_value(data, ("knowledge", "chunk_size"), None), 10 + ) + if knowledge_chunk_size <= 0: + knowledge_chunk_size = 10 + knowledge_chunk_overlap = _coerce_int( + _get_value(data, ("knowledge", "chunk_overlap"), None), 2 + ) + if knowledge_chunk_overlap < 0: + knowledge_chunk_overlap = 0 + knowledge_default_top_k = _coerce_int( + _get_value(data, ("knowledge", "default_top_k"), None), 5 + ) + if knowledge_default_top_k <= 0: + knowledge_default_top_k = 5 + knowledge_enable_rerank = _coerce_bool( + _get_value(data, ("knowledge", "enable_rerank"), None), False + ) + knowledge_rerank_top_k = _coerce_int( + _get_value(data, ("knowledge", "rerank_top_k"), None), 3 + ) + if knowledge_rerank_top_k <= 0: + knowledge_rerank_top_k = 3 + if knowledge_default_top_k <= 1 and knowledge_enable_rerank: + logger.warning( + "[配置] knowledge.default_top_k=%s,无法满足 rerank_top_k < default_top_k," + "已自动禁用重排", + knowledge_default_top_k, + ) + knowledge_enable_rerank = False + if knowledge_rerank_top_k >= knowledge_default_top_k: + fallback = knowledge_default_top_k - 1 + if fallback <= 0: + fallback = 1 + knowledge_enable_rerank = False + logger.warning( + "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," + "且当前 default_top_k=%s 无法满足约束,已自动禁用重排", + knowledge_default_top_k, + ) + # 否则分支 + else: + logger.warning( + "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," + "已回退: rerank_top_k=%s -> %s (default_top_k=%s)", + knowledge_rerank_top_k, + fallback, + knowledge_default_top_k, + ) + knowledge_rerank_top_k = fallback + + return { + "embedding_model": embedding_model, + "rerank_model": rerank_model, + "knowledge_enabled": knowledge_enabled, + "knowledge_base_dir": knowledge_base_dir, + "knowledge_auto_scan": knowledge_auto_scan, + "knowledge_auto_embed": knowledge_auto_embed, + "knowledge_scan_interval": knowledge_scan_interval, + "knowledge_embed_batch_size": knowledge_embed_batch_size, + "knowledge_chunk_size": knowledge_chunk_size, + "knowledge_chunk_overlap": knowledge_chunk_overlap, + "knowledge_default_top_k": knowledge_default_top_k, + "knowledge_enable_rerank": knowledge_enable_rerank, + "knowledge_rerank_top_k": knowledge_rerank_top_k, + } diff --git a/src/Undefined/config/load_sections/logging_tools.py b/src/Undefined/config/load_sections/logging_tools.py new file mode 100644 index 00000000..65a3347a --- /dev/null +++ b/src/Undefined/config/load_sections/logging_tools.py @@ -0,0 +1,122 @@ +"""Load logging_tools config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +import os +import re +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_int, + _coerce_str, + _get_value, + _warn_env_fallback, +) +from ..domain_parsers import ( + _parse_easter_egg_call_mode, +) + +logger = logging.getLogger(__name__) + + +def load_logging_tools( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + log_level = _coerce_str( + _get_value(data, ("logging", "level"), "LOG_LEVEL"), "INFO" + ).upper() + log_file_path = _coerce_str( + _get_value(data, ("logging", "file_path"), "LOG_FILE_PATH"), + "logs/bot.log", + ) + log_max_size_mb = _coerce_int( + _get_value(data, ("logging", "max_size_mb"), "LOG_MAX_SIZE_MB"), 10 + ) + log_backup_count = _coerce_int( + _get_value(data, ("logging", "backup_count"), "LOG_BACKUP_COUNT"), 5 + ) + log_tty_enabled = _coerce_bool( + _get_value(data, ("logging", "tty_enabled"), "LOG_TTY_ENABLED"), + False, + ) + log_thinking = _coerce_bool( + _get_value(data, ("logging", "log_thinking"), "LOG_THINKING"), True + ) + + tools_dot_delimiter = _coerce_str( + _get_value(data, ("tools", "dot_delimiter"), "TOOLS_DOT_DELIMITER"), "-_-" + ).strip() + if not tools_dot_delimiter: + tools_dot_delimiter = "-_-" + # dot_delimiter 必须满足 OpenAI 兼容的 function.name 约束。 + if "." in tools_dot_delimiter or not re.fullmatch( + r"[a-zA-Z0-9_-]+", tools_dot_delimiter + ): + logger.warning( + "[配置] tools.dot_delimiter 非法(仅允许 [a-zA-Z0-9_-] 且不能包含 '.'),已回退默认值: '-_-'(当前=%s)", + tools_dot_delimiter, + ) + tools_dot_delimiter = "-_-" + tools_description_max_len = _coerce_int( + _get_value(data, ("tools", "description_max_len"), "TOOLS_DESCRIPTION_MAX_LEN"), + 1024, + ) + tools_description_truncate_enabled = _coerce_bool( + _get_value( + data, + ("tools", "description_truncate_enabled"), + "TOOLS_DESCRIPTION_TRUNCATE_ENABLED", + ), + False, + ) + tools_sanitize_verbose = _coerce_bool( + _get_value(data, ("tools", "sanitize_verbose"), "TOOLS_SANITIZE_VERBOSE"), + False, + ) + tools_description_preview_len = _coerce_int( + _get_value( + data, + ("tools", "description_preview_len"), + "TOOLS_DESCRIPTION_PREVIEW_LEN", + ), + 160, + ) + + easter_egg_mode_raw = _get_value( + data, + ("easter_egg", "agent_call_message_enabled"), + "EASTER_EGG_AGENT_CALL_MESSAGE_ENABLED", + ) + if easter_egg_mode_raw is None: + easter_egg_mode_raw = os.getenv("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") + if easter_egg_mode_raw is not None: + _warn_env_fallback("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") + # 否则分支 + else: + easter_egg_mode_raw = os.getenv("EASTER_EGG_CALL_MESSAGE_MODE") + if easter_egg_mode_raw is not None: + _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") + + easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( + easter_egg_mode_raw + ) + + return { + "log_level": log_level, + "log_file_path": log_file_path, + "log_max_size": log_max_size_mb * 1024 * 1024, + "log_backup_count": log_backup_count, + "log_tty_enabled": log_tty_enabled, + "log_thinking": log_thinking, + "tools_dot_delimiter": tools_dot_delimiter, + "tools_description_max_len": tools_description_max_len, + "tools_description_truncate_enabled": tools_description_truncate_enabled, + "tools_sanitize_verbose": tools_sanitize_verbose, + "tools_description_preview_len": tools_description_preview_len, + "easter_egg_agent_call_message_mode": easter_egg_agent_call_message_mode, + } diff --git a/src/Undefined/config/load_sections/models.py b/src/Undefined/config/load_sections/models.py new file mode 100644 index 00000000..1a3fff14 --- /dev/null +++ b/src/Undefined/config/load_sections/models.py @@ -0,0 +1,68 @@ +"""Load models config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _get_value, +) +from ..model_parsers import ( + _parse_agent_model_config, + _parse_chat_model_config, + _parse_grok_model_config, + _parse_historian_model_config, + _parse_naga_model_config, + _parse_security_model_config, + _parse_summary_model_config, + _parse_vision_model_config, +) + +logger = logging.getLogger(__name__) + + +def load_models( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + chat_model = _parse_chat_model_config(data) + vision_model = _parse_vision_model_config(data) + security_model_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "enabled"), + "SECURITY_MODEL_ENABLED", + ), + True, + ) + # 未单独配置的模型段会回退到 chat/security/agent 等主模型 + security_model = _parse_security_model_config(data, chat_model) + naga_model = _parse_naga_model_config(data, security_model) + agent_model = _parse_agent_model_config(data) + historian_model = _parse_historian_model_config(data, agent_model) + summary_model, summary_model_configured = _parse_summary_model_config( + data, agent_model + ) + grok_model = _parse_grok_model_config(data) + + model_pool_enabled = _coerce_bool( + _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False + ) + + return { + "chat_model": chat_model, + "vision_model": vision_model, + "security_model_enabled": security_model_enabled, + "security_model": security_model, + "naga_model": naga_model, + "agent_model": agent_model, + "historian_model": historian_model, + "summary_model": summary_model, + "summary_model_configured": summary_model_configured, + "grok_model": grok_model, + "model_pool_enabled": model_pool_enabled, + } diff --git a/src/Undefined/config/load_sections/network.py b/src/Undefined/config/load_sections/network.py new file mode 100644 index 00000000..818b5555 --- /dev/null +++ b/src/Undefined/config/load_sections/network.py @@ -0,0 +1,161 @@ +"""Load network config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, + _normalize_base_url, + _warn_env_fallback, +) + +logger = logging.getLogger(__name__) + + +def load_network( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + searxng_url = _coerce_str( + _get_value(data, ("search", "searxng_url"), "SEARXNG_URL"), "" + ) + grok_search_enabled = _coerce_bool( + _get_value( + data, + ("search", "grok_search_enabled"), + "GROK_SEARCH_ENABLED", + ), + False, + ) + + use_proxy = _coerce_bool( + _get_value(data, ("proxy", "use_proxy"), "USE_PROXY"), True + ) + http_proxy = _coerce_str( + _get_value(data, ("proxy", "http_proxy"), "http_proxy"), "" + ) + # TOML 未配置时回退标准 HTTP_PROXY 环境变量(小写键名不走 ENV_REGISTRY) + if not http_proxy: + http_proxy = _coerce_str(os.getenv("HTTP_PROXY"), "") + if http_proxy: + _warn_env_fallback("HTTP_PROXY") + https_proxy = _coerce_str( + _get_value(data, ("proxy", "https_proxy"), "https_proxy"), "" + ) + if not https_proxy: + https_proxy = _coerce_str(os.getenv("HTTPS_PROXY"), "") + if https_proxy: + _warn_env_fallback("HTTPS_PROXY") + + network_request_timeout = _coerce_float( + _get_value( + data, + ("network", "request_timeout_seconds"), + "NETWORK_REQUEST_TIMEOUT_SECONDS", + ), + 30.0, + ) + if network_request_timeout <= 0: + network_request_timeout = 480.0 + + network_request_retries = _coerce_int( + _get_value( + data, + ("network", "request_retries"), + "NETWORK_REQUEST_RETRIES", + ), + 0, + ) + if network_request_retries < 0: + network_request_retries = 0 + if network_request_retries > 5: + network_request_retries = 5 + + render_browser_max_concurrency = max( + 0, + _coerce_int( + _get_value( + data, + ("render", "browser_max_concurrency"), + "RENDER_BROWSER_MAX_CONCURRENCY", + ), + 0, + ), + ) + + api_xxapi_base_url = _normalize_base_url( + _coerce_str( + _get_value(data, ("api_endpoints", "xxapi_base_url"), "XXAPI_BASE_URL"), + "https://v2.xxapi.cn", + ), + "https://v2.xxapi.cn", + ) + api_xingzhige_base_url = _normalize_base_url( + _coerce_str( + _get_value( + data, + ("api_endpoints", "xingzhige_base_url"), + "XINGZHIGE_BASE_URL", + ), + "https://api.xingzhige.com", + ), + "https://api.xingzhige.com", + ) + api_jkyai_base_url = _normalize_base_url( + _coerce_str( + _get_value(data, ("api_endpoints", "jkyai_base_url"), "JKYAI_BASE_URL"), + "https://api.jkyai.top", + ), + "https://api.jkyai.top", + ) + api_seniverse_base_url = _normalize_base_url( + _coerce_str( + _get_value( + data, + ("api_endpoints", "seniverse_base_url"), + "SENIVERSE_BASE_URL", + ), + "https://api.seniverse.com/v3", + ), + "https://api.seniverse.com/v3", + ) + + weather_api_key = _coerce_str( + _get_value(data, ("weather", "api_key"), "WEATHER_API_KEY"), "" + ) + xxapi_api_token = _coerce_str( + _get_value(data, ("xxapi", "api_token"), "XXAPI_API_TOKEN"), "" + ) + + mcp_config_path = _coerce_str( + _get_value(data, ("mcp", "config_path"), "MCP_CONFIG_PATH"), + "config/mcp.json", + ) + + # Bilibili 配置 + return { + "searxng_url": searxng_url, + "grok_search_enabled": grok_search_enabled, + "use_proxy": use_proxy, + "http_proxy": http_proxy, + "https_proxy": https_proxy, + "network_request_timeout": network_request_timeout, + "network_request_retries": network_request_retries, + "render_browser_max_concurrency": render_browser_max_concurrency, + "api_xxapi_base_url": api_xxapi_base_url, + "api_xingzhige_base_url": api_xingzhige_base_url, + "api_jkyai_base_url": api_jkyai_base_url, + "api_seniverse_base_url": api_seniverse_base_url, + "weather_api_key": weather_api_key, + "xxapi_api_token": xxapi_api_token, + "mcp_config_path": mcp_config_path, + } diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index dc49fa38..708da82a 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -1,119 +1,22 @@ -"""配置加载逻辑""" +"""配置加载逻辑(向后兼容 shim;Config 实现见 config_class)。""" from __future__ import annotations -import logging -import os -import re -import tomllib -from dataclasses import dataclass, field as dataclass_field, fields -from pathlib import Path -from typing import Any, Optional, IO - -try: - from dotenv import load_dotenv -except Exception: # pragma: no cover - StrPath = str | os.PathLike[str] - - def load_dotenv( - dotenv_path: StrPath | None = None, - stream: IO[str] | None = None, - verbose: bool = False, - override: bool = False, - interpolate: bool = True, - encoding: str | None = "utf-8", - ) -> bool: - return False - - -from .models import ( - APIConfig, - AgentModelConfig, - ChatModelConfig, - CognitiveConfig, - EmbeddingModelConfig, - GrokModelConfig, - ImageGenConfig, - ImageGenModelConfig, - MemeConfig, - MessageBatcherConfig, - NagaConfig, - RenderCacheConfig, - RerankModelConfig, - SecurityModelConfig, - VisionModelConfig, -) -from .coercers import ( # noqa: F401 — re-exported for backward compat - _coerce_bool, - _coerce_float, - _coerce_int, - _coerce_int_list, - _coerce_str, - _coerce_str_list, - _get_model_request_params, - _get_value, - _normalize_base_url, - _normalize_queue_interval, - _normalize_str, - _warn_env_fallback, -) -from .resolvers import ( # noqa: F401 — re-exported for backward compat - _resolve_api_mode, - _resolve_reasoning_effort, - _resolve_reasoning_effort_style, - _resolve_responses_force_stateless_replay, - _resolve_responses_tool_choice_compat, - _resolve_thinking_compat_flags, -) -from .admin import ( # noqa: F401 — re-exported for backward compat - LOCAL_CONFIG_PATH, - load_local_admins, - save_local_admins, -) -from .webui_settings import ( # noqa: F401 — re-exported for backward compat +from .admin import LOCAL_CONFIG_PATH, load_local_admins, save_local_admins +from .config_class import Config, ConfigBuilder +from .toml_io import CONFIG_PATH, load_toml_data +from .webui_settings import ( DEFAULT_WEBUI_PASSWORD, DEFAULT_WEBUI_PORT, DEFAULT_WEBUI_URL, WebUISettings, load_webui_settings, ) -from .model_parsers import ( - _log_debug_info, - _merge_admins, - _parse_agent_model_config, - _parse_chat_model_config, - _parse_embedding_model_config, - _parse_grok_model_config, - _parse_historian_model_config, - _parse_image_edit_model_config, - _parse_image_gen_config, - _parse_image_gen_model_config, - _parse_naga_model_config, - _parse_rerank_model_config, - _parse_security_model_config, - _parse_summary_model_config, - _parse_vision_model_config, - _verify_required_fields, -) -from .domain_parsers import ( - _parse_api_config, - _parse_cognitive_config, - _parse_easter_egg_call_mode, - _parse_memes_config, - _parse_message_batcher_config, - _parse_naga_config, - _parse_render_cache_config, - _update_dataclass, -) - -logger = logging.getLogger(__name__) -CONFIG_PATH = Path("config.toml") - -# Re-export symbols that external modules import from this module. __all__ = [ "CONFIG_PATH", "Config", + "ConfigBuilder", "DEFAULT_WEBUI_PASSWORD", "DEFAULT_WEBUI_PORT", "DEFAULT_WEBUI_URL", @@ -124,1681 +27,3 @@ def load_dotenv( "load_webui_settings", "save_local_admins", ] - - -def _load_env() -> None: - try: - load_dotenv() - except Exception: - logger.debug("加载 .env 失败,继续使用 config.toml", exc_info=True) - - -def _build_toml_decode_hint(line: str) -> str: - hints: list[str] = [] - if "\\" in line: - hints.append( - 'Windows 路径建议用单引号(不转义)或双反斜杠,或直接用正斜杠,例如:path = \'D:\\AI\\bot\' / path = "D:\\\\AI\\\\bot" / path = "D:/AI/bot"' - ) - hints.append('多行文本请用三引号,例如:prompt = """..."""') - return ";".join(hints) - - -def _format_toml_decode_error( - path: Path, text: str, exc: tomllib.TOMLDecodeError -) -> str: - lineno: int | None = getattr(exc, "lineno", None) - colno: int | None = getattr(exc, "colno", None) - if not isinstance(lineno, int) or not isinstance(colno, int): - match = re.search(r"\(at line (\d+), column (\d+)\)", str(exc)) - if match: - lineno = int(match.group(1)) - colno = int(match.group(2)) - - if isinstance(lineno, int) and lineno > 0: - lines = text.splitlines() - line = lines[lineno - 1] if 0 <= (lineno - 1) < len(lines) else "" - caret_pos = max((colno or 1) - 1, 0) - caret = " " * min(caret_pos, len(line)) + "^" - hint = _build_toml_decode_hint(line) - location = f"line={lineno} col={colno or 1}" - return f"{exc} ({location})\n> {line}\n {caret}\n提示:{hint}" - return str(exc) - - -def load_toml_data( - config_path: Optional[Path] = None, *, strict: bool = False -) -> dict[str, Any]: - """读取 config.toml 并返回字典""" - path = config_path or CONFIG_PATH - if not path.exists(): - return {} - text = "" - try: - text = path.read_bytes().decode("utf-8-sig") - data = tomllib.loads(text) - if isinstance(data, dict): - return data - logger.warning("config.toml 内容不是对象结构") - return {} - except tomllib.TOMLDecodeError as exc: - message = _format_toml_decode_error(path, text, exc) - logger.error("config.toml 解析失败 (%s): %s", path.resolve(), message) - if strict: - raise ValueError(message) from exc - return {} - except UnicodeDecodeError as exc: - logger.error("config.toml 编码错误 (%s): %s", path.resolve(), exc) - if strict: - raise ValueError(str(exc)) from exc - return {} - except OSError as exc: - logger.error("读取 config.toml 失败: %s", exc) - if strict: - raise ValueError(str(exc)) from exc - return {} - - -@dataclass -class Config: - """应用配置""" - - bot_qq: int - superadmin_qq: int - admin_qqs: list[int] - # 访问控制模式:off / blacklist / allowlist - access_mode: str - # 访问控制(会话白名单 + 黑名单) - allowed_group_ids: list[int] - blocked_group_ids: list[int] - allowed_private_ids: list[int] - blocked_private_ids: list[int] - # 是否允许超级管理员在私聊中绕过 allowed_private_ids(仅私聊收发) - superadmin_bypass_allowlist: bool - # 是否允许超级管理员在私聊中绕过 blocked_private_ids(仅私聊收发) - superadmin_bypass_private_blacklist: bool - forward_proxy_qq: int | None - process_every_message: bool - process_private_message: bool - process_poke_message: bool - keyword_reply_enabled: bool - repeat_enabled: bool - repeat_threshold: int - repeat_cooldown_minutes: int - inverted_question_enabled: bool - context_recent_messages_limit: int - ai_request_max_retries: int - missing_tool_call_retries: int - nagaagent_mode_enabled: bool - onebot_ws_url: str - onebot_token: str - chat_model: ChatModelConfig - vision_model: VisionModelConfig - security_model_enabled: bool - security_model: SecurityModelConfig - naga_model: SecurityModelConfig - agent_model: AgentModelConfig - historian_model: AgentModelConfig - summary_model: AgentModelConfig - summary_model_configured: bool - grok_model: GrokModelConfig - model_pool_enabled: bool - log_level: str - log_file_path: str - log_max_size: int - log_backup_count: int - log_tty_enabled: bool - log_thinking: bool - tools_dot_delimiter: str - tools_description_truncate_enabled: bool - tools_description_max_len: int - tools_sanitize_verbose: bool - tools_description_preview_len: int - easter_egg_agent_call_message_mode: str - token_usage_max_size_mb: int - token_usage_max_archives: int - token_usage_max_total_mb: int - token_usage_archive_prune_mode: str - history_max_records: int - history_filtered_result_limit: int - history_search_scan_limit: int - history_summary_fetch_limit: int - history_summary_time_fetch_limit: int - history_onebot_fetch_limit: int - history_group_analysis_limit: int - attachment_remote_download_max_size_mb: int - attachment_cache_max_total_size_mb: int - attachment_cache_max_records: int - attachment_cache_max_age_days: int - attachment_url_reference_max_records: int - attachment_url_max_length: int - skills_hot_reload: bool - skills_hot_reload_interval: float - skills_hot_reload_debounce: float - agent_intro_autogen_enabled: bool - agent_intro_autogen_queue_interval: float - agent_intro_autogen_max_tokens: int - agent_intro_hash_path: str - searxng_url: str - grok_search_enabled: bool - use_proxy: bool - http_proxy: str - https_proxy: str - network_request_timeout: float - network_request_retries: int - render_browser_max_concurrency: int - api_xxapi_base_url: str - api_xingzhige_base_url: str - api_jkyai_base_url: str - api_seniverse_base_url: str - weather_api_key: str - xxapi_api_token: str - mcp_config_path: str - prefetch_tools: list[str] - prefetch_tools_hide: bool - webui_url: str - webui_port: int - webui_password: str - api: APIConfig - # Code Delivery Agent - code_delivery_enabled: bool - code_delivery_task_root: str - code_delivery_docker_image: str - code_delivery_container_name_prefix: str - code_delivery_container_name_suffix: str - code_delivery_command_timeout: int - code_delivery_max_command_output: int - code_delivery_default_archive_format: str - code_delivery_max_archive_size_mb: int - code_delivery_cleanup_on_finish: bool - code_delivery_cleanup_on_start: bool - code_delivery_llm_max_retries: int - code_delivery_notify_on_llm_failure: bool - code_delivery_container_memory_limit: str - code_delivery_container_cpu_limit: str - code_delivery_command_blacklist: list[str] - # messages 工具集 - messages_send_text_file_max_size_kb: int - messages_send_url_file_max_size_mb: int - # 嵌入模型 - embedding_model: EmbeddingModelConfig - rerank_model: RerankModelConfig - # 知识库 - knowledge_enabled: bool - knowledge_base_dir: str - knowledge_auto_scan: bool - knowledge_auto_embed: bool - knowledge_scan_interval: float - knowledge_embed_batch_size: int - knowledge_chunk_size: int - knowledge_chunk_overlap: int - knowledge_default_top_k: int - knowledge_enable_rerank: bool - knowledge_rerank_top_k: int - # Bilibili 视频提取 - bilibili_auto_extract_enabled: bool - bilibili_cookie: str - bilibili_prefer_quality: int - bilibili_max_duration: int - bilibili_max_file_size: int - bilibili_oversize_strategy: str - bilibili_danmaku_enabled: bool - bilibili_danmaku_batch_size: int - bilibili_danmaku_max_count: int - bilibili_auto_extract_group_ids: list[int] - bilibili_auto_extract_private_ids: list[int] - # arXiv 论文提取 - arxiv_auto_extract_enabled: bool - arxiv_max_file_size: int - arxiv_auto_extract_group_ids: list[int] - arxiv_auto_extract_private_ids: list[int] - arxiv_auto_extract_max_items: int - arxiv_author_preview_limit: int - arxiv_summary_preview_chars: int - # GitHub 仓库自动提取 - github_auto_extract_enabled: bool - github_request_timeout_seconds: float - github_auto_extract_group_ids: list[int] - github_auto_extract_private_ids: list[int] - github_auto_extract_max_items: int - # 认知记忆 - cognitive: CognitiveConfig - # 表情包库 - memes: MemeConfig - # 同 sender 短时多消息合并器 - message_batcher: MessageBatcherConfig - # HTML 渲染结果缓存 - render_cache: RenderCacheConfig - # Naga 集成 - naga: NagaConfig - # 生图工具配置 - image_gen: ImageGenConfig - models_image_gen: ImageGenModelConfig - models_image_edit: ImageGenModelConfig - _allowed_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _blocked_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _allowed_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _blocked_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _bilibili_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _bilibili_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _arxiv_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _arxiv_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _github_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _github_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - - def __post_init__(self) -> None: - # 访问控制属于高频热路径,启动后缓存为 set 降低重复构建开销。 - normalized_mode = str(self.access_mode).strip().lower() - if normalized_mode not in {"off", "blacklist", "allowlist", "legacy"}: - normalized_mode = "off" - self.access_mode = normalized_mode - self._allowed_group_ids_set = {int(item) for item in self.allowed_group_ids} - self._blocked_group_ids_set = {int(item) for item in self.blocked_group_ids} - self._allowed_private_ids_set = {int(item) for item in self.allowed_private_ids} - self._blocked_private_ids_set = {int(item) for item in self.blocked_private_ids} - self._bilibili_group_ids_set = { - int(item) for item in self.bilibili_auto_extract_group_ids - } - self._bilibili_private_ids_set = { - int(item) for item in self.bilibili_auto_extract_private_ids - } - self._arxiv_group_ids_set = { - int(item) for item in self.arxiv_auto_extract_group_ids - } - self._arxiv_private_ids_set = { - int(item) for item in self.arxiv_auto_extract_private_ids - } - self._github_group_ids_set = { - int(item) for item in self.github_auto_extract_group_ids - } - self._github_private_ids_set = { - int(item) for item in self.github_auto_extract_private_ids - } - - @classmethod - def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Config": - """从 config.toml 和本地配置加载配置""" - _load_env() - data = load_toml_data(config_path, strict=strict) - - bot_qq = _coerce_int(_get_value(data, ("core", "bot_qq"), "BOT_QQ"), 0) - superadmin_qq = _coerce_int( - _get_value(data, ("core", "superadmin_qq"), "SUPERADMIN_QQ"), 0 - ) - admin_qqs = _coerce_int_list(_get_value(data, ("core", "admin_qq"), "ADMIN_QQ")) - forward_proxy = _coerce_int( - _get_value(data, ("core", "forward_proxy_qq"), "FORWARD_PROXY_QQ"), - 0, - ) - forward_proxy_qq = forward_proxy if forward_proxy > 0 else None - process_every_message = _coerce_bool( - _get_value( - data, - ("core", "process_every_message"), - "PROCESS_EVERY_MESSAGE", - ), - True, - ) - process_private_message = _coerce_bool( - _get_value( - data, - ("core", "process_private_message"), - "PROCESS_PRIVATE_MESSAGE", - ), - True, - ) - process_poke_message = _coerce_bool( - _get_value( - data, - ("core", "process_poke_message"), - "PROCESS_POKE_MESSAGE", - ), - True, - ) - keyword_reply_raw = _get_value( - data, - ("easter_egg", "keyword_reply_enabled"), - "KEYWORD_REPLY_ENABLED", - ) - if keyword_reply_raw is None: - # 兼容旧配置:历史上放在 [core].keyword_reply_enabled - keyword_reply_raw = _get_value( - data, - ("core", "keyword_reply_enabled"), - None, - ) - keyword_reply_enabled = _coerce_bool(keyword_reply_raw, False) - repeat_enabled = _coerce_bool( - _get_value( - data, - ("easter_egg", "repeat_enabled"), - "EASTER_EGG_REPEAT_ENABLED", - ), - False, - ) - inverted_question_enabled = _coerce_bool( - _get_value( - data, - ("easter_egg", "inverted_question_enabled"), - "EASTER_EGG_INVERTED_QUESTION_ENABLED", - ), - False, - ) - repeat_threshold = _coerce_int( - _get_value( - data, - ("easter_egg", "repeat_threshold"), - "EASTER_EGG_REPEAT_THRESHOLD", - ), - 3, - ) - if repeat_threshold < 2: - repeat_threshold = 2 - if repeat_threshold > 20: - repeat_threshold = 20 - repeat_cooldown_minutes = _coerce_int( - _get_value( - data, - ("easter_egg", "repeat_cooldown_minutes"), - "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", - ), - 60, - ) - if repeat_cooldown_minutes < 0: - repeat_cooldown_minutes = 0 - context_recent_messages_limit = _coerce_int( - _get_value( - data, - ("core", "context_recent_messages_limit"), - "CONTEXT_RECENT_MESSAGES_LIMIT", - ), - 20, - ) - if context_recent_messages_limit < 0: - context_recent_messages_limit = 0 - - ai_request_max_retries = _coerce_int( - _get_value( - data, - ("core", "ai_request_max_retries"), - "AI_REQUEST_MAX_RETRIES", - ), - 2, - ) - if ai_request_max_retries < 0: - ai_request_max_retries = 0 - - missing_tool_call_retries = _coerce_int( - _get_value( - data, - ("core", "missing_tool_call_retries"), - "MISSING_TOOL_CALL_RETRIES", - ), - 3, - ) - if missing_tool_call_retries < 0: - missing_tool_call_retries = 0 - - nagaagent_mode_enabled = _coerce_bool( - _get_value( - data, - ("features", "nagaagent_mode_enabled"), - "NAGAAGENT_MODE_ENABLED", - ), - False, - ) - onebot_ws_url = _coerce_str( - _get_value(data, ("onebot", "ws_url"), "ONEBOT_WS_URL"), "" - ) - onebot_token = _coerce_str( - _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" - ) - - embedding_model = _parse_embedding_model_config(data) - rerank_model = _parse_rerank_model_config(data) - - knowledge_enabled = _coerce_bool( - _get_value(data, ("knowledge", "enabled"), None), False - ) - knowledge_base_dir = _coerce_str( - _get_value(data, ("knowledge", "base_dir"), None), "knowledge" - ) - knowledge_auto_scan = _coerce_bool( - _get_value(data, ("knowledge", "auto_scan"), None), False - ) - knowledge_auto_embed = _coerce_bool( - _get_value(data, ("knowledge", "auto_embed"), None), False - ) - knowledge_scan_interval = _coerce_float( - _get_value(data, ("knowledge", "scan_interval"), None), 60.0 - ) - if knowledge_scan_interval <= 0: - knowledge_scan_interval = 60.0 - knowledge_embed_batch_size = _coerce_int( - _get_value(data, ("knowledge", "embed_batch_size"), None), 64 - ) - if knowledge_embed_batch_size <= 0: - knowledge_embed_batch_size = 64 - knowledge_chunk_size = _coerce_int( - _get_value(data, ("knowledge", "chunk_size"), None), 10 - ) - if knowledge_chunk_size <= 0: - knowledge_chunk_size = 10 - knowledge_chunk_overlap = _coerce_int( - _get_value(data, ("knowledge", "chunk_overlap"), None), 2 - ) - if knowledge_chunk_overlap < 0: - knowledge_chunk_overlap = 0 - knowledge_default_top_k = _coerce_int( - _get_value(data, ("knowledge", "default_top_k"), None), 5 - ) - if knowledge_default_top_k <= 0: - knowledge_default_top_k = 5 - knowledge_enable_rerank = _coerce_bool( - _get_value(data, ("knowledge", "enable_rerank"), None), False - ) - knowledge_rerank_top_k = _coerce_int( - _get_value(data, ("knowledge", "rerank_top_k"), None), 3 - ) - if knowledge_rerank_top_k <= 0: - knowledge_rerank_top_k = 3 - if knowledge_default_top_k <= 1 and knowledge_enable_rerank: - logger.warning( - "[配置] knowledge.default_top_k=%s,无法满足 rerank_top_k < default_top_k," - "已自动禁用重排", - knowledge_default_top_k, - ) - knowledge_enable_rerank = False - if knowledge_rerank_top_k >= knowledge_default_top_k: - fallback = knowledge_default_top_k - 1 - if fallback <= 0: - fallback = 1 - knowledge_enable_rerank = False - logger.warning( - "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," - "且当前 default_top_k=%s 无法满足约束,已自动禁用重排", - knowledge_default_top_k, - ) - else: - logger.warning( - "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," - "已回退: rerank_top_k=%s -> %s (default_top_k=%s)", - knowledge_rerank_top_k, - fallback, - knowledge_default_top_k, - ) - knowledge_rerank_top_k = fallback - - chat_model = _parse_chat_model_config(data) - vision_model = _parse_vision_model_config(data) - security_model_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "enabled"), - "SECURITY_MODEL_ENABLED", - ), - True, - ) - security_model = _parse_security_model_config(data, chat_model) - naga_model = _parse_naga_model_config(data, security_model) - agent_model = _parse_agent_model_config(data) - historian_model = _parse_historian_model_config(data, agent_model) - summary_model, summary_model_configured = _parse_summary_model_config( - data, agent_model - ) - grok_model = _parse_grok_model_config(data) - - model_pool_enabled = _coerce_bool( - _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False - ) - - superadmin_qq, admin_qqs = _merge_admins( - superadmin_qq=superadmin_qq, admin_qqs=admin_qqs - ) - - access_mode_raw = _get_value(data, ("access", "mode"), "ACCESS_MODE") - allowed_group_ids = _coerce_int_list( - _get_value(data, ("access", "allowed_group_ids"), "ALLOWED_GROUP_IDS") - ) - blocked_group_ids = _coerce_int_list( - _get_value(data, ("access", "blocked_group_ids"), "BLOCKED_GROUP_IDS") - ) - allowed_private_ids = _coerce_int_list( - _get_value(data, ("access", "allowed_private_ids"), "ALLOWED_PRIVATE_IDS") - ) - blocked_private_ids = _coerce_int_list( - _get_value(data, ("access", "blocked_private_ids"), "BLOCKED_PRIVATE_IDS") - ) - superadmin_bypass_allowlist = _coerce_bool( - _get_value( - data, - ("access", "superadmin_bypass_allowlist"), - "SUPERADMIN_BYPASS_ALLOWLIST", - ), - True, - ) - superadmin_bypass_private_blacklist = _coerce_bool( - _get_value( - data, - ("access", "superadmin_bypass_private_blacklist"), - "SUPERADMIN_BYPASS_PRIVATE_BLACKLIST", - ), - False, - ) - if access_mode_raw is None: - # 兼容旧配置:未配置 mode 时沿用历史行为(群黑名单 + 白名单联动)。 - if ( - allowed_group_ids - or blocked_group_ids - or allowed_private_ids - or blocked_private_ids - ): - access_mode = "legacy" - logger.warning( - "[配置] access.mode 未设置,已启用兼容模式(legacy)。建议显式设置为 off/blacklist/allowlist。" - ) - else: - access_mode = "off" - else: - access_mode = _coerce_str(access_mode_raw, "off").lower() - if access_mode not in {"off", "blacklist", "allowlist"}: - logger.warning( - "[配置] access.mode 非法(仅支持 off/blacklist/allowlist),已回退为 off: %s", - access_mode, - ) - access_mode = "off" - - log_level = _coerce_str( - _get_value(data, ("logging", "level"), "LOG_LEVEL"), "INFO" - ).upper() - log_file_path = _coerce_str( - _get_value(data, ("logging", "file_path"), "LOG_FILE_PATH"), - "logs/bot.log", - ) - log_max_size_mb = _coerce_int( - _get_value(data, ("logging", "max_size_mb"), "LOG_MAX_SIZE_MB"), 10 - ) - log_backup_count = _coerce_int( - _get_value(data, ("logging", "backup_count"), "LOG_BACKUP_COUNT"), 5 - ) - log_tty_enabled = _coerce_bool( - _get_value(data, ("logging", "tty_enabled"), "LOG_TTY_ENABLED"), - False, - ) - log_thinking = _coerce_bool( - _get_value(data, ("logging", "log_thinking"), "LOG_THINKING"), True - ) - - tools_dot_delimiter = _coerce_str( - _get_value(data, ("tools", "dot_delimiter"), "TOOLS_DOT_DELIMITER"), "-_-" - ).strip() - if not tools_dot_delimiter: - tools_dot_delimiter = "-_-" - # dot_delimiter 必须满足 OpenAI 兼容的 function.name 约束。 - if "." in tools_dot_delimiter or not re.fullmatch( - r"[a-zA-Z0-9_-]+", tools_dot_delimiter - ): - logger.warning( - "[配置] tools.dot_delimiter 非法(仅允许 [a-zA-Z0-9_-] 且不能包含 '.'),已回退默认值: '-_-'(当前=%s)", - tools_dot_delimiter, - ) - tools_dot_delimiter = "-_-" - tools_description_max_len = _coerce_int( - _get_value( - data, ("tools", "description_max_len"), "TOOLS_DESCRIPTION_MAX_LEN" - ), - 1024, - ) - tools_description_truncate_enabled = _coerce_bool( - _get_value( - data, - ("tools", "description_truncate_enabled"), - "TOOLS_DESCRIPTION_TRUNCATE_ENABLED", - ), - False, - ) - tools_sanitize_verbose = _coerce_bool( - _get_value(data, ("tools", "sanitize_verbose"), "TOOLS_SANITIZE_VERBOSE"), - False, - ) - tools_description_preview_len = _coerce_int( - _get_value( - data, - ("tools", "description_preview_len"), - "TOOLS_DESCRIPTION_PREVIEW_LEN", - ), - 160, - ) - - easter_egg_mode_raw = _get_value( - data, - ("easter_egg", "agent_call_message_enabled"), - "EASTER_EGG_AGENT_CALL_MESSAGE_ENABLED", - ) - if easter_egg_mode_raw is None: - easter_egg_mode_raw = os.getenv("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") - if easter_egg_mode_raw is not None: - _warn_env_fallback("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") - else: - easter_egg_mode_raw = os.getenv("EASTER_EGG_CALL_MESSAGE_MODE") - if easter_egg_mode_raw is not None: - _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") - - easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( - easter_egg_mode_raw - ) - - token_usage_max_size_mb = _coerce_int( - _get_value(data, ("token_usage", "max_size_mb"), "TOKEN_USAGE_MAX_SIZE_MB"), - 5, - ) - token_usage_max_archives = _coerce_int( - _get_value( - data, ("token_usage", "max_archives"), "TOKEN_USAGE_MAX_ARCHIVES" - ), - 30, - ) - token_usage_max_total_mb = _coerce_int( - _get_value( - data, ("token_usage", "max_total_mb"), "TOKEN_USAGE_MAX_TOTAL_MB" - ), - 0, - ) - token_usage_archive_prune_mode = _coerce_str( - _get_value( - data, - ("token_usage", "archive_prune_mode"), - "TOKEN_USAGE_ARCHIVE_PRUNE_MODE", - ), - "delete", - ) - - history_max_records = max( - 0, - _coerce_int( - _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), - 10000, - ), - ) - history_filtered_result_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "filtered_result_limit"), - "HISTORY_FILTERED_RESULT_LIMIT", - ), - 200, - ), - ) - history_search_scan_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "search_scan_limit"), - "HISTORY_SEARCH_SCAN_LIMIT", - ), - 10000, - ), - ) - history_summary_fetch_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "summary_fetch_limit"), - "HISTORY_SUMMARY_FETCH_LIMIT", - ), - 1000, - ), - ) - history_summary_time_fetch_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "summary_time_fetch_limit"), - "HISTORY_SUMMARY_TIME_FETCH_LIMIT", - ), - 5000, - ), - ) - history_onebot_fetch_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "onebot_fetch_limit"), - "HISTORY_ONEBOT_FETCH_LIMIT", - ), - 10000, - ), - ) - history_group_analysis_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "group_analysis_limit"), - "HISTORY_GROUP_ANALYSIS_LIMIT", - ), - 500, - ), - ) - attachment_remote_download_max_size_mb = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "remote_download_max_size_mb"), - "ATTACHMENTS_REMOTE_DOWNLOAD_MAX_SIZE_MB", - ), - 25, - ), - ) - attachment_cache_max_total_size_mb = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "cache_max_total_size_mb"), - "ATTACHMENTS_CACHE_MAX_TOTAL_SIZE_MB", - ), - 0, - ), - ) - attachment_cache_max_records = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "cache_max_records"), - "ATTACHMENTS_CACHE_MAX_RECORDS", - ), - 2000, - ), - ) - attachment_cache_max_age_days = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "cache_max_age_days"), - "ATTACHMENTS_CACHE_MAX_AGE_DAYS", - ), - 7, - ), - ) - attachment_url_reference_max_records = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "url_reference_max_records"), - "ATTACHMENTS_URL_REFERENCE_MAX_RECORDS", - ), - 2000, - ), - ) - attachment_url_max_length = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "url_max_length"), - "ATTACHMENTS_URL_MAX_LENGTH", - ), - 8192, - ), - ) - - skills_hot_reload = _coerce_bool( - _get_value(data, ("skills", "hot_reload"), "SKILLS_HOT_RELOAD"), True - ) - skills_hot_reload_interval = _coerce_float( - _get_value( - data, ("skills", "hot_reload_interval"), "SKILLS_HOT_RELOAD_INTERVAL" - ), - 2.0, - ) - skills_hot_reload_debounce = _coerce_float( - _get_value( - data, ("skills", "hot_reload_debounce"), "SKILLS_HOT_RELOAD_DEBOUNCE" - ), - 0.5, - ) - - agent_intro_autogen_enabled = _coerce_bool( - _get_value( - data, - ("skills", "intro_autogen_enabled"), - "AGENT_INTRO_AUTOGEN_ENABLED", - ), - True, - ) - agent_intro_autogen_queue_interval = _coerce_float( - _get_value( - data, - ("skills", "intro_autogen_queue_interval"), - "AGENT_INTRO_AUTOGEN_QUEUE_INTERVAL", - ), - 1.0, - ) - agent_intro_autogen_queue_interval = _normalize_queue_interval( - agent_intro_autogen_queue_interval - ) - agent_intro_autogen_max_tokens = _coerce_int( - _get_value( - data, - ("skills", "intro_autogen_max_tokens"), - "AGENT_INTRO_AUTOGEN_MAX_TOKENS", - ), - 8192, - ) - agent_intro_hash_path = _coerce_str( - _get_value(data, ("skills", "intro_hash_path"), "AGENT_INTRO_HASH_PATH"), - ".cache/agent_intro_hashes.json", - ) - - prefetch_tools_raw = _get_value( - data, ("skills", "prefetch_tools"), "PREFETCH_TOOLS" - ) - prefetch_tools = _coerce_str_list(prefetch_tools_raw) - if not prefetch_tools and prefetch_tools_raw is None: - prefetch_tools = ["get_current_time"] - prefetch_tools_hide = _coerce_bool( - _get_value(data, ("skills", "prefetch_tools_hide"), "PREFETCH_TOOLS_HIDE"), - True, - ) - - searxng_url = _coerce_str( - _get_value(data, ("search", "searxng_url"), "SEARXNG_URL"), "" - ) - grok_search_enabled = _coerce_bool( - _get_value( - data, - ("search", "grok_search_enabled"), - "GROK_SEARCH_ENABLED", - ), - False, - ) - - use_proxy = _coerce_bool( - _get_value(data, ("proxy", "use_proxy"), "USE_PROXY"), True - ) - http_proxy = _coerce_str( - _get_value(data, ("proxy", "http_proxy"), "http_proxy"), "" - ) - if not http_proxy: - http_proxy = _coerce_str(os.getenv("HTTP_PROXY"), "") - if http_proxy: - _warn_env_fallback("HTTP_PROXY") - https_proxy = _coerce_str( - _get_value(data, ("proxy", "https_proxy"), "https_proxy"), "" - ) - if not https_proxy: - https_proxy = _coerce_str(os.getenv("HTTPS_PROXY"), "") - if https_proxy: - _warn_env_fallback("HTTPS_PROXY") - - network_request_timeout = _coerce_float( - _get_value( - data, - ("network", "request_timeout_seconds"), - "NETWORK_REQUEST_TIMEOUT_SECONDS", - ), - 30.0, - ) - if network_request_timeout <= 0: - network_request_timeout = 480.0 - - network_request_retries = _coerce_int( - _get_value( - data, - ("network", "request_retries"), - "NETWORK_REQUEST_RETRIES", - ), - 0, - ) - if network_request_retries < 0: - network_request_retries = 0 - if network_request_retries > 5: - network_request_retries = 5 - - render_browser_max_concurrency = max( - 0, - _coerce_int( - _get_value( - data, - ("render", "browser_max_concurrency"), - "RENDER_BROWSER_MAX_CONCURRENCY", - ), - 0, - ), - ) - - api_xxapi_base_url = _normalize_base_url( - _coerce_str( - _get_value(data, ("api_endpoints", "xxapi_base_url"), "XXAPI_BASE_URL"), - "https://v2.xxapi.cn", - ), - "https://v2.xxapi.cn", - ) - api_xingzhige_base_url = _normalize_base_url( - _coerce_str( - _get_value( - data, - ("api_endpoints", "xingzhige_base_url"), - "XINGZHIGE_BASE_URL", - ), - "https://api.xingzhige.com", - ), - "https://api.xingzhige.com", - ) - api_jkyai_base_url = _normalize_base_url( - _coerce_str( - _get_value(data, ("api_endpoints", "jkyai_base_url"), "JKYAI_BASE_URL"), - "https://api.jkyai.top", - ), - "https://api.jkyai.top", - ) - api_seniverse_base_url = _normalize_base_url( - _coerce_str( - _get_value( - data, - ("api_endpoints", "seniverse_base_url"), - "SENIVERSE_BASE_URL", - ), - "https://api.seniverse.com/v3", - ), - "https://api.seniverse.com/v3", - ) - - weather_api_key = _coerce_str( - _get_value(data, ("weather", "api_key"), "WEATHER_API_KEY"), "" - ) - xxapi_api_token = _coerce_str( - _get_value(data, ("xxapi", "api_token"), "XXAPI_API_TOKEN"), "" - ) - - mcp_config_path = _coerce_str( - _get_value(data, ("mcp", "config_path"), "MCP_CONFIG_PATH"), - "config/mcp.json", - ) - - # Bilibili 配置 - bilibili_auto_extract_enabled = _coerce_bool( - _get_value(data, ("bilibili", "auto_extract_enabled"), None), False - ) - bilibili_cookie = _coerce_str( - _get_value(data, ("bilibili", "cookie"), None), "" - ) - if not bilibili_cookie: - # 兼容旧配置项:bilibili.sessdata - bilibili_cookie = _coerce_str( - _get_value(data, ("bilibili", "sessdata"), None), "" - ) - bilibili_prefer_quality = _coerce_int( - _get_value(data, ("bilibili", "prefer_quality"), None), 80 - ) - bilibili_max_duration = _coerce_int( - _get_value(data, ("bilibili", "max_duration"), None), 600 - ) - bilibili_max_file_size = _coerce_int( - _get_value(data, ("bilibili", "max_file_size"), None), 100 - ) - bilibili_oversize_strategy = _coerce_str( - _get_value(data, ("bilibili", "oversize_strategy"), None), "downgrade" - ) - if bilibili_oversize_strategy not in ("downgrade", "info"): - bilibili_oversize_strategy = "downgrade" - bilibili_danmaku_enabled = _coerce_bool( - _get_value(data, ("bilibili", "danmaku_enabled"), None), True - ) - bilibili_danmaku_batch_size = _coerce_int( - _get_value(data, ("bilibili", "danmaku_batch_size"), None), 100 - ) - if bilibili_danmaku_batch_size <= 0: - bilibili_danmaku_batch_size = 100 - bilibili_danmaku_max_count = _coerce_int( - _get_value(data, ("bilibili", "danmaku_max_count"), None), 0 - ) - if bilibili_danmaku_max_count < 0: - bilibili_danmaku_max_count = 0 - bilibili_auto_extract_group_ids = _coerce_int_list( - _get_value(data, ("bilibili", "auto_extract_group_ids"), None) - ) - bilibili_auto_extract_private_ids = _coerce_int_list( - _get_value(data, ("bilibili", "auto_extract_private_ids"), None) - ) - - # arXiv 配置 - arxiv_auto_extract_enabled = _coerce_bool( - _get_value(data, ("arxiv", "auto_extract_enabled"), None), False - ) - arxiv_max_file_size = _coerce_int( - _get_value(data, ("arxiv", "max_file_size"), None), 100 - ) - if arxiv_max_file_size < 0: - arxiv_max_file_size = 100 - arxiv_auto_extract_group_ids = _coerce_int_list( - _get_value(data, ("arxiv", "auto_extract_group_ids"), None) - ) - arxiv_auto_extract_private_ids = _coerce_int_list( - _get_value(data, ("arxiv", "auto_extract_private_ids"), None) - ) - arxiv_auto_extract_max_items = _coerce_int( - _get_value(data, ("arxiv", "auto_extract_max_items"), None), 5 - ) - if arxiv_auto_extract_max_items <= 0: - arxiv_auto_extract_max_items = 5 - if arxiv_auto_extract_max_items > 20: - arxiv_auto_extract_max_items = 20 - arxiv_author_preview_limit = _coerce_int( - _get_value(data, ("arxiv", "author_preview_limit"), None), 20 - ) - if arxiv_author_preview_limit <= 0: - arxiv_author_preview_limit = 20 - if arxiv_author_preview_limit > 100: - arxiv_author_preview_limit = 100 - arxiv_summary_preview_chars = _coerce_int( - _get_value(data, ("arxiv", "summary_preview_chars"), None), 1000 - ) - if arxiv_summary_preview_chars <= 0: - arxiv_summary_preview_chars = 1000 - if arxiv_summary_preview_chars > 8000: - arxiv_summary_preview_chars = 8000 - - # GitHub 配置 - github_auto_extract_enabled = _coerce_bool( - _get_value(data, ("github", "auto_extract_enabled"), None), False - ) - github_request_timeout_seconds = _coerce_float( - _get_value(data, ("github", "request_timeout_seconds"), None), 10.0 - ) - if github_request_timeout_seconds <= 0: - github_request_timeout_seconds = 10.0 - if github_request_timeout_seconds > 60.0: - github_request_timeout_seconds = 60.0 - github_auto_extract_group_ids = _coerce_int_list( - _get_value(data, ("github", "auto_extract_group_ids"), None) - ) - github_auto_extract_private_ids = _coerce_int_list( - _get_value(data, ("github", "auto_extract_private_ids"), None) - ) - github_auto_extract_max_items = _coerce_int( - _get_value(data, ("github", "auto_extract_max_items"), None), 3 - ) - if github_auto_extract_max_items <= 0: - github_auto_extract_max_items = 3 - if github_auto_extract_max_items > 10: - github_auto_extract_max_items = 10 - - # Code Delivery Agent 配置 - code_delivery_enabled = _coerce_bool( - _get_value(data, ("code_delivery", "enabled"), None), True - ) - code_delivery_task_root = _coerce_str( - _get_value(data, ("code_delivery", "task_root"), None), - "data/code_delivery", - ) - code_delivery_docker_image = _coerce_str( - _get_value(data, ("code_delivery", "docker_image"), None), - "ubuntu:24.04", - ) - code_delivery_container_name_prefix = _coerce_str( - _get_value(data, ("code_delivery", "container_name_prefix"), None), - "code_delivery_", - ) - code_delivery_container_name_suffix = _coerce_str( - _get_value(data, ("code_delivery", "container_name_suffix"), None), - "_runner", - ) - code_delivery_command_timeout = _coerce_int( - _get_value( - data, ("code_delivery", "default_command_timeout_seconds"), None - ), - 600, - ) - code_delivery_max_command_output = _coerce_int( - _get_value(data, ("code_delivery", "max_command_output_chars"), None), - 20000, - ) - code_delivery_default_archive_format = _coerce_str( - _get_value(data, ("code_delivery", "default_archive_format"), None), - "zip", - ) - if code_delivery_default_archive_format not in ("zip", "tar.gz"): - code_delivery_default_archive_format = "zip" - code_delivery_max_archive_size_mb = _coerce_int( - _get_value(data, ("code_delivery", "max_archive_size_mb"), None), 200 - ) - code_delivery_cleanup_on_finish = _coerce_bool( - _get_value(data, ("code_delivery", "cleanup_on_finish"), None), True - ) - code_delivery_cleanup_on_start = _coerce_bool( - _get_value(data, ("code_delivery", "cleanup_on_start"), None), True - ) - code_delivery_llm_max_retries = _coerce_int( - _get_value(data, ("code_delivery", "llm_max_retries_per_request"), None), - 5, - ) - code_delivery_notify_on_llm_failure = _coerce_bool( - _get_value(data, ("code_delivery", "notify_on_llm_failure"), None), - True, - ) - code_delivery_container_memory_limit = _coerce_str( - _get_value(data, ("code_delivery", "container_memory_limit"), None), - "", - ) - code_delivery_container_cpu_limit = _coerce_str( - _get_value(data, ("code_delivery", "container_cpu_limit"), None), - "", - ) - code_delivery_command_blacklist_raw = _get_value( - data, ("code_delivery", "command_blacklist"), None - ) - if isinstance(code_delivery_command_blacklist_raw, list): - code_delivery_command_blacklist = [ - str(x) for x in code_delivery_command_blacklist_raw - ] - else: - code_delivery_command_blacklist = [] - - # messages 工具集配置 - messages_send_text_file_max_size_kb = _coerce_int( - _get_value( - data, - ("messages", "send_text_file_max_size_kb"), - "MESSAGES_SEND_TEXT_FILE_MAX_SIZE_KB", - ), - 512, - ) - if messages_send_text_file_max_size_kb <= 0: - messages_send_text_file_max_size_kb = 512 - - messages_send_url_file_max_size_mb = _coerce_int( - _get_value( - data, - ("messages", "send_url_file_max_size_mb"), - "MESSAGES_SEND_URL_FILE_MAX_SIZE_MB", - ), - 100, - ) - if messages_send_url_file_max_size_mb <= 0: - messages_send_url_file_max_size_mb = 100 - - webui_settings = load_webui_settings(config_path) - api_config = _parse_api_config(data) - - cognitive = _parse_cognitive_config(data) - memes = _parse_memes_config(data) - message_batcher = _parse_message_batcher_config(data) - render_cache = _parse_render_cache_config(data) - naga = _parse_naga_config(data) - models_image_gen = _parse_image_gen_model_config(data) - models_image_edit = _parse_image_edit_model_config(data) - image_gen = _parse_image_gen_config(data) - - if strict: - _verify_required_fields( - bot_qq=bot_qq, - superadmin_qq=superadmin_qq, - onebot_ws_url=onebot_ws_url, - chat_model=chat_model, - vision_model=vision_model, - agent_model=agent_model, - knowledge_enabled=knowledge_enabled, - embedding_model=embedding_model, - ) - - _log_debug_info( - chat_model, - vision_model, - security_model, - naga_model, - agent_model, - summary_model, - grok_model, - ) - - return cls( - bot_qq=bot_qq, - superadmin_qq=superadmin_qq, - admin_qqs=admin_qqs, - access_mode=access_mode, - allowed_group_ids=allowed_group_ids, - blocked_group_ids=blocked_group_ids, - allowed_private_ids=allowed_private_ids, - blocked_private_ids=blocked_private_ids, - superadmin_bypass_allowlist=superadmin_bypass_allowlist, - superadmin_bypass_private_blacklist=superadmin_bypass_private_blacklist, - forward_proxy_qq=forward_proxy_qq, - process_every_message=process_every_message, - process_private_message=process_private_message, - process_poke_message=process_poke_message, - keyword_reply_enabled=keyword_reply_enabled, - repeat_enabled=repeat_enabled, - repeat_threshold=repeat_threshold, - repeat_cooldown_minutes=repeat_cooldown_minutes, - inverted_question_enabled=inverted_question_enabled, - context_recent_messages_limit=context_recent_messages_limit, - ai_request_max_retries=ai_request_max_retries, - missing_tool_call_retries=missing_tool_call_retries, - nagaagent_mode_enabled=nagaagent_mode_enabled, - onebot_ws_url=onebot_ws_url, - onebot_token=onebot_token, - chat_model=chat_model, - vision_model=vision_model, - security_model_enabled=security_model_enabled, - security_model=security_model, - naga_model=naga_model, - agent_model=agent_model, - historian_model=historian_model, - summary_model=summary_model, - summary_model_configured=summary_model_configured, - grok_model=grok_model, - model_pool_enabled=model_pool_enabled, - log_level=log_level, - log_file_path=log_file_path, - log_max_size=log_max_size_mb * 1024 * 1024, - log_backup_count=log_backup_count, - log_tty_enabled=log_tty_enabled, - log_thinking=log_thinking, - tools_dot_delimiter=tools_dot_delimiter, - tools_description_truncate_enabled=tools_description_truncate_enabled, - tools_description_max_len=tools_description_max_len, - tools_sanitize_verbose=tools_sanitize_verbose, - tools_description_preview_len=tools_description_preview_len, - easter_egg_agent_call_message_mode=easter_egg_agent_call_message_mode, - token_usage_max_size_mb=token_usage_max_size_mb, - token_usage_max_archives=token_usage_max_archives, - token_usage_max_total_mb=token_usage_max_total_mb, - token_usage_archive_prune_mode=token_usage_archive_prune_mode, - skills_hot_reload=skills_hot_reload, - history_max_records=history_max_records, - history_filtered_result_limit=history_filtered_result_limit, - history_search_scan_limit=history_search_scan_limit, - history_summary_fetch_limit=history_summary_fetch_limit, - history_summary_time_fetch_limit=history_summary_time_fetch_limit, - history_onebot_fetch_limit=history_onebot_fetch_limit, - history_group_analysis_limit=history_group_analysis_limit, - attachment_remote_download_max_size_mb=attachment_remote_download_max_size_mb, - attachment_cache_max_total_size_mb=attachment_cache_max_total_size_mb, - attachment_cache_max_records=attachment_cache_max_records, - attachment_cache_max_age_days=attachment_cache_max_age_days, - attachment_url_reference_max_records=attachment_url_reference_max_records, - attachment_url_max_length=attachment_url_max_length, - skills_hot_reload_interval=skills_hot_reload_interval, - skills_hot_reload_debounce=skills_hot_reload_debounce, - agent_intro_autogen_enabled=agent_intro_autogen_enabled, - agent_intro_autogen_queue_interval=agent_intro_autogen_queue_interval, - agent_intro_autogen_max_tokens=agent_intro_autogen_max_tokens, - agent_intro_hash_path=agent_intro_hash_path, - searxng_url=searxng_url, - grok_search_enabled=grok_search_enabled, - use_proxy=use_proxy, - http_proxy=http_proxy, - https_proxy=https_proxy, - network_request_timeout=network_request_timeout, - network_request_retries=network_request_retries, - render_browser_max_concurrency=render_browser_max_concurrency, - api_xxapi_base_url=api_xxapi_base_url, - api_xingzhige_base_url=api_xingzhige_base_url, - api_jkyai_base_url=api_jkyai_base_url, - api_seniverse_base_url=api_seniverse_base_url, - weather_api_key=weather_api_key, - xxapi_api_token=xxapi_api_token, - mcp_config_path=mcp_config_path, - prefetch_tools=prefetch_tools, - prefetch_tools_hide=prefetch_tools_hide, - webui_url=webui_settings.url, - webui_port=webui_settings.port, - webui_password=webui_settings.password, - api=api_config, - code_delivery_enabled=code_delivery_enabled, - code_delivery_task_root=code_delivery_task_root, - code_delivery_docker_image=code_delivery_docker_image, - code_delivery_container_name_prefix=code_delivery_container_name_prefix, - code_delivery_container_name_suffix=code_delivery_container_name_suffix, - code_delivery_command_timeout=code_delivery_command_timeout, - code_delivery_max_command_output=code_delivery_max_command_output, - code_delivery_default_archive_format=code_delivery_default_archive_format, - code_delivery_max_archive_size_mb=code_delivery_max_archive_size_mb, - code_delivery_cleanup_on_finish=code_delivery_cleanup_on_finish, - code_delivery_cleanup_on_start=code_delivery_cleanup_on_start, - code_delivery_llm_max_retries=code_delivery_llm_max_retries, - code_delivery_notify_on_llm_failure=code_delivery_notify_on_llm_failure, - code_delivery_container_memory_limit=code_delivery_container_memory_limit, - code_delivery_container_cpu_limit=code_delivery_container_cpu_limit, - code_delivery_command_blacklist=code_delivery_command_blacklist, - messages_send_text_file_max_size_kb=messages_send_text_file_max_size_kb, - messages_send_url_file_max_size_mb=messages_send_url_file_max_size_mb, - bilibili_auto_extract_enabled=bilibili_auto_extract_enabled, - bilibili_cookie=bilibili_cookie, - bilibili_prefer_quality=bilibili_prefer_quality, - bilibili_max_duration=bilibili_max_duration, - bilibili_max_file_size=bilibili_max_file_size, - bilibili_oversize_strategy=bilibili_oversize_strategy, - bilibili_danmaku_enabled=bilibili_danmaku_enabled, - bilibili_danmaku_batch_size=bilibili_danmaku_batch_size, - bilibili_danmaku_max_count=bilibili_danmaku_max_count, - bilibili_auto_extract_group_ids=bilibili_auto_extract_group_ids, - bilibili_auto_extract_private_ids=bilibili_auto_extract_private_ids, - arxiv_auto_extract_enabled=arxiv_auto_extract_enabled, - arxiv_max_file_size=arxiv_max_file_size, - arxiv_auto_extract_group_ids=arxiv_auto_extract_group_ids, - arxiv_auto_extract_private_ids=arxiv_auto_extract_private_ids, - arxiv_auto_extract_max_items=arxiv_auto_extract_max_items, - arxiv_author_preview_limit=arxiv_author_preview_limit, - arxiv_summary_preview_chars=arxiv_summary_preview_chars, - github_auto_extract_enabled=github_auto_extract_enabled, - github_request_timeout_seconds=github_request_timeout_seconds, - github_auto_extract_group_ids=github_auto_extract_group_ids, - github_auto_extract_private_ids=github_auto_extract_private_ids, - github_auto_extract_max_items=github_auto_extract_max_items, - embedding_model=embedding_model, - rerank_model=rerank_model, - knowledge_enabled=knowledge_enabled, - knowledge_base_dir=knowledge_base_dir, - knowledge_auto_scan=knowledge_auto_scan, - knowledge_auto_embed=knowledge_auto_embed, - knowledge_scan_interval=knowledge_scan_interval, - knowledge_embed_batch_size=knowledge_embed_batch_size, - knowledge_chunk_size=knowledge_chunk_size, - knowledge_chunk_overlap=knowledge_chunk_overlap, - knowledge_default_top_k=knowledge_default_top_k, - knowledge_enable_rerank=knowledge_enable_rerank, - knowledge_rerank_top_k=knowledge_rerank_top_k, - cognitive=cognitive, - memes=memes, - message_batcher=message_batcher, - render_cache=render_cache, - naga=naga, - image_gen=image_gen, - models_image_gen=models_image_gen, - models_image_edit=models_image_edit, - ) - - @property - def bilibili_sessdata(self) -> str: - """兼容旧字段名,等价于 bilibili_cookie。""" - return self.bilibili_cookie - - def allowlist_mode_enabled(self) -> bool: - """是否启用白名单限制模式。""" - - return self.access_mode in {"allowlist", "legacy"} and ( - bool(self.allowed_group_ids) or bool(self.allowed_private_ids) - ) - - def group_allowlist_enabled(self) -> bool: - """群聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" - - return bool(self.allowed_group_ids) - - def private_allowlist_enabled(self) -> bool: - """私聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" - - return bool(self.allowed_private_ids) - - def blacklist_mode_enabled(self) -> bool: - """是否启用黑名单限制模式。""" - - return self.access_mode in {"blacklist", "legacy"} and ( - bool(self.blocked_group_ids) or bool(self.blocked_private_ids) - ) - - def access_control_enabled(self) -> bool: - """是否启用访问控制。""" - - return self.allowlist_mode_enabled() or self.blacklist_mode_enabled() - - def group_access_denied_reason(self, group_id: int) -> str | None: - """群聊访问被拒绝原因。 - - 返回: - - "blacklist": 命中 access.blocked_group_ids - - "allowlist": allowlist 模式下不在 access.allowed_group_ids - - None: 允许访问 - """ - - normalized_group_id = int(group_id) - if self.access_mode == "off": - return None - if self.access_mode == "blacklist": - if normalized_group_id in self._blocked_group_ids_set: - return "blacklist" - return None - if self.access_mode == "legacy": - if normalized_group_id in self._blocked_group_ids_set: - return "blacklist" - if not self.allowlist_mode_enabled(): - return None - if normalized_group_id not in self._allowed_group_ids_set: - return "allowlist" - return None - if not self.group_allowlist_enabled(): - return None - if normalized_group_id not in self._allowed_group_ids_set: - return "allowlist" - return None - - def is_group_allowed(self, group_id: int) -> bool: - """群聊是否允许收发消息。""" - - return self.group_access_denied_reason(group_id) is None - - def private_access_denied_reason(self, user_id: int) -> str | None: - """私聊访问被拒绝原因。""" - - normalized_user_id = int(user_id) - if self.access_mode == "off": - return None - if self.access_mode == "blacklist": - if normalized_user_id not in self._blocked_private_ids_set: - return None - if ( - self.superadmin_bypass_private_blacklist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - return "blacklist" - if self.access_mode == "legacy": - if normalized_user_id in self._blocked_private_ids_set: - if ( - self.superadmin_bypass_private_blacklist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - return "blacklist" - if not self.allowlist_mode_enabled(): - return None - if ( - self.superadmin_bypass_allowlist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - if normalized_user_id not in self._allowed_private_ids_set: - return "allowlist" - return None - if not self.private_allowlist_enabled(): - return None - if ( - self.superadmin_bypass_allowlist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - if normalized_user_id not in self._allowed_private_ids_set: - return "allowlist" - return None - - def is_private_allowed(self, user_id: int) -> bool: - """私聊是否允许收发消息。""" - - return self.private_access_denied_reason(user_id) is None - - def is_bilibili_auto_extract_allowed_group(self, group_id: int) -> bool: - """群聊是否允许 bilibili 自动提取。""" - if self._bilibili_group_ids_set: - return int(group_id) in self._bilibili_group_ids_set - # 功能白名单为空时跟随全局 access 控制 - return self.is_group_allowed(group_id) - - def is_bilibili_auto_extract_allowed_private(self, user_id: int) -> bool: - """私聊是否允许 bilibili 自动提取。""" - if self._bilibili_private_ids_set: - return int(user_id) in self._bilibili_private_ids_set - # 功能白名单为空时跟随全局 access 控制 - return self.is_private_allowed(user_id) - - def is_arxiv_auto_extract_allowed_group(self, group_id: int) -> bool: - """群聊是否允许 arXiv 自动提取。""" - if self._arxiv_group_ids_set: - return int(group_id) in self._arxiv_group_ids_set - return self.is_group_allowed(group_id) - - def is_arxiv_auto_extract_allowed_private(self, user_id: int) -> bool: - """私聊是否允许 arXiv 自动提取。""" - if self._arxiv_private_ids_set: - return int(user_id) in self._arxiv_private_ids_set - return self.is_private_allowed(user_id) - - def is_github_auto_extract_allowed_group(self, group_id: int) -> bool: - """群聊是否允许 GitHub 仓库自动提取。""" - if self._github_group_ids_set: - return int(group_id) in self._github_group_ids_set - return self.is_group_allowed(group_id) - - def is_github_auto_extract_allowed_private(self, user_id: int) -> bool: - """私聊是否允许 GitHub 仓库自动提取。""" - if self._github_private_ids_set: - return int(user_id) in self._github_private_ids_set - return self.is_private_allowed(user_id) - - def should_process_group_message(self, is_at_bot: bool) -> bool: - """是否处理该条群消息。""" - - if self.process_every_message: - return True - return bool(is_at_bot) - - def should_process_private_message(self) -> bool: - """是否处理私聊消息回复。""" - - return bool(self.process_private_message) - - def should_process_poke_message(self) -> bool: - """是否处理拍一拍触发。""" - - return bool(self.process_poke_message) - - def get_context_recent_messages_limit(self) -> int: - """获取上下文最近历史消息条数上限。""" - - limit = int(self.context_recent_messages_limit) - if limit < 0: - return 0 - return limit - - def security_check_enabled(self) -> bool: - """是否启用安全模型检查。""" - - return bool(self.security_model_enabled) - - def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: - changes: dict[str, tuple[Any, Any]] = {} - for field in fields(self): - name = field.name - old_value = getattr(self, name) - new_value = getattr(new_config, name) - if isinstance( - old_value, - ( - ChatModelConfig, - VisionModelConfig, - SecurityModelConfig, - AgentModelConfig, - GrokModelConfig, - ), - ): - changes.update(_update_dataclass(old_value, new_value, prefix=name)) - continue - if old_value != new_value: - setattr(self, name, new_value) - changes[name] = (old_value, new_value) - return changes - - def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: - new_config = Config.load(strict=strict) - return self.update_from(new_config) - - def add_admin(self, qq: int) -> bool: - if qq in self.admin_qqs: - return False - self.admin_qqs.append(qq) - local_admins = load_local_admins() - if qq not in local_admins: - local_admins.append(qq) - save_local_admins(local_admins) - return True - - def remove_admin(self, qq: int) -> bool: - if qq == self.superadmin_qq or qq not in self.admin_qqs: - return False - self.admin_qqs.remove(qq) - local_admins = load_local_admins() - if qq in local_admins: - local_admins.remove(qq) - save_local_admins(local_admins) - return True - - def is_superadmin(self, qq: int) -> bool: - return qq == self.superadmin_qq - - def is_admin(self, qq: int) -> bool: - return qq in self.admin_qqs diff --git a/src/Undefined/config/parsers/__init__.py b/src/Undefined/config/parsers/__init__.py new file mode 100644 index 00000000..faf5adeb --- /dev/null +++ b/src/Undefined/config/parsers/__init__.py @@ -0,0 +1,39 @@ +"""Model configuration parsers.""" + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass +from .agent import _parse_agent_model_config +from .chat import _parse_chat_model_config +from .embedding import _parse_embedding_model_config, _parse_rerank_model_config +from .grok import _parse_grok_model_config +from .helpers import _log_debug_info, _merge_admins, _verify_required_fields +from .historian import _parse_historian_model_config +from .image import ( + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, +) +from .naga import _parse_naga_model_config +from .pool import _parse_model_pool +from .security import _parse_security_model_config +from .summary import _parse_summary_model_config +from .vision import _parse_vision_model_config + +__all__ = [ + "_log_debug_info", + "_merge_admins", + "_parse_agent_model_config", + "_parse_chat_model_config", + "_parse_embedding_model_config", + "_parse_grok_model_config", + "_parse_historian_model_config", + "_parse_image_edit_model_config", + "_parse_image_gen_config", + "_parse_image_gen_model_config", + "_parse_model_pool", + "_parse_naga_model_config", + "_parse_rerank_model_config", + "_parse_security_model_config", + "_parse_summary_model_config", + "_parse_vision_model_config", + "_verify_required_fields", +] diff --git a/src/Undefined/config/parsers/agent.py b/src/Undefined/config/parsers/agent.py new file mode 100644 index 00000000..5170d181 --- /dev/null +++ b/src/Undefined/config/parsers/agent.py @@ -0,0 +1,163 @@ +"""Agent model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + AgentModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) +from .pool import _parse_model_pool + +logger = logging.getLogger(__name__) + + +def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "agent", "queue_interval_seconds"), + "AGENT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="agent", + include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "agent", "AGENT_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "agent", "AGENT_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "prompt_cache_enabled"), + "AGENT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "reasoning_enabled"), + "AGENT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "agent", "reasoning_effort"), + "AGENT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "stream_enabled"), + "AGENT_MODEL_STREAM_ENABLED", + ), + False, + ) + context_window_tokens = _resolve_context_window_tokens( + data, "agent", "AGENT_MODEL_CONTEXT_WINDOW_TOKENS" + ) + config = AgentModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" + ), + 4096, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "agent", "thinking_enabled"), + "AGENT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "agent", "thinking_budget_tokens"), + "AGENT_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "agent", "reasoning_effort_style"), + "AGENT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "agent"), + ) + config.pool = _parse_model_pool(data, "agent", config) + return config diff --git a/src/Undefined/config/parsers/chat.py b/src/Undefined/config/parsers/chat.py new file mode 100644 index 00000000..6f1028fa --- /dev/null +++ b/src/Undefined/config/parsers/chat.py @@ -0,0 +1,165 @@ +"""Chat model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + ChatModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) +from .pool import _parse_model_pool + +logger = logging.getLogger(__name__) + + +# 解析 [models.chat]:主对话模型 API、thinking/reasoning、队列间隔与模型池 +def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: + # 该模型独立的发车间隔(秒),0=立即发车 + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "chat", "queue_interval_seconds"), + "CHAT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + # DeepSeek/兼容模型的 thinking 预算与 tool_call 互斥开关 + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="chat", + include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + # OpenAI 兼容层:chat_completions / responses 及 reasoning 回放策略 + api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "chat", "CHAT_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "chat", "CHAT_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "prompt_cache_enabled"), + "CHAT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "reasoning_enabled"), + "CHAT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "chat", "reasoning_effort"), + "CHAT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "stream_enabled"), + "CHAT_MODEL_STREAM_ENABLED", + ), + False, + ) + context_window_tokens = _resolve_context_window_tokens( + data, "chat", "CHAT_MODEL_CONTEXT_WINDOW_TOKENS" + ) + config = ChatModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), + "", + ), + context_window_tokens=context_window_tokens, + max_tokens=_coerce_int( + _get_value(data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "chat", "thinking_enabled"), + "CHAT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "chat", "thinking_budget_tokens"), + "CHAT_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "chat", "reasoning_effort_style"), + "CHAT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "chat"), + ) + config.pool = _parse_model_pool(data, "chat", config) + return config diff --git a/src/Undefined/config/parsers/embedding.py b/src/Undefined/config/parsers/embedding.py new file mode 100644 index 00000000..8c5f3aa7 --- /dev/null +++ b/src/Undefined/config/parsers/embedding.py @@ -0,0 +1,106 @@ +"""Embedding model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + EmbeddingModelConfig, + RerankModelConfig, +) +from ..resolvers import ( + _resolve_context_window_tokens, +) + +logger = logging.getLogger(__name__) + + +def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: + return EmbeddingModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" + ), + "", + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + _get_value( + data, ("models", "embedding", "queue_interval_seconds"), None + ), + 0.0, + ), + 0.0, + ), + dimensions=_coerce_int( + _get_value(data, ("models", "embedding", "dimensions"), None), 0 + ) + or None, + query_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "query_instruction"), None), "" + ), + context_window_tokens=_resolve_context_window_tokens( + data, "embedding", "EMBEDDING_MODEL_CONTEXT_WINDOW_TOKENS" + ), + document_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "document_instruction"), None), + "", + ), + request_params=_get_model_request_params(data, "embedding"), + ) + + +def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), + 0.0, + ), + 0.0, + ) + return RerankModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + context_window_tokens=_resolve_context_window_tokens( + data, "rerank", "RERANK_MODEL_CONTEXT_WINDOW_TOKENS" + ), + query_instruction=_coerce_str( + _get_value(data, ("models", "rerank", "query_instruction"), None), "" + ), + request_params=_get_model_request_params(data, "rerank"), + ) diff --git a/src/Undefined/config/parsers/grok.py b/src/Undefined/config/parsers/grok.py new file mode 100644 index 00000000..efcbbadb --- /dev/null +++ b/src/Undefined/config/parsers/grok.py @@ -0,0 +1,129 @@ +"""Grok model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + GrokModelConfig, +) +from ..resolvers import ( + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, +) + +logger = logging.getLogger(__name__) + + +def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "grok", "queue_interval_seconds"), + "GROK_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + context_window_tokens = _resolve_context_window_tokens( + data, "grok", "GROK_MODEL_CONTEXT_WINDOW_TOKENS" + ) + return GrokModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS"), + 8192, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_enabled"), + "GROK_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "grok", "thinking_budget_tokens"), + "GROK_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_include_budget"), + "GROK_MODEL_THINKING_INCLUDE_BUDGET", + ), + True, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "grok", "reasoning_effort_style"), + "GROK_MODEL_REASONING_EFFORT_STYLE", + ), + ), + prompt_cache_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "prompt_cache_enabled"), + "GROK_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ), + reasoning_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "reasoning_enabled"), + "GROK_MODEL_REASONING_ENABLED", + ), + False, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + data, + ("models", "grok", "reasoning_effort"), + "GROK_MODEL_REASONING_EFFORT", + ), + "medium", + ), + stream_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "stream_enabled"), + "GROK_MODEL_STREAM_ENABLED", + ), + False, + ), + request_params=_get_model_request_params(data, "grok"), + ) diff --git a/src/Undefined/config/parsers/helpers.py b/src/Undefined/config/parsers/helpers.py new file mode 100644 index 00000000..19c82989 --- /dev/null +++ b/src/Undefined/config/parsers/helpers.py @@ -0,0 +1,122 @@ +"""Admin merge, validation, and debug helpers for model parsers.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging + +from ..admin import load_local_admins +from ..models import ( + AgentModelConfig, + ChatModelConfig, + EmbeddingModelConfig, + GrokModelConfig, + SecurityModelConfig, + VisionModelConfig, +) + +# 合并多来源配置 +logger = logging.getLogger(__name__) + + +# 合并多来源配置 +def _merge_admins(superadmin_qq: int, admin_qqs: list[int]) -> tuple[int, list[int]]: + # admins.json 与 config.toml 的 admin_qq 合并,去重后超管必在列表中 + local_admins = load_local_admins() + all_admins = list(set(admin_qqs + local_admins)) + if superadmin_qq and superadmin_qq not in all_admins: + all_admins.append(superadmin_qq) + # 校验必填字段 + return superadmin_qq, all_admins + + +# 校验必填字段 +def _verify_required_fields( + bot_qq: int, + superadmin_qq: int, + onebot_ws_url: str, + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + agent_model: AgentModelConfig, + knowledge_enabled: bool, + embedding_model: EmbeddingModelConfig, +) -> None: + missing: list[str] = [] + if bot_qq <= 0: + missing.append("core.bot_qq") + if superadmin_qq <= 0: + missing.append("core.superadmin_qq") + if not onebot_ws_url: + missing.append("onebot.ws_url") + if not chat_model.api_url: + missing.append("models.chat.api_url") + if not chat_model.api_key: + missing.append("models.chat.api_key") + if not chat_model.model_name: + missing.append("models.chat.model_name") + if not vision_model.api_url: + missing.append("models.vision.api_url") + if not vision_model.api_key: + missing.append("models.vision.api_key") + if not vision_model.model_name: + missing.append("models.vision.model_name") + if not agent_model.api_url: + missing.append("models.agent.api_url") + if not agent_model.api_key: + missing.append("models.agent.api_key") + if not agent_model.model_name: + missing.append("models.agent.model_name") + if knowledge_enabled: + if not embedding_model.api_url: + missing.append("models.embedding.api_url") + if not embedding_model.model_name: + missing.append("models.embedding.model_name") + if missing: + # 输出调试/诊断日志 + raise ValueError(f"缺少必需配置: {', '.join(missing)}") + + +# 输出调试/诊断日志 +def _log_debug_info( + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + security_model: SecurityModelConfig, + naga_model: SecurityModelConfig, + agent_model: AgentModelConfig, + summary_model: AgentModelConfig, + grok_model: GrokModelConfig, +) -> None: + configs: list[ + tuple[ + str, + ChatModelConfig + | VisionModelConfig + | SecurityModelConfig + | AgentModelConfig + | GrokModelConfig, + ] + ] = [ + ("chat", chat_model), + ("vision", vision_model), + ("security", security_model), + ("naga", naga_model), + ("agent", agent_model), + ("summary", summary_model), + ("grok", grok_model), + ] + for name, cfg in configs: + logger.debug( + "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", + name, + cfg.model_name, + cfg.api_url, + bool(cfg.api_key), + getattr(cfg, "api_mode", "chat_completions"), + cfg.thinking_enabled, + getattr(cfg, "reasoning_enabled", False), + getattr(cfg, "reasoning_effort", "medium"), + getattr(cfg, "thinking_tool_call_compat", False), + getattr(cfg, "responses_tool_choice_compat", False), + getattr(cfg, "responses_force_stateless_replay", False), + ) diff --git a/src/Undefined/config/parsers/historian.py b/src/Undefined/config/parsers/historian.py new file mode 100644 index 00000000..74e06bf4 --- /dev/null +++ b/src/Undefined/config/parsers/historian.py @@ -0,0 +1,144 @@ +"""Historian model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + AgentModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_historian_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> AgentModelConfig: + h = data.get("models", {}).get("historian", {}) + if not isinstance(h, dict) or not h: + return fallback + queue_interval_seconds = _coerce_float( + h.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"historian": h}}, + model_name="historian", + include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "prompt_cache_enabled"), + "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + context_window_tokens = _coerce_int( + h.get("context_window_tokens"), fallback.context_window_tokens + ) + return AgentModelConfig( + api_url=_coerce_str(h.get("api_url"), fallback.api_url), + api_key=_coerce_str(h.get("api_key"), fallback.api_key), + model_name=_coerce_str(h.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + h.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort_style"), + "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=_coerce_bool( + h.get("reasoning_content_replay"), fallback.reasoning_content_replay + ), + system_prompt_as_user=_coerce_bool( + h.get("system_prompt_as_user"), fallback.system_prompt_as_user + ), + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_enabled"), + "HISTORIAN_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort"), + "HISTORIAN_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + stream_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "stream_enabled"), + "HISTORIAN_MODEL_STREAM_ENABLED", + ), + fallback.stream_enabled, + ), + request_params=merge_request_params( + fallback.request_params, + h.get("request_params"), + ), + ) diff --git a/src/Undefined/config/parsers/image.py b/src/Undefined/config/parsers/image.py new file mode 100644 index 00000000..cccf8460 --- /dev/null +++ b/src/Undefined/config/parsers/image.py @@ -0,0 +1,120 @@ +"""Image model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, +) +from ..models import ( + ImageGenConfig, + ImageGenModelConfig, +) + +logger = logging.getLogger(__name__) + + +def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_gen] 生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" + ), + "", + ), + context_window_tokens=_coerce_int( + _get_value( + data, + ("models", "image_gen", "context_window_tokens"), + None, + ), + 0, + ), + request_params=_get_model_request_params(data, "image_gen"), + ) + + +def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_edit] 参考图生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_url"), + "IMAGE_EDIT_MODEL_API_URL", + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_key"), + "IMAGE_EDIT_MODEL_API_KEY", + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, + ("models", "image_edit", "model_name"), + "IMAGE_EDIT_MODEL_NAME", + ), + "", + ), + context_window_tokens=_coerce_int( + _get_value( + data, + ("models", "image_edit", "context_window_tokens"), + None, + ), + 0, + ), + request_params=_get_model_request_params(data, "image_edit"), + ) + + +def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: + """解析 [image_gen] 生图工具配置""" + return ImageGenConfig( + provider=_coerce_str( + _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), + "xingzhige", + ), + xingzhige_size=_coerce_str( + _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" + ), + openai_size=_coerce_str( + _get_value(data, ("image_gen", "openai_size"), None), "" + ), + openai_quality=_coerce_str( + _get_value(data, ("image_gen", "openai_quality"), None), "" + ), + openai_style=_coerce_str( + _get_value(data, ("image_gen", "openai_style"), None), "" + ), + openai_timeout=_coerce_float( + _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 + ), + ) diff --git a/src/Undefined/config/parsers/naga.py b/src/Undefined/config/parsers/naga.py new file mode 100644 index 00000000..920862e0 --- /dev/null +++ b/src/Undefined/config/parsers/naga.py @@ -0,0 +1,206 @@ +"""Naga model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + SecurityModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_naga_model_config( + data: dict[str, Any], security_model: SecurityModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "naga", "queue_interval_seconds"), + "NAGA_MODEL_QUEUE_INTERVAL", + ), + security_model.queue_interval_seconds, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="naga", + include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, + "naga", + "NAGA_MODEL_REASONING_CONTENT_REPLAY", + default=security_model.reasoning_content_replay, + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, + "naga", + "NAGA_MODEL_SYSTEM_PROMPT_AS_USER", + default=security_model.system_prompt_as_user, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "prompt_cache_enabled"), + "NAGA_MODEL_PROMPT_CACHE_ENABLED", + ), + getattr(security_model, "prompt_cache_enabled", True), + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "reasoning_enabled"), + "NAGA_MODEL_REASONING_ENABLED", + ), + getattr(security_model, "reasoning_enabled", False), + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "naga", "reasoning_effort"), + "NAGA_MODEL_REASONING_EFFORT", + ), + getattr(security_model, "reasoning_effort", "medium"), + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "stream_enabled"), + "NAGA_MODEL_STREAM_ENABLED", + ), + getattr(security_model, "stream_enabled", False), + ) + + if api_url and api_key and model_name: + context_window_tokens = _resolve_context_window_tokens( + data, + "naga", + "NAGA_MODEL_CONTEXT_WINDOW_TOKENS", + default=security_model.context_window_tokens, + ) + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "max_tokens"), + "NAGA_MODEL_MAX_TOKENS", + ), + 160, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "naga", "thinking_enabled"), + "NAGA_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "thinking_budget_tokens"), + "NAGA_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "naga", "reasoning_effort_style"), + "NAGA_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "naga"), + ) + + logger.info( + "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" + ) + return SecurityModelConfig( + api_url=security_model.api_url, + api_key=security_model.api_key, + model_name=security_model.model_name, + max_tokens=security_model.max_tokens, + context_window_tokens=security_model.context_window_tokens, + queue_interval_seconds=security_model.queue_interval_seconds, + api_mode=security_model.api_mode, + thinking_enabled=security_model.thinking_enabled, + thinking_budget_tokens=security_model.thinking_budget_tokens, + thinking_include_budget=security_model.thinking_include_budget, + reasoning_effort_style=security_model.reasoning_effort_style, + thinking_tool_call_compat=security_model.thinking_tool_call_compat, + reasoning_content_replay=security_model.reasoning_content_replay, + system_prompt_as_user=security_model.system_prompt_as_user, + responses_tool_choice_compat=security_model.responses_tool_choice_compat, + responses_force_stateless_replay=security_model.responses_force_stateless_replay, + prompt_cache_enabled=security_model.prompt_cache_enabled, + reasoning_enabled=security_model.reasoning_enabled, + reasoning_effort=security_model.reasoning_effort, + stream_enabled=security_model.stream_enabled, + request_params=merge_request_params(security_model.request_params), + ) diff --git a/src/Undefined/config/parsers/pool.py b/src/Undefined/config/parsers/pool.py new file mode 100644 index 00000000..72abcce4 --- /dev/null +++ b/src/Undefined/config/parsers/pool.py @@ -0,0 +1,142 @@ +"""Model pool parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _normalize_queue_interval, + _VALID_API_MODES, +) +from ..models import AgentModelConfig, ChatModelConfig, ModelPool, ModelPoolEntry +from ..resolvers import ( + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, +) + +logger = logging.getLogger(__name__) + + +def _parse_model_pool( + data: dict[str, Any], + model_section: str, + primary_config: ChatModelConfig | AgentModelConfig, +) -> ModelPool | None: + """解析模型池配置,缺省字段继承 primary_config""" + pool_data = data.get("models", {}).get(model_section, {}).get("pool") + if not isinstance(pool_data, dict): + return None + + enabled = _coerce_bool(pool_data.get("enabled"), False) + strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() + if strategy not in ("default", "round_robin", "random"): + strategy = "default" + + raw_models = pool_data.get("models") + if not isinstance(raw_models, list): + return ModelPool(enabled=enabled, strategy=strategy) + + entries: list[ModelPoolEntry] = [] + for item in raw_models: + if not isinstance(item, dict): + continue + name = _coerce_str(item.get("model_name"), "").strip() + if not name: + continue + entries.append( + ModelPoolEntry( + api_url=_coerce_str(item.get("api_url"), primary_config.api_url), + api_key=_coerce_str(item.get("api_key"), primary_config.api_key), + model_name=name, + context_window_tokens=_coerce_int( + item.get("context_window_tokens"), + primary_config.context_window_tokens, + ), + max_tokens=_coerce_int( + item.get("max_tokens"), primary_config.max_tokens + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + item.get("queue_interval_seconds"), + primary_config.queue_interval_seconds, + ), + primary_config.queue_interval_seconds, + ), + api_mode=( + _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + ) + if _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + in _VALID_API_MODES + else primary_config.api_mode, + thinking_enabled=_coerce_bool( + item.get("thinking_enabled"), primary_config.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + item.get("thinking_budget_tokens"), + primary_config.thinking_budget_tokens, + ), + thinking_include_budget=_coerce_bool( + item.get("thinking_include_budget"), + primary_config.thinking_include_budget, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + item.get("reasoning_effort_style"), + primary_config.reasoning_effort_style, + ), + thinking_tool_call_compat=_coerce_bool( + item.get("thinking_tool_call_compat"), + primary_config.thinking_tool_call_compat, + ), + reasoning_content_replay=_coerce_bool( + item.get("reasoning_content_replay"), + primary_config.reasoning_content_replay, + ), + system_prompt_as_user=_coerce_bool( + item.get("system_prompt_as_user"), + primary_config.system_prompt_as_user, + ), + responses_tool_choice_compat=_coerce_bool( + item.get("responses_tool_choice_compat"), + primary_config.responses_tool_choice_compat, + ), + responses_force_stateless_replay=_coerce_bool( + item.get("responses_force_stateless_replay"), + primary_config.responses_force_stateless_replay, + ), + prompt_cache_enabled=_coerce_bool( + item.get("prompt_cache_enabled"), + primary_config.prompt_cache_enabled, + ), + reasoning_enabled=_coerce_bool( + item.get("reasoning_enabled"), + primary_config.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + item.get("reasoning_effort"), + primary_config.reasoning_effort, + ), + stream_enabled=_coerce_bool( + item.get("stream_enabled"), + getattr(primary_config, "stream_enabled", False), + ), + request_params=merge_request_params( + primary_config.request_params, + item.get("request_params"), + ), + ) + ) + + return ModelPool(enabled=enabled, strategy=strategy, models=entries) diff --git a/src/Undefined/config/parsers/security.py b/src/Undefined/config/parsers/security.py new file mode 100644 index 00000000..f5044a91 --- /dev/null +++ b/src/Undefined/config/parsers/security.py @@ -0,0 +1,196 @@ +"""Security model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + ChatModelConfig, + SecurityModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_security_model_config( + data: dict[str, Any], chat_model: ChatModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "security", "queue_interval_seconds"), + "SECURITY_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="security", + include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "security", "SECURITY_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "security", "SECURITY_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "prompt_cache_enabled"), + "SECURITY_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "reasoning_enabled"), + "SECURITY_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "security", "reasoning_effort"), + "SECURITY_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "stream_enabled"), + "SECURITY_MODEL_STREAM_ENABLED", + ), + False, + ) + + context_window_tokens = _resolve_context_window_tokens( + data, "security", "SECURITY_MODEL_CONTEXT_WINDOW_TOKENS" + ) + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "max_tokens"), + "SECURITY_MODEL_MAX_TOKENS", + ), + 100, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "security", "thinking_enabled"), + "SECURITY_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "thinking_budget_tokens"), + "SECURITY_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "security", "reasoning_effort_style"), + "SECURITY_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "security"), + ) + + logger.warning("未配置安全模型,将使用对话模型作为后备") + return SecurityModelConfig( + api_url=chat_model.api_url, + api_key=chat_model.api_key, + model_name=chat_model.model_name, + context_window_tokens=chat_model.context_window_tokens, + max_tokens=chat_model.max_tokens, + queue_interval_seconds=chat_model.queue_interval_seconds, + api_mode=chat_model.api_mode, + thinking_enabled=False, + thinking_budget_tokens=0, + thinking_include_budget=True, + reasoning_effort_style="openai", + thinking_tool_call_compat=chat_model.thinking_tool_call_compat, + reasoning_content_replay=chat_model.reasoning_content_replay, + system_prompt_as_user=chat_model.system_prompt_as_user, + responses_tool_choice_compat=chat_model.responses_tool_choice_compat, + responses_force_stateless_replay=chat_model.responses_force_stateless_replay, + prompt_cache_enabled=chat_model.prompt_cache_enabled, + reasoning_enabled=chat_model.reasoning_enabled, + reasoning_effort=chat_model.reasoning_effort, + stream_enabled=chat_model.stream_enabled, + request_params=merge_request_params(chat_model.request_params), + ) diff --git a/src/Undefined/config/parsers/summary.py b/src/Undefined/config/parsers/summary.py new file mode 100644 index 00000000..dd7757dc --- /dev/null +++ b/src/Undefined/config/parsers/summary.py @@ -0,0 +1,147 @@ +"""Summary model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + AgentModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_summary_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> tuple[AgentModelConfig, bool]: + s = data.get("models", {}).get("summary", {}) + if not isinstance(s, dict) or not s: + return fallback, False + queue_interval_seconds = _coerce_float( + s.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"summary": s}}, + model_name="summary", + include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "prompt_cache_enabled"), + "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + context_window_tokens = _coerce_int( + s.get("context_window_tokens"), fallback.context_window_tokens + ) + return ( + AgentModelConfig( + api_url=_coerce_str(s.get("api_url"), fallback.api_url), + api_key=_coerce_str(s.get("api_key"), fallback.api_key), + model_name=_coerce_str(s.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + s.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort_style"), + "SUMMARY_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=_coerce_bool( + s.get("reasoning_content_replay"), fallback.reasoning_content_replay + ), + system_prompt_as_user=_coerce_bool( + s.get("system_prompt_as_user"), fallback.system_prompt_as_user + ), + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_enabled"), + "SUMMARY_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort"), + "SUMMARY_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + stream_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "stream_enabled"), + "SUMMARY_MODEL_STREAM_ENABLED", + ), + fallback.stream_enabled, + ), + request_params=merge_request_params( + fallback.request_params, + s.get("request_params"), + ), + ), + True, + ) diff --git a/src/Undefined/config/parsers/vision.py b/src/Undefined/config/parsers/vision.py new file mode 100644 index 00000000..9299c980 --- /dev/null +++ b/src/Undefined/config/parsers/vision.py @@ -0,0 +1,162 @@ +"""Vision model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + VisionModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "vision", "queue_interval_seconds"), + "VISION_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="vision", + include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "vision", "VISION_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "vision", "VISION_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "prompt_cache_enabled"), + "VISION_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "reasoning_enabled"), + "VISION_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "vision", "reasoning_effort"), + "VISION_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "stream_enabled"), + "VISION_MODEL_STREAM_ENABLED", + ), + False, + ) + context_window_tokens = _resolve_context_window_tokens( + data, "vision", "VISION_MODEL_CONTEXT_WINDOW_TOKENS" + ) + return VisionModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "max_tokens"), + "VISION_MODEL_MAX_TOKENS", + ), + 8192, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "vision", "thinking_enabled"), + "VISION_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "thinking_budget_tokens"), + "VISION_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "vision", "reasoning_effort_style"), + "VISION_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "vision"), + ) diff --git a/src/Undefined/config/toml_io.py b/src/Undefined/config/toml_io.py new file mode 100644 index 00000000..b9c96d46 --- /dev/null +++ b/src/Undefined/config/toml_io.py @@ -0,0 +1,110 @@ +"""TOML file I/O and environment bootstrap for configuration.""" + +from __future__ import annotations + +# 配置 I/O:读取 config.toml、加载 .env bootstrap、格式化解析错误 + +import logging +import os +import re +import tomllib +from pathlib import Path +from typing import Any, IO, Optional + +try: + from dotenv import load_dotenv +except Exception: # pragma: no cover + StrPath = str | os.PathLike[str] + + def load_dotenv( + dotenv_path: StrPath | None = None, + stream: IO[str] | None = None, + verbose: bool = False, + override: bool = False, + interpolate: bool = True, + encoding: str | None = "utf-8", + ) -> bool: + return False + + +logger = logging.getLogger(__name__) + +CONFIG_PATH = Path("config.toml") + +__all__ = ["CONFIG_PATH", "_load_env", "load_toml_data"] + + +def _load_env() -> None: + # dotenv 仅 bootstrap 环境变量,不覆盖已有 os.environ(TOML 仍优先于 env) + try: + load_dotenv() + except Exception: + logger.debug("加载 .env 失败,继续使用 config.toml", exc_info=True) + + +def _build_toml_decode_hint(line: str) -> str: + """根据出错行内容生成 TOML 修复提示。""" + hints: list[str] = [] + if "\\" in line: + hints.append( + 'Windows 路径建议用单引号(不转义)或双反斜杠,或直接用正斜杠,例如:path = \'D:\\AI\\bot\' / path = "D:\\\\AI\\\\bot" / path = "D:/AI/bot"' + ) + hints.append('多行文本请用三引号,例如:prompt = """..."""') + return ";".join(hints) + + +def _format_toml_decode_error( + path: Path, text: str, exc: tomllib.TOMLDecodeError +) -> str: + """将 tomllib 解析异常格式化为带行号、caret 与中文提示的可读消息。""" + lineno: int | None = getattr(exc, "lineno", None) + colno: int | None = getattr(exc, "colno", None) + if not isinstance(lineno, int) or not isinstance(colno, int): + match = re.search(r"\(at line (\d+), column (\d+)\)", str(exc)) + if match: + lineno = int(match.group(1)) + colno = int(match.group(2)) + + if isinstance(lineno, int) and lineno > 0: + lines = text.splitlines() + line = lines[lineno - 1] if 0 <= (lineno - 1) < len(lines) else "" + caret_pos = max((colno or 1) - 1, 0) + caret = " " * min(caret_pos, len(line)) + "^" + hint = _build_toml_decode_hint(line) + location = f"line={lineno} col={colno or 1}" + return f"{exc} ({location})\n> {line}\n {caret}\n提示:{hint}" + return str(exc) + + +def load_toml_data( + config_path: Optional[Path] = None, *, strict: bool = False +) -> dict[str, Any]: + """读取 config.toml 并返回字典;文件不存在时返回空 dict。""" + path = config_path or CONFIG_PATH + if not path.exists(): + return {} + text = "" + try: + # utf-8-sig 兼容带 BOM 的编辑器输出 + text = path.read_bytes().decode("utf-8-sig") + data = tomllib.loads(text) + if isinstance(data, dict): + return data + logger.warning("config.toml 内容不是对象结构") + return {} + except tomllib.TOMLDecodeError as exc: + message = _format_toml_decode_error(path, text, exc) + logger.error("config.toml 解析失败 (%s): %s", path.resolve(), message) + if strict: + raise ValueError(message) from exc + return {} + except UnicodeDecodeError as exc: + logger.error("config.toml 编码错误 (%s): %s", path.resolve(), exc) + if strict: + raise ValueError(str(exc)) from exc + return {} + except OSError as exc: + logger.error("读取 config.toml 失败: %s", exc) + if strict: + raise ValueError(str(exc)) from exc + return {} diff --git a/tests/test_cli_startup_compat.py b/tests/test_cli_startup_compat.py new file mode 100644 index 00000000..82f0f557 --- /dev/null +++ b/tests/test_cli_startup_compat.py @@ -0,0 +1,128 @@ +"""CLI 启动兼容性与根包 import 行为测试(Phase 0 起持续有效)。""" + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path +from typing import Iterator + +import pytest + +_MINIMAL_CONFIG = """ +[onebot] +ws_url = "ws://127.0.0.1:3999" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "gpt-test" +""" + + +@pytest.fixture +def isolated_undefined_modules() -> Iterator[None]: + """测试前后清理 Undefined 相关 sys.modules 条目。""" + prefix = "Undefined" + saved = { + k: v + for k, v in sys.modules.items() + if k == prefix or k.startswith(f"{prefix}.") + } + for key in saved: + del sys.modules[key] + yield + for key in list(sys.modules): + if key == prefix or key.startswith(f"{prefix}."): + del sys.modules[key] + sys.modules.update(saved) + + +@pytest.fixture +def reset_config_singleton() -> Iterator[None]: + """重置 config 模块全局单例,避免测试间污染。""" + import Undefined.config as config_module + + saved_config = config_module._config + saved_manager = config_module._config_manager + config_module._config = None + config_module._config_manager = None + yield + config_module._config = saved_config + config_module._config_manager = saved_manager + + +def test_entry_point_undefined_main_run_importable() -> None: + from Undefined.main import run + + assert callable(run) + + +def test_entry_point_undefined_webui_run_importable() -> None: + from Undefined.webui import run + + assert callable(run) + + +def test_import_undefined_does_not_eagerly_load_onebot_or_handlers( + isolated_undefined_modules: None, +) -> None: + import Undefined # noqa: F401 + + assert "Undefined.onebot" not in sys.modules + assert "Undefined.handlers" not in sys.modules + assert "Undefined.main" not in sys.modules + + +def test_get_config_reads_config_toml_from_cwd( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + reset_config_singleton: None, +) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "config.toml").write_text(_MINIMAL_CONFIG, encoding="utf-8") + + import Undefined.config as config_module + + config_module._config = None + config_module._config_manager = None + + cfg = config_module.get_config(strict=False) + assert cfg.onebot_ws_url == "ws://127.0.0.1:3999" + assert cfg.chat_model.model_name == "gpt-test" + + +def test_root_get_config_same_as_subpackage(isolated_undefined_modules: None) -> None: + import Undefined + + from Undefined.config import get_config as subpackage_get_config + + assert Undefined.get_config is subpackage_get_config + + +def test_import_undefined_does_not_import_webui_app( + isolated_undefined_modules: None, +) -> None: + import Undefined # noqa: F401 + + assert "Undefined.webui.app" not in sys.modules + + +def test_webui_run_deferred_import(monkeypatch: pytest.MonkeyPatch) -> None: + """webui.run 应延迟加载 app,import webui 包本身不拉起重型依赖。""" + import Undefined.webui as webui_module + + original_import = importlib.import_module + app_imported = False + + def tracking_import(name: str, package: object | None = None) -> object: + nonlocal app_imported + if name == "Undefined.webui.app": + app_imported = True + package_name = package if isinstance(package, str) else None + return original_import(name, package_name) + + monkeypatch.setattr(importlib, "import_module", tracking_import) + importlib.reload(webui_module) + + assert not app_imported + assert callable(webui_module.run) diff --git a/tests/test_config_env_only.py b/tests/test_config_env_only.py new file mode 100644 index 00000000..42049d2a --- /dev/null +++ b/tests/test_config_env_only.py @@ -0,0 +1,61 @@ +from __future__ import annotations + + +import pytest + +from Undefined.config import Config + + +def test_env_only_chat_model_fields(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONEBOT_WS_URL", "ws://env-only:3001") + monkeypatch.setenv("CHAT_MODEL_API_URL", "https://env.example/v1") + monkeypatch.setenv("CHAT_MODEL_API_KEY", "env-key") + monkeypatch.setenv("CHAT_MODEL_NAME", "env-chat") + monkeypatch.setenv("VISION_MODEL_API_URL", "https://env.example/v1") + monkeypatch.setenv("VISION_MODEL_API_KEY", "env-key") + monkeypatch.setenv("VISION_MODEL_NAME", "env-vision") + monkeypatch.setenv("AGENT_MODEL_API_URL", "https://env.example/v1") + monkeypatch.setenv("AGENT_MODEL_API_KEY", "env-key") + monkeypatch.setenv("AGENT_MODEL_NAME", "env-agent") + + cfg = Config.from_mapping({}, strict=False) + assert cfg.onebot_ws_url == "ws://env-only:3001" + assert cfg.chat_model.model_name == "env-chat" + assert cfg.vision_model.model_name == "env-vision" + assert cfg.agent_model.model_name == "env-agent" + + +def test_env_overridden_by_mapping(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CHAT_MODEL_NAME", "from-env") + cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://x"}, + "models": { + "chat": { + "api_url": "u", + "api_key": "k", + "model_name": "from-toml", + }, + "vision": {"api_url": "u", "api_key": "k", "model_name": "v"}, + "agent": {"api_url": "u", "api_key": "k", "model_name": "a"}, + }, + }, + strict=False, + ) + assert cfg.chat_model.model_name == "from-toml" + + +def test_http_proxy_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("HTTP_PROXY", "http://127.0.0.1:7890") + cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://x"}, + "models": { + "chat": {"api_url": "u", "api_key": "k", "model_name": "m"}, + "vision": {"api_url": "u", "api_key": "k", "model_name": "v"}, + "agent": {"api_url": "u", "api_key": "k", "model_name": "a"}, + }, + }, + strict=False, + ) + assert cfg.http_proxy == "http://127.0.0.1:7890" diff --git a/tests/test_config_env_registry.py b/tests/test_config_env_registry.py new file mode 100644 index 00000000..385f3351 --- /dev/null +++ b/tests/test_config_env_registry.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from Undefined.config.env_registry import ( + ENV_ALTERNATES, + ENV_REGISTRY, + all_env_mappings, + env_key_for_path, +) + + +def test_registry_contains_core_paths() -> None: + assert ENV_REGISTRY[("core", "bot_qq")] == "BOT_QQ" + assert ENV_REGISTRY[("onebot", "ws_url")] == "ONEBOT_WS_URL" + assert ENV_REGISTRY[("models", "chat", "api_url")] == "CHAT_MODEL_API_URL" + + +def test_env_key_for_path_lookup() -> None: + assert env_key_for_path(("models", "agent", "model_name")) == "AGENT_MODEL_NAME" + assert env_key_for_path(("nonexistent",)) is None + + +def test_all_env_mappings_is_copy() -> None: + snapshot = all_env_mappings() + snapshot[("fake",)] = "FAKE" + assert ("fake",) not in ENV_REGISTRY + + +def test_alternate_env_keys_documented() -> None: + assert "HTTP_PROXY" in ENV_ALTERNATES + assert "EASTER_EGG_CALL_MESSAGE_MODE" in ENV_ALTERNATES + + +def test_registry_has_model_context_window_entries() -> None: + assert ("models", "chat", "context_window_tokens") in ENV_REGISTRY diff --git a/tests/test_config_from_mapping.py b/tests/test_config_from_mapping.py new file mode 100644 index 00000000..70e0dbd0 --- /dev/null +++ b/tests/test_config_from_mapping.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from Undefined.config import Config, set_config + + +_MINIMAL_MAPPING = { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "gpt-test", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "vision-test", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "agent-test", + }, + }, +} + + +def test_from_mapping_builds_without_toml() -> None: + cfg = Config.from_mapping(_MINIMAL_MAPPING, strict=False) + assert cfg.onebot_ws_url == "ws://127.0.0.1:3001" + assert cfg.chat_model.model_name == "gpt-test" + + +def test_builder_with_mapping() -> None: + cfg = Config.builder().with_mapping(_MINIMAL_MAPPING).build(strict=False) + assert cfg.agent_model.model_name == "agent-test" + + +def test_set_config_injects_singleton(monkeypatch: pytest.MonkeyPatch) -> None: + import Undefined.config as config_pkg + + monkeypatch.setattr(config_pkg, "_config", None) + cfg = Config.from_mapping(_MINIMAL_MAPPING, strict=False) + set_config(cfg) + assert config_pkg.get_config(strict=False) is cfg + + +def test_from_mapping_matches_load(tmp_path: Path) -> None: + toml = """ +[onebot] +ws_url = "ws://127.0.0.1:3001" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "gpt-test" +[models.vision] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "vision-test" +[models.agent] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "agent-test" +""" + path = tmp_path / "config.toml" + path.write_text(toml, encoding="utf-8") + from_file = Config.load(path, strict=False) + from_map = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "gpt-test", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "vision-test", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "agent-test", + }, + }, + }, + strict=False, + ) + assert from_file.chat_model.model_name == from_map.chat_model.model_name + assert from_file.onebot_ws_url == from_map.onebot_ws_url diff --git a/tests/test_public_api_imports.py b/tests/test_public_api_imports.py new file mode 100644 index 00000000..41b87886 --- /dev/null +++ b/tests/test_public_api_imports.py @@ -0,0 +1,141 @@ +"""根包公共 API 与向后兼容 import 路径测试(Phase 3 API-FACADE 启用)。""" + +from __future__ import annotations + +import importlib +from typing import Any + +import pytest + +# 根包 lazy re-export 符号(见 docs/python-api.md) +_ROOT_EXPORTS: tuple[str, ...] = ( + "Config", + "get_config", + "AIClient", + "ToolRegistry", + "AgentRegistry", + "PipelineRegistry", + "BaseRegistry", + "AnthropicSkillRegistry", + "CognitiveService", + "KnowledgeManager", + "MemeService", + "AttachmentRegistry", + "RuntimeAPIServer", + "RuntimeAPIContext", +) + +# 拆分后须继续可用的 shim / 深层 import 路径 +_BACKWARD_COMPAT_PATHS: tuple[tuple[str, str], ...] = ( + ("Undefined.config.loader", "Config"), + ("Undefined.ai.client", "AIClient"), + ("Undefined.attachments", "AttachmentRegistry"), + ("Undefined.handlers", "MessageHandler"), + ("Undefined.onebot", "OneBotClient"), + ("Undefined.skills.tools", "ToolRegistry"), + ("Undefined.skills.agents", "AgentRegistry"), + ("Undefined.skills.pipelines.registry", "PipelineRegistry"), + ("Undefined.skills.registry", "BaseRegistry"), + ("Undefined.skills.anthropic_skills", "AnthropicSkillRegistry"), + ("Undefined.cognitive.service", "CognitiveService"), + ("Undefined.knowledge.manager", "KnowledgeManager"), + ("Undefined.memes.service", "MemeService"), + ("Undefined.api.app", "RuntimeAPIServer"), + ("Undefined.api._context", "RuntimeAPIContext"), +) + + +@pytest.mark.parametrize("symbol", _ROOT_EXPORTS) +def test_root_package_exports(symbol: str) -> None: + import Undefined + + assert hasattr(Undefined, symbol), f"Undefined.{symbol} missing from root exports" + getattr(Undefined, symbol) + + +def test_root_package_lazy_import_does_not_load_cli_modules() -> None: + import sys + + saved_modules = { + name: module + for name, module in sys.modules.items() + if name == "Undefined" or name.startswith("Undefined.") + } + try: + for name in saved_modules: + del sys.modules[name] + + import Undefined # noqa: F401 + + assert "Undefined.onebot" not in sys.modules + assert "Undefined.handlers" not in sys.modules + assert "Undefined.main" not in sys.modules + finally: + for name in list(sys.modules): + if ( + name == "Undefined" or name.startswith("Undefined.") + ) and name not in saved_modules: + del sys.modules[name] + sys.modules.update(saved_modules) + + +@pytest.mark.parametrize(("module_path", "symbol"), _BACKWARD_COMPAT_PATHS) +def test_backward_compat_import_path(module_path: str, symbol: str) -> None: + module = importlib.import_module(module_path) + assert hasattr(module, symbol), f"{module_path}.{symbol} missing" + getattr(module, symbol) + + +def test_set_config_not_used_by_default_get_config( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """库嵌入 set_config() 与 CLI get_config() 路径隔离(CONFIG Phase 2)。""" + import Undefined.config as config_module + + from Undefined.config import Config, set_config + + config_module._config = None + config_module._config_manager = None + + monkeypatch.chdir(tmp_path) + (tmp_path / "config.toml").write_text( + """ +[onebot] +ws_url = "ws://127.0.0.1:3999" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "from-file" +""", + encoding="utf-8", + ) + + cfg = config_module.get_config(strict=False) + assert cfg.chat_model.model_name == "from-file" + + injected = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "injected", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "vision-test", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "agent-test", + }, + }, + }, + strict=False, + ) + set_config(injected) + assert config_module.get_config(strict=False) is injected From 9e9615dae8f1724f1c551af6615eddd135145a89 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 21:26:44 +0800 Subject: [PATCH 05/16] refactor: split runtime modules, docs, and WebUI JS Extract ai, handlers, attachments, cognitive, memes, coordinator, onebot, and agent runner into subpackages with compatibility shims; add python-api docs and py.typed; trim noisy inline comments across runtime code. Co-authored-by: Cursor --- ARCHITECTURE.md | 22 +- README.md | 30 + docs/configuration.md | 357 +++++- docs/deployment.md | 4 +- docs/development.md | 140 ++- docs/python-api.md | 284 +++++ src/Undefined/ai/client.py | 10 - src/Undefined/ai/client/__init__.py | 37 + src/Undefined/ai/client/ask_loop.py | 623 ++++++++++ src/Undefined/ai/client/queue.py | 283 +++++ src/Undefined/ai/client/setup.py | 880 ++++++++++++++ src/Undefined/ai/llm.py | 2 - src/Undefined/ai/llm/__init__.py | 22 + src/Undefined/ai/llm/requester.py | 1016 +++++++++++++++++ src/Undefined/ai/llm/sanitize.py | 558 +++++++++ src/Undefined/ai/llm/streaming.py | 392 +++++++ src/Undefined/ai/llm/thinking.py | 214 ++++ src/Undefined/ai/llm/types.py | 27 + src/Undefined/ai/model_selector.py | 1 - src/Undefined/ai/multimodal.py | 23 - src/Undefined/ai/multimodal/__init__.py | 32 + src/Undefined/ai/multimodal/analyzer.py | 599 ++++++++++ src/Undefined/ai/multimodal/constants.py | 138 +++ src/Undefined/ai/multimodal/detection.py | 101 ++ src/Undefined/ai/multimodal/parsing.py | 107 ++ src/Undefined/ai/parsing.py | 10 - src/Undefined/ai/prompts.py | 11 - src/Undefined/ai/prompts/__init__.py | 10 + src/Undefined/ai/prompts/builder.py | 599 ++++++++++ src/Undefined/ai/prompts/cognitive.py | 137 +++ src/Undefined/ai/prompts/constants.py | 20 + src/Undefined/ai/prompts/system_context.py | 165 +++ .../ai/transports/openai_transport.py | 1 - src/Undefined/api/routes/naga/__init__.py | 21 + src/Undefined/api/routes/naga/auth.py | 30 + src/Undefined/api/routes/naga/bind.py | 159 +++ src/Undefined/api/routes/naga/send.py | 665 +++++++++++ src/Undefined/api/routes/naga/unbind.py | 100 ++ src/Undefined/attachments/__init__.py | 43 + src/Undefined/attachments/models.py | 93 ++ src/Undefined/attachments/registry.py | 903 +++++++++++++++ src/Undefined/attachments/render.py | 280 +++++ src/Undefined/attachments/segments.py | 566 +++++++++ src/Undefined/bilibili/wbi.py | 2 + src/Undefined/cognitive/historian.py | 7 - src/Undefined/cognitive/historian/__init__.py | 5 + src/Undefined/cognitive/historian/helpers.py | 85 ++ src/Undefined/cognitive/historian/tools.py | 78 ++ src/Undefined/cognitive/historian/worker.py | 903 +++++++++++++++ src/Undefined/cognitive/job_queue.py | 3 - src/Undefined/cognitive/profile_storage.py | 5 - src/Undefined/cognitive/service.py | 4 - src/Undefined/cognitive/service/__init__.py | 5 + src/Undefined/cognitive/service/helpers.py | 169 +++ src/Undefined/cognitive/service/service.py | 751 ++++++++++++ src/Undefined/cognitive/vector_store.py | 3 - src/Undefined/config/domain_parsers.py | 1 - src/Undefined/handlers.py | 1 + src/Undefined/handlers/__init__.py | 28 + src/Undefined/handlers/auto_extract.py | 223 ++++ src/Undefined/handlers/message_flow.py | 845 ++++++++++++++ src/Undefined/handlers/poke.py | 293 +++++ src/Undefined/handlers/repeat.py | 146 +++ src/Undefined/memes/_service.py | 269 +++++ src/Undefined/memes/ingest.py | 734 ++++++++++++ src/Undefined/memes/search.py | 591 ++++++++++ src/Undefined/memes/service.py | 1 - src/Undefined/onebot/__init__.py | 16 + src/Undefined/onebot/client.py | 873 ++++++++++++++ src/Undefined/onebot/message.py | 58 + src/Undefined/py.typed | 0 src/Undefined/services/ai_coordinator.py | 14 - src/Undefined/services/commands/bugfix.py | 189 +++ src/Undefined/services/commands/stats.py | 822 +++++++++++++ .../services/coordinator/__init__.py | 88 ++ .../services/coordinator/background.py | 250 ++++ .../services/coordinator/batching.py | 174 +++ src/Undefined/services/coordinator/group.py | 405 +++++++ src/Undefined/services/coordinator/private.py | 282 +++++ .../services/message_batcher/__init__.py | 48 + .../services/message_batcher/scheduler.py | 700 ++++++++++++ .../services/message_batcher/state.py | 106 ++ .../code_delivery_agent/tools/read/handler.py | 1 + .../skills/agents/runner/__init__.py | 12 + src/Undefined/skills/agents/runner/context.py | 133 +++ src/Undefined/skills/agents/runner/loop.py | 182 +++ src/Undefined/skills/agents/runner/tools.py | 179 +++ src/Undefined/skills/http_client.py | 171 +++ src/Undefined/skills/http_config.py | 43 + src/Undefined/skills/tools/__init__.py | 4 - .../skills/tools/bilibili_video/handler.py | 5 - .../skills/tools/fetch_image_uid/handler.py | 1 - .../skills/tools/get_current_time/handler.py | 18 - .../skills/tools/get_picture/handler.py | 11 - .../skills/tools/get_user_info/handler.py | 5 - .../tools/python_interpreter/handler.py | 5 - src/Undefined/skills/tools/qq_like/handler.py | 3 - .../skills/tools/task_progress/handler.py | 6 - src/Undefined/utils/render_cache.py | 3 + src/Undefined/utils/sender_helpers.py | 116 ++ tests/test_ai_coordinator_queue_routing.py | 6 +- 101 files changed, 19598 insertions(+), 198 deletions(-) create mode 100644 docs/python-api.md create mode 100644 src/Undefined/ai/client/__init__.py create mode 100644 src/Undefined/ai/client/ask_loop.py create mode 100644 src/Undefined/ai/client/queue.py create mode 100644 src/Undefined/ai/client/setup.py create mode 100644 src/Undefined/ai/llm/__init__.py create mode 100644 src/Undefined/ai/llm/requester.py create mode 100644 src/Undefined/ai/llm/sanitize.py create mode 100644 src/Undefined/ai/llm/streaming.py create mode 100644 src/Undefined/ai/llm/thinking.py create mode 100644 src/Undefined/ai/llm/types.py create mode 100644 src/Undefined/ai/multimodal/__init__.py create mode 100644 src/Undefined/ai/multimodal/analyzer.py create mode 100644 src/Undefined/ai/multimodal/constants.py create mode 100644 src/Undefined/ai/multimodal/detection.py create mode 100644 src/Undefined/ai/multimodal/parsing.py create mode 100644 src/Undefined/ai/prompts/__init__.py create mode 100644 src/Undefined/ai/prompts/builder.py create mode 100644 src/Undefined/ai/prompts/cognitive.py create mode 100644 src/Undefined/ai/prompts/constants.py create mode 100644 src/Undefined/ai/prompts/system_context.py create mode 100644 src/Undefined/api/routes/naga/__init__.py create mode 100644 src/Undefined/api/routes/naga/auth.py create mode 100644 src/Undefined/api/routes/naga/bind.py create mode 100644 src/Undefined/api/routes/naga/send.py create mode 100644 src/Undefined/api/routes/naga/unbind.py create mode 100644 src/Undefined/attachments/__init__.py create mode 100644 src/Undefined/attachments/models.py create mode 100644 src/Undefined/attachments/registry.py create mode 100644 src/Undefined/attachments/render.py create mode 100644 src/Undefined/attachments/segments.py create mode 100644 src/Undefined/cognitive/historian/__init__.py create mode 100644 src/Undefined/cognitive/historian/helpers.py create mode 100644 src/Undefined/cognitive/historian/tools.py create mode 100644 src/Undefined/cognitive/historian/worker.py create mode 100644 src/Undefined/cognitive/service/__init__.py create mode 100644 src/Undefined/cognitive/service/helpers.py create mode 100644 src/Undefined/cognitive/service/service.py create mode 100644 src/Undefined/handlers/__init__.py create mode 100644 src/Undefined/handlers/auto_extract.py create mode 100644 src/Undefined/handlers/message_flow.py create mode 100644 src/Undefined/handlers/poke.py create mode 100644 src/Undefined/handlers/repeat.py create mode 100644 src/Undefined/memes/_service.py create mode 100644 src/Undefined/memes/ingest.py create mode 100644 src/Undefined/memes/search.py create mode 100644 src/Undefined/onebot/__init__.py create mode 100644 src/Undefined/onebot/client.py create mode 100644 src/Undefined/onebot/message.py create mode 100644 src/Undefined/py.typed create mode 100644 src/Undefined/services/commands/bugfix.py create mode 100644 src/Undefined/services/commands/stats.py create mode 100644 src/Undefined/services/coordinator/__init__.py create mode 100644 src/Undefined/services/coordinator/background.py create mode 100644 src/Undefined/services/coordinator/batching.py create mode 100644 src/Undefined/services/coordinator/group.py create mode 100644 src/Undefined/services/coordinator/private.py create mode 100644 src/Undefined/services/message_batcher/__init__.py create mode 100644 src/Undefined/services/message_batcher/scheduler.py create mode 100644 src/Undefined/services/message_batcher/state.py create mode 100644 src/Undefined/skills/agents/runner/__init__.py create mode 100644 src/Undefined/skills/agents/runner/context.py create mode 100644 src/Undefined/skills/agents/runner/loop.py create mode 100644 src/Undefined/skills/agents/runner/tools.py create mode 100644 src/Undefined/utils/sender_helpers.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 2a48c448..3a8c043e 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -18,14 +18,14 @@ graph TB ConfigLoader["ConfigManager
配置管理器
[config/manager.py + loader.py]"] ConfigHotReload["ConfigHotReload
热更新应用器
[config/hot_reload.py]"] ConfigModels["配置模型
[config/models.py]
ChatModelConfig
VisionModelConfig
SecurityModelConfig
AgentModelConfig"] - OneBotClient["OneBotClient
WebSocket 客户端
[onebot.py]"] + OneBotClient["OneBotClient
WebSocket 客户端
[onebot/ + onebot.py shim]"] Context["RequestContext
请求上下文
[context.py]"] WebUI["webui.py
配置控制台
[src/Undefined/webui.py]"] end %% ==================== 消息处理层 ==================== subgraph MessageLayer["消息处理层 (src/Undefined/)"] - MessageHandler["MessageHandler
消息处理器
[handlers.py]"] + MessageHandler["MessageHandler
消息处理器
[handlers/ + handlers.py shim]"] subgraph BilibiliModule["Bilibili 模块 (bilibili/)"] BilibiliParser["parser.py
标识符解析
• BV/AV号 • URL
• b23.tv短链 • 小程序JSON"] @@ -46,10 +46,10 @@ graph TB CommandDispatcher["CommandDispatcher
命令分发器
• /help /stats /admin
• /bugfix /faq
[services/command.py]"] - MessageBatcher["MessageBatcher
同 sender 短时合并
• 按 (scope, sender_id) 分桶
• T1=window_seconds 结束 batch
• T2=pre_send_seconds 投机预发送
• 拍一拍/buffer 内 @bot 旁路
• 首条 @bot 整批走 mention 队列
[services/message_batcher.py]"] + MessageBatcher["MessageBatcher
同 sender 短时合并
• 按 (scope, sender_id) 分桶
• T1=window_seconds 结束 batch
• T2=pre_send_seconds 投机预发送
• 拍一拍/buffer 内 @bot 旁路
• 首条 @bot 整批走 mention 队列
[services/message_batcher/ + shim]"] subgraph QueueSystem["车站-列车 队列系统 (services/)"] - AICoordinator["AICoordinator
AI 协调器
• Prompt 构建
• 队列管理
• 回复执行
[ai_coordinator.py]"] + AICoordinator["AICoordinator
AI 协调器
• Prompt 构建
• 队列管理
• 回复执行
[services/coordinator/ + ai_coordinator.py shim]"] QueueManager["QueueManager
队列管理器
[queue_manager.py]"] subgraph ModelQueues["ModelQueue 队列组 (按模型隔离)"] @@ -65,13 +65,13 @@ graph TB %% ==================== AI 核心能力层 ==================== subgraph AILayer["AI 核心能力层 (src/Undefined/ai/)"] - AIClient["AIClient
AI 客户端主入口
[client.py]
• 技能热重载 • MCP 初始化
• Agent intro 生成"] + AIClient["AIClient
AI 客户端主入口
[ai/client/ + client.py shim]
• 技能热重载 • MCP 初始化
• Agent intro 生成"] subgraph AIComponents["AI 组件"] - PromptBuilder["PromptBuilder
提示词构建器
[prompts.py]"] - ModelRequester["ModelRequester
模型请求器
[llm.py]
• OpenAI SDK • 工具清理
• Thinking 提取"] + PromptBuilder["PromptBuilder
提示词构建器
[ai/prompts/ + prompts.py shim]"] + ModelRequester["ModelRequester
模型请求器
[ai/llm/ + llm.py shim]
• OpenAI SDK • 工具清理
• Thinking 提取"] ToolManager["ToolManager
工具管理器
[tooling.py]
• 工具执行 • Agent 工具合并
• MCP 工具注入"] - MultimodalAnalyzer["MultimodalAnalyzer
多模态分析器
[multimodal.py]
• 图片/音频/视频"] + MultimodalAnalyzer["MultimodalAnalyzer
多模态分析器
[ai/multimodal/ + multimodal.py shim]
• 图片/音频/视频"] SummaryService["SummaryService
总结服务
[summaries.py]
• 聊天记录总结
• 标题生成"] TokenCounter["TokenCounter
Token 统计
[tokens.py]"] Parsing["Parsing
响应解析
[parsing.py]"] @@ -849,10 +849,10 @@ description: 从 PDF 文件中提取文本和表格,填写表单。当用户 ### 8层架构分层 1. **外部实体层**:用户、管理员、OneBot 协议端 (NapCat/Lagrange.Core)、大模型 API 服务商 -2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块) -3. **消息处理层**:MessageHandler (handlers.py)、SecurityService (security.py)、CommandDispatcher (services/command.py)、MessageBatcher (services/message_batcher.py)、AICoordinator (ai_coordinator.py)、QueueManager (queue_manager.py)、自动处理管线 (skills/pipelines/)、Bilibili/arXiv/GitHub 解析与发送模块 +2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py + parsers/ + load_sections/)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot/ + onebot.py shim)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块,含 naga/ 子包) +3. **消息处理层**:MessageHandler (handlers/ + handlers.py shim)、SecurityService (security.py)、CommandDispatcher (services/command.py + commands/ mixins)、MessageBatcher (services/message_batcher/ + shim)、AICoordinator (services/coordinator/ + ai_coordinator.py shim)、QueueManager (queue_manager.py)、自动处理管线 (skills/pipelines/)、Bilibili/arXiv/GitHub 解析与发送模块 自动提取由 `PipelineRegistry` 并行检测、并行处理全部命中的管线;发送结果写入历史后继续进入 AI 自动回复。 -4. **AI 核心能力层**:AIClient (client.py)、PromptBuilder (prompts.py)、ModelRequester (llm.py)、ToolManager (tooling.py)、MultimodalAnalyzer (multimodal.py)、SummaryService (summaries.py)、TokenCounter (tokens.py) +4. **AI 核心能力层**:AIClient (ai/client/ + client.py shim)、PromptBuilder (ai/prompts/ + prompts.py shim)、ModelRequester (ai/llm/ + llm.py shim)、ToolManager (tooling.py)、MultimodalAnalyzer (ai/multimodal/ + multimodal.py shim)、SummaryService (summaries.py)、TokenCounter (tokens.py) 5. **存储与上下文层**:MessageHistoryManager (utils/history.py, 10000条限制)、MemoryStorage (memory.py, 置顶备忘录, 500条上限)、EndSummaryStorage、CognitiveService + JobQueue + HistorianWorker + VectorStore + ProfileStorage、MemeService + MemeWorker + MemeStore + MemeVectorStore (表情包库)、FAQStorage、ScheduledTaskStorage、TokenUsageStorage (自动归档) 6. **技能系统层**:ToolRegistry (registry.py)、AgentRegistry、6个 Agents、11类 Toolsets 7. **异步 IO 层**:统一 IO 工具 (utils/io.py),包含 write_json、read_json、append_line、跨平台文件锁 (flock/msvcrt) diff --git a/README.md b/README.md index b626b06b..6d2497ae 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将各个模块的深入解析与高阶玩法整理成了专题游览图。这里是开启探索的钥匙: - ⚙️ **[安装与部署指南](docs/deployment.md)**:不管你是需要 `pip` 无脑一键安装,还是源码二次开发,这里的排坑指南应有尽有。 +- 📦 **[Python 库 API 参考](docs/python-api.md)**:作为库嵌入时的 import 路径、`Config.from_mapping` / `set_config` 与公共 API 符号表。 - 🖥️ **[WebUI 使用指南](docs/webui-guide.md)**:管理控制台功能一览——配置编辑、日志查看、认知记忆管理、表情包库、AI 对话与系统监控。 - 🧭 **[Management API 与远程管理](docs/management-api.md)**:WebUI / App 共用的管理接口、认证、配置/日志/Bot 控制与引导探针说明。 - 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置。 @@ -97,6 +98,35 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将 --- +## 作为 Python 库使用 + +除 CLI / WebUI 部署外,Undefined 也可作为 Python 库嵌入到其他应用或测试中,复用配置系统、`AIClient`、Skills 注册表、认知记忆、知识库等组件。 + +```bash +pip install Undefined-bot # 或 uv sync(源码) +``` + +```python +from Undefined.config import Config, set_config + +cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, + "vision": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, + "agent": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, + }, + }, + strict=False, +) +set_config(cfg) # opt-in;CLI 启动链不调用 +``` + +完整 import 路径、公共 API 表与嵌入示例见 **[Python 库 API 参考](docs/python-api.md)**;程序化配置详见 [配置详解 — 库嵌入配置](docs/configuration.md#2-库嵌入配置)。 + +--- + ## ⚡ 快速开始 (源码模式) > 👶 **新手必看**:如果您是首次部署此类项目或不熟悉 Git/环境配置,**强烈建议直接前往 [《详细安装与部署指南》](docs/deployment.md)** 阅读手把手教程,避免遇到常见报错。 diff --git a/docs/configuration.md b/docs/configuration.md index ed87bece..db808031 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2,6 +2,7 @@ 本文档是 Undefined 当前配置系统的完整说明,覆盖: - 配置加载顺序与解析规则 +- **库嵌入**(`Config.from_mapping` / `set_config`)程序化配置 - 严格模式必填项 - 每个配置节与字段的用途、默认值、约束、回退行为 - 热更新与重启生效边界 @@ -41,7 +42,99 @@ --- -## 2. 严格模式(`strict=True`)必填项 +## 2. 库嵌入配置 + +除 CLI / WebUI 从 CWD 读取 `config.toml` 外,Undefined 支持在 Python 代码中**程序化构建配置**,供测试、脚本或其它应用嵌入库组件时使用。 + +> 完整 API 说明见 [Python 库 API 参考](python-api.md)。 + +### 2.1 适用场景 + +- 单元测试 / 集成测试:无需准备真实 `config.toml` +- 下游应用:只复用 `AIClient`、`KnowledgeManager` 等模块,不启动 QQ Bot +- CI / 容器:通过环境变量 + 空 mapping 注入密钥,配置文件只保留非敏感项 + +### 2.2 加载优先级 + +``` +Python 显式 mapping / builder.override > config.toml > 环境变量 > 代码默认值 +``` + +| 入口 | 是否读 `config.toml` | 说明 | +|------|---------------------|------| +| `Config.load()` | 是 | CLI / WebUI 默认路径 | +| `Config.from_mapping(dict)` | 否 | 纯内存构建 | +| `Config.builder().with_mapping(...).build()` | 否 | 在 mapping 上链式覆盖 | +| `get_config()` | 视情况 | 未 `set_config()` 时等价于 `Config.load()` | + +`from_mapping` / `builder` 仍会读取进程环境变量中**已注册**的兜底项(TOML / mapping 未提供的字段)。注册表见 [`env_registry.py`](../src/Undefined/config/env_registry.py) 与本文 [§8 环境变量兜底](#8-环境变量兜底迁移建议)。 + +### 2.3 `Config.from_mapping` + +结构与 `config.toml` 一致,例如: + +```python +from Undefined.config import Config + +cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o-mini", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o-mini", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o-mini", + }, + }, + }, + strict=False, +) +``` + +- `strict=True`:与 CLI 相同,缺失 [§3 严格模式](#3-严格模式stricttrue必填项) 必填项时报错退出。 +- `strict=False`:适合测试与渐进式嵌入;生产 Bot 仍建议 `strict=True`。 + +### 2.4 `Config.builder` + +```python +cfg = ( + Config.builder() + .with_mapping(base_mapping) + .override(log_level="DEBUG") + .build(strict=False) +) +``` + +`override()` 目前覆盖 mapping 顶层键;嵌套结构请直接在 `with_mapping` 的 dict 中提供。 + +### 2.5 `set_config()`(opt-in) + +```python +from Undefined.config import Config, get_config, set_config + +cfg = Config.from_mapping({...}, strict=False) +set_config(cfg) +assert get_config(strict=False) is cfg +``` + +**硬约束**: + +- `set_config()` 仅供库嵌入 opt-in;**CLI / WebUI 启动链不得调用**。 +- 未调用 `set_config()` 时,`get_config()` 仍从 CWD 加载 `./config.toml`,与独立运行 Bot 行为一致。 + +--- + +## 3. 严格模式(`strict=True`)必填项 程序主流程使用严格模式加载配置。缺失以下字段会报错退出: - `core.bot_qq` @@ -56,7 +149,7 @@ --- -## 3. 最小可运行配置示例 +## 4. 最小可运行配置示例 ```toml [core] @@ -85,7 +178,7 @@ model_name = "gpt-4o-mini" --- -## 4. 全量字段说明 +## 5. 全量字段说明 ### 4.1 `[core]` 机器人核心行为 @@ -914,7 +1007,7 @@ Prompt caching 补充: --- -## 5. 热更新与重启边界 +## 6. 热更新与重启边界 ### 5.1 热更新监听对象 - `config.toml` @@ -967,7 +1060,7 @@ Prompt caching 补充: --- -## 6. 兼容旧字段与隐藏字段 +## 7. 兼容旧字段与隐藏字段 - `models..deepseek_new_cot_support`:旧 thinking 兼容开关。 - `[core].keyword_reply_enabled`:旧位置,建议迁移到 `[easter_egg]`。 @@ -976,26 +1069,250 @@ Prompt caching 补充: --- -## 7. 环境变量兜底(迁移建议) +## 8. 环境变量兜底(迁移建议) + +虽然推荐统一写入 `config.toml`,当前仍支持环境变量兜底。规则: + +1. **仅当 TOML / `from_mapping` 未提供对应项** 时读取环境变量。 +2. 检测到 env 兜底时可能输出 `[配置]` 告警,建议迁移到 TOML。 +3. 主注册表由 `src/Undefined/config/env_registry.py` 维护;变更注册表时请同步更新本节表格。 + + -虽然推荐统一写入 `config.toml`,但当前仍支持大量环境变量兜底,常用示例: -- `BOT_QQ` / `SUPERADMIN_QQ` -- `ONEBOT_WS_URL` / `ONEBOT_TOKEN` -- `CHAT_MODEL_API_URL` / `CHAT_MODEL_API_KEY` / `CHAT_MODEL_NAME` -- `CHAT_MODEL_API_MODE` / `CHAT_MODEL_REASONING_ENABLED` / `CHAT_MODEL_REASONING_EFFORT` / `CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT` / `CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY` -- `VISION_MODEL_*` / `AGENT_MODEL_*` / `SECURITY_MODEL_*` / `NAGA_MODEL_*` / `HISTORIAN_MODEL_*` -- 上述模型环境变量同样覆盖 `*_THINKING_ENABLED`、`*_THINKING_BUDGET_TOKENS`、`*_THINKING_TOOL_CALL_COMPAT`、`*_RESPONSES_TOOL_CHOICE_COMPAT`、`*_RESPONSES_FORCE_STATELESS_REPLAY` -- `EMBEDDING_MODEL_*` / `RERANK_MODEL_*` -- `SEARXNG_URL` -- `HTTP_PROXY` / `HTTPS_PROXY` +以下环境变量在 **TOML 对应项缺失** 时作为兜底读取。 +完整注册表见 `src/Undefined/config/env_registry.py`。 + +#### `access` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `access.allowed_group_ids` | `ALLOWED_GROUP_IDS` | +| `access.allowed_private_ids` | `ALLOWED_PRIVATE_IDS` | +| `access.blocked_group_ids` | `BLOCKED_GROUP_IDS` | +| `access.blocked_private_ids` | `BLOCKED_PRIVATE_IDS` | +| `access.mode` | `ACCESS_MODE` | + +#### `api_endpoints` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `api_endpoints.jkyai_base_url` | `JKYAI_BASE_URL` | +| `api_endpoints.xxapi_base_url` | `XXAPI_BASE_URL` | + +#### `core` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `core.admin_qq` | `ADMIN_QQ` | +| `core.bot_qq` | `BOT_QQ` | +| `core.forward_proxy_qq` | `FORWARD_PROXY_QQ` | +| `core.superadmin_qq` | `SUPERADMIN_QQ` | + +#### `features` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `features.pool_enabled` | `MODEL_POOL_ENABLED` | + +#### `history` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `history.max_records` | `HISTORY_MAX_RECORDS` | + +#### `image_gen` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `image_gen.provider` | `IMAGE_GEN_PROVIDER` | + +#### `logging` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `logging.backup_count` | `LOG_BACKUP_COUNT` | +| `logging.file_path` | `LOG_FILE_PATH` | +| `logging.level` | `LOG_LEVEL` | +| `logging.log_thinking` | `LOG_THINKING` | +| `logging.max_size_mb` | `LOG_MAX_SIZE_MB` | +| `logging.tty_enabled` | `LOG_TTY_ENABLED` | + +#### `mcp` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `mcp.config_path` | `MCP_CONFIG_PATH` | + +#### `models.agent` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.agent.api_key` | `AGENT_MODEL_API_KEY` | +| `models.agent.api_mode` | `AGENT_MODEL_API_MODE` | +| `models.agent.api_url` | `AGENT_MODEL_API_URL` | +| `models.agent.context_window_tokens` | `AGENT_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.agent.model_name` | `AGENT_MODEL_NAME` | +| `models.agent.reasoning_content_replay` | `AGENT_MODEL_REASONING_CONTENT_REPLAY` | +| `models.agent.responses_force_stateless_replay` | `AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY` | +| `models.agent.responses_tool_choice_compat` | `AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT` | +| `models.agent.system_prompt_as_user` | `AGENT_MODEL_SYSTEM_PROMPT_AS_USER` | + +#### `models.chat` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.chat.api_key` | `CHAT_MODEL_API_KEY` | +| `models.chat.api_mode` | `CHAT_MODEL_API_MODE` | +| `models.chat.api_url` | `CHAT_MODEL_API_URL` | +| `models.chat.context_window_tokens` | `CHAT_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.chat.max_tokens` | `CHAT_MODEL_MAX_TOKENS` | +| `models.chat.model_name` | `CHAT_MODEL_NAME` | +| `models.chat.reasoning_content_replay` | `CHAT_MODEL_REASONING_CONTENT_REPLAY` | +| `models.chat.responses_force_stateless_replay` | `CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY` | +| `models.chat.responses_tool_choice_compat` | `CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT` | +| `models.chat.system_prompt_as_user` | `CHAT_MODEL_SYSTEM_PROMPT_AS_USER` | + +#### `models.embedding` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.embedding.context_window_tokens` | `EMBEDDING_MODEL_CONTEXT_WINDOW_TOKENS` | + +#### `models.grok` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.grok.api_key` | `GROK_MODEL_API_KEY` | +| `models.grok.api_url` | `GROK_MODEL_API_URL` | +| `models.grok.context_window_tokens` | `GROK_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.grok.max_tokens` | `GROK_MODEL_MAX_TOKENS` | +| `models.grok.model_name` | `GROK_MODEL_NAME` | + +#### `models.naga` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.naga.api_key` | `NAGA_MODEL_API_KEY` | +| `models.naga.api_mode` | `NAGA_MODEL_API_MODE` | +| `models.naga.api_url` | `NAGA_MODEL_API_URL` | +| `models.naga.context_window_tokens` | `NAGA_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.naga.model_name` | `NAGA_MODEL_NAME` | +| `models.naga.reasoning_content_replay` | `NAGA_MODEL_REASONING_CONTENT_REPLAY` | +| `models.naga.responses_force_stateless_replay` | `NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY` | +| `models.naga.responses_tool_choice_compat` | `NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT` | +| `models.naga.system_prompt_as_user` | `NAGA_MODEL_SYSTEM_PROMPT_AS_USER` | + +#### `models.rerank` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.rerank.api_key` | `RERANK_MODEL_API_KEY` | +| `models.rerank.api_url` | `RERANK_MODEL_API_URL` | +| `models.rerank.context_window_tokens` | `RERANK_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.rerank.model_name` | `RERANK_MODEL_NAME` | + +#### `models.security` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.security.api_key` | `SECURITY_MODEL_API_KEY` | +| `models.security.api_mode` | `SECURITY_MODEL_API_MODE` | +| `models.security.api_url` | `SECURITY_MODEL_API_URL` | +| `models.security.context_window_tokens` | `SECURITY_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.security.model_name` | `SECURITY_MODEL_NAME` | +| `models.security.reasoning_content_replay` | `SECURITY_MODEL_REASONING_CONTENT_REPLAY` | +| `models.security.responses_force_stateless_replay` | `SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY` | +| `models.security.responses_tool_choice_compat` | `SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT` | +| `models.security.system_prompt_as_user` | `SECURITY_MODEL_SYSTEM_PROMPT_AS_USER` | + +#### `models.vision` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `models.vision.api_key` | `VISION_MODEL_API_KEY` | +| `models.vision.api_mode` | `VISION_MODEL_API_MODE` | +| `models.vision.api_url` | `VISION_MODEL_API_URL` | +| `models.vision.context_window_tokens` | `VISION_MODEL_CONTEXT_WINDOW_TOKENS` | +| `models.vision.model_name` | `VISION_MODEL_NAME` | +| `models.vision.reasoning_content_replay` | `VISION_MODEL_REASONING_CONTENT_REPLAY` | +| `models.vision.responses_force_stateless_replay` | `VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY` | +| `models.vision.responses_tool_choice_compat` | `VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT` | +| `models.vision.system_prompt_as_user` | `VISION_MODEL_SYSTEM_PROMPT_AS_USER` | + +#### `onebot` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `onebot.token` | `ONEBOT_TOKEN` | +| `onebot.ws_url` | `ONEBOT_WS_URL` | + +#### `proxy` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `proxy.use_proxy` | `USE_PROXY` | + +#### `search` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `search.searxng_url` | `SEARXNG_URL` | + +#### `skills` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `skills.hot_reload` | `SKILLS_HOT_RELOAD` | +| `skills.intro_hash_path` | `AGENT_INTRO_HASH_PATH` | +| `skills.prefetch_tools_hide` | `PREFETCH_TOOLS_HIDE` | + +#### `token_usage` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `token_usage.max_archives` | `TOKEN_USAGE_MAX_ARCHIVES` | +| `token_usage.max_size_mb` | `TOKEN_USAGE_MAX_SIZE_MB` | +| `token_usage.max_total_mb` | `TOKEN_USAGE_MAX_TOTAL_MB` | + +#### `tools` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `tools.description_max_len` | `TOOLS_DESCRIPTION_MAX_LEN` | +| `tools.dot_delimiter` | `TOOLS_DOT_DELIMITER` | +| `tools.sanitize_verbose` | `TOOLS_SANITIZE_VERBOSE` | + +#### `weather` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `weather.api_key` | `WEATHER_API_KEY` | + +#### `xxapi` + +| TOML 路径 | 环境变量 | +|-----------|----------| +| `xxapi.api_token` | `XXAPI_API_TOKEN` | + +#### 备用 / 兼容环境变量 + +以下变量不在主注册表中,但在解析时仍会被读取: + +| 环境变量 | 映射 TOML 路径 | +|----------|----------------| +| `EASTER_EGG_AGENT_CALL_MESSAGE_MODE` | `easter_egg.agent_call_message_enabled` | +| `EASTER_EGG_CALL_MESSAGE_MODE` | `easter_egg.agent_call_message_enabled` | +| `HTTPS_PROXY` | `proxy.https_proxy` | +| `HTTP_PROXY` | `proxy.http_proxy` | + + 建议: -1. 把长期配置迁移到 `config.toml`。 -2. 环境变量只保留临时覆写或 CI 场景。 ---- +1. 把长期配置迁移到 `config.toml`。 +2. 环境变量只保留临时覆写、CI 密钥或库嵌入场景的敏感项注入。 -## 8. 运维建议(生产环境) +## 9. 运维建议(生产环境) 1. 首次部署先改 `webui.password`,避免默认密码模式。 2. 显式配置 `access.mode`,不要依赖 legacy 行为。 diff --git a/docs/deployment.md b/docs/deployment.md index b936e4c3..a042c25d 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -2,6 +2,8 @@ 提供源码部署与 pip/uv tool 安装两种方式:**源码部署是推荐的首选方式**,功能完整且经过充分测试;pip/uv tool 安装适合快速体验,但部分功能支持尚不完善。 +> **作为 Python 库嵌入**:若你不需要启动 QQ Bot CLI,而是要在自己的应用或测试中复用 Undefined 组件(配置、`AIClient`、Skills、认知记忆等),请参阅 [Python 库 API 参考](python-api.md) 与 [配置详解 — 库嵌入配置](configuration.md#2-库嵌入配置)。CLI 入口(`Undefined` / `Undefined-webui`)行为不受库嵌入 API 影响。 + > Python 版本要求:`3.11`~`3.13`(包含)。 > > 若使用 `uv`,通常不需要你手动限制系统 Python 版本;`uv` 会根据项目约束自动选择/下载兼容解释器。 @@ -131,7 +133,7 @@ uv tool run --from Undefined-bot playwright install > **渲染依赖提醒**:同源码部署要求一致,你需要在宿主机上预先安装 Playwright 浏览器内核。请参考上文 [3. 安装渲染运行时](#3-安装渲染运行时)。未配置前,网页截图、Markdown 渲染和复杂 LaTeX 公式回退渲染可能会失败。 -安装完成后,在任意目录准备 `config.toml` 并启动: +安装完成后,在任意目录准备 `config.toml` 并启动(库嵌入场景也可用 `Config.from_mapping()` 代替配置文件,见 [python-api.md](python-api.md)): ```bash # 启动方式(二选一) diff --git a/docs/development.md b/docs/development.md index 987c1bdc..bac9f503 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,27 +11,43 @@ Undefined 欢迎开发者参与共建和进行二次开发! ```text src/Undefined/ ├── changelog.py # CHANGELOG.md 解析与版本查询公共层 -├── ai/ # AI 运行时核心组件 (client, prompt, tooling 工具组装, summary 短期摘要, multimodal 多模态) +├── ai/ # AI 运行时核心(子包 + 根级 shim:client.py / llm.py / prompts.py / multimodal.py) +│ ├── client/ # AIClient 组合:setup / queue / ask_loop +│ ├── llm/ # ModelRequester、streaming、thinking、sanitize +│ ├── prompts/ # PromptBuilder、system_context、cognitive 片段 +│ └── multimodal/# 多模态检测、解析与分析 +├── attachments/ # 附件注册、渲染、作用域隔离(attachments.py shim) ├── arxiv/ # arXiv 论文解析、元信息获取、PDF 下载与发送 ├── bilibili/ # B站视频流解析、分段下载与异步发送 -├── cognitive/ # 认知记忆系统底座 (向量存储, 史官合并/改写, 侧写生成, 任务队列) +├── cognitive/ # 认知记忆系统(service/ 门面 + historian/ 史官后台) +├── config/ # 配置系统(parsers/ 域解析 + load_sections/ 分段加载 + loader shim) +├── handlers/ # OneBot 消息分流(message_flow / poke / repeat / auto_extract;handlers.py shim) +├── onebot/ # OneBot WebSocket 客户端(onebot.py shim) ├── skills/ # 技能插件核心目录 (存放所有的工具与智能体) │ ├── tools/ # 基础原子的工具 (独立的功能单元,如读写文件、网络请求等) │ ├── toolsets/ # 聚合工具集 (分组后的工具组) │ │ └── cognitive/ # 认知记忆主动暴露工具 (search_events, get_profile 等) -│ ├── agents/ # 智能体 (独立自主的子 AI,负责处理诸如 Web 搜索、文件分析的具体长时任务) +│ ├── agents/ # 智能体 (含 runner/ 通用循环子包) │ ├── commands/ # 中心化斜杠指令系统 (实现如 /help, /stats, /admin 等平台功能) +│ ├── pipelines/ # 自动提取管线 (bilibili / arxiv / github 等) │ └── anthropic_skills/# Anthropic 协议集成的外部 Skills (兼容 SKILL.md 格式) -├── config/ # 配置系统 (loader.py TOML 解析, models.py 数据模型, hot_reload.py 热更新) ├── api/ # Management API + Runtime API -│ ├── routes/ # 路由子模块 (chat, tools, naga, system, memes, memory, cognitive, health) +│ ├── routes/ # 路由子模块 (chat, tools, naga/, system, memes, memory, cognitive, health) │ ├── app.py # aiohttp 服务主入口 (薄包装委派到 routes/) │ └── _openapi.py # OpenAPI 文档生成 -├── memes/ # 表情包库 (两阶段 AI 管线, SQLite + ChromaDB) -├── services/ # 核心运行服务 (Queue 任务队列, Command 命令分发, Security 安全防护拦截) -├── utils/ # 通用支持工具组 (io.py 异步原子读写, history.py, coerce.py 类型强转, fake_at.py 假@检测) -├── handlers.py # 最外层 OneBot 消息分流处理层 -└── onebot.py # OneBot WebSocket 客户端核心连接 +├── memes/ # 表情包库 (_service 门面 + ingest/ + search/ + store + vector_store) +├── services/ # 核心运行服务 +│ ├── coordinator/ # AICoordinator mixins(ai_coordinator.py shim) +│ ├── commands/ # CommandDispatcher mixins(stats / bugfix) +│ ├── message_batcher/ # 同 sender 短时合并(message_batcher.py shim) +│ ├── command.py # 命令分发门面 + shim 组合 +│ ├── queue_manager.py # 车站-列车队列 +│ └── security.py # 注入检测与速率限制 +├── utils/ # 通用支持工具组 (__init__.py 聚合 io/paths/resources;io.py 异步原子读写, history.py, coerce.py 类型强转) +├── handlers.py # compatibility shim → handlers/ +├── onebot.py # compatibility shim → onebot/ +├── attachments.py # compatibility shim → attachments/ +└── ai_coordinator.py # compatibility shim → services/coordinator/ ``` ## 开发指南 @@ -96,3 +112,107 @@ bash scripts/install_git_hooks.sh - 当提交包含 JS / Tauri / WebUI 前端相关改动时,还会自动执行 `Biome + TypeScript + cargo fmt/check` > **注意**:项目严格遵守类型注释规范,`mypy .` 通过是代码入库的前提条件;跨平台控制台相关改动则以 `npm run check` 通过为准。 + +## 注释规范 + +库化重构期间,各 Track 在拆分与注释 Wave 中须遵守以下 docstring 与行内注释约定。目标:提升可读性、支撑 `fuck-u-code` 注释比例达标(<30%),且**不改变运行时行为**。 + +### 模块 docstring + +每个 `.py` 文件(shim 除外)顶部须有**一行摘要** + 可选段落说明职责边界: + +```python +"""OneBot WebSocket 客户端连接管理。 + +负责与 NapCat/Lagrange 建立 WS 连接、心跳与事件分发;不处理业务消息逻辑。 +""" +``` + +- 使用中文或英文均可,与同目录现有风格保持一致。 +- Shim 文件仅保留一行:`# .py — compatibility shim; do not add logic here.` + +### 类 docstring + +公开类(`class X` 无 leading `_`)须有 docstring,说明**职责**与**主要协作对象**: + +```python +class CognitiveService: + """认知记忆运行时入口。 + + 协调向量检索、侧写读写与史官后台任务队列;由 main 进程持有单例。 + """ +``` + +- 内部辅助类(`_Foo`、`SkillStats` 等 dataclass)鼓励简短一行说明。 +- 禁止复制类型签名(mypy 已覆盖);重点写「为什么存在」。 + +### 公开方法 / 函数 docstring + +模块级公开函数与类公开方法(无 leading `_`)须有 docstring,推荐 Google 风格精简版: + +```python +def get_config(strict: bool = True) -> Config: + """获取全局配置单例。 + + Args: + strict: 为 True 时缺少必填项则抛错;False 时使用默认值填充。 + + Returns: + 已加载的 Config 实例。 + """ +``` + +- `@property` 公开 getter 视同方法。 +- 异步公开方法同样适用;注明可能抛出的业务异常(若有)。 +- 复杂算法或非 obvious 分支:**行内注释**说明意图,而非复述代码。 + +### 行内注释 + +- 仅用于解释**非 obvious 的业务规则**、兼容分支、性能/并发考量。 +- 禁止「递增 i」「返回结果」类冗余注释。 +- 魔法数字须命名常量或注释来源(配置项名 / 协议字段)。 + +### Skills handler 统一模板 + +`skills/tools/**/handler.py`、`skills/toolsets/**/handler.py`、`skills/agents/**/handler.py`、`skills/commands/**/handler.py`、`skills/pipelines/**/handler.py` 在注释 Wave 中统一采用: + +```python +"""<工具/Agent/命令/管线的人类可读名称>。 + +<一句话说明能力边界与主要输入输出;可列 1~3 条 bullet 行为要点。> + +config.json 关键字段: — <含义>(若非 obvious)。 +""" + +from __future__ import annotations + +# ... 实现 ... + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> Any: + """执行入口(由 Registry 调用)。 + + Args: + args: LLM tool call 解析后的参数字典。 + context: 运行时注入上下文(sender、session、registry 等)。 + + Returns: + 工具结果字符串或结构化 payload;异常由 Registry 捕获并记录。 + """ +``` + +- **禁止**在 handler 注释 Wave 中修改 `config.json`、目录结构或 handler 签名。 +- handler 内私有函数 `_foo` 可选一行 docstring;复杂解析逻辑建议补充。 + +### 注释 Track 自检 + +注释-only PR 合并前: + +```bash +uv run ruff format . +uv run ruff check . +uv run mypy src/Undefined// +uv run pytest tests/ # 全量由 Phase 3 Integrator 执行 +``` + +公共 API 说明见 [`docs/python-api.md`](python-api.md)。 diff --git a/docs/python-api.md b/docs/python-api.md new file mode 100644 index 00000000..1d84c9ad --- /dev/null +++ b/docs/python-api.md @@ -0,0 +1,284 @@ +# Python 库 API 参考 + +Undefined 可作为 Python 库嵌入到其他应用、脚本或测试环境中,复用配置系统、AI 客户端、Skills 注册表、认知记忆、知识库等组件,而无需启动完整的 QQ Bot CLI。 + +> CLI 入口(`Undefined` / `Undefined-webui`)行为不变;库嵌入路径与 CLI 启动链隔离。详见 [配置详解 — 库嵌入配置](configuration.md#2-库嵌入配置)。 + +--- + +## 安装 + +```bash +# 源码开发 +uv sync + +# 或 PyPI 包 +pip install Undefined-bot +``` + +Python 版本要求:`3.11` ~ `3.13`。 + +--- + +## 推荐 import 路径 + +### 根包(`stable`,Phase 3 lazy re-export) + +以下符号承诺通过 `from Undefined import …` 长期稳定(完整清单见下文 [公共 API 符号表](#公共-api-符号表)): + +```python +from Undefined import ( + __version__, + Config, + get_config, + set_config, + AIClient, + ToolRegistry, + AgentRegistry, + PipelineRegistry, + BaseRegistry, + AnthropicSkillRegistry, + CognitiveService, + KnowledgeManager, + MemeService, + AttachmentRegistry, + RuntimeAPIServer, + RuntimeAPIContext, +) +``` + +> **注意**:Phase 3 之前根包 lazy re-export 可能尚未全部启用;若 `from Undefined import X` 失败,请使用下方子包路径,二者语义等价。 + +### 子包(`stable` / `subpackage`) + +| 稳定性 | 模块 | 常用符号 | +|--------|------|----------| +| stable | `Undefined.config` | `Config`, `get_config`, `set_config`, `ConfigBuilder`, `ChatModelConfig`, `VisionModelConfig`, … | +| stable | `Undefined.ai` | `AIClient` | +| stable | `Undefined.skills` | `ToolRegistry`, `AgentRegistry`, `PipelineRegistry` | +| stable | `Undefined.cognitive` | `CognitiveService`, `CognitiveVectorStore`, `ProfileStorage`, … | +| stable | `Undefined.knowledge` | `KnowledgeManager`, `Embedder`, `Reranker`, `RetrievalRuntime` | +| stable | `Undefined.memes` | `MemeService`, `MemeStore`, `MemeWorker`, … | +| stable | `Undefined.attachments` | `AttachmentRegistry` | +| stable | `Undefined.api` | `RuntimeAPIServer`, `RuntimeAPIContext` | +| subpackage | `Undefined.skills.registry` | `BaseRegistry`, `SkillItem`, `SkillStats` | +| subpackage | `Undefined.skills.anthropic_skills` | `AnthropicSkillRegistry` | +| subpackage | `Undefined.mcp` | `MCPToolRegistry`, `MCPToolSetRegistry` | + +### 向后兼容 shim 路径 + +拆分后旧路径仍可用(测试与下游代码可继续引用): + +```python +from Undefined.config.loader import Config # → Undefined.config.Config +from Undefined.ai.client import AIClient +from Undefined.attachments import AttachmentRegistry +from Undefined.skills.tools import ToolRegistry +from Undefined.cognitive.service import CognitiveService +from Undefined.knowledge.manager import KnowledgeManager +from Undefined.memes.service import MemeService +from Undefined.api.app import RuntimeAPIServer +``` + +拆分后的各模块旁保留 compatibility shim 文件,旧 import 路径仍可用(见各 shim 文件顶部的 re-export)。 + +### 内部模块(不承诺稳定) + +以下模块**不会**进入根包 re-export,也不保证跨版本兼容: + +- `Undefined.main`, `Undefined.webui`, `Undefined.handlers`, `Undefined.onebot` +- `Undefined.config.coercers`, `Undefined.config.model_parsers` +- `Undefined.utils.io`, `Undefined.utils.paths` + +--- + +## 配置 API + +库嵌入场景的核心入口是 `Config.from_mapping()` 与 opt-in 的 `set_config()`。 + +### 加载优先级 + +``` +Python 显式 mapping / override > config.toml > 环境变量 > 代码默认值 +``` + +- `Config.from_mapping()` / `Config.builder()`:**不读取** `config.toml`,适合测试与无文件部署。 +- `Config.load()`:从指定或 CWD 下的 `config.toml` 加载(CLI 路径)。 +- 环境变量仅在 TOML / mapping **未提供**对应项时兜底;详见 [配置详解 — 环境变量兜底](configuration.md#8-环境变量兜底迁移建议)。 + +### `Config.from_mapping` + +从内存 dict 构建配置,结构与 `config.toml` 一致: + +```python +from Undefined.config import Config + +cfg = Config.from_mapping( + { + "core": {"bot_qq": 123456, "superadmin_qq": 654321}, + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o-mini", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o-mini", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-xxx", + "model_name": "gpt-4o-mini", + }, + }, + }, + strict=False, # 库嵌入 / 测试可放宽;生产 Bot 建议 strict=True +) + +print(cfg.chat_model.model_name) # gpt-4o-mini +``` + +`strict=True` 时缺失必填项(如 `onebot.ws_url`、各模型 `api_url` 等)会抛出异常;行为与 CLI 严格模式一致。 + +### `Config.builder` + +链式构建器,适合在 base mapping 上覆盖少量字段: + +```python +cfg = ( + Config.builder() + .with_mapping({"onebot": {"ws_url": "ws://127.0.0.1:3001"}, "models": {...}}) + .override(log_level="DEBUG") + .build(strict=False) +) +``` + +### `set_config`(opt-in 单例注入) + +将已构建的 `Config` 注入全局单例,供 `get_config()` 读取: + +```python +from Undefined.config import Config, get_config, set_config + +cfg = Config.from_mapping({...}, strict=False) +set_config(cfg) + +assert get_config(strict=False) is cfg +``` + +**约束**: + +- `set_config()` 仅供库嵌入 opt-in 使用;**CLI / WebUI 启动链不得调用**。 +- 未调用 `set_config()` 时,`get_config()` 仍走 CWD 下 `./config.toml`(与 CLI 行为一致)。 + +### 纯环境变量构建 + +mapping 为空时,已注册的环境变量仍可兜底填充配置: + +```python +import os + +os.environ["ONEBOT_WS_URL"] = "ws://127.0.0.1:3001" +os.environ["CHAT_MODEL_API_URL"] = "https://api.example/v1" +# ... 其他必填 env + +cfg = Config.from_mapping({}, strict=False) +``` + +完整 env 注册表见 [配置详解 §8](configuration.md#8-环境变量兜底迁移建议)。 + +--- + +## 典型嵌入示例 + +### 单元测试 + +```python +from Undefined.config import Config, set_config + +@pytest.fixture +def app_config(): + cfg = Config.from_mapping(MINIMAL_MAPPING, strict=False) + set_config(cfg) + yield cfg +``` + +### 脚本中复用 AIClient + +```python +from Undefined.config import Config +from Undefined.ai.client import AIClient + +cfg = Config.from_mapping({...}, strict=False) +client = AIClient(cfg) +# 使用 client 发起 LLM 请求 … +``` + +### 挂载 Runtime API + +```python +from Undefined.config import Config, set_config +from Undefined.api import RuntimeAPIServer, RuntimeAPIContext + +cfg = Config.from_mapping({...}, strict=True) +set_config(cfg) +server = RuntimeAPIServer(RuntimeAPIContext(...)) +``` + +--- + +## 公共 API 符号表 + +根包与子包 `__all__` 中列出的符号为稳定面;semver minor 内不 breaking。 + +### 根包 re-export(`stable`) + +| 符号 | 定义模块 | 说明 | +|------|----------|------| +| `__version__` | `Undefined` | 包版本 | +| `Config` | `Undefined.config` | 应用配置 dataclass | +| `get_config` | `Undefined.config` | 获取全局配置单例 | +| `set_config` | `Undefined.config` | opt-in 注入 Config(CLI 不调用) | +| `Config.builder` | `Undefined.config` | 链式配置构建器 | +| `Config.from_mapping` | `Undefined.config` | 从 dict 构建配置 | +| `AIClient` | `Undefined.ai` | LLM 请求客户端 | +| `ToolRegistry` | `Undefined.skills` | 工具注册表 | +| `AgentRegistry` | `Undefined.skills` | Agent 注册表 | +| `PipelineRegistry` | `Undefined.skills` | 自动处理管线注册表 | +| `BaseRegistry` | `Undefined.skills.registry` | 注册表基类 | +| `AnthropicSkillRegistry` | `Undefined.skills.anthropic_skills` | Anthropic Skills 注册表 | +| `CognitiveService` | `Undefined.cognitive` | 认知记忆服务 | +| `KnowledgeManager` | `Undefined.knowledge` | 本地知识库管理 | +| `MemeService` | `Undefined.memes` | 表情包库服务 | +| `AttachmentRegistry` | `Undefined.attachments` | 附件 UID 登记 | +| `RuntimeAPIServer` | `Undefined.api` | 主进程 Runtime API 服务 | +| `RuntimeAPIContext` | `Undefined.api` | Runtime API 运行时上下文 | + +### 子包公开面 + +| 包 | 稳定性 | 符号 | +|----|--------|------| +| `Undefined.config` | stable | `Config`, `get_config`, `get_config_manager`, `set_config`, `WebUISettings`, `load_webui_settings`, `ChatModelConfig`, `VisionModelConfig`, `SecurityModelConfig`, `APIConfig`, `AgentModelConfig`, `EmbeddingModelConfig`, `GrokModelConfig`, `RerankModelConfig`, `ModelPool`, `ModelPoolEntry`, `MemeConfig`, `MessageBatcherConfig`, `RenderCacheConfig` | +| `Undefined.ai` | stable | `AIClient` | +| `Undefined.skills` | stable | `ToolRegistry`, `AgentRegistry`, `PipelineRegistry` | +| `Undefined.skills.registry` | subpackage | `BaseRegistry`, `SkillItem`, `SkillStats`, `RegistryExecutionTimeoutError` | +| `Undefined.skills.anthropic_skills` | subpackage | `AnthropicSkillRegistry` | +| `Undefined.skills.pipelines` | subpackage | `PipelineRegistry`, `PipelineDetection` | +| `Undefined.cognitive` | stable | `CognitiveService`, `CognitiveVectorStore`, `ProfileStorage`, `HistorianWorker`, `JobQueue` | +| `Undefined.knowledge` | stable | `KnowledgeManager`, `Embedder`, `Reranker`, `RetrievalRuntime` | +| `Undefined.memes` | stable | `MemeService`, `MemeStore`, `MemeWorker`, `MemeVectorStore`, `MemeRecord`, `MemeSearchItem`, `MemeSourceRecord` | +| `Undefined.attachments` | stable | `AttachmentRegistry` | +| `Undefined.api` | stable | `RuntimeAPIServer`, `RuntimeAPIContext` | +| `Undefined.mcp` | subpackage | `MCPToolRegistry`, `MCPToolSetRegistry` | + +--- + +## 相关文档 + +- [配置详解](configuration.md) — TOML 字段、热更新、库嵌入(§2)、环境变量全表(§8) +- [安装与部署](deployment.md) — CLI 部署与库嵌入交叉引用 +- [Runtime API / OpenAPI](openapi.md) — HTTP 集成 +- [开发者与拓展中心](development.md) — 源码结构与自检命令 diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py index 599f3abb..7f4b5084 100644 --- a/src/Undefined/ai/client.py +++ b/src/Undefined/ai/client.py @@ -216,9 +216,7 @@ def __init__( else: self.attachment_registry = AttachmentRegistry(http_client=self._http_client) - # 私聊发送回调 self._send_private_message_callback: Optional[SendPrivateMessageCallback] = None - # 发送图片回调 self._send_image_callback: Optional[ Callable[[int, str, str], Awaitable[None]] ] = None @@ -227,7 +225,6 @@ def __init__( self.current_group_id: Optional[int] = None self.current_user_id: Optional[int] = None - # 初始化工具注册表 base_dir = Path(__file__).resolve().parents[1] self.tool_registry = ToolRegistry(base_dir / "skills" / "tools") self.agent_registry = AgentRegistry(base_dir / "skills" / "agents") @@ -245,7 +242,6 @@ def __init__( anthropic_skill_registry=self.anthropic_skill_registry, ) - # 初始化模型选择器 self.model_selector = ModelSelector() # 绑定上下文资源扫描路径(基于注册表 watch_paths) @@ -275,7 +271,6 @@ def __init__( # 后台任务引用集合(防止被 GC) self._background_tasks: set[asyncio.Task[Any]] = set() - # 保存配置供后续使用 runtime_config = self._get_runtime_config() self._intro_config = AgentIntroGenConfig( enabled=runtime_config.agent_intro_autogen_enabled, @@ -351,7 +346,6 @@ async def init_mcp_async() -> None: self._mcp_init_task = asyncio.create_task(init_mcp_async()) - # 异步加载模型偏好 async def load_preferences_async() -> None: try: await self.model_selector.load_preferences() @@ -365,7 +359,6 @@ async def load_preferences_async() -> None: async def close(self) -> None: logger.info("[清理] 正在关闭 AIClient...") - # 1) 停止后台任务(避免关闭 HTTP client 后仍有请求在跑) intro_gen = getattr(self, "_agent_intro_generator", None) if intro_gen is not None: await intro_gen.stop() @@ -390,7 +383,6 @@ async def close(self) -> None: if hasattr(self, "_prompt_builder") and self._prompt_builder is not None: self._prompt_builder.set_cognitive_service(None) - # 2) 等待 MCP 初始化完成,再关闭 MCP toolsets if hasattr(self, "_mcp_init_task") and not self._mcp_init_task.done(): await self._mcp_init_task @@ -409,7 +401,6 @@ async def close(self) -> None: except Exception as exc: logger.warning("[清理] 刷新附件注册表失败: %s", exc) - # 3) 最后关闭共享 HTTP client if hasattr(self, "_http_client"): logger.info("[清理] 正在关闭 AIClient HTTP 客户端...") await self._http_client.aclose() @@ -932,7 +923,6 @@ async def _maybe_prefetch_tools( results: list[tuple[str, Any]] = [] for name in to_run: try: - # 为特定工具准备参数 tool_args: dict[str, Any] = {} if name == "get_current_time": tool_args = {"format": "text", "include_lunar": True} diff --git a/src/Undefined/ai/client/__init__.py b/src/Undefined/ai/client/__init__.py new file mode 100644 index 00000000..40f823ad --- /dev/null +++ b/src/Undefined/ai/client/__init__.py @@ -0,0 +1,37 @@ +"""AI 客户端子包。 + +对外稳定入口:``AIClient``;旧路径 ``Undefined.ai.client`` 通过 shim 保持兼容。 +""" + +from Undefined.ai.client.ask_loop import ClientAskLoopMixin +from Undefined.ai.client.setup import ( + MISSING_TOOL_CALL_RETRY_HINT, + SendMessageCallback, + SendPrivateMessageCallback, + _INVALID_TOOL_CALL_CONTENT, + _build_invalid_tool_call_response, + _resolve_summary_model_config, +) + +# 会话消息拉取 helper,供 ask 与 slash 命令共用 +from Undefined.services.message_summary_fetch import fetch_session_messages + + +# MRO:ClientAskLoopMixin → ClientQueueMixin → ClientSetupMixin,能力按 mixin 分层叠加 +class AIClient(ClientAskLoopMixin): + """AI 模型客户端。 + + 协调 Prompt 构建、队列化 LLM 请求、工具调用与多模态/摘要能力。 + """ + + +__all__ = [ + "AIClient", + "MISSING_TOOL_CALL_RETRY_HINT", + "SendMessageCallback", + "SendPrivateMessageCallback", + "_INVALID_TOOL_CALL_CONTENT", + "_build_invalid_tool_call_response", + "_resolve_summary_model_config", + "fetch_session_messages", +] diff --git a/src/Undefined/ai/client/ask_loop.py b/src/Undefined/ai/client/ask_loop.py new file mode 100644 index 00000000..75a074d9 --- /dev/null +++ b/src/Undefined/ai/client/ask_loop.py @@ -0,0 +1,623 @@ +"""AI 客户端 ask 主循环与工具调用迭代。""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Awaitable, Callable + +from Undefined.ai.client.queue import ClientQueueMixin +from Undefined.ai.client.setup import ( + MISSING_TOOL_CALL_RETRY_HINT, + SendMessageCallback, + _build_invalid_tool_call_response, +) +from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT +from Undefined.context import RequestContext +from Undefined.services.message_summary_fetch import fetch_session_messages +from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.tool_calls import parse_tool_arguments + +logger = logging.getLogger(__name__) + + +class ClientAskLoopMixin(ClientQueueMixin): + """``ask()`` 多轮工具调用主循环。""" + + async def ask( + self, + question: str, + context: str = "", + send_message_callback: SendMessageCallback | None = None, + get_recent_messages_callback: Callable[ + [str, str, int, int], Awaitable[list[dict[str, Any]]] + ] + | None = None, + get_image_url_callback: Callable[[str], Awaitable[str | None]] | None = None, + get_forward_msg_callback: Callable[[str], Awaitable[list[dict[str, Any]]]] + | None = None, + send_like_callback: Callable[[int, int], Awaitable[None]] | None = None, + sender: Any = None, + history_manager: Any = None, + onebot_client: Any = None, + scheduler: Any = None, + extra_context: dict[str, Any] | None = None, + ) -> str: + """发送问题给 AI 并获取回复 (支持工具调用和迭代) + + 参数: + question: 用户输入的问题 + context: 额外的上下文背景 + send_message_callback: 发送消息的回调,支持可选的 reply_to + get_recent_messages_callback: 获取上下文历史消息的回调 + get_image_url_callback: 获取图片 URL 的回调 + get_forward_msg_callback: 获取合并转发内容的回调 + send_like_callback: 点赞回调 + sender: 消息发送助手实例 + history_manager: 历史记录管理器实例 + onebot_client: OneBot 客户端实例 + scheduler: 任务调度器实例 + extra_context: 额外的上下文负载 + + 返回: + AI 生成的最终文本回复 + """ + # ===== 阶段一:从 RequestContext / extra_context 组装 pre_context ===== + ctx = RequestContext.current() + pre_context: dict[str, Any] = {} + if ctx: + if ctx.group_id is not None: + pre_context["group_id"] = ctx.group_id + if ctx.user_id is not None: + pre_context["user_id"] = ctx.user_id + if ctx.sender_id is not None: + pre_context["sender_id"] = ctx.sender_id + pre_context["request_type"] = ctx.request_type + pre_context["request_id"] = ctx.request_id + if extra_context: + pre_context.update(extra_context) + + # ===== 阶段二:构建 LLM messages 与 OpenAI tools schema ===== + messages = await self._prompt_builder.build_messages( + question, + get_recent_messages_callback=get_recent_messages_callback, + extra_context=extra_context, + ) + + tools = self.tool_manager.get_openai_tools() + tools = self._filter_tools_for_runtime_config(tools) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[AI消息] 构建完成: messages=%s tools=%s question_len=%s", + len(messages), + len(tools), + len(question), + ) + log_debug_json(logger, "[AI消息内容]", messages) + + # ===== 阶段三:组装 tool_context,注入回调、服务与 RequestContext 字段 ===== + tool_context = ctx.get_resources() if ctx else {} + tool_context["conversation_ended"] = False + tool_context.setdefault("agent_histories", {}) + + # 显式注入 RequestContext 的核心字段(与 tooling.py:execute_tool_call 保持一致) + if ctx: + if ctx.group_id is not None: + tool_context.setdefault("group_id", ctx.group_id) + if ctx.user_id is not None: + tool_context.setdefault("user_id", ctx.user_id) + if ctx.sender_id is not None: + tool_context.setdefault("sender_id", ctx.sender_id) + tool_context.setdefault("request_type", ctx.request_type) + tool_context.setdefault("request_id", ctx.request_id) + + if extra_context: + tool_context.update(extra_context) + + # 注入常用资源(用于工具执行) + tool_context.setdefault("ai_client", self) + tool_context.setdefault("runtime_config", self._get_runtime_config()) + tool_context.setdefault("search_wrapper", self._search_wrapper) + tool_context.setdefault( + "crawl4ai_available", self._crawl4ai_capabilities.available + ) + tool_context.setdefault( + "crawl4ai_proxy_config_available", + self._crawl4ai_capabilities.proxy_config_available, + ) + tool_context.setdefault("end_summary_storage", self._end_summary_storage) + tool_context.setdefault("end_summaries", self._prompt_builder.end_summaries) + tool_context.setdefault( + "send_private_message_callback", self._send_private_message_callback + ) + tool_context.setdefault("send_message_callback", send_message_callback) + tool_context.setdefault( + "get_recent_messages_callback", get_recent_messages_callback + ) + + async def fetch_session_messages_callback( + *, + group_id: int, + user_id: int, + count: int | None = None, + time_range: str | None = None, + ) -> str: + return await fetch_session_messages( + history_manager, + group_id=group_id, + user_id=user_id, + count=count, + time_range=time_range, + runtime_config=self._get_runtime_config(), + ) + + tool_context.setdefault( + "fetch_session_messages_callback", fetch_session_messages_callback + ) + tool_context.setdefault("get_image_url_callback", get_image_url_callback) + tool_context.setdefault("get_forward_msg_callback", get_forward_msg_callback) + tool_context.setdefault("send_like_callback", send_like_callback) + tool_context.setdefault("sender", sender) + tool_context.setdefault("history_manager", history_manager) + tool_context.setdefault("onebot_client", onebot_client) + tool_context.setdefault("scheduler", scheduler) + tool_context.setdefault("send_image_callback", self._send_image_callback) + tool_context.setdefault( + "attachment_registry", + getattr(self, "attachment_registry", None), + ) + tool_context.setdefault("memory_storage", self.memory_storage) + tool_context.setdefault("knowledge_manager", self._knowledge_manager) + tool_context.setdefault("cognitive_service", self._cognitive_service) + tool_context.setdefault("meme_service", self._meme_service) + tool_context.setdefault("current_question", question) + message_ids = tool_context.get("message_ids") + if not isinstance(message_ids, list): + message_ids = [] + tool_context["message_ids"] = message_ids + trigger_message_id = tool_context.get("trigger_message_id") + if trigger_message_id is not None: + trigger_message_id_text = str(trigger_message_id).strip() + if trigger_message_id_text and trigger_message_id_text not in message_ids: + message_ids.append(trigger_message_id_text) + + # ===== 阶段四:模型选择、思维链/重试参数与主循环状态初始化 ===== + await self.model_selector.wait_ready() + selected_model_name = pre_context.get("selected_model_name") + if selected_model_name: + effective_chat_config = self._find_chat_config_by_name(selected_model_name) + else: + effective_chat_config = self.chat_config + + max_iterations = 1000 + iteration = 0 + conversation_ended = False + cot_compat = getattr(effective_chat_config, "thinking_tool_call_compat", False) + capture_reasoning = cot_compat or bool( + getattr(effective_chat_config, "reasoning_content_replay", False) + ) + cot_compat_logged = False + cot_missing_logged = False + transport_state: dict[str, Any] | None = None + queue_lane = self._resolve_queue_lane(tool_context.get("queue_lane")) + pre_tool_failure_count = 0 + missing_tool_call_count = 0 + last_missing_tool_call_content = "" + runtime_config = self._get_runtime_config() + max_pre_tool_retries = max( + 0, + int(getattr(runtime_config, "ai_request_max_retries", 0) or 0), + ) + max_missing_tool_call_retries = max( + 0, + int(getattr(runtime_config, "missing_tool_call_retries", 3) or 0), + ) + + # ===== 阶段五:多轮 LLM + 工具调用主循环(每轮一次请求) ===== + while iteration < max_iterations: + iteration += 1 + logger.info(f"[AI决策] 开始第 {iteration} 轮迭代...") + message_checkpoint_len = len(messages) + transport_state_checkpoint = transport_state + + try: + result = await self.submit_queued_llm_call( + model_config=effective_chat_config, + messages=messages, + max_tokens=8192, + call_type="chat", + tools=tools, + tool_choice="auto", + transport_state=transport_state, + queue_lane=queue_lane, + ) + except Exception as exc: + logger.exception( + "[queued_llm_error] call_type=chat model=%s lane=%s iteration=%s error=%s", + effective_chat_config.model_name, + queue_lane, + iteration, + exc, + ) + raise + + try: + tool_execution_started = False + tool_name_map = ( + result.get("_tool_name_map") if isinstance(result, dict) else None + ) + api_to_internal: dict[str, str] = {} + if isinstance(tool_name_map, dict): + raw_api_to_internal = tool_name_map.get("api_to_internal") + if isinstance(raw_api_to_internal, dict): + # LLM 出站时工具名可能被编码,执行前映射回内部名 + api_to_internal = { + str(k): str(v) for k, v in raw_api_to_internal.items() + } + + next_transport_state = ( + result.get("_transport_state") if isinstance(result, dict) else None + ) + transport_state = ( + next_transport_state + if isinstance(next_transport_state, dict) + else None + ) + + choice = result.get("choices", [{}])[0] + message = choice.get("message", {}) + content: str = message.get("content") or "" + reasoning_content = message.get("reasoning_content") + tool_calls = message.get("tool_calls", []) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[AI响应] content_len=%s tool_calls=%s", + len(content), + len(tool_calls), + ) + # 无 tool_calls 与有 tool_calls 走不同分支 + if tool_calls: + log_debug_json(logger, "[AI工具调用]", tool_calls) + + log_thinking = self._get_runtime_config().log_thinking + if ( + capture_reasoning + and tools + and log_thinking + and not cot_compat_logged + ): + cot_compat_logged = True + logger.info( + "[思维链兼容] 多轮工具调用 reasoning_content 本地回填已启用" + ) + if ( + capture_reasoning + and log_thinking + and tools + and getattr(effective_chat_config, "thinking_enabled", False) + and not reasoning_content + and tool_calls + and not cot_missing_logged + ): + cot_missing_logged = True + message_keys = ( + ", ".join(sorted(message.keys())) + if isinstance(message, dict) + else type(message).__name__ + ) + logger.info( + "[思维链兼容] 未在响应中发现 reasoning_content(可能是模型/服务商不返回思维链);message_keys=%s", + message_keys, + ) + + # 部分模型会同时返回文本与 tool_calls;对外动作以工具为准,丢弃 content + if content.strip() and tool_calls: + logger.debug( + "检测到 content 与工具调用同时存在,忽略 content,仅执行工具调用" + ) + content = "" + + # 无 tool_calls 与有 tool_calls 走不同分支 + if not tool_calls: + if conversation_ended: + logger.info( + "[AI回复] 会话结束,返回最终内容: length=%s", + len(content), + ) + return content + + # 未调用工具:累计重试次数,超限则 fallback 发送或直接返回文本 + if content.strip(): + last_missing_tool_call_content = content.strip() + missing_tool_call_count += 1 + if missing_tool_call_count > max_missing_tool_call_retries: + logger.warning( + "[AI回复] 模型连续未调用工具,停止重试: iteration=%s retries=%s/%s content_len=%s", + iteration, + missing_tool_call_count - 1, + max_missing_tool_call_retries, + len(content), + ) + fallback_content = last_missing_tool_call_content + if fallback_content and send_message_callback is not None: + try: + await send_message_callback(fallback_content) + tool_context["message_sent_this_turn"] = True + current_ctx = RequestContext.current() + if current_ctx is not None: + current_ctx.set_resource( + "message_sent_this_turn", True + ) + return "" + except Exception: + logger.exception("[AI回复] fallback 发送失败") + return fallback_content + + logger.warning( + "[AI回复] 模型返回文本但未调用工具(iteration=%s retry=%s/%s content_len=%s),要求重试", + iteration, + missing_tool_call_count, + max_missing_tool_call_retries, + len(content), + ) + assistant_retry_message: dict[str, Any] = { + "role": "assistant", + "content": content, + } + if capture_reasoning and reasoning_content is not None: + assistant_retry_message["reasoning_content"] = reasoning_content + messages.append(assistant_retry_message) + messages.append( + { + "role": "user", + "content": MISSING_TOOL_CALL_RETRY_HINT, + } + ) + continue + + assistant_message: dict[str, Any] = { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + missing_tool_call_count = 0 + last_missing_tool_call_content = "" + phase = message.get("phase") + if phase is not None: + assistant_message["phase"] = phase + output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) + if isinstance(output_items, list): + assistant_message[RESPONSES_OUTPUT_ITEMS_KEY] = output_items + if capture_reasoning and reasoning_content is not None: + assistant_message["reasoning_content"] = reasoning_content + messages.append(assistant_message) + + tool_tasks = [] + tool_call_ids = [] + tool_api_names: list[str] = [] + tool_internal_names: list[str] = [] + end_tool_call: dict[str, Any] | None = None + end_tool_args: dict[str, Any] = {} + tool_results: list[Any] = [] + + # 逐个处理模型返回的 tool_call + for tool_call in tool_calls: + call_id = "" + if isinstance(tool_call, dict): + call_id = str(tool_call.get("id", "") or "") + function = tool_call.get("function") + else: + function = None + if not isinstance(function, dict): + logger.warning( + "[工具调用] 跳过无效工具调用: missing_function ID=%s", + call_id, + ) + messages.append(_build_invalid_tool_call_response(tool_call)) + continue + api_function_name = str(function.get("name", "") or "").strip() + if not api_function_name: + logger.warning( + "[工具调用] 跳过无效工具调用: empty_name ID=%s", + call_id, + ) + messages.append(_build_invalid_tool_call_response(tool_call)) + continue + raw_args = function.get("arguments") + + internal_function_name = api_to_internal.get( + api_function_name, + api_function_name, + ) + + if internal_function_name != api_function_name: + logger.info( + "[工具准备] 准备调用: %s (原名: %s) (ID=%s)", + internal_function_name, + api_function_name, + call_id, + ) + else: + logger.info( + "[工具准备] 准备调用: %s (ID=%s)", + api_function_name, + call_id, + ) + logger.debug( + f"[工具参数] {api_function_name} 参数: {redact_string(str(raw_args))}" + ) + + function_args = parse_tool_arguments( + raw_args, + logger=logger, + tool_name=str(api_function_name), + ) + + if not isinstance(function_args, dict): + function_args = {} + + # 检测 end 工具,暂存后统一处理 + if internal_function_name == "end": + # 无 tool_calls 与有 tool_calls 走不同分支 + if len(tool_calls) > 1: + logger.warning( + "[工具调用] end 与其他工具同时调用," + "将先执行其他工具,end 将返回拒绝结果" + ) + end_tool_call = tool_call + end_tool_args = function_args + continue + + tool_call_ids.append(call_id) + tool_api_names.append(str(api_function_name)) + tool_internal_names.append(str(internal_function_name)) + tool_tasks.append( + self.tool_manager.execute_tool( + str(internal_function_name), function_args, tool_context + ) + ) + + if tool_tasks: + tool_execution_started = True + logger.info( + "[工具执行] 开始并发执行 %s 个工具调用: %s", + len(tool_tasks), + ", ".join(tool_internal_names), + ) + tool_results = await asyncio.gather( + *tool_tasks, + return_exceptions=True, + ) + + for i, tool_result in enumerate(tool_results): + call_id = tool_call_ids[i] + api_fname = tool_api_names[i] + internal_fname = tool_internal_names[i] + + if isinstance(tool_result, Exception): + logger.error( + "[工具异常] %s (ID=%s) 执行抛出异常: %s", + internal_fname, + call_id, + tool_result, + ) + content_str = f"执行失败: {str(tool_result)}" + else: + content_str = str(tool_result) + logger.debug( + "[工具响应] %s (ID=%s) 返回内容长度=%s", + internal_fname, + call_id, + len(content_str), + ) + if logger.isEnabledFor(logging.DEBUG): + log_debug_json( + logger, + f"[工具响应体] {internal_fname} (ID={call_id})", + content_str, + ) + + messages.append( + { + "role": "tool", + "tool_call_id": call_id, + "name": api_fname, + "content": content_str, + } + ) + + # 如果是 get_forward_msg 工具调用,将其结果写入历史记录 + if internal_fname == "get_forward_msg" and not isinstance( + tool_result, Exception + ): + task = asyncio.create_task( + self._save_forward_to_history( + content_str, + pre_context, + history_manager, + ) + ) + task.add_done_callback( + lambda t: t.exception() if not t.cancelled() else None + ) + + # 会话是否已由 end 工具标记结束 + if tool_context.get("conversation_ended"): + conversation_ended = True + logger.info( + "[会话状态] 工具触发会话结束标记: tool=%s", + internal_fname, + ) + + if end_tool_call: + end_call_id = end_tool_call.get("id", "") + end_api_name = end_tool_call.get("function", {}).get("name", "end") + if tool_tasks: + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": END_CO_CALL_REJECT_CONTENT, + } + ) + logger.info( + "[工具调用] end 与其他工具同时调用," + "其它工具已执行,end 已回填拒绝响应" + ) + else: + # end 单独调用,正常执行(参数已在循环中解析) + tool_execution_started = True + end_result = await self.tool_manager.execute_tool( + "end", end_tool_args, tool_context + ) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": str(end_result), + } + ) + # 会话是否已由 end 工具标记结束 + if tool_context.get("conversation_ended"): + conversation_ended = True + logger.info("[会话状态] end 工具触发会话结束") + + # 会话是否已由 end 工具标记结束 + if conversation_ended: + logger.info("[会话状态] 对话已结束(调用 end 工具)") + return "" + pre_tool_failure_count = 0 + + except Exception as exc: + if ( + not tool_execution_started + and pre_tool_failure_count < max_pre_tool_retries + ): + pre_tool_failure_count += 1 + del messages[message_checkpoint_len:] + transport_state = transport_state_checkpoint + logger.warning( + "[chat.pre_tool_retry] model=%s lane=%s retry=%s/%s iteration=%s error=%s", + effective_chat_config.model_name, + queue_lane, + pre_tool_failure_count, + max_pre_tool_retries, + iteration, + exc, + ) + continue + # 工具已执行或重试用尽:吞掉异常,避免向用户暴露内部错误 + logger.exception( + "[chat.suppressed_error] model=%s lane=%s iteration=%s error=%s", + effective_chat_config.model_name, + queue_lane, + iteration, + exc, + ) + return "" + + logger.warning("[AI决策] 达到最大迭代次数,未能完成处理") + return "达到最大迭代次数,未能完成处理" diff --git a/src/Undefined/ai/client/queue.py b/src/Undefined/ai/client/queue.py new file mode 100644 index 00000000..39bb1df7 --- /dev/null +++ b/src/Undefined/ai/client/queue.py @@ -0,0 +1,283 @@ +"""AI 客户端队列化 LLM 调用与摘要请求。""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from uuid import uuid4 + +from Undefined.ai.parsing import extract_choices_content +from Undefined.ai.queue_budget import ( + compute_queued_llm_timeout_seconds, + resolve_effective_retry_count, +) +from Undefined.context import RequestContext +import Undefined.ai.client as ai_client_module +from Undefined.services.queue_manager import ( + ALL_QUEUE_LANES, + QUEUE_LANE_BACKGROUND, + QUEUE_LANE_GROUP_MENTION, + QUEUE_LANE_GROUP_NORMAL, + QUEUE_LANE_GROUP_SUPERADMIN, + QUEUE_LANE_PRIVATE, + QUEUE_LANE_SUPERADMIN, +) + +from Undefined.ai.client.setup import ClientSetupMixin + +logger = logging.getLogger(__name__) + + +class ClientQueueMixin(ClientSetupMixin): + """统一队列 LLM 调用与会话摘要投递。""" + + def _resolve_queue_lane(self, queue_lane: Any = None) -> str: + # 优先级:显式参数 > RequestContext 资源 > 按会话类型推断 > 后台 + queue_lane_text = str(queue_lane or "").strip().lower() + if queue_lane_text in ALL_QUEUE_LANES: + return queue_lane_text + + ctx = RequestContext.current() + if ctx is not None: + ctx_lane = str(ctx.get_resource("queue_lane") or "").strip().lower() + if ctx_lane in ALL_QUEUE_LANES: + return ctx_lane + + runtime_config = self._get_runtime_config() + superadmin_qq = int(getattr(runtime_config, "superadmin_qq", 0) or 0) + if ctx.request_type == "private": + if superadmin_qq > 0 and ( + ctx.user_id == superadmin_qq or ctx.sender_id == superadmin_qq + ): + return QUEUE_LANE_SUPERADMIN + return QUEUE_LANE_PRIVATE + if ctx.request_type == "group": + if superadmin_qq > 0 and ctx.sender_id == superadmin_qq: + return QUEUE_LANE_GROUP_SUPERADMIN + # @bot 走 mention 队列,与普通群聊隔离 + if bool(ctx.get_resource("is_at_bot")): + return QUEUE_LANE_GROUP_MENTION + return QUEUE_LANE_GROUP_NORMAL + + return QUEUE_LANE_BACKGROUND + + def _get_queued_llm_wait_timeout_seconds(self) -> float: + retry_count = resolve_effective_retry_count( + self._get_runtime_config(), + getattr(self, "_queue_manager", None), + ) + return compute_queued_llm_timeout_seconds( + self._get_runtime_config(), + self.chat_config, + retry_count=retry_count, + ) + + async def submit_queued_llm_call( + self, + model_config: Any, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + tool_choice: Any = "auto", + call_type: str = "background", + max_tokens: int | None = None, + transport_state: dict[str, Any] | None = None, + queue_lane: str | None = None, + ) -> dict[str, Any]: + """将 LLM 调用投递到统一队列,走统一发车间隔和重试逻辑。 + 无 queue_manager 时降级为直接调用。""" + effective_max_tokens = ( + max_tokens + if max_tokens is not None + else getattr(model_config, "max_tokens", 4096) + ) + resolved_queue_lane = self._resolve_queue_lane(queue_lane) + # 无队列管理器时直接请求,跳等车/重试封装 + if self._queue_manager is None: + return await self.request_model( + model_config=model_config, + messages=messages, + tools=tools, + tool_choice=tool_choice, + call_type=call_type, + max_tokens=effective_max_tokens, + transport_state=transport_state, + ) + request_id = uuid4().hex + event: asyncio.Event = asyncio.Event() + # 挂起表:QueueManager 回调 set_llm_call_result 时唤醒等待方 + self._pending_llm_calls[request_id] = (event, None) + model_name = getattr(model_config, "model_name", "default") + request: dict[str, Any] = { + "type": "queued_llm_call", + "request_id": request_id, + "model_config": model_config, + "messages": messages, + "tools": tools, + "tool_choice": tool_choice, + "call_type": call_type, + "max_tokens": effective_max_tokens, + "transport_state": transport_state, + } + ctx = RequestContext.current() + if ctx is not None: + if ctx.group_id is not None: + request["group_id"] = ctx.group_id + if ctx.user_id is not None: + request["user_id"] = ctx.user_id + logger.info( + "[queued_llm_enqueue] request_id=%s call_type=%s model=%s lane=%s messages=%s tools=%s", + request_id, + call_type, + model_name, + resolved_queue_lane, + len(messages), + bool(tools), + ) + receipt = await self._queue_manager.add_queued_llm_request( + request, + lane=resolved_queue_lane, + model_name=model_name, + ) + wait_timeout = compute_queued_llm_timeout_seconds( + self._get_runtime_config(), + model_config, + retry_count=resolve_effective_retry_count( + self._get_runtime_config(), self._queue_manager + ), + initial_wait_seconds=float( + getattr(receipt, "estimated_wait_seconds", 0.0) or 0.0 + ), + # 首次 dispatch 间隔已含在 estimated_wait 中,避免重复计入 + include_first_dispatch_interval=False, + ) + try: + await asyncio.wait_for(event.wait(), timeout=wait_timeout) + except asyncio.TimeoutError: + logger.exception( + "[queued_llm_wait_timeout] request_id=%s call_type=%s model=%s lane=%s timeout=%.1fs", + request_id, + call_type, + model_name, + resolved_queue_lane, + wait_timeout, + ) + raise + # finally:无论成败都执行清理 + finally: + entry = self._pending_llm_calls.pop(request_id, None) + _, result = entry if entry is not None else (None, None) + if isinstance(result, Exception): + raise result + return result or {} + + async def submit_background_llm_call( + self, + model_config: Any, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + tool_choice: Any = "auto", + call_type: str = "background", + max_tokens: int | None = None, + transport_state: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """后台 LLM 提交兼容包装。""" + return await self.submit_queued_llm_call( + model_config=model_config, + messages=messages, + tools=tools, + tool_choice=tool_choice, + call_type=call_type, + max_tokens=max_tokens, + transport_state=transport_state, + queue_lane=QUEUE_LANE_BACKGROUND, + ) + + def set_llm_call_result( + self, request_id: str, result: dict[str, Any] | Exception + ) -> None: + entry = self._pending_llm_calls.get(request_id) + if entry is None: + return + event, _ = entry + self._pending_llm_calls[request_id] = (event, result) + event.set() + + async def _summarize_message_history_queued( + self, + messages_text: str, + instruction: str = "", + ) -> str: + model_config = self._resolve_summary_model_for_requests() + built_messages = await self._summary_service.build_message_summary_messages( + # messages_text, instruction + messages_text, + instruction, + ) + result = await self.submit_queued_llm_call( + model_config=model_config, + messages=built_messages, + tools=None, + call_type="message_summary", + max_tokens=model_config.max_tokens, + ) + return extract_choices_content(result).strip() + + async def _merge_summaries_queued(self, summaries: list[str]) -> str: + if len(summaries) == 1: + return summaries[0] + + model_config = self._resolve_summary_model_for_requests() + messages = await self._summary_service.build_message_merge_messages(summaries) + result = await self.submit_queued_llm_call( + model_config=model_config, + messages=messages, + tools=None, + call_type="merge_message_summaries", + max_tokens=8192, + ) + return extract_choices_content(result).strip() + + async def summarize_command_session( + self, + history_manager: Any, + *, + group_id: int, + user_id: int, + count: int | None = None, + time_range: str | None = None, + instruction: str = "", + ) -> str: + """Fetch session messages and summarize via summary model without tools.""" + messages_text = await ai_client_module.fetch_session_messages( + history_manager, + group_id=group_id, + user_id=user_id, + count=count, + time_range=time_range, + runtime_config=self.runtime_config, + include_header=False, + ) + if not messages_text: + return "当前会话暂无消息记录" + if messages_text.startswith("无法解析时间范围"): + return messages_text + + input_budget = await self._summary_service.resolve_message_input_budget( + instruction + ) + total_tokens = self.count_tokens(messages_text) + if total_tokens <= input_budget: + return await self._summarize_message_history_queued( + # messages_text, instruction + messages_text, + instruction, + ) + + # 超长会话:分块摘要后再合并,避免超出上下文窗口 + chunks = self.split_messages_by_tokens(messages_text, input_budget) + summaries = [ + await self._summarize_message_history_queued(chunk, instruction) + for chunk in chunks + ] + return await self._merge_summaries_queued(summaries) diff --git a/src/Undefined/ai/client/setup.py b/src/Undefined/ai/client/setup.py new file mode 100644 index 00000000..068b9cfc --- /dev/null +++ b/src/Undefined/ai/client/setup.py @@ -0,0 +1,880 @@ +"""AI 客户端生命周期与配置。""" + +from __future__ import annotations + +import asyncio +import html +import logging +import re +from pathlib import Path +from typing import Any, Awaitable, Callable, Optional, Protocol, TYPE_CHECKING + +import httpx + +from Undefined.attachments import AttachmentRegistry +from Undefined.ai.llm import ModelRequester +from Undefined.ai.model_selector import ModelSelector +from Undefined.ai.multimodal import MultimodalAnalyzer +from Undefined.ai.prompts import PromptBuilder +from Undefined.ai.crawl4ai_support import get_crawl4ai_capabilities +from Undefined.ai.summaries import SummaryService +from Undefined.ai.tokens import TokenCounter +from Undefined.ai.tooling import ToolManager +from Undefined.config import ( + ChatModelConfig, + VisionModelConfig, + AgentModelConfig, + Config, + GrokModelConfig, +) +from Undefined.context import RequestContext +from Undefined.context_resource_registry import set_context_resource_scan_paths +from Undefined.end_summary_storage import EndSummaryStorage +from Undefined.memory import MemoryStorage +from Undefined.skills.agents import AgentRegistry +from Undefined.skills.agents.intro_generator import ( + AgentIntroGenConfig, + AgentIntroGenerator, +) +from Undefined.skills.anthropic_skills import AnthropicSkillRegistry +from Undefined.skills.tools import ToolRegistry +from Undefined.token_usage_storage import TokenUsageStorage +from Undefined.utils.logging import redact_string + +logger = logging.getLogger(__name__) + + +# 模型返回纯文本但未调用 tool 时,追加到 messages 的纠正提示(不写死具体 tool) +MISSING_TOOL_CALL_RETRY_HINT = ( + "【系统提示】你上一轮输出了纯文本且未调用任何工具。" + "本环境必须通过工具调用来完成对外动作与结束本轮处理。" + "请结合上文完整对话历史与已有 tool 返回结果,自行决定下一步应调用的工具;" + "不要直接以纯文本作为最终对外回复。" +) + + +_CONTENT_TAG_PATTERN = re.compile( + r"(.*?)", + re.DOTALL | re.IGNORECASE, +) + +_INVALID_TOOL_CALL_CONTENT = ( + "无效工具调用:工具名称为空或格式非法,系统已跳过执行。" + "请使用可用工具名重新调用,或调用 end 结束本轮。" +) + + +def _build_invalid_tool_call_response(tool_call: Any) -> dict[str, Any]: + """为模型发出的 malformed tool call 构造 tool 角色回填消息。""" + call_id = "" + tool_name = "" + if isinstance(tool_call, dict): + call_id = str(tool_call.get("id", "") or "") + function = tool_call.get("function") + if isinstance(function, dict): + tool_name = str(function.get("name", "") or "").strip() + return { + "role": "tool", + "tool_call_id": call_id, + "name": tool_name, + "content": _INVALID_TOOL_CALL_CONTENT, + } + + +class SendMessageCallback(Protocol): + def __call__( + self, message: str, reply_to: int | None = None + ) -> Awaitable[None]: ... + + +class SendPrivateMessageCallback(Protocol): + def __call__( + self, user_id: int, message: str, reply_to: int | None = None + ) -> Awaitable[None]: ... + + +# 尝试导入 langchain SearxSearchWrapper +if TYPE_CHECKING: + from langchain_community.utilities import ( + SearxSearchWrapper as SearxSearchWrapperType, + ) +else: + SearxSearchWrapperType = object + +_SearxSearchWrapper: type[SearxSearchWrapperType] | None +try: + from langchain_community.utilities import SearxSearchWrapper as _SearxSearchWrapper + + _SEARX_AVAILABLE = True +except Exception: + _SearxSearchWrapper = None + _SEARX_AVAILABLE = False + logger.warning( + "[初始化] langchain_community 未安装或 SearxSearchWrapper 不可用,搜索功能将禁用" + ) + + +def _attachment_remote_download_max_bytes(runtime_config: Config) -> int: + value = int(runtime_config.attachment_remote_download_max_size_mb) + return max(0, value) * 1024 * 1024 + + +def _attachment_cache_max_bytes(runtime_config: Config) -> int: + value = int(runtime_config.attachment_cache_max_total_size_mb) + return max(0, value) * 1024 * 1024 + + +def _attachment_cache_max_age_seconds(runtime_config: Config) -> int: + value = int(runtime_config.attachment_cache_max_age_days) + return max(0, value) * 24 * 60 * 60 + + +def _resolve_summary_model_config( + runtime_config: Config | None, + fallback: AgentModelConfig, +) -> AgentModelConfig: + if runtime_config is None: + # 回退到默认/主配置 + return fallback + if not getattr(runtime_config, "summary_model_configured", False): + # 回退到默认/主配置 + return fallback + summary_model = getattr(runtime_config, "summary_model", None) + if isinstance(summary_model, AgentModelConfig): + return summary_model + # 回退到默认/主配置 + return fallback + + +class ClientSetupMixin: + """AI 客户端初始化、配置热更新与资源清理。""" + + def __init__( + self, + chat_config: ChatModelConfig, + vision_config: VisionModelConfig, + agent_config: AgentModelConfig, + memory_storage: Optional[MemoryStorage] = None, + end_summary_storage: Optional[EndSummaryStorage] = None, + bot_qq: int = 0, + runtime_config: Config | None = None, + cognitive_service: Any = None, + ) -> None: + """初始化 AI 客户端 + + 参数: + chat_config: 对话模型配置 + vision_config: 视觉模型配置 + agent_config: 智能体模型配置 + memory_storage: 长期记忆存储 + end_summary_storage: 短期回忆存储 + bot_qq: 机器人自身的 QQ 号 + """ + self.chat_config = chat_config + self.vision_config = vision_config + self.agent_config = agent_config + self.bot_qq = bot_qq + self.runtime_config = runtime_config + self.memory_storage = memory_storage + self._end_summary_storage = end_summary_storage or EndSummaryStorage() + self._crawl4ai_capabilities = get_crawl4ai_capabilities() + + self._http_client = httpx.AsyncClient(timeout=480.0) + self._token_usage_storage = TokenUsageStorage() + self._requester = ModelRequester(self._http_client, self._token_usage_storage) + self._token_counter = TokenCounter() + self._knowledge_manager: Any = None + self._cognitive_service: Any = cognitive_service + self._meme_service: Any = None + if self.runtime_config is not None: + self.attachment_registry = AttachmentRegistry( + http_client=self._http_client, + remote_download_max_bytes=_attachment_remote_download_max_bytes( + self.runtime_config + ), + max_cache_bytes=_attachment_cache_max_bytes(self.runtime_config), + max_records=self.runtime_config.attachment_cache_max_records, + max_age_seconds=_attachment_cache_max_age_seconds(self.runtime_config), + url_reference_max_records=( + self.runtime_config.attachment_url_reference_max_records + ), + url_max_length=self.runtime_config.attachment_url_max_length, + ) + else: + self.attachment_registry = AttachmentRegistry(http_client=self._http_client) + + self._send_private_message_callback: Optional[SendPrivateMessageCallback] = None + self._send_image_callback: Optional[ + Callable[[int, str, str], Awaitable[None]] + ] = None + + # 当前群聊ID和用户ID(用于send_message工具) + self.current_group_id: Optional[int] = None + self.current_user_id: Optional[int] = None + + base_dir = Path(__file__).resolve().parents[1] + self.tool_registry = ToolRegistry(base_dir / "skills" / "tools") + self.agent_registry = AgentRegistry(base_dir / "skills" / "agents") + + # 初始化 Anthropic Agent Skills 注册表(可选,目录不存在时自动跳过) + anthropic_skills_dir = base_dir / "skills" / "anthropic_skills" + dot_delimiter = self._get_runtime_config().tools_dot_delimiter + self.anthropic_skill_registry = AnthropicSkillRegistry( + anthropic_skills_dir, + dot_delimiter=dot_delimiter, + ) + + self.tool_manager = ToolManager( + self.tool_registry, + self.agent_registry, + anthropic_skill_registry=self.anthropic_skill_registry, + ) + + self.model_selector = ModelSelector() + + # 绑定上下文资源扫描路径(基于注册表 watch_paths) + scan_paths = [ + p + for p in ( + self.tool_registry._watch_paths + self.agent_registry._watch_paths + ) + if p.exists() + ] + set_context_resource_scan_paths(scan_paths) + logger.debug( + "[初始化] 上下文资源扫描路径已绑定: count=%s", + len(scan_paths), + ) + + # Agent intro 生成器(延迟初始化,需要外部设置 queue_manager) + self._agent_intro_generator: Any | None = None + self._agent_intro_task: asyncio.Task[None] | None = None + self._queue_manager: Any | None = None + self._intro_config: Any | None = None + # 后台 LLM 调用挂起表(走队列的后台请求) + self._pending_llm_calls: dict[ + str, tuple[asyncio.Event, dict[str, Any] | Exception | None] + ] = {} + + # 后台任务引用集合(防止被 GC) + self._background_tasks: set[asyncio.Task[Any]] = set() + + runtime_config = self._get_runtime_config() + self._intro_config = AgentIntroGenConfig( + enabled=runtime_config.agent_intro_autogen_enabled, + queue_interval_seconds=runtime_config.agent_intro_autogen_queue_interval, + max_tokens=runtime_config.agent_intro_autogen_max_tokens, + cache_path=Path(runtime_config.agent_intro_hash_path), + ) + + # 启动 skills 热重载 + hot_reload_enabled = runtime_config.skills_hot_reload + if hot_reload_enabled: + interval = runtime_config.skills_hot_reload_interval + debounce = runtime_config.skills_hot_reload_debounce + self.tool_registry.start_hot_reload(interval=interval, debounce=debounce) + self.agent_registry.start_hot_reload(interval=interval, debounce=debounce) + self.anthropic_skill_registry.start_hot_reload( + interval=interval, debounce=debounce + ) + logger.info( + "[初始化] 技能热重载已启用: interval=%.2fs debounce=%.2fs", + interval, + debounce, + ) + else: + logger.info("[初始化] 技能热重载已禁用") + + # 初始化搜索 wrapper + self._search_wrapper: Optional[Any] = None + if _SEARX_AVAILABLE and _SearxSearchWrapper is not None: + searxng_url = runtime_config.searxng_url + if searxng_url: + try: + self._search_wrapper = _SearxSearchWrapper( + searx_host=searxng_url, k=10 + ) + logger.info( + "[初始化] SearxSearchWrapper 初始化成功: url=%s k=10", + redact_string(searxng_url), + ) + except Exception as exc: + logger.warning("[初始化] SearxSearchWrapper 初始化失败: %s", exc) + else: + logger.info("[初始化] SEARXNG_URL 未配置,搜索功能禁用") + + if self._crawl4ai_capabilities.available: + logger.info("[初始化] crawl4ai 可用,网页获取功能已启用") + else: + detail = self._crawl4ai_capabilities.error + if detail: + logger.warning( + "[初始化] crawl4ai 不可用,网页获取功能将禁用: %s", + detail, + ) + else: + logger.warning("[初始化] crawl4ai 不可用,网页获取功能将禁用") + + self._prompt_builder = PromptBuilder( + bot_qq=self.bot_qq, + memory_storage=self.memory_storage, + end_summary_storage=self._end_summary_storage, + runtime_config_getter=self._get_runtime_config, + anthropic_skill_registry=self.anthropic_skill_registry, + cognitive_service=self._cognitive_service, + ) + self._multimodal = MultimodalAnalyzer(self._requester, self.vision_config) + self._rebuild_summary_service() + + async def init_mcp_async() -> None: + try: + await self.tool_registry.initialize_mcp_toolsets() + except Exception as exc: + logger.warning("[初始化] 异步初始化 MCP 工具集失败: %s", exc) + + self._mcp_init_task = asyncio.create_task(init_mcp_async()) + + async def load_preferences_async() -> None: + try: + await self.model_selector.load_preferences() + except Exception as exc: + logger.warning("[初始化] 加载模型偏好失败: %s", exc) + + self._preferences_load_task = asyncio.create_task(load_preferences_async()) + + logger.info("[初始化] AIClient 初始化完成") + + async def close(self) -> None: + logger.info("[清理] 正在关闭 AIClient...") + + intro_gen = getattr(self, "_agent_intro_generator", None) + if intro_gen is not None: + await intro_gen.stop() + if hasattr(self, "_agent_intro_task") and self._agent_intro_task: + if not self._agent_intro_task.done(): + await self._agent_intro_task + knowledge_manager = getattr(self, "_knowledge_manager", None) + if knowledge_manager is not None and hasattr(knowledge_manager, "stop"): + try: + await knowledge_manager.stop() + except Exception as exc: + logger.warning("[清理] 关闭知识库管理器失败: %s", exc) + self._knowledge_manager = None + cognitive_service = getattr(self, "_cognitive_service", None) + if cognitive_service is not None: + if hasattr(cognitive_service, "stop"): + try: + await cognitive_service.stop() + except Exception as exc: + logger.warning("[清理] 关闭认知记忆服务失败: %s", exc) + self._cognitive_service = None + if hasattr(self, "_prompt_builder") and self._prompt_builder is not None: + self._prompt_builder.set_cognitive_service(None) + + if hasattr(self, "_mcp_init_task") and not self._mcp_init_task.done(): + await self._mcp_init_task + + if hasattr(self, "tool_registry"): + await self.tool_registry.stop_hot_reload() + await self.tool_registry.close_mcp_toolsets() + if hasattr(self, "agent_registry"): + await self.agent_registry.stop_hot_reload() + if hasattr(self, "anthropic_skill_registry"): + await self.anthropic_skill_registry.stop_hot_reload() + + attachment_registry = getattr(self, "attachment_registry", None) + if attachment_registry is not None and hasattr(attachment_registry, "flush"): + try: + await attachment_registry.flush() + except Exception as exc: + logger.warning("[清理] 刷新附件注册表失败: %s", exc) + + if hasattr(self, "_http_client"): + logger.info("[清理] 正在关闭 AIClient HTTP 客户端...") + await self._http_client.aclose() + + logger.info("[清理] AIClient 已关闭") + + def set_queue_manager(self, queue_manager: Any) -> None: + """设置队列管理器并启动 Agent intro 生成器。 + + 参数: + queue_manager: 队列管理器实例 + """ + if self._queue_manager is not None: + logger.warning("[AI客户端] queue_manager 已设置,跳过重复设置") + return + + if queue_manager is None: + logger.warning("[AI客户端] 传入的 queue_manager 为 None") + return + + self._queue_manager = queue_manager + + # 启动/刷新 Agent intro 自动生成 + if self._intro_config: + self.apply_intro_config(self._intro_config) + + def apply_intro_config(self, config: AgentIntroGenConfig) -> None: + """应用 Agent intro 生成器配置(支持热更新)。""" + self._intro_config = config + if self._queue_manager is None: + return + task = asyncio.create_task(self._refresh_intro_generator(config)) + task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) + + async def _refresh_intro_generator(self, config: AgentIntroGenConfig) -> None: + if not config.enabled: + if self._agent_intro_generator is not None: + await self._agent_intro_generator.stop() + self._agent_intro_generator = None + self._agent_intro_task = None + logger.info("[Agent介绍] 自动生成已关闭") + return + + if self._queue_manager is None: + return + + if self._agent_intro_generator is None: + self._agent_intro_generator = AgentIntroGenerator( + self.agent_registry.base_dir, + self, + self._queue_manager, + config, + ) + self._agent_intro_task = asyncio.create_task( + self._agent_intro_generator.start() + ) + logger.info( + "[Agent介绍] 自动生成已启动: interval=%.2fs max_tokens=%s cache=%s", + config.queue_interval_seconds, + config.max_tokens, + config.cache_path, + ) + return + + if self._agent_intro_generator.config.cache_path != config.cache_path: + # 缓存路径变更需重建生成器,否则 hash 与落盘目录不一致 + await self._agent_intro_generator.stop() + self._agent_intro_generator = AgentIntroGenerator( + self.agent_registry.base_dir, + self, + self._queue_manager, + config, + ) + self._agent_intro_task = asyncio.create_task( + self._agent_intro_generator.start() + ) + logger.info( + "[Agent介绍] 缓存路径变更,已重启生成器: cache=%s", + config.cache_path, + ) + return + + self._agent_intro_generator.config = config + + def set_knowledge_manager(self, manager: Any) -> None: + self._knowledge_manager = manager + + def set_cognitive_service(self, service: Any) -> None: + self._cognitive_service = service + if hasattr(self, "_prompt_builder") and self._prompt_builder is not None: + self._prompt_builder.set_cognitive_service(service) + logger.info( + "[AI客户端] 认知记忆服务已挂载并同步到 PromptBuilder: enabled=%s", + bool(getattr(service, "enabled", False)) if service is not None else False, + ) + + def set_meme_service(self, service: Any) -> None: + self._meme_service = service + resolver = None + async_resolver = None + if service is not None and hasattr(service, "resolve_global_image_sync"): + resolver = service.resolve_global_image_sync + if service is not None and hasattr(service, "resolve_global_image"): + async_resolver = service.resolve_global_image + self.attachment_registry.set_global_image_resolver(resolver) + self.attachment_registry.set_global_image_resolver_async(async_resolver) + logger.info( + "[AI客户端] 表情包服务已挂载: enabled=%s", + bool(getattr(service, "enabled", False)) if service is not None else False, + ) + + def apply_search_config(self, searxng_url: str) -> None: + """应用搜索服务配置(支持热更新)。""" + if not _SEARX_AVAILABLE or _SearxSearchWrapper is None: + if searxng_url: + logger.warning( + "[配置] 搜索组件不可用,已忽略 SEARXNG_URL=%s", + redact_string(searxng_url), + ) + else: + logger.info("[配置] 搜索组件不可用,搜索已禁用") + self._search_wrapper = None + return + + if not searxng_url: + self._search_wrapper = None + logger.info("[配置] SEARXNG_URL 未配置,搜索功能已禁用") + return + + try: + self._search_wrapper = _SearxSearchWrapper(searx_host=searxng_url, k=10) + logger.info( + "[配置] 搜索服务已更新: url=%s k=10", + redact_string(searxng_url), + ) + except Exception as exc: + logger.warning("[配置] 搜索服务更新失败: %s", exc) + self._search_wrapper = None + logger.info("[配置] 搜索服务已回退为禁用") + + def apply_model_configs( + self, + *, + chat_config: ChatModelConfig, + vision_config: VisionModelConfig, + agent_config: AgentModelConfig, + runtime_config: Config, + ) -> None: + """应用热更新后的模型配置。""" + self.chat_config = chat_config + self.vision_config = vision_config + self.agent_config = agent_config + self.runtime_config = runtime_config + self._multimodal = MultimodalAnalyzer(self._requester, self.vision_config) + self._rebuild_summary_service() + self.apply_attachment_config(runtime_config) + logger.info( + "[配置] AI 模型配置已热更新: chat=%s vision=%s agent=%s", + self.chat_config.model_name, + self.vision_config.model_name, + self.agent_config.model_name, + ) + + def apply_runtime_config(self, runtime_config: Config) -> None: + """应用不需要重建模型客户端的运行时配置。""" + self.runtime_config = runtime_config + self._rebuild_summary_service() + logger.info("[配置] AI 运行时配置已热更新") + + def _rebuild_summary_service(self) -> None: + self._summary_service = SummaryService( + self._requester, + _resolve_summary_model_config(self.runtime_config, self.agent_config), + self._token_counter, + ) + + def _resolve_summary_model_for_requests(self) -> AgentModelConfig: + return _resolve_summary_model_config(self.runtime_config, self.agent_config) + + def apply_attachment_config(self, runtime_config: Config) -> None: + self.attachment_registry.set_limits( + remote_download_max_bytes=_attachment_remote_download_max_bytes( + runtime_config + ), + max_cache_bytes=_attachment_cache_max_bytes(runtime_config), + max_records=runtime_config.attachment_cache_max_records, + max_age_seconds=_attachment_cache_max_age_seconds(runtime_config), + url_reference_max_records=( + runtime_config.attachment_url_reference_max_records + ), + url_max_length=runtime_config.attachment_url_max_length, + ) + + def count_tokens(self, text: str) -> int: + return self._token_counter.count(text) + + def _get_runtime_config(self) -> Config: + if self.runtime_config is not None: + return self.runtime_config + from Undefined.config import get_config + + return get_config(strict=False) + + def _find_chat_config_by_name(self, model_name: str) -> ChatModelConfig: + """根据模型名查找配置(主模型或池中模型)""" + if model_name == self.chat_config.model_name: + return self.chat_config + if self.chat_config.pool and self.chat_config.pool.enabled: + for entry in self.chat_config.pool.models: + if entry.model_name == model_name: + return self.model_selector._entry_to_chat_config( + # entry, self.chat_config + entry, + self.chat_config, + ) + return self.chat_config + + def _get_prefetch_tool_names(self) -> list[str]: + runtime_config = self._get_runtime_config() + return list(runtime_config.prefetch_tools) + + def _filter_tools_for_runtime_config( + self, tools: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + runtime_config = self._get_runtime_config() + enabled = bool(getattr(runtime_config, "nagaagent_mode_enabled", False)) + if enabled: + return tools + + # 关闭 NagaAgent 模式时:隐藏相关 Agent,避免被模型误调用。 + filtered: list[dict[str, Any]] = [] + for tool in tools: + function = tool.get("function") if isinstance(tool, dict) else None + name = function.get("name") if isinstance(function, dict) else None + if name == "naga_code_analysis_agent": + continue + filtered.append(tool) + return filtered + + def _prefetch_hide_tools(self) -> bool: + runtime_config = self._get_runtime_config() + return runtime_config.prefetch_tools_hide + + def _is_missing_tool_result(self, result: Any) -> bool: + if not isinstance(result, str): + return False + return result.startswith("未找到项目") or result.startswith("未找到 MCP 工具") + + async def _maybe_prefetch_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + call_type: str, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: + if not tools: + return messages, tools + + # 预先调用部分工具,为模型补充稳定上下文(同一 call_type 仅执行一次) + prefetch_names = self._get_prefetch_tool_names() + if not prefetch_names: + return messages, tools + + available_names = { + tool.get("function", {}).get("name") + for tool in tools + if tool.get("function") + } + prefetch_targets = [name for name in prefetch_names if name in available_names] + if not prefetch_targets: + return messages, tools + + # 使用 RequestContext 缓存已执行的预先调用,避免重复触发 + ctx = RequestContext.current() + cache: dict[str, list[str]] = {} + done: set[str] = set() + if ctx: + cache = ctx.get_resource("prefetch_tools", {}) or {} + done = set(cache.get(call_type, [])) + + to_run = [name for name in prefetch_targets if name not in done] + if not to_run: + return messages, tools + + results: list[tuple[str, Any]] = [] + for name in to_run: + try: + tool_args: dict[str, Any] = {} + if name == "get_current_time": + tool_args = {"format": "text", "include_lunar": True} + + result = await self.tool_manager.execute_tool( + name, + tool_args, + { + "runtime_config": self._get_runtime_config(), + "easter_egg_silent": True, + }, + ) + except Exception as exc: + logger.warning("[预先调用] %s 执行失败: %s", name, exc) + continue + + if self._is_missing_tool_result(result): + logger.warning("[预先调用] %s 未找到对应工具,跳过", name) + continue + + results.append((name, result)) + done.add(name) + + if not results: + return messages, tools + + if ctx: + cache[call_type] = sorted(done) + ctx.set_resource("prefetch_tools", cache) + + content_lines = ["【预先工具结果】"] + content_lines.extend([f"- {name}: {result}" for name, result in results]) + prefetch_message = {"role": "system", "content": "\n".join(content_lines)} + + insert_idx = 0 + # 紧接在已有 system 消息之后插入 prefetch 结果,保持指令顺序 + for idx, msg in enumerate(messages): + if msg.get("role") == "system": + insert_idx = idx + 1 + else: + break + new_messages = list(messages) + new_messages.insert(insert_idx, prefetch_message) + + if self._prefetch_hide_tools(): + hidden = set(name for name in done) + tools = [ + tool + for tool in tools + if tool.get("function", {}).get("name") not in hidden + ] + return new_messages, tools + + async def request_model( + self, + model_config: ( + ChatModelConfig | VisionModelConfig | AgentModelConfig | GrokModelConfig + ), + messages: list[dict[str, Any]], + max_tokens: int = 8192, + call_type: str = "chat", + tools: list[dict[str, Any]] | None = None, + tool_choice: str = "auto", + transport_state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + tools = self.tool_manager.maybe_merge_agent_tools(call_type, tools) + message_count_for_transport = len(messages) + # Responses 续轮(previous_response_id)时跳过 prefetch,避免重复注入系统消息 + if not ( + isinstance(transport_state, dict) + and transport_state.get("previous_response_id") + ): + messages, tools = await self._maybe_prefetch_tools( + # messages, tools, call_type + messages, + tools, + call_type, + ) + return await self._requester.request( + model_config=model_config, + messages=messages, + max_tokens=max_tokens, + call_type=call_type, + tools=tools, + tool_choice=tool_choice, + transport_state=transport_state, + message_count_for_transport=message_count_for_transport, + **kwargs, + ) + + def get_active_agent_mcp_registry(self, agent_name: str) -> Any | None: + return self.tool_manager.get_active_agent_mcp_registry(agent_name) + + async def analyze_multimodal( + self, + media_url: str, + media_type: str = "auto", + prompt_extra: str = "", + ) -> dict[str, str]: + return await self._multimodal.analyze(media_url, media_type, prompt_extra) + + async def describe_image( + self, image_url: str, prompt_extra: str = "" + ) -> dict[str, str]: + return await self._multimodal.describe_image(image_url, prompt_extra) + + async def judge_meme_image(self, image_url: str) -> dict[str, Any]: + return await self._multimodal.judge_meme_image(image_url) + + async def describe_meme_image(self, image_url: str) -> dict[str, Any]: + return await self._multimodal.describe_meme_image(image_url) + + def get_media_history(self, media_key: str) -> list[dict[str, str]]: + """获取指定媒体键的多模态分析历史 Q&A 记录。""" + return self._multimodal.get_history(media_key) + + async def save_media_history( + self, media_key: str, question: str, answer: str + ) -> None: + """保存一条多模态分析 Q&A 到历史记录并持久化到磁盘。""" + await self._multimodal.save_history(media_key, question, answer) + + async def summarize_chat(self, messages: str, context: str = "") -> str: + return await self._summary_service.summarize_chat(messages, context) + + async def merge_summaries(self, summaries: list[str]) -> str: + return await self._summary_service.merge_summaries(summaries) + + def split_messages_by_tokens(self, messages: str, max_tokens: int) -> list[str]: + return self._summary_service.split_messages_by_tokens(messages, max_tokens) + + async def generate_title(self, summary: str) -> str: + return await self._summary_service.generate_title(summary) + + def _extract_message_excerpt(self, question: str) -> str: + matched = _CONTENT_TAG_PATTERN.search(question) + if matched: + content = html.unescape(matched.group(1)) + else: + content = question + cleaned = " ".join(content.split()).strip() + if not cleaned: + return "(无文本内容)" + if len(cleaned) > 120: + return cleaned[:117].rstrip() + "..." + return cleaned + + def _is_end_only_tool_calls( + self, + tool_calls: list[dict[str, Any]], + api_to_internal: dict[str, str], + ) -> bool: + # 无 tool_calls 与有 tool_calls 走不同分支 + if not tool_calls: + return False + # 逐个处理模型返回的 tool_call + for tool_call in tool_calls: + function = tool_call.get("function", {}) + api_name = str(function.get("name", "") or "") + internal_name = api_to_internal.get(api_name, api_name) + if internal_name != "end": + return False + return True + + async def _save_forward_to_history( + self, + content: str, + pre_context: dict[str, Any], + history_manager: Any, + ) -> None: + """将合并转发消息写入历史记录""" + if history_manager is None: + return + + try: + group_id = pre_context.get("group_id") + user_id = pre_context.get("user_id") + + if group_id is not None: + await history_manager.add_group_message( + group_id=int(group_id), + sender_id=0, + text_content=content, + sender_card="", + sender_nickname="[合并转发内容]", + group_name="", + role="system", + title="", + message_id=None, + ) + elif user_id is not None: + await history_manager.add_private_message( + user_id=int(user_id), + text_content=content, + display_name="[合并转发内容]", + user_name="", + message_id=None, + ) + else: + logger.debug("[合并转发] 无法写入历史:缺少 group_id 和 user_id") + except Exception as exc: + logger.debug("[合并转发] 写入历史失败: %s", exc) diff --git a/src/Undefined/ai/llm.py b/src/Undefined/ai/llm.py index 03801a87..59411faf 100644 --- a/src/Undefined/ai/llm.py +++ b/src/Undefined/ai/llm.py @@ -415,7 +415,6 @@ def _sanitize_openai_tool_names_in_request( new_message = message - # 重写 role=tool 的 name 字段(可选字段)。 msg_name = message.get("name") if isinstance(msg_name, str) and msg_name: mapped = internal_to_api.get(msg_name) @@ -1164,7 +1163,6 @@ def _extract_thinking_content(result: dict[str, Any]) -> str: if thinking: return thinking - # 尝试从响应根对象中提取 return _extract_from_result(result) diff --git a/src/Undefined/ai/llm/__init__.py b/src/Undefined/ai/llm/__init__.py new file mode 100644 index 00000000..a8191f6a --- /dev/null +++ b/src/Undefined/ai/llm/__init__.py @@ -0,0 +1,22 @@ +"""LLM 模型请求子包。 + +对外稳定入口:``ModelRequester``、``build_request_body``、``ModelConfig``; +旧路径 ``Undefined.ai.llm`` 通过包根与 ``llm.py`` shim 保持兼容。 +""" + +from Undefined.ai.llm.requester import ModelRequester, build_request_body +from Undefined.ai.llm.sanitize import _encode_tool_name_for_api +from Undefined.ai.llm.streaming import should_fallback_from_stream +from Undefined.ai.llm.types import ModelConfig + +# 测试与内部调用沿用的私有符号别名(保持旧 import 路径可用) +_should_fallback_from_stream = should_fallback_from_stream + +# 子包公开 API 列表 +__all__ = [ + "ModelRequester", + "build_request_body", + "ModelConfig", + "_encode_tool_name_for_api", + "_should_fallback_from_stream", +] diff --git a/src/Undefined/ai/llm/requester.py b/src/Undefined/ai/llm/requester.py new file mode 100644 index 00000000..d5c025da --- /dev/null +++ b/src/Undefined/ai/llm/requester.py @@ -0,0 +1,1016 @@ +"""统一 LLM 模型请求封装与请求体构建。 + +``ModelRequester`` 负责 OpenAI 兼容 API 的 chat/responses/embed/rerank 调用、 +流式聚合与 token 用量记录;出站清洗与思维链提取委托 ``sanitize`` / ``thinking`` 子模块。 +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +import re +import time +from datetime import datetime +from typing import Any +from urllib.parse import parse_qsl, urlsplit, urlunsplit + +import httpx +from openai import ( + APIConnectionError, + APIStatusError, + APITimeoutError, + AsyncOpenAI, +) + +from Undefined.ai.llm.sanitize import ( + _tool_name_dot_delimiter, + desc_preview, + prepare_chat_completion_messages, + relocate_system_to_first_user, + sanitize_chat_completion_messages, + sanitize_openai_messages_tool_arguments, + sanitize_openai_tool_names_in_request, + sanitize_openai_tools, + tools_description_max_len, + tools_description_truncate_enabled, + tools_sanitize_verbose, +) +from Undefined.ai.llm.streaming import ( + aggregate_chat_completions_stream, + aggregate_responses_stream, + ensure_chat_stream_usage_options, + should_fallback_from_stream, + split_chat_completion_params, + split_responses_params, + without_stream_request_fields, +) +from Undefined.ai.llm.thinking import ( + extract_thinking_content, + normalize_thinking_override, +) +from Undefined.ai.llm.types import ModelConfig +from Undefined.ai.parsing import extract_choices_content +from Undefined.ai.retrieval import RetrievalRequester +from Undefined.ai.tokens import TokenCounter +from Undefined.ai.transports import ( + API_MODE_CHAT_COMPLETIONS, + API_MODE_RESPONSES, + build_responses_request_body, + get_api_mode, + get_effort_payload, + get_effort_style, + get_thinking_payload, + normalize_responses_result, +) +from Undefined.config import Config, EmbeddingModelConfig, RerankModelConfig, get_config +from Undefined.context import RequestContext +from Undefined.token_usage_storage import TokenUsage, TokenUsageStorage +from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.request_params import ( + merge_request_params, + split_reserved_request_params, +) + +logger = logging.getLogger(__name__) + +__all__ = ["ModelRequester", "build_request_body", "ModelConfig"] + +_SDK_REQUEST_OPTION_FIELDS: frozenset[str] = frozenset( + {"extra_headers", "extra_query", "extra_body", "timeout"} +) + +_CHAT_COMPLETIONS_RESERVED_FIELDS: frozenset[str] = ( + frozenset( + { + "model", + "messages", + "max_tokens", + "tools", + "tool_choice", + "stream", + "stream_options", + "thinking", + "reasoning", + "reasoning_effort", + "output_config", + } + ) + | _SDK_REQUEST_OPTION_FIELDS +) + +_RESPONSES_RESERVED_FIELDS: frozenset[str] = ( + frozenset( + { + "model", + "input", + "instructions", + "max_output_tokens", + "tools", + "tool_choice", + "previous_response_id", + "stream", + "stream_options", + "thinking", + "reasoning", + "reasoning_effort", + "output_config", + } + ) + | _SDK_REQUEST_OPTION_FIELDS +) + +_TOOLS_PARAM_INDEX_RE = re.compile(r"Tools\[(\d+)\]", re.IGNORECASE) +_RESPONSES_MISSING_TOOL_CALL_OUTPUT_RE = re.compile( + r"no tool call found for function call output with call_id", + re.IGNORECASE, +) + +_PROMPT_CACHE_KEY_MAX_LEN = 128 + + +def _get_runtime_config() -> Config | None: + try: + return get_config(strict=False) + except Exception: + return None + + +def _hash8(text: str) -> str: + return hashlib.sha1(text.encode("utf-8"), usedforsecurity=False).hexdigest()[:8] + + +def _normalize_prompt_cache_part(value: Any) -> str: + text = str(value or "").strip().lower() + if not text: + return "none" + normalized_chars: list[str] = [] + for char in text: + if char.isalnum() or char in {"-", "_", ":"}: + normalized_chars.append(char) + else: + normalized_chars.append("_") + normalized = "".join(normalized_chars).strip("_") + return normalized or "none" + + +def _build_scope_prompt_cache_part() -> str: + # prompt_cache_key 按会话 scope 隔离,避免群/私聊上下文串缓存 + ctx = RequestContext.current() + if ctx is None: + return "scope:global" + if ctx.group_id is not None: + return f"group:{int(ctx.group_id)}" + if ctx.user_id is not None: + return f"private:{int(ctx.user_id)}" + if ctx.sender_id is not None: + return f"sender:{int(ctx.sender_id)}" + request_type = _normalize_prompt_cache_part(ctx.request_type) + return f"type:{request_type}" + + +def _build_default_prompt_cache_key(model_config: ModelConfig, call_type: str) -> str: + model_name = _normalize_prompt_cache_part(getattr(model_config, "model_name", "")) + scope_part = _build_scope_prompt_cache_part() + call_part = _normalize_prompt_cache_part(call_type) + key = f"pc:{model_name}:{call_part}:{scope_part}" + if len(key) <= _PROMPT_CACHE_KEY_MAX_LEN: + return key + suffix = "_" + _hash8(key) + prefix_len = max(1, _PROMPT_CACHE_KEY_MAX_LEN - len(suffix)) + return key[:prefix_len] + suffix + + +def _responses_should_fallback_to_stateless_replay( + exc: APIStatusError, + request_body: dict[str, Any], + *, + stateless_replay: bool, +) -> bool: + # 仅当续轮携带 function_call_output 且服务端报 call_id 不匹配时才降级 + if stateless_replay or not request_body.get("previous_response_id"): + return False + input_items = request_body.get("input") + if not isinstance(input_items, list) or not any( + isinstance(item, dict) and item.get("type") == "function_call_output" + for item in input_items + ): + return False + if exc.status_code != 400 or not isinstance(exc.body, dict): + return False + error = exc.body.get("error") + if not isinstance(error, dict): + return False + message = str(error.get("message", "")).strip() + param = str(error.get("param", "")).strip().lower() + return param == "input" and bool( + _RESPONSES_MISSING_TOOL_CALL_OUTPUT_RE.search(message) + ) + + +def _normalize_openai_base_url( + api_url: str, +) -> tuple[str, dict[str, object] | None, bool]: + """将旧式 /chat/completions URL 归一化为 OpenAI SDK 需要的 base_url。 + + 兼容策略(B):如果发现 api_url 末尾包含 /chat/completions,则自动裁剪为 base_url, + 以便统一走 OpenAI SDK,并给出弃用警告。 + """ + try: + parts = urlsplit(api_url) + except Exception: + return api_url, None, False + + path = parts.path or "" + trimmed_path = path.rstrip("/") + suffix = "/chat/completions" + if not trimmed_path.endswith(suffix): + return api_url, None, False + + new_path = trimmed_path[: -len(suffix)] + default_query: dict[str, object] | None = None + if parts.query: + default_query = { + k: v for k, v in parse_qsl(parts.query, keep_blank_values=True) + } + normalized = urlunsplit(parts._replace(path=new_path, query="", fragment="")) + return normalized, default_query, True + + +def _warn_ignored_request_params( + *, + call_type: str, + model_name: str, + ignored: dict[str, Any], +) -> None: + if not ignored: + return + logger.warning( + "[request_params] ignored_keys=%s type=%s model=%s", + ",".join(sorted(ignored)), + call_type, + model_name, + ) + + +def _build_effective_request_kwargs( + model_config: ModelConfig, + *, + call_type: str, + overrides: dict[str, Any], +) -> dict[str, Any]: + merged = merge_request_params( + getattr(model_config, "request_params", {}), + overrides, + ) + thinking_override = overrides["thinking"] if "thinking" in overrides else None + has_thinking_override = "thinking" in overrides + reserved_fields = ( + _RESPONSES_RESERVED_FIELDS + if get_api_mode(model_config) == API_MODE_RESPONSES + else _CHAT_COMPLETIONS_RESERVED_FIELDS + ) + allowed, ignored = split_reserved_request_params( + merged, + reserved_fields, + ) + if has_thinking_override: + ignored.pop("thinking", None) + _warn_ignored_request_params( + call_type=call_type, + model_name=model_config.model_name, + ignored=ignored, + ) + if has_thinking_override: + allowed["thinking"] = thinking_override + return allowed + + +class ModelRequester: + """统一的模型请求封装。""" + + def __init__( + self, + http_client: httpx.AsyncClient, + token_usage_storage: TokenUsageStorage, + ) -> None: + self._http_client = http_client + self._token_usage_storage = token_usage_storage + self._openai_clients: dict[ + tuple[str, str, tuple[tuple[str, str], ...] | None], AsyncOpenAI + ] = {} + self._token_counters: dict[str, TokenCounter] = {} + self._warned_legacy_api_urls: set[str] = set() + self._background_tasks: set[asyncio.Task[Any]] = set() + self._retrieval_requester = RetrievalRequester( + get_openai_client=self._get_openai_client_for_model, + response_to_dict=self._response_to_dict, + get_token_counter=self._get_token_counter, + record_usage=self._record_usage, + ) + + async def request( + self, + model_config: ModelConfig, + messages: list[dict[str, Any]], + max_tokens: int = 8192, + call_type: str = "chat", + tools: list[dict[str, Any]] | None = None, + tool_choice: str = "auto", + transport_state: dict[str, Any] | None = None, + message_count_for_transport: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """发送请求到模型 API。""" + start_time = time.perf_counter() + cot_compat = getattr(model_config, "thinking_tool_call_compat", False) + reasoning_replay = bool( + getattr(model_config, "reasoning_content_replay", False) + ) + api_mode = get_api_mode(model_config) + transport_message_count = ( + message_count_for_transport + if message_count_for_transport is not None + else len(messages) + ) + messages_for_api, tool_args_fixed = sanitize_openai_messages_tool_arguments( + messages + ) + if tool_args_fixed and logger.isEnabledFor(logging.INFO): + logger.info( + "[messages.sanitize] tool_args_fixed=%s messages=%s", + tool_args_fixed, + len(messages_for_api), + ) + if api_mode == API_MODE_CHAT_COMPLETIONS: + ( + messages_for_api, + stripped_message_count, + stripped_message_fields, + ) = sanitize_chat_completion_messages( + messages_for_api, + preserve_reasoning_content=reasoning_replay, + ) + if bool(getattr(model_config, "system_prompt_as_user", False)): + messages_for_api = relocate_system_to_first_user(messages_for_api) + if stripped_message_count and logger.isEnabledFor(logging.INFO): + details = ",".join( + f"{key}={value}" + for key, value in sorted(stripped_message_fields.items()) + ) + logger.info( + "[chat_completions.standardize] stripped_internal_message_fields=%s messages=%s", + details, + stripped_message_count, + ) + + tools_for_api = tools + api_to_internal: dict[str, str] = {} + internal_to_api: dict[str, str] = {} + if isinstance(tools_for_api, list): + request_for_sanitize = { + "messages": messages_for_api, + "tools": list(tools_for_api), + } + api_to_internal, internal_to_api = sanitize_openai_tool_names_in_request( + request_for_sanitize + ) + raw_messages = request_for_sanitize.get("messages") + if isinstance(raw_messages, list): + messages_for_api = raw_messages + raw_tools = request_for_sanitize.get("tools") + if isinstance(raw_tools, list): + tools_for_api = raw_tools + + if isinstance(tools_for_api, list): + sanitized_tools, changed_count, changes = sanitize_openai_tools( + tools_for_api + ) + tools_for_api = sanitized_tools + if changed_count and logger.isEnabledFor(logging.INFO): + logger.info( + "[tools.sanitize] changed=%s total=%s truncate_enabled=%s max_desc_len=%s", + changed_count, + len(sanitized_tools), + tools_description_truncate_enabled(), + tools_description_max_len(), + ) + if tools_sanitize_verbose(): + for change in changes: + logger.info( + "[tools.sanitize.item] index=%s name=%s reasons=%s old_len=%s new_len=%s old=%s new=%s", + change.get("index"), + change.get("name"), + ",".join(change.get("reasons", [])), + change.get("old_len"), + change.get("new_len"), + change.get("old_preview"), + change.get("new_preview"), + ) + + effective_kwargs = _build_effective_request_kwargs( + model_config, + call_type=call_type, + overrides=dict(kwargs), + ) + if bool( + getattr(model_config, "prompt_cache_enabled", True) + # ) and not effective_kwargs.get("prompt_cache_key"): + ) and not effective_kwargs.get("prompt_cache_key"): + effective_kwargs["prompt_cache_key"] = _build_default_prompt_cache_key( + model_config, + call_type, + ) + responses_stateless_replay = bool( + getattr(model_config, "responses_force_stateless_replay", False) + ) or bool( + isinstance(transport_state, dict) + and transport_state.get("stateless_replay") + ) + effective_transport_state: dict[str, Any] | None + if responses_stateless_replay: + effective_transport_state = dict(transport_state or {}) + effective_transport_state["stateless_replay"] = True + else: + effective_transport_state = transport_state + request_body = build_request_body( + model_config=model_config, + messages=messages_for_api, + max_tokens=max_tokens, + tools=tools_for_api, + tool_choice=tool_choice, + internal_to_api=internal_to_api, + transport_state=effective_transport_state, + **effective_kwargs, + ) + + try: + if cot_compat and logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[思维链兼容] enabled=%s type=%s model=%s api_mode=%s thinking_enabled=%s tools=%s messages=%s", + cot_compat, + call_type, + model_config.model_name, + api_mode, + getattr(model_config, "thinking_enabled", False), + bool(tools), + len(messages), + ) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "[API请求] type=%s model=%s api_mode=%s url=%s max_tokens=%s tools=%s tool_choice=%s messages=%s", + call_type, + model_config.model_name, + api_mode, + model_config.api_url, + max_tokens, + bool(tools_for_api), + tool_choice, + len(messages), + ) + log_debug_json(logger, "[API请求体]", request_body) + + try: + raw_result = await self._request_with_openai(model_config, request_body) + except APIStatusError as exc: + # Responses 续轮失败:自动切换 stateless replay 重发全量 input + if ( + api_mode == API_MODE_RESPONSES + and _responses_should_fallback_to_stateless_replay( + exc, + request_body, + stateless_replay=responses_stateless_replay, + ) + ): + logger.warning( + "[responses.compat] previous_response_id 续轮失败,自动降级为 stateless replay: model=%s call_type=%s previous_response_id=%s", + model_config.model_name, + call_type, + request_body.get("previous_response_id", ""), + ) + effective_transport_state = dict(effective_transport_state or {}) + effective_transport_state["stateless_replay"] = True + responses_stateless_replay = True + request_body = build_request_body( + model_config=model_config, + messages=messages_for_api, + max_tokens=max_tokens, + tools=tools_for_api, + tool_choice=tool_choice, + internal_to_api=internal_to_api, + transport_state=effective_transport_state, + **effective_kwargs, + ) + if logger.isEnabledFor(logging.DEBUG): + log_debug_json( + logger, "[API请求体][stateless replay]", request_body + ) + raw_result = await self._request_with_openai( + model_config, request_body + ) + else: + raise + if api_mode == API_MODE_RESPONSES: + result = normalize_responses_result( + raw_result, + api_to_internal if api_to_internal else None, + ) + response_id = str( + raw_result.get("id") or result.get("id") or "" + ).strip() + if response_id: + choice = result.get("choices", [{}])[0] + message = ( + choice.get("message", {}) if isinstance(choice, dict) else {} + ) + tool_calls = ( + message.get("tool_calls", []) + if isinstance(message, dict) + else [] + ) + # 记录续轮锚点:下一轮只发送 tool_result 及之后的消息 + result["_transport_state"] = { + "api_mode": api_mode, + "previous_response_id": response_id, + "tool_result_start_index": transport_message_count + + (1 if tool_calls else 0), + } + if responses_stateless_replay: + result["_transport_state"]["stateless_replay"] = True + else: + result = self._normalize_result(raw_result) + if api_to_internal: + result["_tool_name_map"] = { + "api_to_internal": api_to_internal, + "internal_to_api": internal_to_api, + "dot_delimiter": _tool_name_dot_delimiter(), + } + duration = time.perf_counter() - start_time + + usage = result.get("usage", {}) or {} + prompt_tokens = int(usage.get("prompt_tokens", 0) or 0) + completion_tokens = int(usage.get("completion_tokens", 0) or 0) + total_tokens = int(usage.get("total_tokens", 0) or 0) + if total_tokens == 0 and (prompt_tokens or completion_tokens): + total_tokens = prompt_tokens + completion_tokens + if total_tokens == 0: + prompt_tokens, completion_tokens, total_tokens = self._estimate_usage( + model_config.model_name, messages_for_api, result + ) + + logger.info( + f"[API响应] {call_type} 完成: 耗时={duration:.2f}s, " + f"Tokens={total_tokens} (P:{prompt_tokens} + C:{completion_tokens}), " + f"模型={model_config.model_name}" + ) + + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[API响应体]", result) + + self._maybe_log_thinking(result, call_type, model_config.model_name) + + self._record_usage( + model_name=model_config.model_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + duration_seconds=duration, + call_type=call_type, + ) + + return result + except APIStatusError as exc: + response = exc.response + try: + body = ( + json.dumps(exc.body, ensure_ascii=False, default=str) + if exc.body is not None + else "" + ) + except Exception: + body = str(exc.body) + if ( + exc.status_code == 400 + and isinstance(exc.body, dict) + and isinstance(exc.body.get("error"), dict) + ): + param = exc.body.get("error", {}).get("param") + if isinstance(param, str): + match = _TOOLS_PARAM_INDEX_RE.search(param) + if match and isinstance(request_body.get("tools"), list): + try: + idx = int(match.group(1)) + except ValueError: + idx = -1 + if 0 <= idx < len(request_body["tools"]): + tool = request_body["tools"][idx] + tool_name = ( + tool.get("function", {}).get("name") + if isinstance(tool, dict) + else "" + ) + desc_len: int | None = None + desc_preview_text = "" + if isinstance(tool, dict): + function = tool.get("function", {}) + if isinstance(function, dict): + desc = function.get("description") + if desc is not None: + desc_str = ( + desc if isinstance(desc, str) else str(desc) + ) + desc_len = len(desc_str) + desc_preview_text = desc_preview(desc_str) + logger.error( + "[tools.invalid] index=%s name=%s desc_len=%s desc=%s param=%s", + idx, + tool_name, + desc_len, + desc_preview_text, + param, + ) + logger.error( + "[API响应错误] status=%s request_id=%s url=%s body=%s", + exc.status_code, + exc.request_id or "", + response.request.url, + redact_string(body), + ) + raise + except (APIConnectionError, APITimeoutError) as exc: + logger.error("[API连接错误] type=%s message=%s", type(exc).__name__, exc) + raise + except Exception as exc: + logger.exception(f"[model.request.error] {call_type} 调用失败: {exc}") + raise + + def _thinking_logging_enabled(self) -> bool: + runtime_config = _get_runtime_config() + if runtime_config is None: + return True + return bool(runtime_config.log_thinking) + + def _maybe_log_thinking( + self, result: dict[str, Any], call_type: str, model_name: str + ) -> None: + if not self._thinking_logging_enabled(): + return + thinking = extract_thinking_content(result) + if thinking: + logger.info( + "[思维链] type=%s model=%s content=%s", + call_type, + model_name, + redact_string(thinking), + ) + + async def _request_with_openai( + self, model_config: ModelConfig, request_body: dict[str, Any] + ) -> dict[str, Any]: + client = self._get_openai_client_for_model(model_config) + if bool(getattr(model_config, "stream_enabled", False)): + try: + return await self._request_with_openai_streaming( + # client, model_config, request_body + client, + model_config, + request_body, + ) + except Exception as exc: + # 上游不支持流式时,剥离 stream 字段后降级为非流式重试 + if not should_fallback_from_stream(exc): + raise + logger.warning( + "[API流式回退] model=%s api_mode=%s reason=%s", + getattr(model_config, "model_name", ""), + get_api_mode(model_config), + type(exc).__name__, + ) + request_body = without_stream_request_fields(request_body) + if get_api_mode(model_config) == API_MODE_RESPONSES: + params, extra_body = split_responses_params(request_body) + if extra_body: + params["extra_body"] = extra_body + response = await client.responses.create(**params) + return self._response_to_dict(response) + params, extra_body = split_chat_completion_params(request_body) + if extra_body: + params["extra_body"] = extra_body + response = await client.chat.completions.create(**params) + return self._response_to_dict(response) + + async def _request_with_openai_streaming( + self, + client: AsyncOpenAI, + model_config: ModelConfig, + request_body: dict[str, Any], + ) -> dict[str, Any]: + api_mode = get_api_mode(model_config) + stream_body = dict(request_body) + stream_body["stream"] = True + if api_mode == API_MODE_RESPONSES: + return await self._stream_responses_request(client, stream_body) + ensure_chat_stream_usage_options(stream_body) + return await self._stream_chat_completions_request( + # client, stream_body, model_config + client, + stream_body, + model_config, + ) + + async def _stream_chat_completions_request( + self, + client: AsyncOpenAI, + request_body: dict[str, Any], + model_config: ModelConfig, + ) -> dict[str, Any]: + params, extra_body = split_chat_completion_params(request_body) + if extra_body: + params["extra_body"] = extra_body + response = await client.chat.completions.create(**params) + + reasoning_replay = bool( + getattr(model_config, "reasoning_content_replay", False) + ) + chunks: list[dict[str, Any]] = [] + async for chunk in response: + chunks.append(self._response_to_dict(chunk)) + return aggregate_chat_completions_stream( + chunks, + reasoning_replay=reasoning_replay, + ) + + async def _stream_responses_request( + self, client: AsyncOpenAI, request_body: dict[str, Any] + ) -> dict[str, Any]: + params, extra_body = split_responses_params(request_body) + if extra_body: + params["extra_body"] = extra_body + stream = await client.responses.create(**params) + + events: list[dict[str, Any]] = [] + async for event in stream: + events.append(self._response_to_dict(event)) + return aggregate_responses_stream(events) + + async def embed( + self, + model_config: EmbeddingModelConfig, + texts: list[str], + ) -> list[list[float]]: + """调用统一检索请求层的 embeddings。""" + return await self._retrieval_requester.embed(model_config, texts) + + async def rerank( + self, + model_config: RerankModelConfig, + query: str, + documents: list[str], + top_n: int | None = None, + ) -> list[dict[str, Any]]: + """调用统一检索请求层的 rerank。""" + return await self._retrieval_requester.rerank( + model_config=model_config, + query=query, + documents=documents, + top_n=top_n, + ) + + def _get_openai_client_for_model(self, model_config: ModelConfig) -> AsyncOpenAI: + base_url, default_query, changed = _normalize_openai_base_url( + model_config.api_url + ) + if changed and model_config.api_url not in self._warned_legacy_api_urls: + self._warned_legacy_api_urls.add(model_config.api_url) + logger.warning( + "[配置弃用] 检测到 *_MODEL_API_URL 末尾包含 /chat/completions,这种写法已弃用;" + "已自动裁剪为 base_url=%s(原值=%s)。", + base_url, + model_config.api_url, + ) + return self._get_openai_client( + base_url=base_url, + api_key=model_config.api_key, + default_query=default_query, + ) + + def _record_usage( + self, + *, + model_name: str, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + duration_seconds: float, + call_type: str, + ) -> None: + task = asyncio.create_task( + self._token_usage_storage.record( + TokenUsage( + timestamp=datetime.now().isoformat(), + model_name=model_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + duration_seconds=duration_seconds, + call_type=call_type, + success=True, + ) + ) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + def _get_openai_client( + self, base_url: str, api_key: str, default_query: dict[str, object] | None + ) -> AsyncOpenAI: + query_key = None + if default_query: + query_key = tuple( + sorted((str(k), str(v)) for k, v in default_query.items()) + ) + cache_key = (base_url, api_key, query_key) + client = self._openai_clients.get(cache_key) + if client is not None: + return client + # 复用上层注入的 httpx client(连接池/超时等),避免每个 OpenAI client 自建连接池。 + client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + timeout=480.0, + default_query=default_query, + http_client=self._http_client, + ) + self._openai_clients[cache_key] = client + return client + + def _response_to_dict(self, response: Any) -> dict[str, Any]: + if isinstance(response, dict): + return response + for attr in ("model_dump", "to_dict", "dict"): + method = getattr(response, attr, None) + if callable(method): + try: + value = method() + if isinstance(value, dict): + return value + except Exception: + continue + to_json = getattr(response, "to_json", None) + if callable(to_json): + try: + raw_json = to_json() + loaded = json.loads(str(raw_json)) + if isinstance(loaded, dict): + return loaded + except Exception: + pass + return {"data": str(response)} + + def _normalize_result(self, result: dict[str, Any]) -> dict[str, Any]: + choices = result.get("choices") + if isinstance(choices, list): + return result + data = result.get("data") + if isinstance(data, dict): + data_choices = data.get("choices") + if isinstance(data_choices, list): + normalized = dict(result) + normalized["choices"] = data_choices + return normalized + normalized = dict(result) + normalized["choices"] = [{}] + return normalized + + def _get_token_counter(self, model_name: str) -> TokenCounter: + counter = self._token_counters.get(model_name) + if counter is None: + counter = TokenCounter(model_name) + self._token_counters[model_name] = counter + return counter + + def _estimate_usage( + self, + model_name: str, + messages: list[dict[str, Any]], + result: dict[str, Any], + ) -> tuple[int, int, int]: + counter = self._get_token_counter(model_name) + try: + prompt_text = "\n".join( + json.dumps(message, ensure_ascii=False, default=str) + for message in messages + ) + except Exception: + prompt_text = str(messages) + prompt_tokens = counter.count(prompt_text) + + completion_text = "" + try: + completion_text = extract_choices_content(result) + except Exception: + completion_text = "" + if not completion_text: + choices = result.get("choices") + if isinstance(choices, list) and choices: + choice = choices[0] + if isinstance(choice, dict): + message = choice.get("message", {}) + tool_calls = ( + message.get("tool_calls") + if isinstance(message, dict) + else choice.get("tool_calls") + ) + # 无 tool_calls 与有 tool_calls 走不同分支 + if tool_calls: + try: + completion_text = json.dumps( + tool_calls, ensure_ascii=False, default=str + ) + except Exception: + completion_text = str(tool_calls) + completion_tokens = counter.count(completion_text) if completion_text else 0 + total_tokens = prompt_tokens + completion_tokens + logger.debug( + "[API响应] usage 缺失,估算 tokens: prompt=%s completion=%s total=%s", + prompt_tokens, + completion_tokens, + total_tokens, + ) + return prompt_tokens, completion_tokens, total_tokens + + +def build_request_body( + model_config: ModelConfig, + messages: list[dict[str, Any]], + max_tokens: int, + tools: list[dict[str, Any]] | None = None, + tool_choice: str = "auto", + internal_to_api: dict[str, str] | None = None, + transport_state: dict[str, Any] | None = None, + **kwargs: Any, +) -> dict[str, Any]: + """构建 API 请求体。""" + api_mode = get_api_mode(model_config) + extra_kwargs: dict[str, Any] = dict(kwargs) + + if "thinking" in extra_kwargs: + normalized = normalize_thinking_override( + extra_kwargs.get("thinking"), model_config + ) + if normalized is None: + extra_kwargs.pop("thinking", None) + else: + extra_kwargs["thinking"] = normalized + + if api_mode == API_MODE_RESPONSES: + extra_kwargs.pop("reasoning", None) + extra_kwargs.pop("reasoning_effort", None) + extra_kwargs.pop("output_config", None) + return build_responses_request_body( + model_config, + messages, + max_tokens, + tools=tools, + tool_choice=tool_choice, + extra_kwargs=extra_kwargs, + internal_to_api=internal_to_api or {}, + transport_state=transport_state, + ) + + body: dict[str, Any] = { + "model": model_config.model_name, + "messages": prepare_chat_completion_messages(model_config, messages), + "max_tokens": max_tokens, + } + + extra_kwargs.pop("reasoning", None) + extra_kwargs.pop("reasoning_effort", None) + extra_kwargs.pop("output_config", None) + + thinking = get_thinking_payload(model_config) + if thinking is not None: + body["thinking"] = thinking + + effort_payload = get_effort_payload(model_config) + if effort_payload is not None: + style = get_effort_style(model_config) + # Anthropic 风格走 output_config,OpenAI 风格走 reasoning_effort + if style == "anthropic": + body["output_config"] = effort_payload + else: + body["reasoning_effort"] = effort_payload["effort"] + + if tools: + body["tools"] = tools + thinking_active = "thinking" in body + # 部分 thinking 模型不接受 dict 形 tool_choice,强制降为 auto + if thinking_active and isinstance(tool_choice, dict): + body["tool_choice"] = "auto" + else: + body["tool_choice"] = tool_choice + + body.update(extra_kwargs) + return body diff --git a/src/Undefined/ai/llm/sanitize.py b/src/Undefined/ai/llm/sanitize.py new file mode 100644 index 00000000..ec549b1c --- /dev/null +++ b/src/Undefined/ai/llm/sanitize.py @@ -0,0 +1,558 @@ +"""LLM 出站请求清洗与工具名规范化。 + +负责工具 schema/description 清洗、历史消息字段剥离、工具名 API 编码; +不发起 HTTP 请求,也不解析模型响应。 +""" + +from __future__ import annotations + +import hashlib +import logging +import re +from typing import Any + +from Undefined.ai.llm.types import ModelConfig +from Undefined.config import Config, get_config +from Undefined.utils.tool_calls import normalize_tool_arguments_json + +logger = logging.getLogger(__name__) + +_DEFAULT_TOOLS_DESCRIPTION_MAX_LEN = 1024 +_DEFAULT_TOOLS_DESCRIPTION_PREVIEW_LEN = 160 + +_DEFAULT_TOOL_NAME_DOT_DELIMITER = "-_-" +_TOOL_NAME_MAX_LEN = 64 +_TOOL_NAME_ALLOWED_RE = re.compile(r"^[a-zA-Z0-9_-]+$") + +_CHAT_COMPLETION_STRIP_THINKING_KEYS: frozenset[str] = frozenset( + ("thinking", "reasoning", "chain_of_thought", "cot", "thoughts") +) +CHAT_COMPLETION_INTERNAL_MESSAGE_KEYS: frozenset[str] = frozenset( + ( + "reasoning_content", + *_CHAT_COMPLETION_STRIP_THINKING_KEYS, + "_responses_output_items", + "phase", + ) +) + + +def _get_runtime_config() -> Config | None: + try: + return get_config(strict=False) + except Exception: + return None + + +def _tool_name_dot_delimiter() -> str: + runtime_config = _get_runtime_config() + value = ( + getattr(runtime_config, "tools_dot_delimiter", None) if runtime_config else None + ) + text = str(value).strip() if value is not None else _DEFAULT_TOOL_NAME_DOT_DELIMITER + if not text: + return _DEFAULT_TOOL_NAME_DOT_DELIMITER + if "." in text: + return _DEFAULT_TOOL_NAME_DOT_DELIMITER + if not _TOOL_NAME_ALLOWED_RE.match(text): + return _DEFAULT_TOOL_NAME_DOT_DELIMITER + # 保持较短长度,避免工具名被服务端截断。 + if len(text) > 16: + return text[:16] + return text + + +def _hash8(text: str) -> str: + return hashlib.sha1(text.encode("utf-8"), usedforsecurity=False).hexdigest()[:8] + + +def _encode_tool_name_for_api(tool_name: str) -> str: + """将内部工具名编码为服务端可接受的 function.name。 + + - 将 '.' 替换为 '-_-'(保留工具集命名语义) + - 其他不允许字符替换为 '_' + - 强制最大长度(<=64),超长时追加稳定哈希 + """ + raw = str(tool_name or "").strip() + if not raw: + return "tool" + + # 保留工具集分隔语义:category.tool -> categorytool + encoded = raw.replace(".", _tool_name_dot_delimiter()) + + # 替换其他不允许字符。 + cleaned_chars: list[str] = [] + for ch in encoded: + if ch.isalnum() or ch in {"_", "-"}: + cleaned_chars.append(ch) + else: + cleaned_chars.append("_") + encoded = "".join(cleaned_chars) + + if not encoded: + encoded = "tool" + + if len(encoded) > _TOOL_NAME_MAX_LEN: + suffix = "_" + _hash8(raw) + prefix_len = max(1, _TOOL_NAME_MAX_LEN - len(suffix)) + encoded = encoded[:prefix_len] + suffix + + # 最后兜底校验(理论上应始终通过) + if not _TOOL_NAME_ALLOWED_RE.match(encoded): + suffix = "_" + _hash8(raw) + encoded = re.sub(r"[^a-zA-Z0-9_-]", "_", encoded) + if len(encoded) > _TOOL_NAME_MAX_LEN: + encoded = encoded[: _TOOL_NAME_MAX_LEN - len(suffix)] + suffix + if not encoded: + encoded = "tool" + suffix + + return encoded + + +def sanitize_openai_tool_names_in_request( + request_body: dict[str, Any], +) -> tuple[dict[str, str], dict[str, str]]: + """将 request_body 的 tools/messages 工具名改写为服务端可接受的名称。 + + Returns: + (api_to_internal, internal_to_api) 映射表。 + + Notes: + - 仅保证 tools schema 中出现的名称可逆映射。 + - 历史消息中的工具调用会尽力重写。 + """ + tools = request_body.get("tools") + if not isinstance(tools, list) or not tools: + return {}, {} + + internal_to_api: dict[str, str] = {} + api_to_internal: dict[str, str] = {} + used_api: set[str] = set() + + new_tools: list[dict[str, Any]] = [] + for tool in tools: + if not isinstance(tool, dict): + new_tools.append(tool) + continue + function = tool.get("function") + if not isinstance(function, dict): + new_tools.append(tool) + continue + internal_name = str(function.get("name", "") or "") + if not internal_name: + new_tools.append(tool) + continue + + # 稳定编码;如发生冲突则追加后缀。 + base_api_name = _encode_tool_name_for_api(internal_name) + api_name = base_api_name + if api_name in used_api and api_to_internal.get(api_name) != internal_name: + suffix = "_" + _hash8(internal_name) + prefix_len = max(1, _TOOL_NAME_MAX_LEN - len(suffix)) + api_name = base_api_name[:prefix_len] + suffix + if api_name in used_api and api_to_internal.get(api_name) != internal_name: + # 极少数冲突兜底:加入索引避免重复。 + suffix = "_" + _hash8(f"{internal_name}:{len(used_api)}") + prefix_len = max(1, _TOOL_NAME_MAX_LEN - len(suffix)) + api_name = base_api_name[:prefix_len] + suffix + + used_api.add(api_name) + internal_to_api[internal_name] = api_name + api_to_internal[api_name] = internal_name + + if api_name != internal_name: + tool = dict(tool) + function = dict(function) + function["name"] = api_name + tool["function"] = function + new_tools.append(tool) + + request_body["tools"] = new_tools + + # 尽力重写历史消息中的工具名。 + messages = request_body.get("messages") + if isinstance(messages, list) and messages: + new_messages: list[dict[str, Any]] = [] + changed = False + for message in messages: + if not isinstance(message, dict): + new_messages.append(message) + continue + + new_message = message + + msg_name = message.get("name") + if isinstance(msg_name, str) and msg_name: + mapped = internal_to_api.get(msg_name) + if mapped and mapped != msg_name: + if new_message is message: + new_message = dict(message) + new_message["name"] = mapped + changed = True + elif (not _TOOL_NAME_ALLOWED_RE.match(msg_name)) or ( + len(msg_name) > _TOOL_NAME_MAX_LEN + ): + # 即便名称不在 schema 映射中,也尽量保证请求合法(如工具被重命名/移除)。 + safe = _encode_tool_name_for_api(msg_name) + if safe != msg_name: + if new_message is message: + new_message = dict(message) + new_message["name"] = safe + changed = True + + tool_calls = message.get("tool_calls") + # 无 tool_calls 与有 tool_calls 走不同分支 + if isinstance(tool_calls, list) and tool_calls: + new_tool_calls: list[Any] = [] + tool_calls_changed = False + # 逐个处理模型返回的 tool_call + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + new_tool_calls.append(tool_call) + continue + function = tool_call.get("function") + if not isinstance(function, dict): + new_tool_calls.append(tool_call) + continue + fname = function.get("name") + if not isinstance(fname, str) or not fname: + new_tool_calls.append(tool_call) + continue + mapped = internal_to_api.get(fname) + safe_name = mapped or _encode_tool_name_for_api(fname) + if safe_name != fname: + tool_calls_changed = True + new_tool_call = dict(tool_call) + new_function = dict(function) + new_function["name"] = safe_name + new_tool_call["function"] = new_function + new_tool_calls.append(new_tool_call) + else: + new_tool_calls.append(tool_call) + + # 无 tool_calls 与有 tool_calls 走不同分支 + if tool_calls_changed: + if new_message is message: + new_message = dict(message) + new_message["tool_calls"] = new_tool_calls + changed = True + + new_messages.append(new_message) + + if changed: + request_body["messages"] = new_messages + + return api_to_internal, internal_to_api + + +def _tools_sanitize_enabled() -> bool: + # 历史配置项 tools.sanitize 已迁移为 tools.dot_delimiter。 + # 为兼容严格网关,description 的 schema 清洗默认始终开启。 + return True + + +def tools_sanitize_verbose() -> bool: + """是否输出工具 schema 清洗的详细日志。""" + runtime_config = _get_runtime_config() + if runtime_config is not None: + return bool(runtime_config.tools_sanitize_verbose) + return False + + +def tools_description_max_len() -> int: + """返回工具 description 允许的最大长度。""" + runtime_config = _get_runtime_config() + if runtime_config is None: + return _DEFAULT_TOOLS_DESCRIPTION_MAX_LEN + value = runtime_config.tools_description_max_len + return value if value > 0 else _DEFAULT_TOOLS_DESCRIPTION_MAX_LEN + + +def tools_description_truncate_enabled() -> bool: + """是否启用工具 description 截断。""" + runtime_config = _get_runtime_config() + if runtime_config is None: + return False + return bool(runtime_config.tools_description_truncate_enabled) + + +def _clean_control_chars(text: str) -> str: + """将 ASCII 控制字符替换为空格。""" + return "".join(" " if ord(ch) < 32 or ord(ch) == 127 else ch for ch in text) + + +def desc_preview(text: str) -> str: + """生成工具 description 的日志预览片段。""" + runtime_config = _get_runtime_config() + if runtime_config is None: + preview_len = _DEFAULT_TOOLS_DESCRIPTION_PREVIEW_LEN + else: + preview_len = runtime_config.tools_description_preview_len + if preview_len <= 0: + preview_len = _DEFAULT_TOOLS_DESCRIPTION_PREVIEW_LEN + return text[:preview_len] + ("…" if len(text) > preview_len else "") + + +def _normalize_tool_description( + description: Any, + tool_name: str, + max_len: int, + truncate_enabled: bool, +) -> str: + """规范化工具 function.description,适配更严格的 OpenAI 兼容服务。""" + if description is None: + normalized = "" + elif isinstance(description, str): + normalized = description + else: + normalized = str(description) + + normalized = _clean_control_chars(normalized) + normalized = " ".join(normalized.split()) + normalized = normalized.strip() + if not normalized: + normalized = f"Tool function {tool_name}" + if truncate_enabled and len(normalized) > max_len: + normalized = normalized[:max_len].rstrip() + return normalized + + +def sanitize_openai_tools( + tools: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], int, list[dict[str, Any]]]: + """清洗 tools schema,避免严格网关因非法 description 返回 400。""" + if not tools or not _tools_sanitize_enabled(): + return tools, 0, [] + + max_len = tools_description_max_len() + truncate_enabled = tools_description_truncate_enabled() + changed = 0 + changes: list[dict[str, Any]] = [] + sanitized: list[dict[str, Any]] = [] + for idx, tool in enumerate(tools): + if not isinstance(tool, dict): + sanitized.append(tool) + continue + function = tool.get("function") + if not isinstance(function, dict): + sanitized.append(tool) + continue + name = function.get("name", "") + old_desc = function.get("description") + old_desc_str = ( + "" + if old_desc is None + else (old_desc if isinstance(old_desc, str) else str(old_desc)) + ) + new_desc = _normalize_tool_description( + old_desc, + str(name), + max_len, + truncate_enabled, + ) + + if old_desc_str != new_desc: + reasons: list[str] = [] + if not isinstance(old_desc, str): + reasons.append("non_string") + if any(ord(ch) < 32 or ord(ch) == 127 for ch in old_desc_str): + reasons.append("control_chars") + if "\n" in old_desc_str or "\r" in old_desc_str or "\t" in old_desc_str: + reasons.append("whitespace") + if not old_desc_str.strip(): + reasons.append("empty") + if ( + truncate_enabled + and len(new_desc) >= max_len + and len(old_desc_str) > len(new_desc) + ): + reasons.append("truncated") + + tool = dict(tool) + function = dict(function) + function["description"] = new_desc + tool["function"] = function + changed += 1 + changes.append( + { + "index": idx, + "name": str(name), + "old_len": len(old_desc_str), + "new_len": len(new_desc), + "old_preview": desc_preview(_clean_control_chars(old_desc_str)), + "new_preview": desc_preview(new_desc), + "reasons": reasons, + } + ) + sanitized.append(tool) + return sanitized, changed, changes + + +def sanitize_openai_messages_tool_arguments( + messages: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], int]: + """将 messages[].tool_calls[].function.arguments 规范为严格 JSON 字符串。""" + if not messages: + return messages, 0 + + changed = 0 + sanitized_messages: list[dict[str, Any]] = [] + for message in messages: + if not isinstance(message, dict): + sanitized_messages.append(message) + continue + + tool_calls = message.get("tool_calls") + # 无 tool_calls 与有 tool_calls 走不同分支 + if not isinstance(tool_calls, list) or not tool_calls: + sanitized_messages.append(message) + continue + + tool_calls_changed = False + sanitized_tool_calls: list[Any] = [] + # 逐个处理模型返回的 tool_call + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + sanitized_tool_calls.append(tool_call) + continue + function = tool_call.get("function") + if not isinstance(function, dict): + sanitized_tool_calls.append(tool_call) + continue + + old_args = function.get("arguments") + new_args = normalize_tool_arguments_json(old_args) + if isinstance(old_args, str) and old_args == new_args: + sanitized_tool_calls.append(tool_call) + continue + + tool_calls_changed = True + changed += 1 + new_tool_call = dict(tool_call) + new_function = dict(function) + new_function["arguments"] = new_args + new_tool_call["function"] = new_function + sanitized_tool_calls.append(new_tool_call) + + # 无 tool_calls 与有 tool_calls 走不同分支 + if tool_calls_changed: + new_message = dict(message) + new_message["tool_calls"] = sanitized_tool_calls + sanitized_messages.append(new_message) + else: + sanitized_messages.append(message) + + return sanitized_messages, changed + + +def sanitize_chat_completion_messages( + messages: list[dict[str, Any]], + *, + preserve_reasoning_content: bool = False, +) -> tuple[list[dict[str, Any]], int, dict[str, int]]: + """移除 Chat Completions 非标准消息字段。 + + 本地历史里允许保留 reasoning_content 等兼容字段用于日志/回放; + 发往上游时默认剥离。``preserve_reasoning_content=True`` 时保留 + ``reasoning_content`` 供多轮 CoT 续传,仍剥离其它内部字段。 + """ + if not messages: + return messages, 0, {} + + changed = 0 + stripped_fields: dict[str, int] = {} + sanitized_messages: list[dict[str, Any]] = [] + for message in messages: + if not isinstance(message, dict): + sanitized_messages.append(message) + continue + + sanitized_message = message + removed = False + for key in CHAT_COMPLETION_INTERNAL_MESSAGE_KEYS: + if preserve_reasoning_content and key == "reasoning_content": + continue + if key not in sanitized_message: + continue + if sanitized_message is message: + sanitized_message = dict(message) + sanitized_message.pop(key, None) + stripped_fields[key] = stripped_fields.get(key, 0) + 1 + removed = True + + if removed: + changed += 1 + sanitized_messages.append(sanitized_message) + + return sanitized_messages, changed, stripped_fields + + +def relocate_system_to_first_user( + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """将 system/developer 消息合并注入首条 user 消息(chat_completions 适配)。""" + if not messages: + return messages + + system_parts: list[str] = [] + remaining: list[dict[str, Any]] = [] + for message in messages: + if not isinstance(message, dict): + remaining.append(message) + continue + role = str(message.get("role") or "").strip().lower() + if role in ("system", "developer"): + content = message.get("content") + if content is not None: + text = content if isinstance(content, str) else str(content) + if text.strip(): + system_parts.append(text.strip()) + continue + remaining.append(message) + + if not system_parts: + return messages + + merged_system = "\n\n".join(system_parts) + first_user_idx: int | None = None + for idx, message in enumerate(remaining): + if ( + isinstance(message, dict) + and str(message.get("role") or "").strip().lower() == "user" + ): + first_user_idx = idx + break + + if first_user_idx is None: + remaining.insert(0, {"role": "user", "content": merged_system}) + return remaining + + first_user = dict(remaining[first_user_idx]) + old_content = first_user.get("content") + old_text = ( + old_content + if isinstance(old_content, str) + else (str(old_content) if old_content is not None else "") + ) + if old_text.strip(): + first_user["content"] = f"{merged_system}\n\n{old_text}" + else: + first_user["content"] = merged_system + updated = list(remaining) + updated[first_user_idx] = first_user + return updated + + +def prepare_chat_completion_messages( + model_config: ModelConfig, + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """按模型配置整理 Chat Completions 出站消息。""" + preserve_reasoning = bool(getattr(model_config, "reasoning_content_replay", False)) + prepared, _, _ = sanitize_chat_completion_messages( + messages, + preserve_reasoning_content=preserve_reasoning, + ) + if bool(getattr(model_config, "system_prompt_as_user", False)): + prepared = relocate_system_to_first_user(prepared) + return prepared diff --git a/src/Undefined/ai/llm/streaming.py b/src/Undefined/ai/llm/streaming.py new file mode 100644 index 00000000..50632476 --- /dev/null +++ b/src/Undefined/ai/llm/streaming.py @@ -0,0 +1,392 @@ +"""LLM 流式响应聚合与回退判定。 + +解析 SSE/chunk 事件、合并 delta 与 tool_calls,并在上游不支持流式时 +判定是否降级为非流式请求;不持有 HTTP 客户端或模型配置。 +""" + +from __future__ import annotations + +import json +from typing import Any + +from openai import APIStatusError + +from Undefined.ai.llm.thinking import stringify_thinking +from Undefined.ai.transports import API_MODE_CHAT_COMPLETIONS, API_MODE_RESPONSES + +_CHAT_COMPLETIONS_KNOWN_FIELDS: set[str] = { + "model", + "messages", + "audio", + "metadata", + "max_completion_tokens", + "max_tokens", + "modalities", + "parallel_tool_calls", + "prediction", + "prompt_cache_key", + "prompt_cache_retention", + "reasoning_effort", + "safety_identifier", + "service_tier", + "store", + "temperature", + "top_p", + "n", + "stop", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "stream", + "stream_options", + "tools", + "tool_choice", + "logprobs", + "top_logprobs", + "verbosity", + "web_search_options", +} + +_RESPONSES_KNOWN_FIELDS: set[str] = { + "background", + "context_management", + "conversation", + "include", + "model", + "input", + "instructions", + "max_output_tokens", + "max_tool_calls", + "metadata", + "previous_response_id", + "prompt", + "prompt_cache_key", + "prompt_cache_retention", + "reasoning", + "safety_identifier", + "service_tier", + "store", + "temperature", + "top_p", + "tools", + "tool_choice", + "parallel_tool_calls", + "stream", + "stream_options", + "text", + "truncation", + "user", +} + +_STREAM_FALLBACK_STATUS_CODES = {400, 404, 405, 422, 501} +_STREAM_FALLBACK_ERROR_MARKERS = ( + "stream", + "stream_options", + "streaming", + "not support", + "unsupported", + "unrecognized", + "unknown parameter", + "unexpected parameter", +) + + +def split_chat_completion_params( + body: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """将请求体拆分为 SDK 已知字段与 extra_body。""" + known: dict[str, Any] = {} + extra: dict[str, Any] = {} + for key, value in body.items(): + if key in _CHAT_COMPLETIONS_KNOWN_FIELDS: + known[key] = value + else: + extra[key] = value + return known, extra + + +def split_responses_params( + body: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """将 Responses 请求体拆分为 SDK 已知字段与 extra_body。""" + known: dict[str, Any] = {} + extra: dict[str, Any] = {} + for key, value in body.items(): + if key in _RESPONSES_KNOWN_FIELDS: + known[key] = value + else: + extra[key] = value + return known, extra + + +def without_stream_request_fields(body: dict[str, Any]) -> dict[str, Any]: + """移除 stream / stream_options 字段,用于流式回退。""" + stripped = dict(body) + stripped.pop("stream", None) + stripped.pop("stream_options", None) + return stripped + + +def ensure_chat_stream_usage_options(body: dict[str, Any]) -> None: + """确保 Chat Completions 流式请求携带 include_usage。""" + stream_options = body.get("stream_options") + if stream_options is None: + body["stream_options"] = {"include_usage": True} + return + if isinstance(stream_options, dict) and "include_usage" not in stream_options: + body["stream_options"] = {**stream_options, "include_usage": True} + + +def _status_error_text(exc: APIStatusError) -> str: + parts = [str(exc)] + body = getattr(exc, "body", None) + if isinstance(body, dict): + parts.append(json.dumps(body, ensure_ascii=False, default=str)) + elif body is not None: + parts.append(str(body)) + response = getattr(exc, "response", None) + if response is not None: + try: + parts.append(response.text) + except Exception: + pass + return "\n".join(part for part in parts if part).lower() + + +def should_fallback_from_stream(exc: Exception) -> bool: + """判定流式失败是否应降级为非流式重试。""" + if isinstance(exc, NotImplementedError): + return True + if not isinstance(exc, APIStatusError): + return False + # 仅对明确的 stream 参数/能力错误做回退,避免掩盖其它 4xx + if exc.status_code not in _STREAM_FALLBACK_STATUS_CODES: + return False + text = _status_error_text(exc) + # 回退到默认/主配置 + return any(marker in text for marker in _STREAM_FALLBACK_ERROR_MARKERS) + + +def stringify_stream_delta(value: Any) -> str: + """将流式 delta 字段归一化为字符串片段。""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, list): + parts = [stringify_stream_delta(item) for item in value] + return "".join(part for part in parts if part) + if isinstance(value, dict): + for key in ("text", "content", "delta", "value"): + if value.get(key) is not None: + return stringify_stream_delta(value.get(key)) + return "" + return str(value) + + +def extract_stream_response_item(event: dict[str, Any]) -> dict[str, Any] | None: + """从 Responses 流式事件中提取 output item。""" + for key in ("item", "output_item", "data"): + value = event.get(key) + if isinstance(value, dict): + return value + response = event.get("response") + if isinstance(response, dict) and isinstance(response.get("output"), list): + return None + if isinstance(response, dict): + return response + return None + + +def extract_stream_usage( + event: dict[str, Any], *, api_mode: str +) -> dict[str, Any] | None: + """从流式事件中提取 usage 统计。""" + usage = event.get("usage") + if not isinstance(usage, dict): + response = event.get("response") + if isinstance(response, dict) and isinstance(response.get("usage"), dict): + usage = response.get("usage") + if not isinstance(usage, dict): + return None + if api_mode == API_MODE_RESPONSES: + return { + "input_tokens": int(usage.get("input_tokens", 0) or 0), + "output_tokens": int(usage.get("output_tokens", 0) or 0), + "total_tokens": int(usage.get("total_tokens", 0) or 0), + } + return { + "prompt_tokens": int(usage.get("prompt_tokens", 0) or 0), + "completion_tokens": int(usage.get("completion_tokens", 0) or 0), + "total_tokens": int(usage.get("total_tokens", 0) or 0), + } + + +def ensure_tool_call_slot( + tool_calls: list[dict[str, Any]], index: int +) -> dict[str, Any]: + """确保 tool_calls 列表在指定 index 处存在槽位。""" + while len(tool_calls) <= index: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + return tool_calls[index] + + +def merge_tool_call_delta( + target_tool_calls: list[dict[str, Any]], tool_delta: dict[str, Any] +) -> None: + """将单个 tool_call delta 合并进累积结果。""" + index = tool_delta.get("index") + try: + slot_index = int(index) if index is not None else len(target_tool_calls) + except (TypeError, ValueError): + slot_index = len(target_tool_calls) + tool_call = ensure_tool_call_slot(target_tool_calls, slot_index) + call_id = str(tool_delta.get("id") or "").strip() + if call_id: + tool_call["id"] = call_id + tool_type = str(tool_delta.get("type") or "").strip() + if tool_type: + tool_call["type"] = tool_type + function_delta = tool_delta.get("function") + if not isinstance(function_delta, dict): + return + function = tool_call.setdefault("function", {"name": "", "arguments": ""}) + if not isinstance(function, dict): + function = {"name": "", "arguments": ""} + tool_call["function"] = function + function_name = str(function_delta.get("name") or "").strip() + if function_name: + function["name"] = function_name + arguments_delta = function_delta.get("arguments") + if arguments_delta is not None: + # 流式 tool arguments 按 chunk 拼接,直至 JSON 完整 + function["arguments"] = str(function.get("arguments") or "") + str( + arguments_delta + ) + + +def aggregate_chat_completions_stream( + chunks: list[dict[str, Any]], + *, + reasoning_replay: bool, +) -> dict[str, Any]: + """将 Chat Completions 流式 chunk 列表聚合为完整响应 dict。""" + content_parts: list[str] = [] + reasoning_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + usage: dict[str, Any] | None = None + finish_reason = "stop" + role = "assistant" + + for chunk_dict in chunks: + usage = ( + extract_stream_usage(chunk_dict, api_mode=API_MODE_CHAT_COMPLETIONS) + or usage + ) + choices = chunk_dict.get("choices") + if not isinstance(choices, list): + continue + for choice in choices: + if not isinstance(choice, dict): + continue + delta = choice.get("delta") + if not isinstance(delta, dict): + continue + role_value = str(delta.get("role") or "").strip() + if role_value: + role = role_value + content_delta = stringify_stream_delta(delta.get("content")) + if content_delta: + content_parts.append(content_delta) + if reasoning_replay: + reasoning_delta = stringify_thinking(delta.get("reasoning_content")) + if reasoning_delta: + reasoning_parts.append(reasoning_delta) + raw_tool_calls = delta.get("tool_calls") + # 无 tool_calls 与有 tool_calls 走不同分支 + if isinstance(raw_tool_calls, list): + # 逐个处理模型返回的 tool_call + for tool_delta in raw_tool_calls: + if isinstance(tool_delta, dict): + merge_tool_call_delta(tool_calls, tool_delta) + current_finish_reason = str(choice.get("finish_reason") or "").strip() + if current_finish_reason: + finish_reason = current_finish_reason + + message: dict[str, Any] = { + "role": role, + "content": "".join(content_parts), + } + if reasoning_replay: + reasoning_text = "".join(reasoning_parts).strip() + if reasoning_text: + message["reasoning_content"] = reasoning_text + # 无 tool_calls 与有 tool_calls 走不同分支 + if tool_calls: + message["tool_calls"] = tool_calls + result: dict[str, Any] = { + "choices": [ + { + "index": 0, + "message": message, + "finish_reason": finish_reason, + } + ] + } + if usage is not None: + result["usage"] = usage + return result + + +def aggregate_responses_stream(events: list[dict[str, Any]]) -> dict[str, Any]: + """将 Responses 流式事件列表聚合为完整响应 dict。""" + output_items: list[dict[str, Any]] = [] + output_text_parts: list[str] = [] + usage: dict[str, Any] | None = None + final_response: dict[str, Any] | None = None + + for event_dict in events: + usage = extract_stream_usage(event_dict, api_mode=API_MODE_RESPONSES) or usage + event_type = str(event_dict.get("type") or "").strip().lower() + response = event_dict.get("response") + if isinstance(response, dict): + final_response = response + if event_type == "response.output_text.delta": + delta = stringify_stream_delta(event_dict.get("delta")) + if delta: + output_text_parts.append(delta) + continue + if event_type == "response.completed": + if isinstance(response, dict): + final_response = response + continue + item = extract_stream_response_item(event_dict) + if not isinstance(item, dict): + continue + item_type = str(item.get("type") or "").strip().lower() + if item_type in ("message", "function_call", "reasoning"): + output_items.append(item) + + if final_response is not None: + if usage is not None and not isinstance(final_response.get("usage"), dict): + final_response = dict(final_response) + final_response["usage"] = usage + return final_response + + # 未收到 completed 事件时,用增量 delta 合成最小可用响应 + synthesized: dict[str, Any] = { + "output": output_items, + "output_text": "".join(output_text_parts), + } + if usage is not None: + synthesized["usage"] = usage + return synthesized diff --git a/src/Undefined/ai/llm/thinking.py b/src/Undefined/ai/llm/thinking.py new file mode 100644 index 00000000..507e4691 --- /dev/null +++ b/src/Undefined/ai/llm/thinking.py @@ -0,0 +1,214 @@ +"""思维链(CoT)提取与 thinking 参数规范化。 + +从 Chat Completions / Responses 响应中抽取 reasoning 字段,并将配置中的 +thinking 覆盖值归一化为各上游兼容格式;不负责发送请求。 +""" + +from __future__ import annotations + +from typing import Any + +from Undefined.ai.llm.types import ModelConfig + +_THINKING_KEYS: tuple[str, ...] = ( + "thinking", + "reasoning", + "reasoning_content", + "chain_of_thought", + "cot", + "thoughts", +) + + +def _stringify_thinking_list(value: list[Any]) -> str: + """将列表类型的思维链转换为字符串。 + + Args: + value: 思维链列表 + + Returns: + 格式化后的字符串 + """ + parts = [stringify_thinking(item) for item in value] + return "\n".join([part for part in parts if part]) + + +def _stringify_thinking_dict(value: dict[str, Any]) -> str: + """将字典类型的思维链转换为字符串。 + + Args: + value: 思维链字典 + + Returns: + 格式化后的字符串 + """ + content = value.get("content") + if isinstance(content, str) and content: + return content + return str(value) + + +def stringify_thinking(value: Any) -> str: + """将思维链值转换为字符串。 + + Args: + value: 思维链值(可以是 None、字符串、列表或字典) + + Returns: + 格式化后的字符串 + """ + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, list): + return _stringify_thinking_list(value) + if isinstance(value, dict): + return _stringify_thinking_dict(value) + return str(value) + + +def _extract_from_message(message: dict[str, Any]) -> str: + """从 message 对象中提取思维链内容。 + + Args: + message: message 对象 + + Returns: + 思维链内容字符串 + """ + if not isinstance(message, dict): + return "" + for key in _THINKING_KEYS: + if key in message: + return stringify_thinking(message.get(key)) + return "" + + +def _extract_from_choice(choice: dict[str, Any]) -> str: + """从 choice 对象中提取思维链内容。 + + Args: + choice: choice 对象 + + Returns: + 思维链内容字符串 + """ + if not isinstance(choice, dict): + return "" + + # 优先从 message 中提取 + message = choice.get("message") + if isinstance(message, dict): + thinking = _extract_from_message(message) + if thinking: + return thinking + + # 尝试从 choice 直接提取 + for key in _THINKING_KEYS: + if key in choice: + return stringify_thinking(choice.get(key)) + + return "" + + +def _extract_from_choices(choices: list[Any]) -> str: + """从 choices 列表中提取思维链内容。 + + Args: + choices: choices 列表 + + Returns: + 思维链内容字符串 + """ + if not isinstance(choices, list) or not choices: + return "" + choice = choices[0] + return _extract_from_choice(choice) + + +def _extract_from_result(result: dict[str, Any]) -> str: + """直接从结果对象中提取思维链内容。 + + Args: + result: API 响应结果 + + Returns: + 思维链内容字符串 + """ + for key in _THINKING_KEYS: + if key in result: + return stringify_thinking(result.get(key)) + return "" + + +def extract_thinking_content(result: dict[str, Any]) -> str: + """从 API 响应中提取思维链内容。 + + 提取优先级: + 1. 从 choices[0].message 中提取 + 2. 从 choices[0] 直接提取 + 3. 从响应根对象中提取 + + Args: + result: API 响应结果 + + Returns: + 思维链内容字符串 + """ + # 尝试从 choices 中提取 + choices = result.get("choices") + if isinstance(choices, list): + thinking = _extract_from_choices(choices) + if thinking: + return thinking + + return _extract_from_result(result) + + +def _is_deepseek_provider(model_config: ModelConfig) -> bool: + model_name = str(getattr(model_config, "model_name", "") or "").lower() + if model_name.startswith("deepseek"): + return True + api_url = str(getattr(model_config, "api_url", "") or "").lower() + return "deepseek" in api_url + + +def normalize_thinking_override( + value: Any, model_config: ModelConfig +) -> dict[str, Any] | None: + """将 request 覆盖中的 thinking 值归一化为上游可接受的 dict。""" + if value is None: + return None + + is_deepseek = _is_deepseek_provider(model_config) + + if isinstance(value, dict): + raw_type = value.get("type") + if isinstance(raw_type, str): + type_value = raw_type.strip().lower() + if type_value in {"enabled", "disabled"}: + # DeepSeek 仅接受 {type: enabled|disabled},其它字段原样透传 + return {"type": type_value} if is_deepseek else dict(value) + + raw_enabled = value.get("enabled") + if isinstance(raw_enabled, bool): + type_value = "enabled" if raw_enabled else "disabled" + if is_deepseek: + return {"type": type_value} + normalized = dict(value) + normalized.pop("enabled", None) + normalized["type"] = type_value + return normalized + + return None + + if isinstance(value, bool): + return {"type": "enabled" if value else "disabled"} + + if isinstance(value, str): + type_value = value.strip().lower() + if type_value in {"enabled", "disabled"}: + return {"type": type_value} + + return None diff --git a/src/Undefined/ai/llm/types.py b/src/Undefined/ai/llm/types.py new file mode 100644 index 00000000..74315470 --- /dev/null +++ b/src/Undefined/ai/llm/types.py @@ -0,0 +1,27 @@ +"""LLM 模块共享类型别名。""" + +from __future__ import annotations + +# 联合类型:所有可发起 LLM/嵌入/重排请求的模型配置 +from Undefined.config import ( + AgentModelConfig, + ChatModelConfig, + EmbeddingModelConfig, + GrokModelConfig, + RerankModelConfig, + SecurityModelConfig, + VisionModelConfig, +) + +ModelConfig = ( + ChatModelConfig + | VisionModelConfig + | AgentModelConfig + | SecurityModelConfig + | EmbeddingModelConfig + | GrokModelConfig + | RerankModelConfig +) + +# 类型别名对外 re-export +__all__ = ["ModelConfig"] diff --git a/src/Undefined/ai/model_selector.py b/src/Undefined/ai/model_selector.py index a4d94270..74937636 100644 --- a/src/Undefined/ai/model_selector.py +++ b/src/Undefined/ai/model_selector.py @@ -38,7 +38,6 @@ def __init__( self._rr_lock = threading.Lock() self._rr_counters: dict[str, int] = {} self._preferences: dict[tuple[int, int], dict[str, str]] = {} - # pending_compares 只存模型名列表,不存配置对象 self._pending_compares: dict[tuple[int, int], tuple[list[str], float]] = {} self._loaded = asyncio.Event() diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py index e4dda36f..8e5bcc90 100644 --- a/src/Undefined/ai/multimodal.py +++ b/src/Undefined/ai/multimodal.py @@ -29,11 +29,9 @@ # 每个文件名最多保留的历史 Q&A 条数 _MAX_QA_HISTORY = 5 -# 磁盘持久化路径 _HISTORY_FILE_PATH = Path("data/media_qa_history.json") # 远程媒体缓存目录(用于先下载 URL 再转 data URL) -# Remote media cache directory (download URL first, then convert to data URL). _MEDIA_URL_CACHE_DIR = Path("data/cache/multimodal_media") # 远程媒体缓存清理策略:仅保留最近 6 小时 + 最多 256 个文件。 @@ -42,18 +40,14 @@ _MEDIA_URL_CACHE_MAX_FILES = 256 # 两次自动清理之间的最小间隔(秒),避免每次请求都全量扫描目录。 -# Minimum interval between cleanup runs (seconds) to avoid full scan on every call. _MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS = 60.0 # 下载 URL 到本地缓存时的网络超时(秒)。 -# Network timeout (seconds) when downloading URL to local cache. _MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS = 120.0 # 下载阶段临时文件后缀(追加在缓存文件名后),用于区分真实缓存文件。 -# Download-stage temporary suffix (appended to cache filename) to avoid clashes. _MEDIA_URL_DOWNLOAD_TMP_SUFFIX = ".downloading" -# 文件扩展名常量 _IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg") _AUDIO_EXTENSIONS = (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".aac", ".wma") _VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv", ".wmv") @@ -106,16 +100,13 @@ def _get_media_type_by_extension(url_lower: str) -> str: def detect_media_type(media_url: str, specified_type: str = "auto") -> str: """检测媒体文件的类型(图片、音频或视频)。""" - # 1. 优先级最高:手动指定类型 if specified_type and specified_type != "auto": return specified_type - # 2. 检查 data URL media_type = _detect_from_data_url(media_url) if media_type: return media_type - # 3. 使用 mimetypes 或扩展名猜测 return _detect_by_mimetypes(media_url) @@ -172,7 +163,6 @@ def get_media_mime_type(media_type: str, file_path: str = "") -> str: return _DEFAULT_MIME_TYPES.get(media_type, "application/octet-stream") -# 响应内容类型到字段名的映射 _MEDIA_TYPE_TO_FIELD = { "image": "ocr_text", "audio": "transcript", @@ -232,7 +222,6 @@ def get_media_mime_type(media_type: str, file_path: str = "") -> str: } -# 错误消息映射 _ERROR_MESSAGES = { "read": { "image": "[图片无法读取]", @@ -280,7 +269,6 @@ def _parse_analysis_response(content: str) -> dict[str, str]: "subtitles": ("字幕:", "字幕:"), } - # 初始化所有字段为空 result = { "description": "", "ocr_text": "", @@ -288,7 +276,6 @@ def _parse_analysis_response(content: str) -> dict[str, str]: "subtitles": "", } - # 解析每一行 for line in content.split("\n"): line = line.strip() for field, prefixes in field_prefixes.items(): @@ -615,7 +602,6 @@ async def _build_content_items( """ content_items: list[dict[str, Any]] = [{"type": "text", "text": prompt}] - # 添加媒体内容项 media_item_key = f"{media_type}_url" contents = media_content if isinstance(media_content, list) else [media_content] for mc in contents: @@ -652,13 +638,11 @@ async def analyze( len(prompt_extra), ) - # 检查缓存 cache_key = f"{detected_type}:{media_url[:100]}:{prompt_extra}" if cache_key in self._cache: logger.debug("[媒体分析] 命中缓存: key=%s", cache_key[:120]) return self._cache[cache_key] - # 加载媒体内容 try: media_content = await self._load_media_content(media_url, detected_type) except Exception as exc: @@ -669,7 +653,6 @@ async def analyze( ) } - # 加载提示词 try: prompt = read_text_resource(self._prompt_path) except Exception: @@ -682,16 +665,13 @@ async def analyze( self._prompt_path, ) - # 添加补充提示词 if prompt_extra: prompt += f"\n\n【补充指令】\n{prompt_extra}" - # 构建请求内容 content_items = await self._build_content_items( detected_type, media_content, prompt ) - # 发送分析请求 try: result = await self._requester.request( model_config=self._vision_config, @@ -703,16 +683,13 @@ async def analyze( if logger.isEnabledFor(logging.DEBUG): log_debug_json(logger, "[媒体分析] 原始响应内容", content) - # 解析响应内容 parsed = _parse_analysis_response(content) - # 根据媒体类型构建结果字典 result_dict: dict[str, str] = {"description": parsed["description"]} field_name = _MEDIA_TYPE_TO_FIELD.get(detected_type) if field_name: result_dict[field_name] = parsed[field_name] - # 缓存结果 self._cache[cache_key] = result_dict logger.info(f"[媒体分析] 完成并缓存: {safe_url[:50]}... ({detected_type})") return result_dict diff --git a/src/Undefined/ai/multimodal/__init__.py b/src/Undefined/ai/multimodal/__init__.py new file mode 100644 index 00000000..128ba4ac --- /dev/null +++ b/src/Undefined/ai/multimodal/__init__.py @@ -0,0 +1,32 @@ +"""多模态分析子包。 + +对外稳定入口:``MultimodalAnalyzer``、``detect_media_type``、``get_media_mime_type``; +旧路径 ``Undefined.ai.multimodal`` 通过包根与 ``multimodal.py`` shim 保持兼容。 +""" + +from Undefined.ai.multimodal import constants as _constants + +# 测试 monkeypatch 沿用的模块级私有常量,勿随意改名 +_MEDIA_URL_CACHE_DIR = _constants._MEDIA_URL_CACHE_DIR +_MEDIA_URL_CACHE_TTL_SECONDS = _constants._MEDIA_URL_CACHE_TTL_SECONDS +_MEDIA_URL_CACHE_MAX_FILES = _constants._MEDIA_URL_CACHE_MAX_FILES +_MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS = ( + _constants._MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS +) +_MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS = _constants._MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS +_MEDIA_URL_DOWNLOAD_TMP_SUFFIX = _constants._MEDIA_URL_DOWNLOAD_TMP_SUFFIX + +from Undefined.ai.multimodal.analyzer import MultimodalAnalyzer # noqa: E402 +from Undefined.ai.multimodal.detection import detect_media_type, get_media_mime_type # noqa: E402 + +__all__ = [ + "MultimodalAnalyzer", + "detect_media_type", + "get_media_mime_type", + "_MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS", + "_MEDIA_URL_CACHE_DIR", + "_MEDIA_URL_CACHE_MAX_FILES", + "_MEDIA_URL_CACHE_TTL_SECONDS", + "_MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS", + "_MEDIA_URL_DOWNLOAD_TMP_SUFFIX", +] diff --git a/src/Undefined/ai/multimodal/analyzer.py b/src/Undefined/ai/multimodal/analyzer.py new file mode 100644 index 00000000..83b18c9a --- /dev/null +++ b/src/Undefined/ai/multimodal/analyzer.py @@ -0,0 +1,599 @@ +"""多模态媒体分析器实现。""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import json +import logging +import time +from pathlib import Path +from typing import Any, cast +from urllib.parse import urlsplit + +import aiofiles +import httpx + +from Undefined.ai.llm import ModelRequester +import Undefined.ai.multimodal as _multimodal_pkg +from Undefined.ai.multimodal.constants import ( + ERROR_MESSAGES, + HISTORY_FILE_PATH, + MAX_QA_HISTORY, + MEDIA_TYPE_TO_FIELD, + MEME_DESCRIBE_PROMPT_PATH, + MEME_DESCRIBE_TOOL, + MEME_JUDGE_PROMPT_PATH, + MEME_JUDGE_TOOL, +) +from Undefined.ai.multimodal.detection import detect_media_type, get_media_mime_type +from Undefined.ai.multimodal.parsing import ( + _normalize_meme_tags, + _parse_analysis_response, +) +from Undefined.ai.parsing import extract_choices_content +from Undefined.ai.transports import API_MODE_CHAT_COMPLETIONS, get_api_mode +from Undefined.config import VisionModelConfig +from Undefined.utils.coerce import safe_float +from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.resources import read_text_resource +from Undefined.utils.tool_calls import extract_required_tool_call_arguments + +logger = logging.getLogger(__name__) + + +class MultimodalAnalyzer: + """多模态媒体分析器。 + + 支持分析图片、音频和视频文件,提取描述内容和类型特定信息(如 OCR 文字、转写文字、字幕等)。 + """ + + def __init__( + self, + requester: ModelRequester, + vision_config: VisionModelConfig, + prompt_path: str = "res/prompts/analyze_multimodal.txt", + ) -> None: + """初始化多模态分析器。 + + Args: + requester: 模型请求器 + vision_config: 视觉模型配置 + prompt_path: 提示词模板文件路径 + """ + self._requester = requester + self._vision_config = vision_config + self._prompt_path = prompt_path + self._cache: dict[str, dict[str, str]] = {} + # 按文件名索引的 Q&A 历史:{filename: [{q: ..., a: ...}, ...]} + self._file_history: dict[str, list[dict[str, str]]] = {} + + # URL 下载锁:按 URL 哈希粒度加锁,避免并发下载同一文件造成竞态。 + # URL download lock: keyed by URL hash to avoid duplicate concurrent downloads. + self._url_cache_locks: dict[str, asyncio.Lock] = {} + self._url_cache_locks_guard = asyncio.Lock() + + # 缓存清理锁 + 上次清理时间,避免并发清理相互干扰。 + # Cache cleanup lock + last cleanup timestamp to avoid concurrent cleanup races. + self._url_cache_cleanup_lock = asyncio.Lock() + self._last_url_cache_cleanup_at = 0.0 + + self._load_history() + + async def _load_media_content(self, media_url: str, media_type: str) -> str: + """加载媒体内容。 + + 如果是本地文件,会将其转换为 base64 编码的 data URL。 + + Args: + media_url: 媒体 URL 或本地文件路径 + media_type: 媒体类型 + + Returns: + 可用于 API 请求的媒体内容字符串 + """ + if media_url.startswith("data:"): + return media_url + + if media_url.startswith("http://") or media_url.startswith("https://"): + return await self._load_remote_media_as_data_url(media_url, media_type) + + # 读取本地文件并转换为 base64 + async with aiofiles.open(media_url, "rb") as f: + media_bytes = bytes(await f.read()) + media_data = base64.b64encode(media_bytes).decode() + mime_type = get_media_mime_type(media_type, media_url) + return f"data:{mime_type};base64,{media_data}" + + async def _load_remote_media_as_data_url( + self, media_url: str, media_type: str + ) -> str: + """将远程 URL 下载到缓存并转换为 data URL。""" + cache_key = self._build_url_cache_key(media_url) + lock = await self._get_url_cache_lock(cache_key) + cache_path = self._build_url_cache_path(cache_key, media_url) + + async with lock: + await self._cleanup_url_cache_if_needed() + if not cache_path.exists(): + await self._download_url_to_cache(media_url, cache_path) + async with aiofiles.open(cache_path, "rb") as f: + media_bytes = bytes(await f.read()) + media_data = base64.b64encode(media_bytes).decode() + + mime_type = get_media_mime_type(media_type, media_url) + return f"data:{mime_type};base64,{media_data}" + + def _build_url_cache_key(self, media_url: str) -> str: + """构建 URL 缓存键(使用 URL 内容哈希)。""" + return hashlib.sha256(media_url.encode("utf-8")).hexdigest() + + def _build_url_cache_path(self, cache_key: str, media_url: str) -> Path: + """基于 URL 生成缓存文件路径。""" + suffix = Path(urlsplit(media_url).path).suffix.lower() + if not suffix or len(suffix) > 10: + suffix = ".bin" + return _multimodal_pkg._MEDIA_URL_CACHE_DIR / f"{cache_key}{suffix}" + + async def _get_url_cache_lock(self, cache_key: str) -> asyncio.Lock: + """获取 URL 对应的下载锁(同 URL 串行化)。""" + async with self._url_cache_locks_guard: + lock = self._url_cache_locks.get(cache_key) + if lock is None: + lock = asyncio.Lock() + self._url_cache_locks[cache_key] = lock + return lock + + async def _download_url_to_cache(self, media_url: str, cache_path: Path) -> None: + """下载远程 URL 到缓存文件(原子写入,避免部分文件)。""" + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_name( + f"{cache_path.name}{_multimodal_pkg._MEDIA_URL_DOWNLOAD_TMP_SUFFIX}" + ) + try: + timeout = httpx.Timeout(_multimodal_pkg._MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS) + async with httpx.AsyncClient( + timeout=timeout, follow_redirects=True + ) as client: + response = await client.get(media_url) + response.raise_for_status() + async with aiofiles.open(tmp_path, "wb") as f: + await f.write(response.content) + tmp_path.replace(cache_path) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + raise + + @staticmethod + def _extract_cache_key_from_tmp(path: Path) -> str: + """从临时文件名提取 cache_key({key}.{ext}. -> key)。 + + Extract cache_key from tmp filename ({key}.{ext}. -> key). + """ + return Path(path.stem).stem + + @staticmethod + def _is_download_tmp_path(path: Path) -> bool: + """判断是否为下载过程临时文件({key}.{ext}.)。 + + Identify download tmp files by requiring a dedicated trailing suffix and + at least one original extension segment before it. + """ + suffixes = path.suffixes + return ( + len(suffixes) >= 2 + and suffixes[-1] == _multimodal_pkg._MEDIA_URL_DOWNLOAD_TMP_SUFFIX + ) + + async def _cleanup_url_cache_if_needed(self) -> None: + """按 TTL + 文件数上限清理 URL 媒体缓存。""" + now = time.time() + if ( + now - self._last_url_cache_cleanup_at + < _multimodal_pkg._MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS + ): + return + + async with self._url_cache_cleanup_lock: + # 双重检查,避免并发情况下重复清理。 + # Double-check to avoid repeated cleanup under concurrency. + now = time.time() + if ( + now - self._last_url_cache_cleanup_at + < _multimodal_pkg._MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS + ): + return + self._last_url_cache_cleanup_at = now + + async with self._url_cache_locks_guard: + active_keys = { + key for key, lock in self._url_cache_locks.items() if lock.locked() + } + cache_dir = _multimodal_pkg._MEDIA_URL_CACHE_DIR + if not cache_dir.exists(): + await self._prune_url_cache_locks( + active_keys=active_keys, + present_keys=set(), + ) + return + + files: list[Path] = [p for p in cache_dir.iterdir() if p.is_file()] + expire_before = now - _multimodal_pkg._MEDIA_URL_CACHE_TTL_SECONDS + kept_files: list[Path] = [] + present_keys: set[str] = set() + + # 先按 TTL 清理,跳过正在下载/读取的活跃键。 + # First, TTL cleanup; skip active keys still being downloaded/read. + for path in files: + if self._is_download_tmp_path(path): + tmp_key = self._extract_cache_key_from_tmp(path) + if tmp_key and tmp_key not in active_keys: + path.unlink(missing_ok=True) + continue + present_keys.add(path.stem) + try: + mtime = path.stat().st_mtime + except OSError: + continue + if mtime < expire_before and path.stem not in active_keys: + path.unlink(missing_ok=True) + else: + kept_files.append(path) + + await self._prune_url_cache_locks( + active_keys=active_keys, + present_keys=present_keys, + ) + + # 再按数量上限清理最旧文件,同样跳过活跃键。 + # Then enforce max-file limit by deleting oldest files, skipping active keys. + if len(kept_files) <= _multimodal_pkg._MEDIA_URL_CACHE_MAX_FILES: + return + + kept_with_mtime: list[tuple[float, Path]] = [] + for path in kept_files: + try: + kept_with_mtime.append((path.stat().st_mtime, path)) + except OSError: + continue + kept_with_mtime.sort(key=lambda item: item[0], reverse=True) + for _, path in kept_with_mtime[ + _multimodal_pkg._MEDIA_URL_CACHE_MAX_FILES : + ]: + if path.stem in active_keys: + continue + path.unlink(missing_ok=True) + + async def _prune_url_cache_locks( + self, + *, + active_keys: set[str], + present_keys: set[str], + ) -> None: + """回收不再活跃且已无缓存文件的 URL 锁,避免字典无限增长。 + + Prune stale URL locks with no active task/file to avoid unbounded growth. + """ + async with self._url_cache_locks_guard: + stale_keys = [ + key + for key, lock in self._url_cache_locks.items() + if key not in active_keys + and key not in present_keys + and not lock.locked() + ] + for key in stale_keys: + self._url_cache_locks.pop(key, None) + + async def _build_content_items( + self, media_type: str, media_content: str | list[str], prompt: str + ) -> list[dict[str, Any]]: + """构建请求内容项。 + + Args: + media_type: 媒体类型 + media_content: 媒体内容(URL/data URL),或其列表 + prompt: 提示词 + + Returns: + 包含文本和媒体的内容项列表 + """ + content_items: list[dict[str, Any]] = [{"type": "text", "text": prompt}] + + media_item_key = f"{media_type}_url" + contents = media_content if isinstance(media_content, list) else [media_content] + for mc in contents: + content_items.append({"type": media_item_key, media_item_key: {"url": mc}}) + + return content_items + + async def analyze( + self, + media_url: str, + media_type: str = "auto", + prompt_extra: str = "", + ) -> dict[str, str]: + """分析媒体文件。 + + 始终调用视觉模型进行真实分析,不会因历史缓存而跳过。 + + Args: + media_url: 媒体文件 URL 或本地路径 + media_type: 媒体类型,"auto" 表示自动检测 + prompt_extra: 补充提示词 + + Returns: + 包含描述和类型特定信息的字典 + """ + detected_type = detect_media_type(media_url, media_type) + safe_url = redact_string(media_url) + logger.info(f"[媒体分析] 开始分析 {detected_type}: {safe_url[:50]}...") + logger.debug( + "[媒体分析] media_type=%s detected=%s url_len=%s prompt_extra_len=%s", + media_type, + detected_type, + len(media_url), + len(prompt_extra), + ) + + cache_key = f"{detected_type}:{media_url[:100]}:{prompt_extra}" + if cache_key in self._cache: + logger.debug("[媒体分析] 命中缓存: key=%s", cache_key[:120]) + return self._cache[cache_key] + + try: + media_content = await self._load_media_content(media_url, detected_type) + except Exception as exc: + logger.error(f"无法读取媒体文件: {exc}") + return { + "description": ERROR_MESSAGES["read"].get( + detected_type, ERROR_MESSAGES["read"]["default"] + ) + } + + try: + prompt = read_text_resource(self._prompt_path) + except Exception: + async with aiofiles.open(self._prompt_path, "r", encoding="utf-8") as f: + prompt = await f.read() + + logger.debug( + "[媒体分析] prompt_len=%s path=%s", + len(prompt), + self._prompt_path, + ) + + if prompt_extra: + prompt += f"\n\n【补充指令】\n{prompt_extra}" + + content_items = await self._build_content_items( + detected_type, media_content, prompt + ) + + try: + result = await self._requester.request( + model_config=self._vision_config, + messages=[{"role": "user", "content": content_items}], + max_tokens=self._vision_config.max_tokens, + call_type=f"vision_{detected_type}", + ) + content = extract_choices_content(result) + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[媒体分析] 原始响应内容", content) + + parsed = _parse_analysis_response(content) + + result_dict: dict[str, str] = {"description": parsed["description"]} + field_name = MEDIA_TYPE_TO_FIELD.get(detected_type) + if field_name: + result_dict[field_name] = parsed[field_name] + + self._cache[cache_key] = result_dict + logger.info(f"[媒体分析] 完成并缓存: {safe_url[:50]}... ({detected_type})") + return result_dict + + except Exception as exc: + logger.exception(f"媒体分析失败: {exc}") + return { + "description": ERROR_MESSAGES["analyze"].get( + detected_type, ERROR_MESSAGES["analyze"]["default"] + ) + } + + # ── 媒体键级别的 Q&A 历史管理 ── + + def _load_history(self) -> None: + """从磁盘加载历史 Q&A 缓存。""" + if not HISTORY_FILE_PATH.exists(): + return + try: + with open(HISTORY_FILE_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, dict): + self._file_history = data + logger.info( + "[媒体分析] 从磁盘加载历史缓存: %d 个文件", len(self._file_history) + ) + except Exception as exc: + logger.warning("[媒体分析] 加载历史缓存失败: %s", exc) + + async def _save_history(self) -> None: + """将历史缓存写入磁盘。""" + from Undefined.utils import io + + try: + await io.write_json(HISTORY_FILE_PATH, self._file_history, use_lock=True) + except Exception as exc: + logger.error("[媒体分析] 历史缓存写入磁盘失败: %s", exc) + + def get_history(self, media_key: str) -> list[dict[str, str]]: + """获取指定媒体键的历史 Q&A 记录。 + + Args: + media_key: 媒体唯一键(可包含作用域和文件身份) + + Returns: + Q&A 列表,每项包含 ``q`` 和 ``a`` 两个键 + """ + pairs = self._file_history.get(media_key) + if not pairs: + return [] + return list(pairs[-MAX_QA_HISTORY:]) + + async def save_history(self, media_key: str, question: str, answer: str) -> None: + """保存一条 Q&A 到指定媒体键的历史记录(上限 5 条)并持久化。 + + Args: + media_key: 媒体唯一键(可包含作用域和文件身份) + question: 提问内容 + answer: 分析回答 + """ + pairs = self._file_history.setdefault(media_key, []) + pairs.append({"q": question, "a": answer}) + if len(pairs) > MAX_QA_HISTORY: + self._file_history[media_key] = pairs[-MAX_QA_HISTORY:] + await self._save_history() + + async def describe_image( + self, image_url: str, prompt_extra: str = "" + ) -> dict[str, str]: + """描述图片内容。 + + Args: + image_url: 图片 URL 或本地路径 + prompt_extra: 补充提示词 + + Returns: + 包含描述和 OCR 文字的字典 + """ + result = await self.analyze(image_url, "image", prompt_extra) + if "ocr_text" not in result: + result["ocr_text"] = "" + return result + + async def _load_prompt_text(self, prompt_path: str) -> str: + try: + return read_text_resource(prompt_path) + except Exception: + async with aiofiles.open(prompt_path, "r", encoding="utf-8") as f: + return await f.read() + + def _build_tool_request_kwargs(self) -> dict[str, Any]: + request_kwargs: dict[str, Any] = {} + # 非 thinking 模型强制关闭 thinking,避免 tool_choice 被服务商拒绝 + if ( + get_api_mode(self._vision_config) == API_MODE_CHAT_COMPLETIONS + and not self._vision_config.thinking_enabled + ): + request_kwargs["thinking"] = {"enabled": False, "budget_tokens": 0} + return request_kwargs + + async def _request_required_tool_args( + self, + *, + prompt_path: str, + image_url: str | list[str], + tool_schema: dict[str, Any], + tool_name: str, + call_type: str, + max_tokens: int, + ) -> dict[str, Any]: + if isinstance(image_url, list): + media_contents: list[str] = [] + for url in image_url: + media_contents.append(await self._load_media_content(url, "image")) + media_content: str | list[str] = media_contents + else: + media_content = await self._load_media_content(image_url, "image") + prompt = await self._load_prompt_text(prompt_path) + content_items = await self._build_content_items("image", media_content, prompt) + response = await self._requester.request( + model_config=self._vision_config, + messages=[{"role": "user", "content": content_items}], + max_tokens=max_tokens, + call_type=call_type, + tools=[tool_schema], + tool_choice=cast( + Any, {"type": "function", "function": {"name": tool_name}} + ), + **self._build_tool_request_kwargs(), + ) + return extract_required_tool_call_arguments( + response, + expected_tool_name=tool_name, + stage=call_type, + logger=logger, + error_context=f"image={redact_string(str(image_url) if isinstance(image_url, list) else image_url)[:120]}", + ) + + async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: + safe_url = redact_string( + str(image_url) if isinstance(image_url, list) else image_url + ) + try: + args = await self._request_required_tool_args( + prompt_path=MEME_JUDGE_PROMPT_PATH, + image_url=image_url, + tool_schema=MEME_JUDGE_TOOL, + tool_name="submit_meme_judgement", + call_type="vision_meme_judge", + max_tokens=self._vision_config.max_tokens, + ) + except Exception as exc: + logger.exception("[媒体分析] 表情包判定失败,按非表情包处理: %s", exc) + return { + "is_meme": False, + "confidence": 0.0, + "reason": "", + } + + try: + parsed = { + "is_meme": bool(args.get("is_meme", False)), + "confidence": safe_float(args.get("confidence", 0.0), default=0.0), + "reason": str(args.get("reason") or "").strip(), + } + except Exception: + parsed = {"is_meme": False, "confidence": 0.0, "reason": ""} + logger.info( + "[媒体分析] 表情包判定完成: url=%s is_meme=%s confidence=%.3f reason=%s", + safe_url[:50], + parsed.get("is_meme", False), + safe_float(parsed.get("confidence", 0.0), default=0.0), + str(parsed.get("reason", ""))[:80], + ) + return parsed + + async def describe_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: + safe_url = redact_string( + str(image_url) if isinstance(image_url, list) else image_url + ) + try: + args = await self._request_required_tool_args( + prompt_path=MEME_DESCRIBE_PROMPT_PATH, + image_url=image_url, + tool_schema=MEME_DESCRIBE_TOOL, + tool_name="submit_meme_description", + call_type="vision_meme_describe", + max_tokens=self._vision_config.max_tokens, + ) + except Exception as exc: + logger.exception("[媒体分析] 表情包描述失败: %s", exc) + return {"description": "", "tags": []} + + description = str(args.get("description") or "").strip() + tags = _normalize_meme_tags(args.get("tags")) + logger.info( + "[媒体分析] 表情包描述完成: url=%s desc_len=%s tags=%s", + safe_url[:50], + len(description), + tags, + ) + return {"description": description, "tags": tags} + + +__all__ = ["MultimodalAnalyzer"] diff --git a/src/Undefined/ai/multimodal/constants.py b/src/Undefined/ai/multimodal/constants.py new file mode 100644 index 00000000..6c074c00 --- /dev/null +++ b/src/Undefined/ai/multimodal/constants.py @@ -0,0 +1,138 @@ +"""多模态分析常量与工具 schema 定义。""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +# ===== 历史 Q&A 与磁盘缓存 ===== +# 每个文件名最多保留的历史 Q&A 条数 +_MAX_QA_HISTORY = 5 + +HISTORY_FILE_PATH = Path("data/media_qa_history.json") + +# ===== 远程 URL 媒体缓存策略 ===== +# 远程媒体缓存目录(用于先下载 URL 再转 data URL) +_MEDIA_URL_CACHE_DIR = Path("data/cache/multimodal_media") + +# 远程媒体缓存清理策略:仅保留最近 6 小时 + 最多 256 个文件。 +_MEDIA_URL_CACHE_TTL_SECONDS = 6 * 60 * 60 +_MEDIA_URL_CACHE_MAX_FILES = 256 + +# 两次自动清理之间的最小间隔(秒),避免每次请求都全量扫描目录。 +_MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS = 60.0 + +# 下载 URL 到本地缓存时的网络超时(秒)。 +_MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS = 120.0 + +# 下载阶段临时文件后缀(追加在缓存文件名后),用于区分真实缓存文件。 +_MEDIA_URL_DOWNLOAD_TMP_SUFFIX = ".downloading" + +# ===== 扩展名 / MIME / 错误文案映射 ===== +IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg") +AUDIO_EXTENSIONS = (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".aac", ".wma") +VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv", ".wmv") + +# MIME 类型前缀到媒体类型的映射 +MIME_PREFIX_TO_TYPE = { + "image/": "image", + "audio/": "audio", + "video/": "video", +} + +# 默认 MIME 类型映射 +DEFAULT_MIME_TYPES = { + "image": "image/jpeg", + "audio": "audio/mpeg", + "video": "video/mp4", +} + +MEDIA_TYPE_TO_FIELD = { + "image": "ocr_text", + "audio": "transcript", + "video": "subtitles", +} + +# ===== 表情包判定 / 描述工具 schema ===== +MEME_JUDGE_PROMPT_PATH = "res/prompts/judge_meme_image.txt" +MEME_DESCRIBE_PROMPT_PATH = "res/prompts/describe_meme_image.txt" + +MEME_JUDGE_TOOL: dict[str, Any] = { + "type": "function", + "function": { + "name": "submit_meme_judgement", + "description": "提交表情包判定结果", + "parameters": { + "type": "object", + "properties": { + "is_meme": { + "type": "boolean", + "description": "该图片是否适合进入表情包库", + }, + "confidence": { + "type": "number", + "description": "0 到 1 的置信度", + }, + "reason": { + "type": "string", + "description": "简短中文判定原因", + }, + }, + "required": ["is_meme", "confidence", "reason"], + }, + }, +} + +MEME_DESCRIBE_TOOL: dict[str, Any] = { + "type": "function", + "function": { + "name": "submit_meme_description", + "description": "提交表情包描述与标签", + "parameters": { + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "适合检索的简短中文描述", + }, + "tags": { + "type": "array", + "items": {"type": "string"}, + "description": "0 到 6 个短标签", + }, + }, + "required": ["description", "tags"], + }, + }, +} + +ERROR_MESSAGES = { + "read": { + "image": "[图片无法读取]", + "audio": "[音频无法读取]", + "video": "[视频无法读取]", + "default": "[媒体文件无法读取]", + }, + "analyze": { + "image": "[图片分析失败]", + "audio": "[音频分析失败]", + "video": "[视频分析失败]", + "default": "[媒体分析失败]", + }, +} + +__all__ = [ + "DEFAULT_MIME_TYPES", + "ERROR_MESSAGES", + "HISTORY_FILE_PATH", + "MAX_QA_HISTORY", + "MEDIA_TYPE_TO_FIELD", + "MEME_DESCRIBE_PROMPT_PATH", + "MEME_DESCRIBE_TOOL", + "MEME_JUDGE_PROMPT_PATH", + "MEME_JUDGE_TOOL", + "MIME_PREFIX_TO_TYPE", +] + +# 对外别名,供 analyzer 使用 +MAX_QA_HISTORY = _MAX_QA_HISTORY diff --git a/src/Undefined/ai/multimodal/detection.py b/src/Undefined/ai/multimodal/detection.py new file mode 100644 index 00000000..666bb9b3 --- /dev/null +++ b/src/Undefined/ai/multimodal/detection.py @@ -0,0 +1,101 @@ +"""媒体类型探测与 MIME 推断。""" + +from __future__ import annotations + +from Undefined.ai.multimodal.constants import ( + AUDIO_EXTENSIONS, + DEFAULT_MIME_TYPES, + IMAGE_EXTENSIONS, + MIME_PREFIX_TO_TYPE, + VIDEO_EXTENSIONS, +) + + +def _extract_mime_type_from_data_url(media_url: str) -> str | None: + """从 data URL 中提取 MIME 类型。 + + Args: + media_url: 媒体 URL + + Returns: + MIME 类型前缀(如 ``image/``)或 None + """ + if not media_url.startswith("data:"): + return None + mime_part = media_url.split(";")[0] + if ":" in mime_part: + return mime_part.split(":")[1] + return None + + +def _get_media_type_by_extension(url_lower: str) -> str: + """根据文件扩展名判断媒体类型。""" + for ext in IMAGE_EXTENSIONS: + if ext in url_lower: + return "image" + for ext in AUDIO_EXTENSIONS: + if ext in url_lower: + return "audio" + for ext in VIDEO_EXTENSIONS: + if ext in url_lower: + return "video" + return "image" + + +def detect_media_type(media_url: str, specified_type: str = "auto") -> str: + """检测媒体文件的类型(图片、音频或视频)。""" + if specified_type and specified_type != "auto": + return specified_type + + # data URL 的 MIME 优先于扩展名猜测 + media_type = _detect_from_data_url(media_url) + if media_type: + return media_type + + return _detect_by_mimetypes(media_url) + + +def _detect_from_data_url(media_url: str) -> str | None: + """从 data URL 的 MIME 类型中探测媒体类型。""" + mime = _extract_mime_type_from_data_url(media_url) + if mime: + for prefix, media_type in MIME_PREFIX_TO_TYPE.items(): + if mime.startswith(prefix): + return media_type + return None + + +def _detect_by_mimetypes(media_url: str) -> str: + """利用 mimetypes 库或扩展名探测媒体类型。""" + import mimetypes + + guessed_mime, _ = mimetypes.guess_type(media_url) + if guessed_mime: + for prefix, media_type in MIME_PREFIX_TO_TYPE.items(): + if guessed_mime.startswith(prefix): + return media_type + + return _get_media_type_by_extension(media_url.lower()) + + +def get_media_mime_type(media_type: str, file_path: str = "") -> str: + """获取媒体文件的 MIME 类型。 + + Args: + media_type: 媒体类型(``image``、``audio`` 或 ``video``) + file_path: 文件路径(可选),用于根据扩展名推断 MIME 类型 + + Returns: + MIME 类型字符串 + """ + if file_path: + import mimetypes + + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type: + return mime_type + + return DEFAULT_MIME_TYPES.get(media_type, "application/octet-stream") + + +__all__ = ["detect_media_type", "get_media_mime_type"] diff --git a/src/Undefined/ai/multimodal/parsing.py b/src/Undefined/ai/multimodal/parsing.py new file mode 100644 index 00000000..b5338078 --- /dev/null +++ b/src/Undefined/ai/multimodal/parsing.py @@ -0,0 +1,107 @@ +"""多模态模型响应解析工具。""" + +from __future__ import annotations + +import json +from typing import Any + +from Undefined.utils.coerce import safe_float + + +def _parse_line_value(line: str, prefix: str) -> str: + """解析行内容,提取指定前缀后的值。""" + value = line.split(":", 1)[-1].split(":", 1)[-1].strip() + return "" if value == "无" else value + + +def _parse_analysis_response(content: str) -> dict[str, str]: + """解析 AI 分析响应的内容。""" + field_prefixes = { + "description": ("描述:", "描述:"), + "ocr_text": ("OCR:", "OCR:"), + "transcript": ("转写:", "转写:"), + "subtitles": ("字幕:", "字幕:"), + } + + result = { + "description": "", + "ocr_text": "", + "transcript": "", + "subtitles": "", + } + + for line in content.split("\n"): + line = line.strip() + for field, prefixes in field_prefixes.items(): + if line.startswith(prefixes): + result[field] = _parse_line_value(line, prefixes[0]) + + if not result["description"]: + result["description"] = content + + return result + + +def _extract_json_object(content: str) -> dict[str, Any]: + text = str(content or "").strip() + if not text: + return {} + candidates = [text] + if "```" in text: + parts = text.split("```") + for part in parts: + stripped = part.strip() + if not stripped: + continue + if stripped.lower().startswith("json"): + stripped = stripped[4:].strip() + candidates.append(stripped) + for candidate in candidates: + try: + parsed = json.loads(candidate) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + return parsed + # 兜底:从文本中截取首尾花括号再解析 + start = text.find("{") + end = text.rfind("}") + if start >= 0 and end > start: + try: + parsed = json.loads(text[start : end + 1]) + except json.JSONDecodeError: + return {} + if isinstance(parsed, dict): + return parsed + return {} + + +def _normalize_meme_tags(tags_raw: Any) -> list[str]: + tags: list[str] = [] + if isinstance(tags_raw, list): + seen: set[str] = set() + for item in tags_raw: + text = str(item or "").strip() + lowered = text.lower() + if not text or lowered in seen: + continue + seen.add(lowered) + tags.append(text) + return tags + + +def _parse_meme_analysis_response(content: str) -> dict[str, Any]: + parsed = _extract_json_object(content) + return { + "is_meme": bool(parsed.get("is_meme", False)), + "confidence": safe_float(parsed.get("confidence", 0.0), default=0.0), + "description": str(parsed.get("description") or "").strip(), + "tags": _normalize_meme_tags(parsed.get("tags")), + } + + +__all__ = [ + "_normalize_meme_tags", + "_parse_analysis_response", + "_parse_meme_analysis_response", +] diff --git a/src/Undefined/ai/parsing.py b/src/Undefined/ai/parsing.py index b25f26e0..63aa0dc7 100644 --- a/src/Undefined/ai/parsing.py +++ b/src/Undefined/ai/parsing.py @@ -24,23 +24,18 @@ def _get_content_from_message(message: Any) -> str | None: def _extract_from_choice(choice: Any) -> str: """从单个选项结构中提取最终的文本内容""" - # 如果选项是字符串,直接返回 if isinstance(choice, str): return choice - # 如果选项不是字典,返回空字符串 if not isinstance(choice, dict): return "" - # 尝试从消息中获取 content message = choice.get("message") content = _get_content_from_message(message) - # 如果消息中没有 content,尝试从选项直接获取 if content is None: content = choice.get("content") - # 如果有 tool_calls 但没有 content,返回空字符串 if not content and choice.get("message", {}).get("tool_calls"): return "" @@ -67,13 +62,11 @@ def _find_first_choice(result: dict[str, Any]) -> dict[str, Any] | None: Returns: 第一个选项字典,未找到时返回 None """ - # 直接检查 choices 字段 if "choices" in result and result["choices"]: choice = result["choices"][0] if isinstance(choice, dict): return choice - # 检查 data.choices 字段 data = result.get("data") if isinstance(data, dict) and data.get("choices"): choice = data["choices"][0] @@ -125,12 +118,9 @@ def extract_choices_content(result: dict[str, Any]) -> str: if output_text: return output_text - # 查找第一个选项 choice = _find_first_choice(result) - # 如果没有找到选项,抛出错误 if choice is None: raise KeyError(_build_error_message(result)) - # 从选项中提取内容 return _extract_from_choice(choice) diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index d01904cb..963d8e51 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -110,13 +110,11 @@ def _build_model_config_info(self, runtime_config: Any) -> str: """ parts: list[str] = ["【当前运行环境配置】"] - # 主对话模型 chat_model = getattr(runtime_config, "chat_model", None) if chat_model: model_name = getattr(chat_model, "model_name", "未知") parts.append(f"- 我使用的模型: {model_name}") - # 视觉模型 vision_model = getattr(runtime_config, "vision_model", None) if vision_model: model_name = getattr(vision_model, "model_name", "") @@ -130,14 +128,12 @@ def _build_model_config_info(self, runtime_config: Any) -> str: if model_name: parts.append(f"- Agent 模型: {model_name}") - # 嵌入模型 embedding_model = getattr(runtime_config, "embedding_model", None) if embedding_model: model_name = getattr(embedding_model, "model_name", "") if model_name: parts.append(f"- 嵌入模型: {model_name}") - # 安全模型 security_model = getattr(runtime_config, "security_model", None) if security_model: model_name = getattr(security_model, "model_name", "") @@ -151,23 +147,19 @@ def _build_model_config_info(self, runtime_config: Any) -> str: if model_name: parts.append(f"- 搜索模型: {model_name}") - # 认知记忆 cognitive = getattr(runtime_config, "cognitive", None) if cognitive: enabled = getattr(cognitive, "enabled", False) parts.append(f"- 认知记忆: {'已启用' if enabled else '未启用'}") - # 知识库 knowledge_enabled = bool(getattr(runtime_config, "knowledge_enabled", False)) parts.append(f"- 知识库: {'已启用' if knowledge_enabled else '未启用'}") - # 联网搜索 grok_search_enabled = bool( getattr(runtime_config, "grok_search_enabled", False) ) parts.append(f"- 联网搜索: {'已启用' if grok_search_enabled else '未启用'}") - # 表情包库 memes = getattr(runtime_config, "memes", None) if memes is not None: memes_enabled = bool(getattr(memes, "enabled", False)) @@ -184,7 +176,6 @@ def _build_model_config_info(self, runtime_config: Any) -> str: else: parts.append("- 表情包库: 未启用") - # 模型池 if chat_model: pool = getattr(chat_model, "pool", None) if pool: @@ -195,7 +186,6 @@ def _build_model_config_info(self, runtime_config: Any) -> str: else: parts.append("- 模型池: 未启用") - # 思维链 if chat_model: thinking = getattr(chat_model, "thinking_enabled", False) reasoning = getattr(chat_model, "reasoning_enabled", False) @@ -204,7 +194,6 @@ def _build_model_config_info(self, runtime_config: Any) -> str: else: parts.append("- 思维链: 未启用") - # 彩蛋功能状态 keyword_reply_enabled = bool( getattr(runtime_config, "keyword_reply_enabled", False) ) diff --git a/src/Undefined/ai/prompts/__init__.py b/src/Undefined/ai/prompts/__init__.py new file mode 100644 index 00000000..f642ceed --- /dev/null +++ b/src/Undefined/ai/prompts/__init__.py @@ -0,0 +1,10 @@ +"""Prompt 构建子包。 + +对外稳定入口:``PromptBuilder``;旧路径 ``Undefined.ai.prompts`` 通过 shim 保持兼容。 +""" + +# 子包唯一公开类:PromptBuilder +from Undefined.ai.prompts.builder import PromptBuilder + +# 子包公开 API +__all__ = ["PromptBuilder"] diff --git a/src/Undefined/ai/prompts/builder.py b/src/Undefined/ai/prompts/builder.py new file mode 100644 index 00000000..3906e3a9 --- /dev/null +++ b/src/Undefined/ai/prompts/builder.py @@ -0,0 +1,599 @@ +"""Prompt 消息构建器。""" + +from __future__ import annotations + +import logging +from collections import deque +from datetime import datetime +from typing import Any, Awaitable, Callable, Literal + +import aiofiles + +from Undefined.context import RequestContext +from Undefined.end_summary_storage import ( + EndSummaryStorage, + EndSummaryRecord, + MAX_END_SUMMARIES, +) +from Undefined.memory import MemoryStorage +from Undefined.skills.anthropic_skills import AnthropicSkillRegistry +from Undefined.utils.coerce import safe_int +from Undefined.utils.logging import log_debug_json +from Undefined.utils.resources import read_text_resource +from Undefined.utils.xml import format_message_xml +from Undefined.ai.prompts.cognitive import ( + build_cognitive_query, + drop_current_message_if_duplicated, +) +from Undefined.ai.prompts.system_context import ( + build_model_config_info, + select_system_prompt_path, +) + +logger = logging.getLogger(__name__) + + +class PromptBuilder: + """Prompt 构建器。 + + 协调系统提示词、记忆、认知上下文与历史消息,产出 LLM messages 列表。 + """ + + def __init__( + self, + bot_qq: int, + memory_storage: MemoryStorage | None, + end_summary_storage: EndSummaryStorage, + system_prompt_path: str = "res/prompts/undefined.xml", + runtime_config_getter: Callable[[], Any] | None = None, + anthropic_skill_registry: AnthropicSkillRegistry | None = None, + cognitive_service: Any = None, + ) -> None: + """初始化 Prompt 构建器 + + 参数: + bot_qq: 机器人 QQ 号 + memory_storage: 长期记忆存储 (可选) + end_summary_storage: 短期回忆存储 + system_prompt_path: 系统提示词文件路径 + anthropic_skill_registry: Anthropic Skills 注册中心(可选) + """ + self._bot_qq = bot_qq + self._memory_storage = memory_storage + self._end_summary_storage = end_summary_storage + self._system_prompt_path = system_prompt_path + self._runtime_config_getter = runtime_config_getter + self._anthropic_skill_registry = anthropic_skill_registry + self._cognitive_service = cognitive_service + self._end_summaries: deque[EndSummaryRecord] = deque(maxlen=MAX_END_SUMMARIES) + self._summaries_loaded = False + + def set_cognitive_service(self, service: Any = None) -> None: + """更新认知记忆服务引用(支持运行时注入/替换)。""" + self._cognitive_service = service + logger.info( + "[Prompt] 认知服务引用已更新: enabled=%s", + bool(getattr(service, "enabled", False)) if service is not None else False, + ) + + def _build_cognitive_query( + self, question: str, extra_context: dict[str, Any] | None = None + ) -> tuple[str, bool]: + """兼容旧测试/调用方:委托至 cognitive.build_cognitive_query。""" + return build_cognitive_query(question, extra_context) + + def _build_model_config_info(self, runtime_config: Any) -> str: + """兼容旧测试/调用方:委托至 system_context.build_model_config_info。""" + return build_model_config_info(runtime_config) + + @property + def end_summaries(self) -> deque[EndSummaryRecord]: + """暴露短期摘要缓存,供工具执行上下文共享。""" + return self._end_summaries + + async def _ensure_summaries_loaded(self) -> None: + if not self._summaries_loaded: + loaded_summaries = await self._end_summary_storage.load() + self._end_summaries.extend(loaded_summaries) + self._summaries_loaded = True + logger.debug(f"[AI初始化] 已加载 {len(loaded_summaries)} 条 End 摘要") + + async def _load_each_rules(self) -> str: + path = "res/IMPORTANT/each.md" + try: + return read_text_resource(path) + except Exception: + pass + try: + async with aiofiles.open(path, "r", encoding="utf-8") as f: + return await f.read() + except Exception: + return "" + + async def _load_system_prompt(self) -> str: + system_prompt_path = select_system_prompt_path( + default_path=self._system_prompt_path, + runtime_config_getter=self._runtime_config_getter, + ) + try: + return read_text_resource(system_prompt_path) + except Exception as exc: + logger.debug("读取系统提示词失败,尝试本地路径: %s", exc) + async with aiofiles.open(system_prompt_path, "r", encoding="utf-8") as f: + return await f.read() + + async def build_messages( + self, + question: str, + get_recent_messages_callback: Callable[ + [str, str, int, int], Awaitable[list[dict[str, Any]]] + ] + | None = None, + extra_context: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """构建发送给 AI 的消息列表 + + 参数: + question: 当前用户消息 + get_recent_messages_callback: 获取历史消息的回调函数 + extra_context: 额外的上下文信息 (如 group_id, user_id) + + 返回: + 构建好的消息列表 (role/content 结构) + """ + system_prompt = await self._load_system_prompt() + logger.debug( + "[Prompt] system_prompt_len=%s path=%s", + len(system_prompt), + select_system_prompt_path( + default_path=self._system_prompt_path, + runtime_config_getter=self._runtime_config_getter, + ), + ) + + if self._bot_qq != 0: + bot_qq_info = ( + f"\n" + f"\n\n" + ) + system_prompt = bot_qq_info + system_prompt + + messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] + + # 注入当前运行环境配置信息,让 AI 知道自己的模型名称等非隐私信息 + if self._runtime_config_getter is not None: + try: + runtime_config = self._runtime_config_getter() + config_info = build_model_config_info(runtime_config) + if config_info: + messages.append( + { + "role": "system", + "content": config_info, + } + ) + logger.debug( + "[Prompt] 已注入运行环境配置信息,长度=%s", + len(config_info), + ) + except Exception as exc: + logger.debug("读取运行环境配置失败: %s", exc) + + # 注入群聊关键词自动回复机制说明,避免模型误判历史中的系统彩蛋消息。 + is_group_context = False + ctx = RequestContext.current() + if ctx and ctx.group_id is not None: + is_group_context = True + elif extra_context and extra_context.get("group_id") is not None: + is_group_context = True + + keyword_reply_enabled = False + repeat_enabled = False + repeat_threshold = 3 + inverted_question_enabled = False + if self._runtime_config_getter is not None: + try: + runtime_config = self._runtime_config_getter() + keyword_reply_enabled = bool( + getattr(runtime_config, "keyword_reply_enabled", False) + ) + repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + repeat_threshold = int(getattr(runtime_config, "repeat_threshold", 3)) + inverted_question_enabled = bool( + getattr(runtime_config, "inverted_question_enabled", False) + ) + except Exception as exc: + logger.debug("读取彩蛋功能配置失败: %s", exc) + + if is_group_context and keyword_reply_enabled: + messages.append( + { + "role": "system", + "content": ( + "【系统行为说明 — 关键词自动回复】\n" + '当前群聊已开启关键词自动回复彩蛋(例如触发词"心理委员")。' + "该功能由 handlers.py 中的独立代码路径处理," + "在消息到达你之前就已完成发送。\n\n" + '发送后,历史中会出现以"[系统关键词自动回复] "开头的消息。' + "这些消息完全由系统代码生成(固定文案如'受着''那咋了'等)," + "不经过你的工具调用,与你的决策无关。\n\n" + "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" + "除非用户主动询问,否则不要主动解释此机制。" + ), + } + ) + + if is_group_context and repeat_enabled: + repeat_desc = ( + "【系统行为说明】\n" + f"当前群聊已开启复读彩蛋:当群聊中连续出现{repeat_threshold}条内容相同且来自不同人的消息时," + "系统会自动复读一条相同的消息,并在历史中写入" + '以"[系统复读] "开头的消息。' + ) + if inverted_question_enabled: + repeat_desc += ( + "\n此外,若复读触发时消息内容仅由问号组成(如?或???)," + "系统会发送对应数量的倒问号(¿)代替。" + ) + repeat_desc += ( + "\n\n这类消息属于系统预设机制,不代表你在该轮主动决策。" + "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" + "除非用户主动询问,否则不要主动解释此机制。" + ) + messages.append({"role": "system", "content": repeat_desc}) + + # 注入 Anthropic Skills 元数据(Level 1: 始终加载 name + description) + if ( + self._anthropic_skill_registry + and self._anthropic_skill_registry.has_skills() + ): + skills_xml = self._anthropic_skill_registry.build_metadata_xml() + if skills_xml: + messages.append( + { + "role": "system", + "content": ( + "【可用的 Anthropic Skills】\n" + f"{skills_xml}\n\n" + "注意:以上是可用的 Anthropic Agent Skills 列表。" + "当用户的请求与某个 skill 相关时," + "你可以调用对应的 skill tool(tool_name 字段)" + "来获取该领域的详细指令和知识。" + ), + } + ) + logger.debug( + "[Prompt] 已注入 %d 个 Anthropic Skills 元数据", + len(self._anthropic_skill_registry.get_all_skills()), + ) + + each_rules = await self._load_each_rules() + if each_rules: + messages.append( + { + "role": "system", + "content": f"【强制规则 - 必须在进行任何操作前仔细阅读并严格遵守】\n{each_rules}", + } + ) + + deferred_messages: list[dict[str, Any]] = [] + # 长期记忆 / 认知 / end 摘要 / 历史等延迟注入块(排在主 system 之后) + + if self._memory_storage: + memories = self._memory_storage.get_all() + if memories: + memory_lines = [f"- {mem.fact}" for mem in memories] + memory_text = "\n".join(memory_lines) + deferred_messages.append( + { + "role": "system", + "content": ( + "【memory.* 手动长期记忆(可编辑)】\n" + f"{memory_text}\n\n" + "注意:以上是你通过 memory.add 等工具主动维护的长期事实清单。" + "它与认知记忆(cognitive.* / end.observations 产生的事件与侧写)是两套机制。" + "请根据任务选择合适的记忆工具,避免混用。" + ), + } + ) + logger.info(f"[AI会话] 已注入 {len(memories)} 条长期记忆") + if logger.isEnabledFor(logging.DEBUG): + log_debug_json( + logger, "[AI会话] 注入长期记忆", [mem.fact for mem in memories] + ) + + await self._ensure_summaries_loaded() + if self._cognitive_service and getattr( + self._cognitive_service, "enabled", False + ): + recent_action_inject_k = 30 + if self._runtime_config_getter is not None: + try: + runtime_config = self._runtime_config_getter() + cog_cfg = getattr(runtime_config, "cognitive", None) + if cog_cfg is not None and hasattr( + cog_cfg, "recent_end_summaries_inject_k" + ): + recent_action_inject_k = int( + getattr(cog_cfg, "recent_end_summaries_inject_k") + ) + except Exception: + pass + if recent_action_inject_k < 0: + recent_action_inject_k = 0 + + ctx = RequestContext.current() + resolved_group_id = ( + str(ctx.group_id) + if ctx and ctx.group_id is not None + else (str(extra_context.get("group_id", "")) if extra_context else None) + ) + resolved_user_id = ( + str(ctx.user_id) + if ctx and ctx.user_id is not None + else (str(extra_context.get("user_id", "")) if extra_context else None) + ) + resolved_sender_id = ( + str(ctx.sender_id) + if ctx and ctx.sender_id is not None + else ( + str(extra_context.get("sender_id", "")) if extra_context else None + ) + ) + resolved_request_type = ( + str(ctx.request_type).strip() + if ctx and ctx.request_type + else ( + str(extra_context.get("request_type", "")).strip() + if extra_context + else "" + ) + ) + if not resolved_request_type: + if resolved_group_id and str(resolved_group_id).strip(): + resolved_request_type = "group" + elif resolved_sender_id or resolved_user_id: + resolved_request_type = "private" + cognitive_query, query_enhanced = build_cognitive_query( + question, extra_context + ) + logger.info( + "[AI会话] 开始自动检索认知记忆: raw_query_len=%s effective_query_len=%s query_enhanced=%s type=%s group=%s user=%s sender=%s", + len(question), + len(cognitive_query), + query_enhanced, + resolved_request_type or "", + resolved_group_id or "", + resolved_user_id or "", + resolved_sender_id or "", + ) + cognitive_context = await self._cognitive_service.build_context( + query=cognitive_query, + group_id=resolved_group_id, + user_id=resolved_user_id, + sender_id=resolved_sender_id, + sender_name=str(extra_context.get("sender_name", "")) + if extra_context + else None, + group_name=str(extra_context.get("group_name", "")) + if extra_context + else None, + request_type=resolved_request_type or None, + ) + if cognitive_context: + deferred_messages.append( + {"role": "system", "content": cognitive_context} + ) + logger.info( + "[AI会话] 已注入认知记忆上下文: context_len=%s", + len(cognitive_context), + ) + else: + logger.info("[AI会话] 自动检索完成:未命中可注入认知记忆") + + # 额外注入最近 end 行动记录,作为短期“工作记忆”,弥补史官异步入库延迟与向量检索的漏召回。 + if recent_action_inject_k > 0 and self._end_summaries: + items = list(self._end_summaries)[-recent_action_inject_k:] + recent_summary_lines: list[str] = [] + for item in items: + location_text = "" + location = item.get("location") + if isinstance(location, dict): + location_type = location.get("type") + location_name = location.get("name") + if ( + location_type in {"private", "group"} + and isinstance(location_name, str) + and location_name.strip() + ): + location_text = ( + f" ({location_type}: {location_name.strip()})" + ) + recent_summary_lines.append( + f"- [{item.get('timestamp', '')}] {item.get('summary', '')}{location_text}" + ) + recent_summary_text = "\n".join(recent_summary_lines).strip() + if recent_summary_text: + deferred_messages.append( + { + "role": "system", + "content": ( + f"【短期行动记录(最近 {len(items)} 条,带时间)】\n" + f"{recent_summary_text}\n\n" + "注意:以上是你最近在 end 时记录的行动摘要,用于保持短期连续性。" + "它可能与认知记忆事件存在重复;优先以更具体、更近期的描述为准。" + ), + } + ) + elif self._end_summaries: + summary_lines: list[str] = [] + for item in self._end_summaries: + location_text = "" + location = item.get("location") + if isinstance(location, dict): + location_type = location.get("type") + location_name = location.get("name") + if ( + location_type in {"private", "group"} + and isinstance(location_name, str) + and location_name.strip() + ): + location_text = f" ({location_type}: {location_name.strip()})" + summary_lines.append( + f"- [{item['timestamp']}] {item['summary']}{location_text}" + ) + summary_text = "\n".join(summary_lines) + deferred_messages.append( + { + "role": "system", + "content": ( + "【这是你之前end时记录的事情】\n" + f"{summary_text}\n\n" + "注意:以上是你之前在end时记录的事情,用于帮助你记住之前做了什么或以后可能要做什么。" + ), + } + ) + logger.info( + f"[AI会话] 已注入 {len(self._end_summaries)} 条短期回忆 (end 摘要)" + ) + if logger.isEnabledFor(logging.DEBUG): + log_debug_json( + logger, "[AI会话] 注入短期回忆", list(self._end_summaries) + ) + + if get_recent_messages_callback: + await self._inject_recent_messages( + deferred_messages, get_recent_messages_callback, extra_context, question + ) + + # 记忆/认知/历史等上下文统一排在主 system 之后、当前消息之前 + messages.extend(deferred_messages) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + messages.append( + { + "role": "system", + "content": f"【当前时间】\n{current_time}\n\n注意:以上是当前的系统时间,供你参考。", + } + ) + + messages.append({"role": "user", "content": f"【当前消息】\n{question}"}) + logger.debug( + "[Prompt] messages_ready=%s question_len=%s", + len(messages), + len(question), + ) + return messages + + def _resolve_chat_scope( + self, extra_context: dict[str, Any] | None + ) -> tuple[Literal["group", "private"], int] | None: + ctx = RequestContext.current() + + # 解析顺序:RequestContext 会话类型 > extra_context 回退 + if ctx and ctx.request_type == "group" and ctx.group_id is not None: + group_id = safe_int(ctx.group_id) + if group_id is not None: + return ("group", group_id) + return None + if ctx and ctx.request_type == "private" and ctx.user_id is not None: + user_id = safe_int(ctx.user_id) + if user_id is not None: + return ("private", user_id) + return None + + if extra_context and extra_context.get("group_id") is not None: + group_id = safe_int(extra_context.get("group_id")) + if group_id is not None: + return ("group", group_id) + return None + if extra_context and extra_context.get("user_id") is not None: + user_id = safe_int(extra_context.get("user_id")) + if user_id is not None: + return ("private", user_id) + return None + + return None + + async def _inject_recent_messages( + self, + messages: list[dict[str, Any]], + get_recent_messages_callback: Callable[ + [str, str, int, int], Awaitable[list[dict[str, Any]]] + ], + extra_context: dict[str, Any] | None, + question: str, + ) -> None: + try: + ctx = RequestContext.current() + if ctx: + group_id_from_ctx = ctx.group_id + user_id_from_ctx = ctx.user_id + elif extra_context: + group_id_from_ctx = extra_context.get("group_id") + user_id_from_ctx = extra_context.get("user_id") + else: + group_id_from_ctx = None + user_id_from_ctx = None + + if group_id_from_ctx is not None: + chat_id = str(group_id_from_ctx) + msg_type = "group" + elif user_id_from_ctx is not None: + chat_id = str(user_id_from_ctx) + msg_type = "private" + else: + chat_id = "" + msg_type = "group" + + recent_limit = 20 + if self._runtime_config_getter is not None: + try: + runtime_config = self._runtime_config_getter() + if hasattr(runtime_config, "get_context_recent_messages_limit"): + recent_limit = int( + runtime_config.get_context_recent_messages_limit() + ) + except Exception as exc: + logger.debug("读取上下文历史条数配置失败: %s", exc) + + if recent_limit < 0: + recent_limit = 0 + if recent_limit == 0: + logger.debug("上下文历史消息注入已关闭 (limit=0)") + return + + recent_msgs = await get_recent_messages_callback( + chat_id, + msg_type, + 0, + recent_limit, + ) + recent_msgs = drop_current_message_if_duplicated(recent_msgs, question) + context_lines: list[str] = [format_message_xml(msg) for msg in recent_msgs] + + formatted_context = "\n---\n".join(context_lines) + + if formatted_context: + messages.append( + { + "role": "user", + "content": ( + "【历史消息存档】\n" + f"{formatted_context}\n\n" + "注意:以上是之前的聊天记录,用于提供背景信息。每个消息之间使用 --- 分隔。接下来的用户消息才是当前正在发生的对话。" + ), + } + ) + logger.debug(f"自动预获取了 {len(context_lines)} 条历史消息作为上下文") + if logger.isEnabledFor(logging.DEBUG): + log_debug_json( + logger, + "[Prompt] 历史消息上下文", + context_lines, + ) + except Exception as exc: + logger.warning(f"自动获取历史消息失败: {exc}") + + +__all__ = ["PromptBuilder"] diff --git a/src/Undefined/ai/prompts/cognitive.py b/src/Undefined/ai/prompts/cognitive.py new file mode 100644 index 00000000..bd799ee0 --- /dev/null +++ b/src/Undefined/ai/prompts/cognitive.py @@ -0,0 +1,137 @@ +"""认知记忆检索查询构建辅助。""" + +from __future__ import annotations + +import html +import logging +from typing import Any + +from Undefined.ai.prompts.constants import ( + COGNITIVE_CONTEXT_VALUE_MAX_LEN, + COGNITIVE_QUERY_SHORT_THRESHOLD, + CURRENT_MESSAGE_RE, + XML_ATTR_RE, +) + +logger = logging.getLogger(__name__) + + +def normalize_cognitive_context_value(value: Any) -> str: + """压缩过长的上下文字段,避免污染检索 query。""" + text = " ".join(str(value or "").split()).strip() + if len(text) <= COGNITIVE_CONTEXT_VALUE_MAX_LEN: + return text + return text[: COGNITIVE_CONTEXT_VALUE_MAX_LEN - 3].rstrip() + "..." + + +def extract_current_message_signature(question: str) -> dict[str, str]: + """从当前消息 XML 中提取 sender/time/content 签名。""" + matched = CURRENT_MESSAGE_RE.search(str(question or "")) + if not matched: + return {} + + attrs_text = str(matched.group("attrs") or "") + attrs: dict[str, str] = {} + for attr_match in XML_ATTR_RE.finditer(attrs_text): + key = str(attr_match.group("key") or "").strip() + if not key: + continue + attrs[key] = html.unescape(str(attr_match.group("value") or "")).strip() + + content = html.unescape(str(matched.group("content") or "")).strip() + return { + "sender_id": attrs.get("sender_id", ""), + "timestamp": attrs.get("time", ""), + "content": content, + } + + +def build_cognitive_query( + question: str, extra_context: dict[str, Any] | None = None +) -> tuple[str, bool]: + """构建认知记忆检索 query,短消息时追加少量会话语境。""" + question_text = str(question or "").strip() + signature = extract_current_message_signature(question_text) + current_content = str(signature.get("content", "")).strip() + base_query = current_content or question_text + if not base_query: + return "", False + + if not current_content or len(current_content) > COGNITIVE_QUERY_SHORT_THRESHOLD: + return base_query, False + + # 短消息检索质量差,追加轻量会话语境提升向量召回 + context_parts: list[str] = [] + if extra_context: + if bool(extra_context.get("is_private_chat", False)): + context_parts.append("会话:私聊") + elif str(extra_context.get("group_id", "")).strip(): + context_parts.append("会话:群聊") + if bool(extra_context.get("is_at_bot", False)): + context_parts.append("触发:@机器人") + + sender_name = normalize_cognitive_context_value( + extra_context.get("sender_name", "") + ) + if sender_name: + context_parts.append(f"发送者:{sender_name}") + + group_name = normalize_cognitive_context_value( + extra_context.get("group_name", "") + ) + if group_name: + context_parts.append(f"群:{group_name}") + + if not context_parts: + return base_query, False + return f"{base_query}\n语境: {'; '.join(context_parts)}", True + + +def drop_current_message_if_duplicated( + recent_msgs: list[dict[str, Any]], question: str +) -> list[dict[str, Any]]: + """若历史末尾与当前帧重复,则剔除最后一条避免双重注入。""" + if not recent_msgs: + return recent_msgs + + signature = extract_current_message_signature(question) + if not signature: + return recent_msgs + + last_msg = recent_msgs[-1] + last_sender_id = str(last_msg.get("user_id", "")).strip() + last_timestamp = str(last_msg.get("timestamp", "")).strip() + last_content = str(last_msg.get("message", "")).strip() + + sig_sender_id = str(signature.get("sender_id", "")).strip() + sig_timestamp = str(signature.get("timestamp", "")).strip() + sig_content = str(signature.get("content", "")).strip() + if not sig_sender_id or not sig_content: + return recent_msgs + + if last_sender_id != sig_sender_id: + return recent_msgs + if last_content != sig_content: + return recent_msgs + + if sig_timestamp and last_timestamp and sig_timestamp != last_timestamp: + # 秒级时间戳不一致时,比较到分钟粒度,避免格式差异误杀 + if sig_timestamp[:16] != last_timestamp[:16]: + return recent_msgs + + logger.info( + "[Prompt] 历史注入剔除当前帧: sender=%s sig_time=%s history_time=%s content_preview=%s", + sig_sender_id, + sig_timestamp, + last_timestamp, + sig_content[:60], + ) + return recent_msgs[:-1] + + +__all__ = [ + "build_cognitive_query", + "drop_current_message_if_duplicated", + "extract_current_message_signature", + "normalize_cognitive_context_value", +] diff --git a/src/Undefined/ai/prompts/constants.py b/src/Undefined/ai/prompts/constants.py new file mode 100644 index 00000000..81f6bd8b --- /dev/null +++ b/src/Undefined/ai/prompts/constants.py @@ -0,0 +1,20 @@ +"""Prompt 构建相关常量与正则。""" + +from __future__ import annotations + +import re + +CURRENT_MESSAGE_RE = re.compile( + r"[^>]*)>.*?(?P.*?).*?", + re.DOTALL | re.IGNORECASE, +) +XML_ATTR_RE = re.compile(r'(?P[a-zA-Z_][a-zA-Z0-9_-]*)="(?P[^"]*)"') +COGNITIVE_QUERY_SHORT_THRESHOLD = 20 # 低于此长度视为短 query,追加语境 +COGNITIVE_CONTEXT_VALUE_MAX_LEN = 18 # 注入检索 query 的单字段上限 + +__all__ = [ + "COGNITIVE_CONTEXT_VALUE_MAX_LEN", + "COGNITIVE_QUERY_SHORT_THRESHOLD", + "CURRENT_MESSAGE_RE", + "XML_ATTR_RE", +] diff --git a/src/Undefined/ai/prompts/system_context.py b/src/Undefined/ai/prompts/system_context.py new file mode 100644 index 00000000..e1ac02e9 --- /dev/null +++ b/src/Undefined/ai/prompts/system_context.py @@ -0,0 +1,165 @@ +"""系统提示词选择与运行环境配置注入。""" + +from __future__ import annotations + +from typing import Any + + +def select_system_prompt_path( + *, + default_path: str, + runtime_config_getter: Any | None, +) -> str: + """根据运行时配置选择系统提示词路径。""" + if runtime_config_getter is None: + return default_path + + runtime_config = None + try: + runtime_config = runtime_config_getter() + except Exception: + runtime_config = None + + enabled = bool(getattr(runtime_config, "nagaagent_mode_enabled", False)) + # NagaAgent 模式切换专用系统提示词模板 + if enabled: + return "res/prompts/undefined_nagaagent.xml" + return "res/prompts/undefined.xml" + + +def build_model_config_info(runtime_config: Any) -> str: + """构建模型配置信息,用于注入到 AI 上下文中。 + + 只暴露非隐私字段(model_name 等),不暴露 api_key、api_url 等敏感信息。 + """ + parts: list[str] = ["【当前运行环境配置】"] + + chat_model = getattr(runtime_config, "chat_model", None) + if chat_model: + model_name = getattr(chat_model, "model_name", "未知") + parts.append(f"- 我使用的模型: {model_name}") + + vision_model = getattr(runtime_config, "vision_model", None) + if vision_model: + model_name = getattr(vision_model, "model_name", "") + if model_name: + parts.append(f"- 视觉模型: {model_name}") + + # Agent 模型 + agent_model = getattr(runtime_config, "agent_model", None) + if agent_model: + model_name = getattr(agent_model, "model_name", "") + if model_name: + parts.append(f"- Agent 模型: {model_name}") + + embedding_model = getattr(runtime_config, "embedding_model", None) + if embedding_model: + model_name = getattr(embedding_model, "model_name", "") + if model_name: + parts.append(f"- 嵌入模型: {model_name}") + + security_model = getattr(runtime_config, "security_model", None) + if security_model: + model_name = getattr(security_model, "model_name", "") + if model_name: + parts.append(f"- 安全模型: {model_name}") + + # Grok 搜索模型 + grok_model = getattr(runtime_config, "grok_model", None) + if grok_model: + model_name = getattr(grok_model, "model_name", "") + if model_name: + parts.append(f"- 搜索模型: {model_name}") + + cognitive = getattr(runtime_config, "cognitive", None) + if cognitive: + enabled = getattr(cognitive, "enabled", False) + parts.append(f"- 认知记忆: {'已启用' if enabled else '未启用'}") + + knowledge_enabled = bool(getattr(runtime_config, "knowledge_enabled", False)) + parts.append(f"- 知识库: {'已启用' if knowledge_enabled else '未启用'}") + + grok_search_enabled = bool(getattr(runtime_config, "grok_search_enabled", False)) + parts.append(f"- 联网搜索: {'已启用' if grok_search_enabled else '未启用'}") + + memes = getattr(runtime_config, "memes", None) + if memes is not None: + memes_enabled = bool(getattr(memes, "enabled", False)) + if memes_enabled: + query_mode = str( + getattr(memes, "query_default_mode", "hybrid") or "hybrid" + ).strip() + allow_gif = bool(getattr(memes, "allow_gif", True)) + max_source_bytes = int(getattr(memes, "max_source_image_bytes", 0) or 0) + max_source_kb = max_source_bytes // 1024 if max_source_bytes > 0 else 0 + parts.append( + f"- 表情包库: 已启用(默认检索={query_mode},GIF={'允许' if allow_gif else '禁用'},入库上限={max_source_kb}KB)" + ) + else: + parts.append("- 表情包库: 未启用") + + if chat_model: + pool = getattr(chat_model, "pool", None) + if pool: + pool_enabled = getattr(pool, "enabled", False) + if pool_enabled: + strategy = getattr(pool, "strategy", "default") + parts.append(f"- 模型池: 已启用({strategy})") + else: + parts.append("- 模型池: 未启用") + + if chat_model: + thinking = getattr(chat_model, "thinking_enabled", False) + reasoning = getattr(chat_model, "reasoning_enabled", False) + if thinking or reasoning: + parts.append("- 思维链: 已启用") + else: + parts.append("- 思维链: 未启用") + + keyword_reply_enabled = bool( + getattr(runtime_config, "keyword_reply_enabled", False) + ) + repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + inverted_question_enabled = bool( + getattr(runtime_config, "inverted_question_enabled", False) + ) + agent_call_mode = str( + getattr(runtime_config, "easter_egg_agent_call_message_mode", "none") + ) + easter_egg_parts: list[str] = [] + if keyword_reply_enabled: + easter_egg_parts.append( + '关键词自动回复(触发词"心理委员"等,系统自动发送固定回复)' + ) + if repeat_enabled: + threshold = int(getattr(runtime_config, "repeat_threshold", 3)) + desc = f"复读(群聊连续{threshold}条相同消息时自动复读)" + if inverted_question_enabled: + desc += ",倒问号(复读触发时若消息为问号则发送¿)" + easter_egg_parts.append(desc) + elif inverted_question_enabled: + easter_egg_parts.append("倒问号(复读未启用,此功能不生效)") + if agent_call_mode != "none": + mode_desc = { + "agent": "Agent调用提示", + "tools": "工具调用提示", + "clean": "降噪调用提示", + "all": "全量调用提示", + }.get(agent_call_mode, agent_call_mode) + easter_egg_parts.append(f"调用提示模式={mode_desc}") + if easter_egg_parts: + parts.append("- 彩蛋功能: " + ";".join(easter_egg_parts)) + else: + parts.append("- 彩蛋功能: 未启用") + + parts.append("") + parts.append( + "重要:以上是你的模型配置信息。\n" + "当你需要描述自己是谁、使用什么模型、能力或限制时,\n" + "必须以上述配置为准,忽略你训练数据、长期及认知记忆中的任何冲突信息。" + ) + + return "\n".join(parts) + + +__all__ = ["build_model_config_info", "select_system_prompt_path"] diff --git a/src/Undefined/ai/transports/openai_transport.py b/src/Undefined/ai/transports/openai_transport.py index 7f08cf85..33919d6c 100644 --- a/src/Undefined/ai/transports/openai_transport.py +++ b/src/Undefined/ai/transports/openai_transport.py @@ -392,7 +392,6 @@ def _copy_responses_output_items( call_id = str(cloned.get("call_id") or "").strip() # Some compatibility gateways incorrectly mirror the model's call_id into # function_call.id. OpenAI accepts id as optional, but when present it must - # be the item id generated by the model (typically fc_*), not call_*. if item_id and not item_id.startswith("fc"): if not call_id and item_id.startswith("call"): cloned["call_id"] = item_id diff --git a/src/Undefined/api/routes/naga/__init__.py b/src/Undefined/api/routes/naga/__init__.py new file mode 100644 index 00000000..f6642219 --- /dev/null +++ b/src/Undefined/api/routes/naga/__init__.py @@ -0,0 +1,21 @@ +"""Naga integration route handlers.""" + +# 同时 re-export 渲染 helper,供 send 路由生成 HTML/Markdown 卡片。 +from Undefined.render import render_html_to_image, render_markdown_to_html +from Undefined.api.routes.naga.auth import verify_naga_api_key +from Undefined.api.routes.naga.bind import naga_bind_callback_handler +from Undefined.api.routes.naga.send import ( + naga_messages_send_handler, + naga_messages_send_impl, +) +from Undefined.api.routes.naga.unbind import naga_unbind_handler + +__all__ = [ + "render_html_to_image", + "render_markdown_to_html", + "verify_naga_api_key", + "naga_bind_callback_handler", + "naga_messages_send_handler", + "naga_messages_send_impl", + "naga_unbind_handler", +] diff --git a/src/Undefined/api/routes/naga/auth.py b/src/Undefined/api/routes/naga/auth.py new file mode 100644 index 00000000..1318fc6c --- /dev/null +++ b/src/Undefined/api/routes/naga/auth.py @@ -0,0 +1,30 @@ +"""Naga API 鉴权辅助。""" + +from __future__ import annotations + +import logging + +from aiohttp import web + +from Undefined.api._context import RuntimeAPIContext + +logger = logging.getLogger(__name__) + + +# 校验 Naga 共享密钥,返回错误信息或 ``None`` 表示通过 +def verify_naga_api_key(ctx: RuntimeAPIContext, request: web.Request) -> str | None: + """校验 Naga 共享密钥,返回错误信息或 ``None`` 表示通过。""" + import secrets as _secrets + + cfg = ctx.config_getter() + expected = cfg.naga.api_key + if not expected: + return "naga api_key not configured" + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return "missing or invalid Authorization header" + provided = auth_header[7:] + # 常量时间比较,避免时序侧信道泄露密钥。 + if not _secrets.compare_digest(provided, expected): + return "invalid api_key" + return None diff --git a/src/Undefined/api/routes/naga/bind.py b/src/Undefined/api/routes/naga/bind.py new file mode 100644 index 00000000..d6f4d756 --- /dev/null +++ b/src/Undefined/api/routes/naga/bind.py @@ -0,0 +1,159 @@ +"""Naga 绑定回调路由。""" + +from __future__ import annotations + +import logging +import uuid as _uuid + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _json_error, + _short_text_preview, +) +from Undefined.api.routes.naga.auth import verify_naga_api_key + +logger = logging.getLogger(__name__) + +# ------------------------------------------------------------------ +# POST /api/v1/naga/bind/callback +# ------------------------------------------------------------------ + +# ------------------------------------------------------------------ + + +# POST /api/v1/naga/bind/callback — Naga 绑定回调 +async def naga_bind_callback_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + status = str(body.get("status", "") or "").strip().lower() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + reason = str(body.get("reason", "") or "").strip() + if not bind_uuid or not naga_id: + return _json_error("bind_uuid and naga_id are required", status=400) + if status not in {"approved", "rejected"}: + return _json_error("status must be 'approved' or 'rejected'", status=400) + logger.info( + "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + status, + _short_text_preview(reason, limit=60), + delivery_signature[:12] + "..." if delivery_signature else "", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + sender = ctx.sender + if status == "approved": + if not delivery_signature: + return _json_error( + "delivery_signature is required when approved", status=400 + ) + # 激活绑定:写入 delivery_signature 并移出 pending 队列。 + binding, created, err = await naga_store.activate_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + delivery_signature=delivery_signature, + ) + if err: + logger.warning( + "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + created, + binding.qq_id if binding is not None else "", + ) + if created and binding is not None and sender is not None: + try: + await sender.send_private_message( + binding.qq_id, + f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "approved", + "idempotent": not created, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + # --- rejected --- + pending, removed, err = await naga_store.reject_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + reason=reason, + ) + if err: + logger.warning( + "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + removed, + pending.qq_id if pending is not None else "", + ) + if removed and pending is not None and sender is not None: + try: + detail = f"\n原因: {reason}" if reason else "" + await sender.send_private_message( + pending.qq_id, + f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "rejected", + "idempotent": not removed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) diff --git a/src/Undefined/api/routes/naga/send.py b/src/Undefined/api/routes/naga/send.py new file mode 100644 index 00000000..b14e3e54 --- /dev/null +++ b/src/Undefined/api/routes/naga/send.py @@ -0,0 +1,665 @@ +"""Naga 消息发送路由与实现。""" + +from __future__ import annotations + +import logging +import os +import uuid as _uuid +from copy import deepcopy +from pathlib import Path +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _json_error, + _naga_message_digest, + _parse_response_payload, + _short_text_preview, +) +from Undefined.api._naga_state import NagaState + +from Undefined.api.routes.naga.auth import verify_naga_api_key + +logger = logging.getLogger(__name__) + +# ------------------------------------------------------------------ +# POST /api/v1/naga/messages/send +# ------------------------------------------------------------------ + +# ------------------------------------------------------------------ + + +# POST /api/v1/naga/messages/send — 验签后发送消息 +async def naga_messages_send_handler( + ctx: RuntimeAPIContext, + naga_state: NagaState, + request: web.Request, +) -> Response: + """POST /api/v1/naga/messages/send — 验签后发送消息。""" + from Undefined.api.naga_store import mask_token + + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + request_uuid = str(body.get("uuid", "") or "").strip() + target = body.get("target") + message = body.get("message") + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + if not isinstance(target, dict): + return _json_error("target object is required", status=400) + if not isinstance(message, dict): + return _json_error("message object is required", status=400) + + raw_target_qq = target.get("qq_id") + raw_target_group = target.get("group_id") + if raw_target_qq is None or raw_target_group is None: + return _json_error("target.qq_id and target.group_id are required", status=400) + try: + target_qq = int(raw_target_qq) + target_group = int(raw_target_group) + except Exception: + return _json_error( + "target.qq_id and target.group_id must be integers", status=400 + ) + mode = str(target.get("mode", "") or "").strip().lower() + if mode not in {"private", "group", "both"}: + return _json_error( + # "target.mode must be 'private', 'group', or 'both'", status=... + "target.mode must be 'private', 'group', or 'both'", + status=400, + ) + + fmt = str(message.get("format", "text") or "text").strip().lower() + content = str(message.get("content", "") or "").strip() + if fmt not in {"text", "markdown", "html"}: + return _json_error( + "message.format must be 'text', 'markdown', or 'html'", status=400 + ) + if not content: + return _json_error("message.content is required", status=400) + + message_key = _naga_message_digest( + bind_uuid=bind_uuid, + naga_id=naga_id, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + ) + # message_key 用于并发计数与 request_uuid 幂等,相同 payload 共享同一键。 + logger.info( + "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + request_uuid, + mode, + fmt, + target_qq, + target_group, + message_key, + len(content), + _short_text_preview(content), + mask_token(delivery_signature), + ) + if mode == "both": + logger.warning( + "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + inflight_count = await naga_state.track_send_start(message_key) + if inflight_count > 1: + logger.warning( + "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + inflight_count, + ) + try: + if request_uuid: + # 可选 uuid 启用幂等:冲突/缓存/等待/owner 四态由 NagaState 协调。 + dedupe_action, dedupe_value = await naga_state.register_request_uuid( + request_uuid, message_key + ) + if dedupe_action == "conflict": + logger.warning( + "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return _json_error("uuid reused with different payload", status=409) + if dedupe_action == "cached": + cached_status, cached_payload = dedupe_value + logger.warning( + "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + if dedupe_action == "await": + wait_future = dedupe_value + logger.warning( + "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + cached_status, cached_payload = await wait_future + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + + response = await naga_messages_send_impl( + ctx, + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + trace_id=trace_id, + message_key=message_key, + ) + if request_uuid: + await naga_state.finish_request_uuid( + request_uuid, + message_key, + status=response.status, + payload=_parse_response_payload(response), + ) + return response + except Exception as exc: + if request_uuid: + await naga_state.fail_request_uuid(request_uuid, message_key, exc) + raise + finally: + remaining = await naga_state.track_send_done(message_key) + logger.info( + "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + remaining, + ) + + +# ------------------------------------------------------------------ +# Core send implementation +# ------------------------------------------------------------------ + +# Core send implementation (no NagaState dependency) +# ------------------------------------------------------------------ + + +async def naga_messages_send_impl( + ctx: RuntimeAPIContext, + *, + naga_id: str, + bind_uuid: str, + delivery_signature: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, + trace_id: str, + message_key: str, +) -> Response: + from Undefined.api.naga_store import mask_token + + naga_store = ctx.naga_store + if naga_store is None: + logger.warning( + "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("Naga integration not available", status=503) + + binding, err_msg = await naga_store.acquire_delivery( + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", + trace_id, + naga_id, + bind_uuid, + err_msg.message if err_msg is not None else "unknown_error", + mask_token(delivery_signature), + ) + return _json_error( + err_msg.message if err_msg is not None else "delivery not available", + status=err_msg.http_status if err_msg is not None else 403, + ) + + logger.info( + "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + binding.qq_id, + binding.group_id, + ) + try: + if target_qq != binding.qq_id or target_group != binding.group_id: + logger.warning( + "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", + trace_id, + naga_id, + bind_uuid, + target_qq, + target_group, + binding.qq_id, + binding.group_id, + ) + return _json_error("target does not match bound qq/group", status=403) + + cfg = ctx.config_getter() + if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: + logger.warning( + "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", + trace_id, + naga_id, + bind_uuid, + binding.group_id, + ) + return _json_error("bound group is not in naga.allowed_groups", status=403) + + sender = ctx.sender + if sender is None: + logger.warning( + "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("sender not available", status=503) + + moderation: dict[str, Any] + naga_cfg = getattr(cfg, "naga", None) + moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) + security = getattr(ctx.command_dispatcher, "security", None) + if not moderation_enabled: + moderation = { + "status": "skipped_disabled", + "blocked": False, + "categories": [], + "message": "Naga moderation disabled by config; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + ) + elif security is None or not hasattr(security, "moderate_naga_message"): + moderation = { + "status": "error_allowed", + "blocked": False, + "categories": [], + "message": "Naga moderation service unavailable; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + else: + logger.info( + "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + len(content), + ) + result = await security.moderate_naga_message( + message_format=message_format, + content=content, + ) + moderation = { + "status": result.status, + "blocked": result.blocked, + "categories": result.categories, + "message": result.message, + "model_name": result.model_name, + } + logger.info( + "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + result.blocked, + result.status, + result.model_name, + ",".join(result.categories) or "-", + ) + if moderation["blocked"]: + logger.warning( + "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + moderation["message"], + ) + return web.json_response( + { + "ok": False, + "error": "message blocked by moderation", + "moderation": moderation, + }, + status=403, + ) + + send_content: str | None = content if message_format == "text" else None + image_path: str | None = None + tmp_path: str | None = None + rendered = False + render_fallback = False + if message_format in {"markdown", "html"}: + import tempfile + + from Undefined.api.routes import naga as naga_routes + + try: + html_str = content + if message_format == "markdown": + html_str = await naga_routes.render_markdown_to_html(content) + fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") + os.close(fd) + await naga_routes.render_html_to_image(html_str, tmp_path) + image_path = tmp_path + rendered = True + logger.info( + "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + Path(tmp_path).name if tmp_path is not None else "", + ) + except Exception as exc: + logger.warning( + "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + exc, + ) + send_content = content + render_fallback = True + + sent_private = False + sent_group = False + group_policy_blocked = False + + async def _ensure_delivery_active() -> tuple[Any, Response | None]: + current_binding, live_err = await naga_store.ensure_delivery_active( + naga_id=naga_id, + bind_uuid=bind_uuid, + ) + if current_binding is None: + logger.warning( + "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + live_err.message + if live_err is not None + else "delivery no longer active", + ) + return None, web.json_response( + { + "ok": False, + "error": ( + live_err.message + if live_err is not None + else "delivery no longer active" + ), + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=live_err.http_status if live_err is not None else 409, + ) + return current_binding, None + + try: + cq_image: str | None = None + if image_path is not None: + file_uri = Path(image_path).resolve().as_uri() + cq_image = f"[CQ:image,file={file_uri}]" + + if mode in {"private", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + logger.info( + "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + try: + if send_content is not None: + await sender.send_private_message( + current_binding.qq_id, send_content + ) + elif cq_image is not None: + await sender.send_private_message( + current_binding.qq_id, cq_image + ) + sent_private = True + logger.info( + "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.qq_id, + message_key, + exc, + ) + + if mode in {"group", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + current_cfg = ctx.config_getter() + if current_binding.group_id not in current_cfg.naga.allowed_groups: + group_policy_blocked = True + logger.warning( + "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + else: + logger.info( + "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + try: + if send_content is not None: + await sender.send_group_message( + current_binding.group_id, send_content + ) + elif cq_image is not None: + await sender.send_group_message( + current_binding.group_id, cq_image + ) + sent_group = True + logger.info( + "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.group_id, + message_key, + exc, + ) + finally: + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + + if mode == "private" and not sent_private: + return web.json_response( + { + "ok": False, + "error": "private delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "group" and not sent_group: + return web.json_response( + { + "ok": False, + "error": "group delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "both" and not (sent_private or sent_group): + if group_policy_blocked: + return web.json_response( + { + "ok": False, + "error": "bound group is not in naga.allowed_groups", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=403, + ) + return web.json_response( + { + "ok": False, + "error": "all deliveries failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + + await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) + partial_success = mode == "both" and (sent_private != sent_group) + logger.info( + "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + sent_private, + sent_group, + partial_success, + rendered, + render_fallback, + ) + return web.json_response( + { + "ok": True, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + "sent_private": sent_private, + "sent_group": sent_group, + "partial_success": partial_success, + "delivery_status": ( + "partial_success" if partial_success else "full_success" + ), + "rendered": rendered, + "render_fallback": render_fallback, + "moderation": moderation, + } + ) + finally: + await naga_store.release_delivery(bind_uuid=bind_uuid) diff --git a/src/Undefined/api/routes/naga/unbind.py b/src/Undefined/api/routes/naga/unbind.py new file mode 100644 index 00000000..5b4493ba --- /dev/null +++ b/src/Undefined/api/routes/naga/unbind.py @@ -0,0 +1,100 @@ +"""Naga 解绑路由。""" + +from __future__ import annotations + +import logging +import uuid as _uuid + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _json_error, +) +from Undefined.api.routes.naga.auth import verify_naga_api_key + +logger = logging.getLogger(__name__) + +# ------------------------------------------------------------------ +# POST /api/v1/naga/unbind +# ------------------------------------------------------------------ + +# ------------------------------------------------------------------ + + +# POST /api/v1/naga/unbind — 远端主动解绑 +async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + """POST /api/v1/naga/unbind — 远端主动解绑。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + logger.info( + "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + delivery_signature[:12] + "...", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + # 解绑时等待在途投递完成,避免消息发到已吊销绑定。 + binding, changed, err = await naga_store.revoke_binding( + naga_id, + expected_bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message if err is not None else "binding not found", + ) + return _json_error( + err.message if err is not None else "binding not found", + status=err.http_status if err is not None else 404, + ) + logger.info( + "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + changed, + binding.qq_id, + binding.group_id, + ) + return web.json_response( + { + "ok": True, + "idempotent": not changed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) diff --git a/src/Undefined/attachments/__init__.py b/src/Undefined/attachments/__init__.py new file mode 100644 index 00000000..ad4d1a56 --- /dev/null +++ b/src/Undefined/attachments/__init__.py @@ -0,0 +1,43 @@ +"""附件注册表与富媒体消息辅助工具包。 + +聚合 models、segments、registry、render 子模块的公开 API; +下游可 ``from Undefined.attachments import AttachmentRegistry`` 等。 +""" + +from Undefined.attachments.models import ( + AttachmentRecord, + AttachmentRenderError, + RegisteredMessageAttachments, + RenderedRichMessage, +) +from Undefined.attachments.registry import AttachmentRegistry +from Undefined.attachments.render import ( + dispatch_pending_file_sends, + render_message_with_attachments, + render_message_with_pic_placeholders, +) +from Undefined.attachments.segments import ( + append_attachment_text, + attachment_refs_to_text, + attachment_refs_to_xml, + build_attachment_scope, + register_message_attachments, + scope_from_context, +) + +__all__ = [ + "AttachmentRecord", + "AttachmentRegistry", + "AttachmentRenderError", + "RegisteredMessageAttachments", + "RenderedRichMessage", + "append_attachment_text", + "attachment_refs_to_text", + "attachment_refs_to_xml", + "build_attachment_scope", + "dispatch_pending_file_sends", + "register_message_attachments", + "render_message_with_attachments", + "render_message_with_pic_placeholders", + "scope_from_context", +] diff --git a/src/Undefined/attachments/models.py b/src/Undefined/attachments/models.py new file mode 100644 index 00000000..fd7c58d3 --- /dev/null +++ b/src/Undefined/attachments/models.py @@ -0,0 +1,93 @@ +"""附件领域模型与渲染异常类型。 + +定义 ``AttachmentRecord`` 等不可变数据类及 ``AttachmentRenderError``; +不含注册、解析或 CQ 渲染逻辑。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class AttachmentRecord: + """单条附件的持久化记录。 + + 由 ``AttachmentRegistry`` 写入磁盘并在消息渲染时按 UID 解析; + ``prompt_ref()`` 供 LLM 上下文引用本地可用或远程 URL 附件。 + """ + + uid: str + scope_key: str + kind: str + media_type: str + display_name: str + source_kind: str + source_ref: str + local_path: str | None + mime_type: str + sha256: str + created_at: str + segment_data: dict[str, str] + semantic_kind: str = "" + description: str = "" + + def prompt_ref(self) -> dict[str, str]: + """构建供提示词/历史引用的精简附件字典。 + + Returns: + 含 ``uid``、``kind``、``media_type`` 等字段的字典; + 本地文件不可用时回退 ``source_ref``。 + """ + local_available = False + if self.local_path is not None: + try: + local_available = Path(self.local_path).is_file() + except OSError: + local_available = False + ref: dict[str, str] = { + "uid": self.uid, + "kind": self.kind, + "media_type": self.media_type, + "display_name": self.display_name, + } + if self.source_kind.strip(): + ref["source_kind"] = self.source_kind.strip() + # 本地文件缺失时回退 source_ref,供 LLM 引用远程 URL + if not local_available and self.source_ref.strip(): + ref["source_ref"] = self.source_ref.strip() + if self.semantic_kind.strip(): + ref["semantic_kind"] = self.semantic_kind.strip() + if self.description.strip(): + ref["description"] = self.description.strip() + return ref + + +@dataclass(frozen=True) +class RegisteredMessageAttachments: + """OneBot 消息段注册附件后的归一化结果。""" + + attachments: list[dict[str, str]] + normalized_text: str + + +@dataclass(frozen=True) +class RenderedRichMessage: + """富媒体标签渲染后的投递与历史文本。""" + + delivery_text: str + history_text: str + attachments: list[dict[str, str]] + pending_file_sends: tuple[AttachmentRecord, ...] = () + + +class AttachmentRenderError(RuntimeError): + """附件标签无法渲染时抛出(``strict=True`` 场景)。""" + + +class _RemoteAttachmentTooLarge(Exception): + """远程下载超过字节上限时由 registry 内部捕获。""" + + def __init__(self, mime_type: str = "") -> None: + self.mime_type = mime_type diff --git a/src/Undefined/attachments/registry.py b/src/Undefined/attachments/registry.py new file mode 100644 index 00000000..1e958f90 --- /dev/null +++ b/src/Undefined/attachments/registry.py @@ -0,0 +1,903 @@ +"""附件持久化注册表。 + +负责本地缓存、远程下载、去重与 scope 隔离;由 handlers 与 AI 协调器持有进程级单例。 +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import logging +import mimetypes +from dataclasses import asdict, replace +from datetime import datetime +from pathlib import Path +import time +from typing import Any, Awaitable, Callable, Mapping +from uuid import uuid4 + +import httpx + +from Undefined.attachments.models import AttachmentRecord, _RemoteAttachmentTooLarge +from Undefined.attachments.segments import ( + display_name_from_source, + is_http_url, + media_kind_from_value, + scope_from_context, +) +from Undefined.utils import io +from Undefined.utils.paths import ( + ATTACHMENT_CACHE_DIR, + ATTACHMENT_REGISTRY_FILE, + ensure_dir, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 +_IMAGE_SUFFIX_TO_MIME = { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + ".bmp": "image/bmp", + ".svg": "image/svg+xml", +} +_MAGIC_IMAGE_SUFFIXES: tuple[tuple[bytes, str], ...] = ( + (b"\x89PNG\r\n\x1a\n", ".png"), + (b"\xff\xd8\xff", ".jpg"), + (b"GIF87a", ".gif"), + (b"GIF89a", ".gif"), + (b"BM", ".bmp"), +) +_ATTACHMENT_CACHE_MAX_AGE_SECONDS = 7 * 24 * 60 * 60 +_ATTACHMENT_REGISTRY_MAX_RECORDS = 2000 +_ATTACHMENT_CACHE_MAX_BYTES = 0 +_ATTACHMENT_URL_REFERENCE_MAX_RECORDS = 2000 +_ATTACHMENT_URL_MAX_LENGTH = 8192 +_DEFAULT_REMOTE_DOWNLOAD_MAX_BYTES = 25 * 1024 * 1024 + + +def _now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def _decode_data_url(data_url: str) -> tuple[bytes, str]: + header, _, payload = data_url.partition(",") + if ";base64" not in header.lower(): + raise ValueError("unsupported data URL encoding") + mime_type = ( + header.split(":", 1)[1].split(";", 1)[0].strip() or "application/octet-stream" + ) + return base64.b64decode(payload), mime_type + + +def _guess_suffix_from_bytes(content: bytes) -> str: + for magic, suffix in _MAGIC_IMAGE_SUFFIXES: + if content.startswith(magic): + return suffix + if content.startswith(b"RIFF") and content[8:12] == b"WEBP": + return ".webp" + return ".bin" + + +def _guess_suffix(name: str, content: bytes, mime_type: str) -> str: + suffix = Path(name).suffix.lower() + if suffix: + return suffix + guessed_ext = mimetypes.guess_extension(mime_type or "") + if guessed_ext: + return guessed_ext.lower() + return _guess_suffix_from_bytes(content) + + +def _guess_mime_type(name: str, content: bytes) -> str: + guessed, _ = mimetypes.guess_type(name) + if guessed: + return guessed + suffix = _guess_suffix_from_bytes(content) + return _IMAGE_SUFFIX_TO_MIME.get(suffix, "application/octet-stream") + + +def _remote_reference_source_kind(source_kind: str) -> str: + cleaned = str(source_kind or "").strip() + if not cleaned: + return "remote_url_reference" + if cleaned.endswith("_reference"): + return cleaned + return f"{cleaned}_reference" + + +class AttachmentRegistry: + """按会话作用域持久化的附件注册表。 + + 写入 JSON 注册表与本地缓存目录,支持远程 URL 引用与按需回源下载。 + """ + + def __init__( + self, + *, + registry_path: Path = ATTACHMENT_REGISTRY_FILE, + cache_dir: Path = ATTACHMENT_CACHE_DIR, + http_client: httpx.AsyncClient | None = None, + max_records: int = _ATTACHMENT_REGISTRY_MAX_RECORDS, + max_age_seconds: int = _ATTACHMENT_CACHE_MAX_AGE_SECONDS, + max_cache_bytes: int = _ATTACHMENT_CACHE_MAX_BYTES, + url_reference_max_records: int = _ATTACHMENT_URL_REFERENCE_MAX_RECORDS, + url_max_length: int = _ATTACHMENT_URL_MAX_LENGTH, + remote_download_max_bytes: int = _DEFAULT_REMOTE_DOWNLOAD_MAX_BYTES, + ) -> None: + self._registry_path = registry_path + self._cache_dir = cache_dir + self._http_client = http_client + self._max_records = max(0, int(max_records)) + self._max_age_seconds = max(0, int(max_age_seconds)) + self._max_cache_bytes = max(0, int(max_cache_bytes)) + self._url_reference_max_records = max(0, int(url_reference_max_records)) + self._url_max_length = max(0, int(url_max_length)) + self._remote_download_max_bytes = max(0, int(remote_download_max_bytes)) + self._lock = asyncio.Lock() + self._records: dict[str, AttachmentRecord] = {} + self._loaded = False + self._load_task: asyncio.Task[None] | None = None + self._global_image_resolver: Callable[[str], AttachmentRecord | None] | None = ( + None + ) + self._global_image_resolver_async: ( + Callable[[str], Awaitable[AttachmentRecord | None]] | None + ) = None + + def set_remote_download_max_bytes(self, value: int) -> None: + """设置单次远程下载字节上限。""" + self._remote_download_max_bytes = max(0, int(value)) + + def set_limits( + self, + *, + remote_download_max_bytes: int | None = None, + max_cache_bytes: int | None = None, + max_records: int | None = None, + max_age_seconds: int | None = None, + url_reference_max_records: int | None = None, + url_max_length: int | None = None, + ) -> None: + """批量更新注册表容量与 TTL 限制。""" + if remote_download_max_bytes is not None: + self._remote_download_max_bytes = max(0, int(remote_download_max_bytes)) + if max_cache_bytes is not None: + self._max_cache_bytes = max(0, int(max_cache_bytes)) + if max_records is not None: + self._max_records = max(0, int(max_records)) + if max_age_seconds is not None: + self._max_age_seconds = max(0, int(max_age_seconds)) + if url_reference_max_records is not None: + self._url_reference_max_records = max(0, int(url_reference_max_records)) + if url_max_length is not None: + self._url_max_length = max(0, int(url_max_length)) + + def set_global_image_resolver( + self, + resolver: Callable[[str], AttachmentRecord | None] | None, + ) -> None: + """注册同步全局图片 UID 回退解析器。""" + self._global_image_resolver = resolver + + def set_global_image_resolver_async( + self, + resolver: Callable[[str], Awaitable[AttachmentRecord | None]] | None, + ) -> None: + """注册异步全局图片 UID 回退解析器。""" + self._global_image_resolver_async = resolver + + def _resolve_managed_cache_path(self, raw_path: str | None) -> Path | None: + text = str(raw_path or "").strip() + if not text: + return None + try: + path = Path(text).expanduser().resolve() + cache_root = self._cache_dir.resolve() + except Exception: + return None + if path == cache_root or cache_root not in path.parents: + return None + return path + + def _normalized_url_ref(self, value: str) -> str: + text = str(value or "").strip() + if not is_http_url(text): + return "" + if self._url_max_length > 0 and len(text) > self._url_max_length: + return "" + return text + + def _record_with_local_path( + self, record: AttachmentRecord, local_path: str | None + ) -> AttachmentRecord: + return replace( + record, + local_path=local_path, + source_kind=_remote_reference_source_kind(record.source_kind) + if local_path is None and is_http_url(record.source_ref) + else record.source_kind, + ) + + def _remove_cached_content( + self, + record: AttachmentRecord, + cache_path: Path | None, + removable_paths: set[Path], + ) -> AttachmentRecord | None: + source_ref = self._normalized_url_ref(record.source_ref) + if source_ref: + if cache_path is not None: + removable_paths.add(cache_path) + return self._record_with_local_path(record, None) + if cache_path is not None: + removable_paths.add(cache_path) + return None + + def _prune_records(self) -> bool: + dirty = False + now = time.time() + retained: list[tuple[str, AttachmentRecord, Path | None, float, int]] = [] + removable_paths: set[Path] = set() + + for uid, record in self._records.items(): + cache_path = self._resolve_managed_cache_path(record.local_path) + if record.local_path is None: + has_url_ref = bool(self._normalized_url_ref(record.source_ref)) + if is_http_url(record.source_ref) and not has_url_ref: + dirty = True + continue + try: + mtime = datetime.fromisoformat(record.created_at).timestamp() + except ValueError: + mtime = now + if ( + not has_url_ref + and self._max_age_seconds > 0 + and now - mtime > self._max_age_seconds + ): + dirty = True + continue + retained.append((uid, record, None, mtime, 0)) + continue + if cache_path is None: + replacement = self._remove_cached_content(record, None, removable_paths) + if replacement is not None: + retained.append((uid, replacement, None, now, 0)) + dirty = True + continue + try: + stat_result = cache_path.stat() + mtime = float(stat_result.st_mtime) + size = int(stat_result.st_size) + except OSError: + replacement = self._remove_cached_content( + record, cache_path, removable_paths + ) + if replacement is not None: + retained.append((uid, replacement, None, now, 0)) + dirty = True + continue + if not cache_path.is_file(): + replacement = self._remove_cached_content( + record, cache_path, removable_paths + ) + if replacement is not None: + retained.append((uid, replacement, None, mtime, 0)) + dirty = True + continue + if self._max_age_seconds > 0 and now - mtime > self._max_age_seconds: + replacement = self._remove_cached_content( + record, cache_path, removable_paths + ) + if replacement is not None: + retained.append((uid, replacement, None, mtime, 0)) + dirty = True + continue + retained.append((uid, record, cache_path, mtime, size)) + + if self._max_records > 0 and len(retained) > self._max_records: + # 超出记录上限时按 mtime 淘汰最旧条目 + retained.sort(key=lambda item: item[3]) + overflow = len(retained) - self._max_records + for _uid, _record, cache_path, _mtime, _size in retained[:overflow]: + if cache_path is not None: + removable_paths.add(cache_path) + retained = retained[overflow:] + dirty = True + + if self._max_cache_bytes > 0: + cache_total = sum( + size + for _uid, _record, path, _mtime, size in retained + if path is not None + ) + if cache_total > self._max_cache_bytes: + reduced: list[ + tuple[str, AttachmentRecord, Path | None, float, int] + ] = [] + for uid, record, cache_path, mtime, size in sorted( + retained, key=lambda item: item[3] + ): + if cache_path is not None and cache_total > self._max_cache_bytes: + replacement = self._remove_cached_content( + record, cache_path, removable_paths + ) + if replacement is not None: + reduced.append((uid, replacement, None, mtime, 0)) + cache_total -= size + dirty = True + else: + reduced.append((uid, record, cache_path, mtime, size)) + retained = reduced + + if self._url_reference_max_records > 0: + url_refs = [ + item + for item in retained + if item[2] is None and is_http_url(item[1].source_ref) + ] + if len(url_refs) > self._url_reference_max_records: + # 仅 URL 引用(未下载)单独计数上限 + url_ref_ids = { + uid + for uid, _record, _path, _mtime, _size in sorted( + url_refs, key=lambda item: item[3] + )[: len(url_refs) - self._url_reference_max_records] + } + retained = [item for item in retained if item[0] not in url_ref_ids] + dirty = True + + retained_records = { + uid: record for uid, record, _path, _mtime, _size in retained + } + retained_paths = { + path.resolve() + for _uid, _record, path, _mtime, _size in retained + if path is not None and path.exists() + } + + for path in removable_paths: + try: + resolved = path.resolve() + except Exception: + resolved = path + if resolved in retained_paths: + continue + try: + path.unlink(missing_ok=True) + dirty = True + except OSError: + continue + + if self._cache_dir.exists(): + for item in self._cache_dir.iterdir(): + if not item.is_file(): + continue + try: + resolved = item.resolve() + except Exception: + resolved = item + if resolved in retained_paths: + continue + try: + item.unlink() + dirty = True + except OSError: + continue + + if dirty: + self._records = retained_records + return dirty + + def _load_records_from_payload(self, raw: Any) -> dict[str, AttachmentRecord]: + if not isinstance(raw, dict): + return {} + loaded: dict[str, AttachmentRecord] = {} + for uid, item in raw.items(): + if not isinstance(item, dict): + continue + try: + loaded[str(uid)] = AttachmentRecord( + uid=str(item.get("uid") or uid), + scope_key=str(item.get("scope_key", "") or ""), + kind=media_kind_from_value(item.get("kind", "file")), + media_type=media_kind_from_value( + item.get("media_type") or item.get("kind") or "file" + ), + display_name=str(item.get("display_name", "") or ""), + source_kind=str(item.get("source_kind", "") or ""), + source_ref=str(item.get("source_ref", "") or ""), + local_path=str(item.get("local_path", "") or "") or None, + mime_type=str( + item.get("mime_type", "") or "application/octet-stream" + ), + sha256=str(item.get("sha256", "") or ""), + created_at=str(item.get("created_at", "") or ""), + segment_data={ + str(k): str(v) + for k, v in dict(item.get("segment_data") or {}).items() + if str(k).strip() and str(v).strip() + }, + semantic_kind=str(item.get("semantic_kind", "") or ""), + description=str(item.get("description", "") or ""), + ) + except Exception: + continue + return loaded + + async def _load_from_disk_async(self) -> None: + try: + raw = await io.read_json(self._registry_path, use_lock=False) + except Exception as exc: + logger.warning("[AttachmentRegistry] 读取失败: %s", exc) + self._loaded = True + return + self._records = self._load_records_from_payload(raw) + dirty = self._prune_records() + if dirty: + await self._persist() + self._loaded = True + + async def load(self) -> None: + """等待注册表完成初始加载。""" + if self._loaded: + return + if self._load_task is None: + self._load_task = asyncio.create_task(self._load_from_disk_async()) + await self._load_task + + async def _persist(self) -> None: + payload = {uid: asdict(record) for uid, record in self._records.items()} + await io.write_json(self._registry_path, payload, use_lock=True) + + async def flush(self) -> None: + """将当前注册表状态强制落盘。""" + await self.load() + async with self._lock: + await self._persist() + + def get(self, uid: str) -> AttachmentRecord | None: + """按 UID 读取内存中的附件记录(不触发磁盘加载)。""" + return self._records.get(str(uid).strip()) + + def resolve(self, uid: str, scope_key: str | None) -> AttachmentRecord | None: + """同步解析 UID,含 scope 校验与全局图片回退。""" + record = self.get(uid) + if record is not None: + # scope 不匹配时拒绝跨会话引用 + if record.scope_key and scope_key and record.scope_key != scope_key: + return None + return record + if self._global_image_resolver is not None: + try: + record = self._global_image_resolver(uid) + except Exception: + logger.exception( + "[AttachmentRegistry] global image resolver failed: uid=%s", uid + ) + record = None + if record is None: + return None + if record.scope_key and scope_key and record.scope_key != scope_key: + return None + return record + + async def resolve_async( + self, uid: str, scope_key: str | None + ) -> AttachmentRecord | None: + """异步解析 UID,优先异步全局回退解析器。""" + record = self.get(uid) + if record is not None: + if record.scope_key and scope_key and record.scope_key != scope_key: + return None + return record + if self._global_image_resolver_async is not None: + try: + record = await self._global_image_resolver_async(uid) + except Exception: + logger.exception( + "[AttachmentRegistry] async global image resolver failed: uid=%s", + uid, + ) + record = None + elif self._global_image_resolver is not None: + try: + record = self._global_image_resolver(uid) + except Exception: + logger.exception( + "[AttachmentRegistry] global image resolver failed: uid=%s", uid + ) + record = None + else: + record = None + if record is None: + return None + if record.scope_key and scope_key and record.scope_key != scope_key: + return None + return record + + def resolve_for_context( + self, + uid: str, + context: Mapping[str, Any] | None, + ) -> AttachmentRecord | None: + """从请求上下文推断 scope 后解析 UID。""" + return self.resolve(uid, scope_from_context(context)) + + async def get_url_by_uid(self, uid: str) -> str | None: + """通过附件 UID 获取 source_ref(URL)。""" + await self.load() + record = self.get(uid) + if record is None or not record.source_ref.strip(): + return None + return record.source_ref.strip() + + async def get_uid_by_url(self, url: str) -> str | None: + """通过 URL 查找对应的附件 UID。""" + await self.load() + url = url.strip() + if not url: + return None + for record in self._records.values(): + if record.source_ref.strip() == url: + return record.uid + return None + + def _build_uid(self, prefix: str) -> str: + while True: + uid = f"{prefix}_{uuid4().hex[:8]}" + if uid not in self._records: + return uid + + def _find_by_sha256( + self, scope_key: str, sha256: str, kind: str + ) -> AttachmentRecord | None: + """Find an existing record with matching scope, kind, and SHA-256. + + Only returns a record whose *local_path* still exists on disk. + Must be called while ``self._lock`` is held. + """ + for record in self._records.values(): + if ( + record.scope_key == scope_key + and record.sha256 == sha256 + and record.kind == kind + and record.local_path + and Path(record.local_path).is_file() + ): + return record + return None + + async def register_bytes( + self, + scope_key: str, + content: bytes, + *, + kind: str, + display_name: str, + source_kind: str, + source_ref: str = "", + mime_type: str | None = None, + segment_data: Mapping[str, str] | None = None, + ) -> AttachmentRecord: + """将字节内容写入缓存并注册新附件(含 SHA-256 去重)。""" + await self.load() + normalized_kind = media_kind_from_value(kind) + normalized_media_type = ( + "image" if normalized_kind == "image" else normalized_kind + ) + normalized_mime = mime_type or _guess_mime_type(display_name, content) + suffix = _guess_suffix(display_name, content, normalized_mime) + prefix = "pic" if normalized_media_type == "image" else "file" + + async with self._lock: + digest = await asyncio.to_thread(hashlib.sha256, content) + digest_hex = digest.hexdigest() + + existing = self._find_by_sha256(scope_key, digest_hex, normalized_kind) + if existing is not None: + # 同 scope+SHA256 去重,复用已有 UID + return existing + + uid = self._build_uid(prefix) + file_name = f"{uid}{suffix}" + cache_path = ensure_dir(self._cache_dir) / file_name + await asyncio.to_thread(cache_path.write_bytes, content) + + record = AttachmentRecord( + uid=uid, + scope_key=scope_key, + kind=normalized_kind, + media_type=normalized_media_type, + display_name=display_name or file_name, + source_kind=source_kind, + source_ref=source_ref, + local_path=str(cache_path), + mime_type=normalized_mime, + sha256=digest_hex, + created_at=_now_iso(), + segment_data={ + str(k): str(v) + for k, v in dict(segment_data or {}).items() + if str(k).strip() and str(v).strip() + }, + ) + self._records[uid] = record + self._prune_records() + await self._persist() + return self._records.get(uid, record) + + async def register_local_file( + self, + scope_key: str, + local_path: str | Path, + *, + kind: str, + display_name: str | None = None, + source_kind: str = "local_file", + source_ref: str = "", + segment_data: Mapping[str, str] | None = None, + ) -> AttachmentRecord: + """读取本地文件并注册为附件。""" + path = Path(str(local_path)).expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + else: + path = path.resolve() + if not path.is_file(): + raise FileNotFoundError(path) + + def _read() -> bytes: + return path.read_bytes() + + content = await asyncio.to_thread(_read) + return await self.register_bytes( + scope_key, + content, + kind=kind, + display_name=display_name or path.name, + source_kind=source_kind, + source_ref=source_ref or str(path), + mime_type=mimetypes.guess_type(path.name)[0] or None, + segment_data=segment_data, + ) + + async def register_data_url( + self, + scope_key: str, + data_url: str, + *, + kind: str, + display_name: str, + source_kind: str, + source_ref: str = "", + segment_data: Mapping[str, str] | None = None, + ) -> AttachmentRecord: + """解码 ``data:`` URL 并注册附件。""" + content, mime_type = _decode_data_url(data_url) + return await self.register_bytes( + scope_key, + content, + kind=kind, + display_name=display_name, + source_kind=source_kind, + source_ref=source_ref, + mime_type=mime_type, + segment_data=segment_data, + ) + + async def register_remote_url( + self, + scope_key: str, + url: str, + *, + kind: str, + display_name: str | None = None, + source_kind: str = "remote_url", + source_ref: str = "", + segment_data: Mapping[str, str] | None = None, + ) -> AttachmentRecord: + """下载远程 URL 或在上限时降级为 URL 引用。""" + name = display_name or display_name_from_source(url, "attachment.bin") + return await self._register_remote_url_or_reference( + scope_key, + url, + kind=kind, + display_name=name, + source_kind=source_kind, + source_ref=source_ref or url, + segment_data=segment_data, + ) + + async def register_remote_reference( + self, + scope_key: str, + url: str, + *, + kind: str, + display_name: str | None = None, + source_kind: str = "remote_url_reference", + source_ref: str = "", + mime_type: str | None = None, + segment_data: Mapping[str, str] | None = None, + description: str = "", + ) -> AttachmentRecord: + """仅登记远程 URL 引用,不下载内容。""" + await self.load() + if not self._normalized_url_ref(url): + raise ValueError("远程附件 URL 为空或超过长度上限") + normalized_kind = media_kind_from_value(kind) + normalized_media_type = ( + "image" if normalized_kind == "image" else normalized_kind + ) + prefix = "pic" if normalized_media_type == "image" else "file" + ref = url + normalized_segment_data = dict(segment_data or {}) + if source_ref and source_ref != url: + normalized_segment_data.setdefault("original_source_ref", source_ref) + name = display_name or display_name_from_source(url, "attachment.bin") + digest_hex = hashlib.sha256(ref.encode("utf-8")).hexdigest() + + async with self._lock: + for existing in self._records.values(): + if ( + existing.scope_key == scope_key + and existing.kind == normalized_kind + and existing.local_path is None + and existing.source_ref == ref + ): + return existing + + uid = self._build_uid(prefix) + record = AttachmentRecord( + uid=uid, + scope_key=scope_key, + kind=normalized_kind, + media_type=normalized_media_type, + display_name=name, + source_kind=source_kind, + source_ref=ref, + local_path=None, + mime_type=mime_type or mimetypes.guess_type(name)[0] or "", + sha256=digest_hex, + created_at=_now_iso(), + segment_data={ + str(k): str(v) + for k, v in normalized_segment_data.items() + if str(k).strip() and str(v).strip() + }, + description=description, + ) + self._records[uid] = record + self._prune_records() + await self._persist() + return self._records.get(uid, record) + + async def _register_remote_url_or_reference( + self, + scope_key: str, + url: str, + *, + kind: str, + display_name: str, + source_kind: str, + source_ref: str, + segment_data: Mapping[str, str] | None, + ) -> AttachmentRecord: + if not self._normalized_url_ref(url): + raise ValueError("远程附件 URL 为空或超过长度上限") + timeout = httpx.Timeout(_DEFAULT_REMOTE_TIMEOUT_SECONDS) + max_bytes = self._remote_download_max_bytes + reference_segment_data = dict(segment_data or {}) + if source_ref and source_ref != url: + reference_segment_data.setdefault("original_source_ref", source_ref) + if max_bytes <= 0: + # 配置为 0 时一律只登记 URL 引用,不下载 + return await self.register_remote_reference( + scope_key, + url, + kind=kind, + display_name=display_name, + source_kind=_remote_reference_source_kind(source_kind), + source_ref=url, + segment_data=reference_segment_data, + description="远程附件未下载:remote_download_max_size_mb=0", + ) + + async def _stream(client: httpx.AsyncClient) -> tuple[bytes, str]: + async with client.stream( + "GET", url, timeout=timeout, follow_redirects=True + ) as response: + response.raise_for_status() + mime_type = ( + response.headers.get("content-type", "").split(";", 1)[0].strip() + ) + raw_length = response.headers.get("content-length", "").strip() + if raw_length.isdigit() and int(raw_length) > max_bytes: + raise _RemoteAttachmentTooLarge(mime_type) + + chunks: list[bytes] = [] + total = 0 + async for chunk in response.aiter_bytes(): + total += len(chunk) + # 流式累计超限则降级为 URL 引用 + if total > max_bytes: + raise _RemoteAttachmentTooLarge(mime_type) + chunks.append(chunk) + return b"".join(chunks), mime_type + + try: + if self._http_client is not None: + content, mime_type = await _stream(self._http_client) + else: + async with httpx.AsyncClient( + timeout=timeout, follow_redirects=True + ) as client: + content, mime_type = await _stream(client) + except _RemoteAttachmentTooLarge as exc: + return await self.register_remote_reference( + scope_key, + url, + kind=kind, + display_name=display_name, + source_kind=_remote_reference_source_kind(source_kind), + source_ref=url, + mime_type=exc.mime_type, + segment_data=reference_segment_data, + description=f"远程附件超过下载上限 {max_bytes} bytes,保留 URL 引用。", + ) + + return await self.register_bytes( + scope_key, + content, + kind=kind, + display_name=display_name, + source_kind=source_kind, + source_ref=url, + mime_type=mime_type or None, + segment_data=reference_segment_data, + ) + + async def ensure_local_file(self, record: AttachmentRecord) -> AttachmentRecord: + """若记录仅有 URL 引用则尝试回源下载到本地缓存。""" + await self.load() + if record.local_path and Path(record.local_path).is_file(): + return record + source_ref = self._normalized_url_ref(record.source_ref) + if not source_ref: + return record + existing_uids = set(self._records) + refreshed = await self._register_remote_url_or_reference( + record.scope_key, + source_ref, + kind=record.kind, + display_name=record.display_name, + source_kind=record.source_kind, + source_ref=source_ref, + segment_data=record.segment_data, + ) + if refreshed.local_path is None: + return refreshed + async with self._lock: + current = self._records.get(record.uid) + if current is None: + return refreshed + updated = replace( + current, + local_path=refreshed.local_path, + mime_type=refreshed.mime_type, + sha256=refreshed.sha256, + source_kind=refreshed.source_kind, + segment_data=refreshed.segment_data, + ) + self._records[record.uid] = updated + if refreshed.uid != record.uid and refreshed.uid not in existing_uids: + self._records.pop(refreshed.uid, None) + self._prune_records() + await self._persist() + return self._records.get(record.uid, updated) diff --git a/src/Undefined/attachments/render.py b/src/Undefined/attachments/render.py new file mode 100644 index 00000000..a6a3be44 --- /dev/null +++ b/src/Undefined/attachments/render.py @@ -0,0 +1,280 @@ +"""富媒体标签渲染与待发送文件派发。 + +将 ```` / ```` 占位符转为 CQ 段或历史可读文本; +不修改注册表结构。 +""" + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from Undefined.attachments.models import ( + AttachmentRecord, + AttachmentRenderError, + RenderedRichMessage, +) +from Undefined.attachments.segments import _MEDIA_LABELS, is_http_url + +if TYPE_CHECKING: + from Undefined.attachments.registry import AttachmentRegistry + +logger = logging.getLogger(__name__) + +_PIC_TAG_PATTERN = re.compile( + r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_ATTACHMENT_TAG_PATTERN = re.compile( + r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_UNIFIED_TAG_PATTERN = re.compile( + r"<(?Ppic|attachment)\s+uid=(?P[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) + + +def _escape_cq_component(value: str) -> str: + return ( + value.replace("&", "&") + .replace("[", "[") + .replace("]", "]") + .replace(",", ",") + ) + + +async def render_message_with_attachments( + message: str, + *, + registry: AttachmentRegistry | None, + scope_key: str | None, + strict: bool, +) -> RenderedRichMessage: + """Render ```` and ```` tags into delivery/history text. + + * ```` — backward-compatible, image-only. + * ```` — unified tag for any media type. + Images (``pic_*``) are inlined as CQ images; files (``file_*``) + are collected into *pending_file_sends* for later dispatch. + + Args: + message: 含占位标签的原始消息文本。 + registry: 附件注册表。 + scope_key: 当前会话作用域键。 + strict: 为 True 时 UID 不可用或类型不匹配则抛出 ``AttachmentRenderError``。 + + Returns: + 投递文本、历史文本、附件引用及待发送文件列表。 + + Raises: + AttachmentRenderError: ``strict=True`` 且标签无法解析时。 + """ + has_tags = message and ( + " tag: strictly image-only + if tag_name == "pic" and record.media_type != "image": + replacement = f"[图片 uid={uid} 类型错误]" + if strict: + raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + continue + + # 仅允许图片; 按 media_type 分流 + if record.media_type == "image": + ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) + else: + ok = _render_file_tag( + record, + uid, + strict, + delivery_parts, + history_parts, + pending_files, + ) + + if ok: + attachments.append(record.prompt_ref()) + + delivery_parts.append(message[last_index:]) + history_parts.append(message[last_index:]) + return RenderedRichMessage( + delivery_text="".join(delivery_parts), + history_text="".join(history_parts), + attachments=attachments, + pending_file_sends=tuple(pending_files), + ) + + +def _render_image_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], +) -> bool: + """Render an image attachment as an inline CQ:image. Returns True on success.""" + image_source = record.source_ref + if record.local_path: + image_source = Path(record.local_path).resolve().as_uri() + elif not image_source: + replacement = f"[图片 uid={uid} 缺少文件]" + if strict: + raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return False + + cq_args = [f"file={image_source}"] + for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): + cleaned_key = str(key or "").strip() + cleaned_value = str(value or "").strip() + if ( + not cleaned_key + or not cleaned_value + or cleaned_key in {"file", "original_source_ref"} + ): + continue + cq_args.append( + f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" + ) + delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") + if record.display_name: + history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + else: + history_parts.append(f"[图片 uid={uid}]") + return True + + +def _render_file_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], + pending_files: list[AttachmentRecord], +) -> bool: + """Render a non-image attachment as a pending file send. Returns True on success.""" + if not record.local_path or not Path(record.local_path).is_file(): + if is_http_url(record.source_ref): + # 仅有远程 URL 时先入 pending,发送前尝试回源下载 + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + return True + replacement = f"[文件 uid={uid} 缺少本地文件]" + if strict: + raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return False + + # 文件不在 CQ 文本中内联,单独走 send_group/private_file + # Keep a readable placeholder in history + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + return True + + +render_message_with_pic_placeholders = render_message_with_attachments + + +async def dispatch_pending_file_sends( + rendered: RenderedRichMessage, + *, + sender: Any, + target_type: str, + target_id: int, + registry: AttachmentRegistry | None = None, +) -> None: + """Send pending file attachments collected by *render_message_with_attachments*. + + This is best-effort: each file send failure is logged but does not interrupt + the remaining sends or the caller. + + Args: + rendered: ``render_message_with_attachments`` 的返回值。 + sender: 实现群/私聊文件发送的 OneBot 客户端。 + target_type: ``group`` 或 ``private``。 + target_id: 目标群号或 QQ 号。 + registry: 可选,用于发送前回源下载仅有 URL 的附件。 + """ + if not rendered.pending_file_sends or sender is None: + return + for record in rendered.pending_file_sends: + send_record = record + if ( + not send_record.local_path or not Path(send_record.local_path).is_file() + ) and registry is not None: + try: + send_record = await registry.ensure_local_file(send_record) + except Exception: + logger.warning( + "[文件发送] 回源下载失败 uid=%s source=%s", + send_record.uid, + send_record.source_ref, + exc_info=True, + ) + if not send_record.local_path or not Path(send_record.local_path).is_file(): + logger.warning( + "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", + send_record.uid, + send_record.local_path, + ) + continue + try: + if target_type == "group": + await sender.send_group_file( + target_id, + send_record.local_path, + name=send_record.display_name or None, + ) + else: + await sender.send_private_file( + target_id, + send_record.local_path, + name=send_record.display_name or None, + ) + except Exception: + logger.warning( + "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", + send_record.uid, + target_type, + target_id, + exc_info=True, + ) diff --git a/src/Undefined/attachments/segments.py b/src/Undefined/attachments/segments.py new file mode 100644 index 00000000..d78b458f --- /dev/null +++ b/src/Undefined/attachments/segments.py @@ -0,0 +1,566 @@ +"""OneBot 消息段解析与会话作用域辅助。 + +负责 scope 键构建、附件引用文本/XML 序列化,以及从消息段批量注册附件; +不处理磁盘持久化或 CQ 标签渲染。 +""" + +from __future__ import annotations + +import base64 +import binascii +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Mapping, Sequence +from urllib.parse import unquote, urlsplit + +import httpx + +from Undefined.attachments.models import RegisteredMessageAttachments +from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR +from Undefined.utils.xml import escape_xml_attr + +if TYPE_CHECKING: + from Undefined.attachments.registry import AttachmentRegistry + +logger = logging.getLogger(__name__) + +_MEDIA_LABELS = { + "image": "图片", + "file": "文件", + "audio": "音频", + "video": "视频", + "record": "语音", + "pic": "图片", +} +_WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") +_FORWARD_ATTACHMENT_MAX_DEPTH = 3 + + +def _coerce_positive_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value if value > 0 else None + if isinstance(value, str): + text = value.strip() + if not text: + return None + try: + parsed = int(text) + except ValueError: + return None + return parsed if parsed > 0 else None + return None + + +def build_attachment_scope( + *, + group_id: Any = None, + user_id: Any = None, + request_type: str | None = None, + webui_session: bool = False, +) -> str | None: + """构建附件可见性作用域键。 + + Args: + group_id: 群号;优先于私聊作用域。 + user_id: 用户 QQ 号。 + request_type: 请求类型(``private`` 等)。 + webui_session: WebUI 会话时使用固定 ``webui`` 作用域。 + + Returns: + 形如 ``group:123`` / ``private:456`` / ``webui`` 的键,无法推断时返回 ``None``。 + """ + if webui_session: + return "webui" + + group = _coerce_positive_int(group_id) + if group is not None: + # 群聊作用域优先于私聊,避免同会话附件串读 + return f"group:{group}" + + user = _coerce_positive_int(user_id) + request_type_text = str(request_type or "").strip().lower() + if request_type_text == "private" and user is not None: + return f"private:{user}" + if user is not None: + return f"private:{user}" + return None + + +def scope_from_context(context: Mapping[str, Any] | None) -> str | None: + """从请求上下文字典提取附件作用域键。""" + if not context: + return None + return build_attachment_scope( + group_id=context.get("group_id"), + user_id=context.get("user_id"), + request_type=str(context.get("request_type", "") or ""), + webui_session=bool(context.get("webui_session", False)), + ) + + +def attachment_refs_to_text(attachments: Sequence[Mapping[str, str]]) -> str: + """将附件引用列表转为可读占位文本。""" + if not attachments: + return "" + parts: list[str] = [] + for item in attachments: + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + media_type = str(item.get("media_type") or item.get("kind") or "file").strip() + label = _MEDIA_LABELS.get(media_type, "附件") + name = str(item.get("display_name", "") or "").strip() + if name: + parts.append(f"[{label} uid={uid} name={name}]") + else: + parts.append(f"[{label} uid={uid}]") + return " ".join(parts) + + +def attachment_refs_to_xml( + attachments: Sequence[Mapping[str, str]], + *, + indent: str = " ", +) -> str: + """将附件引用列表序列化为 XML ```` 片段。""" + if not attachments: + return "" + lines = [f"{indent}"] + for item in attachments: + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + kind = str(item.get("kind", "") or item.get("media_type", "") or "file").strip() + media_type = str(item.get("media_type", "") or kind or "file").strip() + name = str(item.get("display_name", "") or "").strip() + attrs = [ + f'uid="{escape_xml_attr(uid)}"', + f'type="{escape_xml_attr(kind or media_type)}"', + f'media_type="{escape_xml_attr(media_type)}"', + ] + if name: + attrs.append(f'name="{escape_xml_attr(name)}"') + source_kind = str(item.get("source_kind", "") or "").strip() + if source_kind: + attrs.append(f'source_kind="{escape_xml_attr(source_kind)}"') + source_ref = str(item.get("source_ref", "") or "").strip() + if source_ref: + attrs.append(f'source_ref="{escape_xml_attr(source_ref)}"') + semantic_kind = str(item.get("semantic_kind", "") or "").strip() + if semantic_kind: + attrs.append(f'semantic_kind="{escape_xml_attr(semantic_kind)}"') + description = str(item.get("description", "") or "").strip() + if description: + attrs.append(f'description="{escape_xml_attr(description)}"') + lines.append(f"{indent} ") + lines.append(f"{indent}") + return "\n".join(lines) + + +def append_attachment_text( + base_text: str, attachments: Sequence[Mapping[str, str]] +) -> str: + """在基础文本后追加附件占位摘要行。""" + attachment_text = attachment_refs_to_text(attachments) + if not attachment_text: + return base_text + if not base_text.strip(): + return attachment_text + return f"{base_text}\n附件: {attachment_text}" + + +def is_http_url(value: str) -> bool: + """判断字符串是否为 HTTP(S) URL。""" + return value.startswith("http://") or value.startswith("https://") + + +def is_data_url(value: str) -> bool: + """判断字符串是否为 ``data:`` URL。""" + return value.startswith("data:") + + +def is_localish_path(value: str) -> bool: + """判断字符串是否像本地绝对路径或 ``file://`` URI。""" + return ( + value.startswith("/") + or value.startswith("file://") + or bool(_WINDOWS_ABS_PATH_RE.match(value)) + ) + + +def display_name_from_source(raw_source: str, fallback: str) -> str: + """从 URL 或路径推断展示文件名。""" + if not raw_source: + return fallback + if raw_source.startswith("file://"): + raw_source = raw_source[7:] + name = Path(unquote(urlsplit(raw_source).path)).name + return name or fallback + + +def media_kind_from_value(value: str) -> str: + """将任意媒体类型字符串规范为 registry 支持的 kind。""" + text = str(value or "").strip().lower() + if text in {"image", "file", "audio", "video", "record"}: + return text + return "file" + + +def segment_text( + type_: str, data: Mapping[str, Any], ref: Mapping[str, str] | None +) -> str: + """将单条 OneBot 消息段转为可读占位文本。""" + if type_ == "text": + return str(data.get("text", "") or "") + if type_ == "at": + qq = str(data.get("qq", "") or "").strip() + name = str(data.get("name") or data.get("nickname") or "").strip() + if qq and name: + return f"[@{qq}({name})]" + if qq: + return f"[@{qq}]" + return "[@]" + if type_ == "face": + return "[表情]" + if type_ == "reply": + reply_id = str(data.get("id") or data.get("message_id") or "").strip() + return f"[引用: {reply_id}]" if reply_id else "[引用]" + if type_ == "forward": + forward_id = str(data.get("id") or data.get("resid") or "").strip() + return f"[合并转发: {forward_id}]" if forward_id else "[合并转发]" + if ref is not None: + label = _MEDIA_LABELS.get( + str(ref.get("media_type") or ref.get("kind") or type_).strip(), "附件" + ) + uid = str(ref.get("uid", "") or "").strip() + name = str(ref.get("display_name", "") or "").strip() + if uid and name: + return f"[{label} uid={uid} name={name}]" + if uid: + return f"[{label} uid={uid}]" + label = _MEDIA_LABELS.get(type_, "附件") + raw = str(data.get("file") or data.get("url") or data.get("id") or "").strip() + return f"[{label}: {raw}]" if raw else f"[{label}]" + + +def _resolve_webui_file_id(file_id: str) -> Path | None: + if not file_id or not file_id.isalnum(): + return None + file_dir = (Path.cwd() / WEBUI_FILE_CACHE_DIR / file_id).resolve() + cache_root = (Path.cwd() / WEBUI_FILE_CACHE_DIR).resolve() + if cache_root not in file_dir.parents and file_dir != cache_root: + return None + if not file_dir.is_dir(): + return None + try: + files = list(file_dir.iterdir()) + except OSError: + return None + for candidate in files: + if candidate.is_file(): + return candidate + return None + + +def _extract_forward_id(data: Mapping[str, Any]) -> str: + forward_id = data.get("id") or data.get("resid") or data.get("message_id") + return str(forward_id).strip() if forward_id is not None else "" + + +def segment_data_from_onebot_data( + data: Mapping[str, Any], + *, + exclude_keys: set[str] | None = None, +) -> dict[str, str]: + """提取 OneBot 段 ``data`` 中需保留的字符串键值。""" + excluded = {key.strip().lower() for key in (exclude_keys or set()) if key.strip()} + normalized: dict[str, str] = {} + for raw_key, raw_value in data.items(): + key = str(raw_key or "").strip() + if not key: + continue + if key.lower() in excluded: + continue + text = str(raw_value or "").strip() + if not text: + continue + normalized[key] = text + return normalized + + +def normalize_message_segments(message: Any) -> list[Mapping[str, Any]]: + """将多种消息表示统一为 OneBot 段列表。""" + if isinstance(message, list): + normalized: list[Mapping[str, Any]] = [] + for item in message: + if isinstance(item, Mapping): + normalized.append(item) + elif isinstance(item, str): + normalized.append({"type": "text", "data": {"text": item}}) + return normalized + if isinstance(message, Mapping): + return [message] + if isinstance(message, str): + return [{"type": "text", "data": {"text": message}}] + return [] + + +def _normalize_forward_nodes(raw_nodes: Any) -> list[Mapping[str, Any]]: + if isinstance(raw_nodes, list): + return [node for node in raw_nodes if isinstance(node, Mapping)] + if isinstance(raw_nodes, Mapping): + messages = raw_nodes.get("messages") + if isinstance(messages, list): + return [node for node in messages if isinstance(node, Mapping)] + return [] + + +async def register_message_attachments( + *, + registry: AttachmentRegistry | None, + segments: Sequence[Mapping[str, Any]], + scope_key: str | None, + resolve_image_url: Callable[[str], Awaitable[str | None]] | None = None, + get_forward_messages: Callable[[str], Awaitable[list[dict[str, Any]]]] + | None = None, +) -> RegisteredMessageAttachments: + """扫描消息段并将图片/文件注册到 ``AttachmentRegistry``。 + + Args: + registry: 附件注册表;为 ``None`` 时仅归一化文本。 + segments: OneBot 消息段序列。 + scope_key: 会话作用域键。 + resolve_image_url: 可选,将 ``file`` 字段解析为可下载 URL。 + get_forward_messages: 可选,拉取合并转发子消息。 + + Returns: + 已注册附件引用与归一化纯文本。 + """ + attachments: list[dict[str, str]] = [] + normalized_parts: list[str] = [] + if registry is None or not scope_key: + for segment in segments: + type_ = str(segment.get("type", "") or "") + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + normalized_parts.append(segment_text(type_, data, None)) + return RegisteredMessageAttachments( + attachments=[], + normalized_text="".join(normalized_parts).strip(), + ) + + visited_forward_ids: set[str] = set() + + async def _collect_from_segments( + current_segments: Sequence[Mapping[str, Any]], + *, + depth: int, + prefix: str, + ) -> None: + for index, segment in enumerate(current_segments): + type_ = str(segment.get("type", "") or "").strip().lower() + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + ref: dict[str, str] | None = None + + try: + if type_ == "image": + raw_source = str(data.get("file") or data.get("url") or "").strip() + display_name = display_name_from_source( + raw_source, + f"image_{index + 1}.png", + ) + if raw_source.startswith("base64://"): + payload = raw_source[len("base64://") :].strip() + content = base64.b64decode(payload) + record = await registry.register_bytes( + scope_key, + content, + kind="image", + display_name=display_name, + source_kind="base64_image", + source_ref=f"{prefix}segment:{index}", + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + elif is_data_url(raw_source): + record = await registry.register_data_url( + scope_key, + raw_source, + kind="image", + display_name=display_name, + source_kind="data_url_image", + source_ref=f"{prefix}segment:{index}", + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + else: + resolved_source = raw_source + if raw_source and resolve_image_url is not None: + try: + # NapCat file id 需经 get_image 解析为可下载 URL + resolved = await resolve_image_url(raw_source) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] image resolver failed: file=%s err=%s", + raw_source, + exc, + ) + resolved = None + if resolved: + resolved_source = str(resolved) + + if is_http_url(resolved_source): + record = await registry.register_remote_url( + scope_key, + resolved_source, + kind="image", + display_name=display_name, + source_kind="remote_image", + source_ref=raw_source or resolved_source, + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + elif is_localish_path(resolved_source): + local_path = ( + resolved_source[7:] + if resolved_source.startswith("file://") + else resolved_source + ) + record = await registry.register_local_file( + scope_key, + local_path, + kind="image", + display_name=display_name, + source_kind="local_image", + source_ref=raw_source or resolved_source, + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + + elif type_ == "file": + file_id = str(data.get("id", "") or "").strip() + raw_source = str(data.get("file") or data.get("url") or "").strip() + local_file_path: Path | None = None + if file_id: + local_file_path = _resolve_webui_file_id(file_id) + elif is_localish_path(raw_source): + local_file_path = Path( + raw_source[7:] + if raw_source.startswith("file://") + else raw_source + ) + display_name = ( + str(data.get("name", "") or "").strip() + or (local_file_path.name if local_file_path is not None else "") + or display_name_from_source(raw_source, f"file_{index + 1}.bin") + ) + if local_file_path is not None and local_file_path.is_file(): + record = await registry.register_local_file( + scope_key, + local_file_path, + kind="file", + display_name=display_name, + source_kind="webui_file" if file_id else "local_file", + source_ref=file_id or raw_source or str(local_file_path), + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + elif is_http_url(raw_source): + record = await registry.register_remote_url( + scope_key, + raw_source, + kind="file", + display_name=display_name, + source_kind="remote_file", + source_ref=file_id or raw_source, + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + + elif ( + type_ == "forward" + and get_forward_messages is not None + and depth < _FORWARD_ATTACHMENT_MAX_DEPTH + ): + # 合并转发递归展开,深度上限防止无限嵌套 + forward_id = _extract_forward_id(data) + if forward_id and forward_id not in visited_forward_ids: + visited_forward_ids.add(forward_id) + try: + nodes = _normalize_forward_nodes( + await get_forward_messages(forward_id) + ) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] forward resolver failed: id=%s err=%s", + forward_id, + exc, + ) + nodes = [] + for node_index, node in enumerate(nodes): + raw_message = ( + node.get("content") + or node.get("message") + or node.get("raw_message") + ) + nested_segments = normalize_message_segments(raw_message) + if not nested_segments: + continue + await _collect_from_segments( + nested_segments, + depth=depth + 1, + prefix=f"{prefix}forward:{forward_id}:{node_index}:", + ) + except ( + binascii.Error, + ValueError, + FileNotFoundError, + httpx.HTTPError, + ) as exc: + logger.warning( + "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", + type_, + index, + exc, + ) + except Exception as exc: + logger.exception( + "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", + type_, + index, + exc, + ) + + if ref is not None: + attachments.append(ref) + if depth == 0: + normalized_parts.append(segment_text(type_, data, ref)) + + await _collect_from_segments(segments, depth=0, prefix="") + + return RegisteredMessageAttachments( + attachments=attachments, + normalized_text="".join(normalized_parts).strip(), + ) diff --git a/src/Undefined/bilibili/wbi.py b/src/Undefined/bilibili/wbi.py index f898af05..2da7a9cc 100644 --- a/src/Undefined/bilibili/wbi.py +++ b/src/Undefined/bilibili/wbi.py @@ -192,6 +192,7 @@ async def get_mixin_key( force_refresh: bool = False, ) -> str: """获取可复用的 mixin_key。""" + # global global _cached_mixin_key_async, _cached_at_async now = time.time() @@ -295,6 +296,7 @@ def get_mixin_key_sync( force_refresh: bool = False, ) -> str: """同步获取可复用的 mixin_key。""" + # global global _cached_mixin_key_sync, _cached_at_sync now = time.time() diff --git a/src/Undefined/cognitive/historian.py b/src/Undefined/cognitive/historian.py index fda333c3..8692452d 100644 --- a/src/Undefined/cognitive/historian.py +++ b/src/Undefined/cognitive/historian.py @@ -307,7 +307,6 @@ async def _process_job(self, job_id: str, job: dict[str, Any]) -> None: len(job.get("profile_targets", []) or []), ) - # 兼容旧版:优先 observations,fallback new_info raw_observations = ( job.get("observations") if "observations" in job @@ -363,7 +362,6 @@ async def _process_job(self, job_id: str, job: dict[str, Any]) -> None: len(canonical), ) - # 侧写合并:传入所有 canonical 文本 has_obs = ( job.get("has_observations") if "has_observations" in job @@ -413,9 +411,7 @@ async def _rewrite( ) -> str: from Undefined.utils.resources import read_text_resource - # 向后兼容:优先 memo,fallback action_summary memo = str(job.get("memo") if "memo" in job else job.get("action_summary", "")) - # 向后兼容:优先 observations,fallback new_info observations = str( job.get("observations") if "observations" in job @@ -537,7 +533,6 @@ def _resolve_profile_targets(self, job: dict[str, Any]) -> list[dict[str, str]]: if targets: return targets - # 向后兼容旧任务:沿用单目标策略。 entity_type = "group" if str(job.get("group_id", "")).strip() else "user" entity_id = str( job.get("group_id") or job.get("user_id") or job.get("sender_id", "") @@ -768,7 +763,6 @@ async def _merge_profile_target( preferred_name = str(target.get("preferred_name", "")).strip() - # 检索该实体的历史事件作为 merge 参考 observations_raw = job.get("observations", job.get("new_info", [])) observations_text = ( "\n".join(observations_raw) @@ -884,7 +878,6 @@ async def _merge_profile_target( ) break - # 追加 assistant 轮次 assistant_msg: dict[str, Any] = { "role": "assistant", "tool_calls": tool_calls, diff --git a/src/Undefined/cognitive/historian/__init__.py b/src/Undefined/cognitive/historian/__init__.py new file mode 100644 index 00000000..0e1fce98 --- /dev/null +++ b/src/Undefined/cognitive/historian/__init__.py @@ -0,0 +1,5 @@ +"""认知史官 Worker 包。""" + +from Undefined.cognitive.historian.worker import HistorianWorker + +__all__ = ["HistorianWorker"] diff --git a/src/Undefined/cognitive/historian/helpers.py b/src/Undefined/cognitive/historian/helpers.py new file mode 100644 index 00000000..e1bda7da --- /dev/null +++ b/src/Undefined/cognitive/historian/helpers.py @@ -0,0 +1,85 @@ +"""Historian 辅助函数。""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime, timezone +from typing import Any + +logger = logging.getLogger(__name__) + +_MAX_LOG_PREVIEW_LEN = 200 + + +def _preview_text(text: str, max_len: int = _MAX_LOG_PREVIEW_LEN) -> str: + compact = re.sub(r"\s+", " ", str(text or "")).strip() + if len(compact) <= max_len: + return compact + return f"{compact[:max_len]}..." + + +def _extract_frontmatter_name(markdown: str) -> str: + text = str(markdown or "") + if not text.startswith("---"): + return "" + try: + import yaml + + parts = text[3:].split("---", 1) + if len(parts) != 2: + return "" + frontmatter = yaml.safe_load(parts[0]) + if not isinstance(frontmatter, dict): + return "" + value = frontmatter.get("name") + return str(value).strip() if value is not None else "" + except Exception: + return "" + + +def _escape_braces(text: str) -> str: + value = str(text or "") + return value.replace("{", "{{").replace("}", "}}") + + +def _resolve_timestamp_epoch(job: dict[str, Any]) -> int: + raw_epoch = job.get("timestamp_epoch") + if isinstance(raw_epoch, (int, float)): + return int(raw_epoch) + if isinstance(raw_epoch, str): + try: + return int(float(raw_epoch.strip())) + except Exception: + pass + + for key in ("timestamp_utc", "timestamp_local"): + raw_value = job.get(key) + if not isinstance(raw_value, str): + continue + text = raw_value.strip() + if not text: + continue + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return int(parsed.timestamp()) + except Exception: + continue + + return int(datetime.now(timezone.utc).timestamp()) + + +def _coerce_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "y", "on"}: + return True + if normalized in {"0", "false", "no", "n", "off", ""}: + return False + return False diff --git a/src/Undefined/cognitive/historian/tools.py b/src/Undefined/cognitive/historian/tools.py new file mode 100644 index 00000000..83cc3018 --- /dev/null +++ b/src/Undefined/cognitive/historian/tools.py @@ -0,0 +1,78 @@ +"""Historian LLM 工具定义。""" + +from __future__ import annotations + +_REWRITE_TOOL = { + "type": "function", + "function": { + "name": "submit_rewrite", + "description": "提交绝对化改写后的事件文本", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string", "description": "改写后的纯文本"}, + }, + "required": ["text"], + }, + }, +} +_READ_PROFILE_TOOL = { + "type": "function", + "function": { + "name": "read_profile", + "description": "读取指定实体的当前侧写内容", + "parameters": { + "type": "object", + "properties": { + "entity_type": { + "type": "string", + "enum": ["user", "group"], + "description": "实体类型:user 或 group", + }, + "entity_id": { + "type": "string", + "description": "实体 ID(用户 QQ 号或群号)", + }, + }, + "required": ["entity_type", "entity_id"], + }, + }, +} +_PROFILE_TOOL = { + "type": "function", + "function": { + "name": "update_profile", + "description": "更新用户/群侧写。调用前必须先用 read_profile 查看当前内容", + "parameters": { + "type": "object", + "properties": { + "entity_type": { + "type": "string", + "enum": ["user", "group"], + "description": "实体类型:user 或 group", + }, + "entity_id": { + "type": "string", + "description": "实体 ID(用户 QQ 号或群号)", + }, + "skip": { + "type": "boolean", + "description": "是否跳过更新;当新信息不稳定/不足时为 true", + }, + "skip_reason": { + "type": "string", + "description": "跳过原因", + }, + "name": {"type": "string", "description": "用户/群名称"}, + "tags": { + "type": "array", + "items": {"type": "string"}, + "maxItems": 10, + "description": "身份级标签(角色/核心领域),最多 10 个,不写话题", + }, + "summary": {"type": "string", "description": "侧写正文(Markdown)"}, + }, + "required": ["entity_type", "entity_id", "skip", "name", "tags", "summary"], + }, + }, +} diff --git a/src/Undefined/cognitive/historian/worker.py b/src/Undefined/cognitive/historian/worker.py new file mode 100644 index 00000000..30ecf43c --- /dev/null +++ b/src/Undefined/cognitive/historian/worker.py @@ -0,0 +1,903 @@ +"""HistorianWorker 实现。""" + +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime +from typing import Any, Callable + +from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY +from Undefined.utils.tool_calls import extract_required_tool_call_arguments + +from Undefined.cognitive.historian.helpers import ( + _coerce_bool, + _escape_braces, + _extract_frontmatter_name, + _preview_text, + _resolve_timestamp_epoch, +) +from Undefined.cognitive.historian.tools import ( + _PROFILE_TOOL, + _READ_PROFILE_TOOL, + _REWRITE_TOOL, +) + +logger = logging.getLogger(__name__) + + +class HistorianWorker: + def __init__( + self, + job_queue: Any, + vector_store: Any, + profile_storage: Any, + ai_client: Any, + config_getter: Callable[[], Any], + model_config: Any = None, + ) -> None: + self._job_queue = job_queue + self._vector_store = vector_store + self._profile_storage = profile_storage + self._ai_client = ai_client + self._config_getter = config_getter + self._model_config = model_config + self._stop_event = asyncio.Event() + self._task: asyncio.Task[None] | None = None + self._inflight_tasks: set[asyncio.Task[None]] = set() + + async def _prepare_query_embedding(self, query_text: str) -> list[float] | None: + embed_query = getattr(self._vector_store, "embed_query", None) + if not callable(embed_query): + return None + try: + result = await embed_query(query_text) + except Exception as exc: + logger.warning("[史官] 预生成查询向量失败,回退即时计算: error=%s", exc) + return None + if not isinstance(result, list): + logger.warning("[史官] 预生成查询向量返回值非法,回退即时计算") + return None + normalized: list[float] = [] + for item in result: + try: + normalized.append(float(item)) + except (TypeError, ValueError): + logger.warning("[史官] 预生成查询向量包含非法元素,回退即时计算") + return None + return normalized + + async def start(self) -> None: + logger.info("[史官] Worker 启动中") + self._task = asyncio.create_task(self._poll_loop()) + logger.info("[史官] Worker 已启动") + + async def stop(self) -> None: + logger.info("[史官] Worker 停止中") + self._stop_event.set() + if self._task: + await self._task + logger.info("[史官] Worker 已停止") + + async def _poll_loop(self) -> None: + dispatch_count = 0 + logger.info("[史官] 轮询循环已开始") + while not self._stop_event.is_set(): + result = await self._job_queue.dequeue() + if result: + job_id, job = result + task = asyncio.create_task(self._process_job_with_retry(job_id, job)) + self._inflight_tasks.add(task) + task.add_done_callback(self._inflight_tasks.discard) + dispatch_count += 1 + logger.info( + "[史官] 任务已发车: job_id=%s inflight=%s", + job_id, + len(self._inflight_tasks), + ) + + config = self._config_getter() + if ( + config.failed_cleanup_interval > 0 + and dispatch_count > 0 + and dispatch_count % config.failed_cleanup_interval == 0 + ): + from Undefined.utils.cache import cleanup_cache_dir + + cleanup_cache_dir( + self._job_queue._failed_dir, + max_age_seconds=config.failed_max_age_days * 86400, + max_files=config.failed_max_files, + ) + logger.info( + "[史官] failed 队列清理已执行: interval=%s max_age_days=%s max_files=%s", + config.failed_cleanup_interval, + config.failed_max_age_days, + config.failed_max_files, + ) + + await asyncio.sleep(config.poll_interval_seconds) + + if self._inflight_tasks: + logger.info( + "[史官] 等待在途任务收敛: inflight=%s", len(self._inflight_tasks) + ) + await asyncio.gather(*list(self._inflight_tasks), return_exceptions=True) + logger.info("[史官] 轮询循环已结束") + + async def _process_job_with_retry(self, job_id: str, job: dict[str, Any]) -> None: + try: + await self._process_job(job_id, job) + except Exception as e: + retry_count = job.get("_retry_count", 0) + max_retries = self._config_getter().job_max_retries + if retry_count < max_retries: + logger.warning( + "[史官] 任务 %s 处理失败 (%s/%s),将自动重试: %s", + job_id, + retry_count + 1, + max_retries, + e, + ) + await self._job_queue.requeue(job_id, str(e)) + else: + logger.error( + "[史官] 任务 %s 达到最大重试次数 (%s),移入 failed: %s", + job_id, + max_retries, + e, + ) + await self._job_queue.fail(job_id, str(e)) + + async def _rewrite_and_validate(self, job: dict[str, Any], job_id: str) -> str: + """改写为绝对化事件文本。""" + canonical = await self._rewrite(job, job_id=job_id) + return canonical + + async def _process_job(self, job_id: str, job: dict[str, Any]) -> None: + logger.info( + "[史官] 开始处理任务 %s: user=%s group=%s sender=%s perspective=%s has_observations=%s profile_targets=%s", + job_id, + job.get("user_id", ""), + job.get("group_id", ""), + job.get("sender_id", ""), + job.get("perspective", ""), + job.get("has_observations", job.get("has_new_info", False)), + len(job.get("profile_targets", []) or []), + ) + + raw_observations = ( + job.get("observations") + if "observations" in job + else job.get("new_info", []) + ) + if isinstance(raw_observations, str): + observation_items = ( + [raw_observations.strip()] if raw_observations.strip() else [] + ) + elif isinstance(raw_observations, list): + observation_items = [ + str(s).strip() for s in raw_observations if str(s).strip() + ] + else: + observation_items = [] + + base_metadata: dict[str, Any] = { + "request_id": job.get("request_id", ""), + "end_seq": job.get("end_seq", 0), + "user_id": job.get("user_id", ""), + "group_id": job.get("group_id", ""), + "sender_id": job.get("sender_id", ""), + "request_type": job.get("request_type", ""), + "timestamp_utc": job.get("timestamp_utc", ""), + "timestamp_local": job.get("timestamp_local", ""), + "timestamp_epoch": _resolve_timestamp_epoch(job), + "timezone": job.get("timezone", ""), + "location_abs": job.get("location_abs", ""), + "message_ids": job.get("message_ids", []), + "perspective": str(job.get("perspective", "")).strip(), + "schema_version": job.get("schema_version", "final_v1"), + } + + canonicals: list[str] = [] + + if observation_items: + # 每条 observation 独立改写+入库 + for idx, info_item in enumerate(observation_items): + sub_job = {**job, "observations": info_item} + event_id = f"{job_id}_{idx}" if len(observation_items) > 1 else job_id + canonical = await self._rewrite_and_validate(sub_job, event_id) + meta = { + **base_metadata, + "has_observations": True, + } + await self._vector_store.upsert_event(event_id, canonical, meta) + canonicals.append(canonical) + logger.info( + "[史官] 任务 %s 事件入库完成(%s/%s): len=%s", + event_id, + idx + 1, + len(observation_items), + len(canonical), + ) + + has_obs = ( + job.get("has_observations") + if "has_observations" in job + else job.get("has_new_info", False) + ) + if has_obs and canonicals: + merged_canonical = "\n".join(canonicals) + await self._merge_profiles(job, merged_canonical, job_id) + + await self._job_queue.complete(job_id) + logger.info("[史官] 任务 %s 处理完成", job_id) + + def _extract_required_tool_args( + self, + response: dict[str, Any], + *, + expected_tool_name: str, + stage: str, + job_id: str, + attempt: int | None = None, + target: str | None = None, + ) -> dict[str, Any]: + suffix = f" stage={stage} expected_tool={expected_tool_name}" + if attempt is not None: + suffix += f" attempt={attempt}" + if target: + suffix += f" target={target}" + try: + return extract_required_tool_call_arguments( + response, + expected_tool_name=expected_tool_name, + stage=stage, + logger=logger, + error_context=f"job_id={job_id}{suffix}", + ) + except Exception as exc: + logger.error( + "[史官] 任务 %s 提取工具参数失败:%s err=%s", job_id, suffix, exc + ) + raise + + async def _rewrite( + self, + job: dict[str, Any], + *, + job_id: str = "", + ) -> str: + from Undefined.utils.resources import read_text_resource + + memo = str(job.get("memo") if "memo" in job else job.get("action_summary", "")) + observations = str( + job.get("observations") + if "observations" in job + else job.get("new_info", "") + ) + message_ids_raw = job.get("message_ids", []) + if isinstance(message_ids_raw, list): + message_ids = [ + str(item).strip() for item in message_ids_raw if str(item).strip() + ] + else: + message_ids = [] + profile_targets_raw = job.get("profile_targets", []) + profile_targets_text = "[]" + if isinstance(profile_targets_raw, list) and profile_targets_raw: + compact_targets: list[str] = [] + for target in profile_targets_raw: + if not isinstance(target, dict): + continue + entity_type = str(target.get("entity_type", "")).strip() + entity_id = str(target.get("entity_id", "")).strip() + perspective = str(target.get("perspective", "")).strip() + if not entity_type or not entity_id: + continue + if perspective: + compact_targets.append(f"{entity_type}:{entity_id}({perspective})") + else: + compact_targets.append(f"{entity_type}:{entity_id}") + if compact_targets: + profile_targets_text = ", ".join(compact_targets) + logger.debug( + "[史官] 任务 %s 发起绝对化改写: memo_len=%s observations_len=%s", + job_id or "unknown", + len(memo), + len(observations), + ) + + template = read_text_resource("res/prompts/historian_rewrite.md") + source_message = str(job.get("source_message", "")).strip() + recent_messages_raw = job.get("recent_messages", []) + recent_messages: list[str] = [] + if isinstance(recent_messages_raw, list): + recent_messages = [ + str(item).strip() for item in recent_messages_raw if str(item).strip() + ] + recent_messages_text = "\n---\n".join(recent_messages) + prompt = template.format( + request_id=job.get("request_id", ""), + end_seq=job.get("end_seq", 0), + timestamp_local=job.get("timestamp_local", ""), + timezone=job.get("timezone", "Asia/Shanghai"), + bot_name=job.get("bot_name", "Undefined"), + user_id=job.get("user_id", ""), + group_id=job.get("group_id", ""), + sender_id=job.get("sender_id", ""), + sender_name=job.get("sender_name", ""), + group_name=job.get("group_name", ""), + message_ids=", ".join(message_ids) if message_ids else "[]", + perspective=job.get("perspective", ""), + profile_targets=profile_targets_text, + force="true" if _coerce_bool(job.get("force", False)) else "false", + action_summary=memo, + new_info=observations, + memo=memo, + observations=observations, + source_message=source_message or "(无)", + recent_messages=recent_messages_text or "(无)", + ) + response = await self._ai_client.submit_background_llm_call( + model_config=self._model_config or self._ai_client.agent_config, + messages=[{"role": "user", "content": prompt}], + tools=[_REWRITE_TOOL], + tool_choice={"type": "function", "function": {"name": "submit_rewrite"}}, + call_type="historian_rewrite", + ) + args = self._extract_required_tool_args( + response=response, + expected_tool_name="submit_rewrite", + stage="historian_rewrite", + job_id=job_id or "unknown", + ) + + text = str(args.get("text", "")).strip() + logger.debug( + "[史官] 任务 %s 收到改写结果: len=%s preview=%s", + job_id or "unknown", + len(text), + _preview_text(text), + ) + return text + + def _resolve_profile_targets(self, job: dict[str, Any]) -> list[dict[str, str]]: + targets: list[dict[str, str]] = [] + seen: set[tuple[str, str]] = set() + raw_targets = job.get("profile_targets") + if isinstance(raw_targets, list): + for item in raw_targets: + if not isinstance(item, dict): + continue + entity_type = str(item.get("entity_type", "")).strip() + raw_entity_id = item.get("entity_id") + entity_id = ( + str(raw_entity_id).strip() if raw_entity_id is not None else "" + ) + if entity_type not in {"user", "group"} or not entity_id: + continue + key = (entity_type, entity_id) + if key in seen: + continue + seen.add(key) + targets.append( + { + "entity_type": entity_type, + "entity_id": entity_id, + "perspective": str(item.get("perspective", "")).strip(), + "preferred_name": str(item.get("preferred_name", "")).strip(), + } + ) + if targets: + return targets + + entity_type = "group" if str(job.get("group_id", "")).strip() else "user" + entity_id = str( + job.get("group_id") or job.get("user_id") or job.get("sender_id", "") + ).strip() + if entity_id: + targets.append( + { + "entity_type": entity_type, + "entity_id": entity_id, + "perspective": "legacy", + "preferred_name": "", + } + ) + return targets + + async def _merge_profiles( + self, job: dict[str, Any], canonical: str, event_id: str + ) -> None: + targets = self._resolve_profile_targets(job) + if not targets: + logger.warning("[史官] 任务 %s 侧写合并跳过:缺少目标实体", event_id) + return + logger.info( + "[史官] 任务 %s 开始合并侧写: target_count=%s targets=%s", + event_id, + len(targets), + [ + (t["entity_type"], t["entity_id"], t.get("perspective", "")) + for t in targets + ], + ) + success_count = 0 + for index, target in enumerate(targets, start=1): + try: + merged = await self._merge_profile_target( + job=job, + canonical=canonical, + event_id=event_id, + target=target, + target_index=index, + target_count=len(targets), + ) + if merged: + success_count += 1 + except Exception as exc: + logger.exception( + "[史官] 任务 %s 侧写目标合并失败: target=%s:%s perspective=%s err=%s", + event_id, + target.get("entity_type", ""), + target.get("entity_id", ""), + target.get("perspective", ""), + exc, + ) + logger.info( + "[史官] 任务 %s 侧写合并结束: success=%s total=%s", + event_id, + success_count, + len(targets), + ) + + async def _write_profile( + self, + *, + entity_type: str, + entity_id: str, + effective_name: str, + tags: list[str], + summary: str, + event_id: str, + perspective: str, + ) -> None: + import yaml + + frontmatter: dict[str, Any] = { + "entity_type": entity_type, + "entity_id": entity_id, + "name": effective_name, + "tags": tags, + "updated_at": datetime.now().isoformat(), + "source_event_id": event_id, + } + if entity_type == "user": + frontmatter["nickname"] = effective_name + frontmatter["qq"] = entity_id + else: + frontmatter["group_name"] = effective_name + frontmatter["group_id"] = entity_id + content = f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{summary}" + + await self._profile_storage.write_profile(entity_type, entity_id, content) + logger.info( + "[史官] 任务 %s 侧写文件写入完成: entity_type=%s entity_id=%s tags=%s perspective=%s", + event_id, + entity_type, + entity_id, + tags, + perspective, + ) + + profile_doc_lines: list[str] = [] + if entity_type == "user": + profile_doc_lines.append(f"昵称: {effective_name}") + profile_doc_lines.append(f"QQ号: {entity_id}") + else: + profile_doc_lines.append(f"群名: {effective_name}") + profile_doc_lines.append(f"群号: {entity_id}") + if tags: + profile_doc_lines.append(f"标签: {', '.join(tags)}") + profile_doc_lines.append(summary) + profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) + + profile_metadata: dict[str, Any] = { + "entity_type": entity_type, + "entity_id": entity_id, + "name": effective_name, + } + if entity_type == "user": + profile_metadata["nickname"] = effective_name + profile_metadata["qq"] = entity_id + else: + profile_metadata["group_name"] = effective_name + profile_metadata["group_id"] = entity_id + + await self._vector_store.upsert_profile( + f"{entity_type}:{entity_id}", + profile_doc, + profile_metadata, + ) + logger.info( + "[史官] 任务 %s 侧写向量入库完成: profile_id=%s perspective=%s", + event_id, + f"{entity_type}:{entity_id}", + perspective, + ) + + @staticmethod + def _historical_event_dedupe_key( + event: dict[str, Any], + ) -> tuple[str, str, str, str, str]: + metadata = event.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + return ( + str(event.get("document", "")).strip(), + str(metadata.get("timestamp_local", "")).strip(), + str(metadata.get("sender_id", "")).strip(), + str(metadata.get("user_id", "")).strip(), + str(metadata.get("group_id", "")).strip(), + ) + + async def _query_user_history_events_for_profile_merge( + self, + *, + query_text: str, + entity_id: str, + top_k: int, + query_embedding: list[float] | None = None, + ) -> list[dict[str, Any]]: + """用户历史检索兼容路径:分别按 sender_id/user_id 查询并合并去重。 + + Compatibility path for user history retrieval: + query sender_id/user_id separately, then merge and dedupe. + """ + safe_top_k = max(1, int(top_k)) + query_embedding_value = query_embedding + if query_embedding_value is None: + query_embedding_value = await self._prepare_query_embedding(query_text) + sender_query = self._vector_store.query_events( + query_text, + top_k=safe_top_k, + where={"sender_id": entity_id}, + apply_mmr=True, + query_embedding=query_embedding_value, + ) + user_query = self._vector_store.query_events( + query_text, + top_k=safe_top_k, + where={"user_id": entity_id}, + apply_mmr=True, + query_embedding=query_embedding_value, + ) + sender_events_raw, user_events_raw = await asyncio.gather( + sender_query, user_query + ) + merged_events = list(sender_events_raw) + list(user_events_raw) + + deduped: list[dict[str, Any]] = [] + seen: set[tuple[str, str, str, str, str]] = set() + for event in merged_events: + key = self._historical_event_dedupe_key(event) + if key in seen: + continue + seen.add(key) + deduped.append(event) + if len(deduped) >= safe_top_k: + break + return deduped + + async def _merge_profile_target( + self, + *, + job: dict[str, Any], + canonical: str, + event_id: str, + target: dict[str, str], + target_index: int, + target_count: int, + ) -> bool: + entity_type = str(target.get("entity_type", "")).strip() + entity_id = str(target.get("entity_id", "")).strip() + perspective = str(target.get("perspective", "")).strip() + if entity_type not in {"user", "group"} or not entity_id: + logger.warning( + "[史官] 任务 %s 侧写目标非法,跳过: target=%s", + event_id, + target, + ) + return False + logger.info( + "[史官] 任务 %s 合并侧写目标(%s/%s): entity_type=%s entity_id=%s perspective=%s", + event_id, + target_index, + target_count, + entity_type, + entity_id, + perspective, + ) + + preferred_name = str(target.get("preferred_name", "")).strip() + + observations_raw = job.get("observations", job.get("new_info", [])) + observations_text = ( + "\n".join(observations_raw) + if isinstance(observations_raw, list) + else str(observations_raw) + ) + query_embedding = await self._prepare_query_embedding(observations_text) + if entity_type == "group": + historical_events = await self._vector_store.query_events( + observations_text, + top_k=8, + where={"group_id": entity_id}, + apply_mmr=True, + query_embedding=query_embedding, + ) + else: + historical_events = await self._query_user_history_events_for_profile_merge( + query_text=observations_text, + entity_id=entity_id, + top_k=8, + query_embedding=query_embedding, + ) + historical_lines = ( + "\n".join( + f"- [{e['metadata'].get('timestamp_local', '')}] {e['document']}" + for e in historical_events + ) + or "(暂无历史事件)" + ) + + from Undefined.utils.resources import read_text_resource + + template = read_text_resource("res/prompts/historian_profile_merge.md") + message_ids_raw = job.get("message_ids", []) + if isinstance(message_ids_raw, list): + message_ids = [ + str(item).strip() for item in message_ids_raw if str(item).strip() + ] + else: + message_ids = [] + + prompt = template.format( + historical_events=_escape_braces(historical_lines), + canonical_text=_escape_braces(canonical), + observations=_escape_braces(observations_text), + new_info=_escape_braces(observations_text), + target_entity_type=entity_type, + target_entity_id=entity_id, + target_perspective=perspective, + target_display_name=_escape_braces(preferred_name or entity_id), + request_type=_escape_braces(str(job.get("request_type", ""))), + user_id=_escape_braces(str(job.get("user_id", ""))), + group_id=_escape_braces(str(job.get("group_id", ""))), + sender_id=_escape_braces(str(job.get("sender_id", ""))), + sender_name=_escape_braces(str(job.get("sender_name", ""))), + group_name=_escape_braces(str(job.get("group_name", ""))), + timestamp_local=_escape_braces(str(job.get("timestamp_local", ""))), + timezone=_escape_braces(str(job.get("timezone", ""))), + event_id=_escape_braces(event_id), + request_id=_escape_braces(str(job.get("request_id", ""))), + end_seq=_escape_braces(str(job.get("end_seq", 0))), + message_ids=_escape_braces(", ".join(message_ids) if message_ids else "[]"), + memo=_escape_braces(str(job.get("memo", job.get("action_summary", "")))), + action_summary=_escape_braces( + str(job.get("memo", job.get("action_summary", ""))) + ), + source_message=_escape_braces(str(job.get("source_message", ""))), + recent_messages=_escape_braces( + "\n".join( + f"- {str(item).strip()}" + for item in (job.get("recent_messages", []) or []) + if str(item).strip() + ) + or "(无)" + ), + ) + + messages: list[dict[str, Any]] = [{"role": "user", "content": prompt}] + tools = [_READ_PROFILE_TOOL, _PROFILE_TOOL] + result = False + max_turns = 100 + transport_state: dict[str, Any] | None = None + + for turn in range(max_turns): + response = await self._ai_client.submit_background_llm_call( + model_config=self._model_config or self._ai_client.agent_config, + messages=messages, + tools=tools, + tool_choice="auto", + call_type="historian_profile_merge", + transport_state=transport_state, + ) + + next_transport_state = ( + response.get("_transport_state") if isinstance(response, dict) else None + ) + transport_state = ( + next_transport_state if isinstance(next_transport_state, dict) else None + ) + + choices = response.get("choices") or [] + if not choices: + logger.warning("[史官] 任务 %s turn=%s 响应无 choices", event_id, turn) + break + message = choices[0].get("message") if isinstance(choices[0], dict) else {} + if not isinstance(message, dict): + break + + tool_calls = message.get("tool_calls") or [] + if not tool_calls: + logger.info( + "[史官] 任务 %s turn=%s 无 tool_calls,结束", event_id, turn + ) + break + + assistant_msg: dict[str, Any] = { + "role": "assistant", + "tool_calls": tool_calls, + } + if message.get("content"): + assistant_msg["content"] = message["content"] + output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) + if isinstance(output_items, list): + assistant_msg[RESPONSES_OUTPUT_ITEMS_KEY] = output_items + messages.append(assistant_msg) + + tool_results: list[dict[str, Any]] = [] + done = False + + for tc in tool_calls: + if not isinstance(tc, dict): + continue + func = tc.get("function") or {} + tc_name = str(func.get("name", "")).strip() + tc_id = str(tc.get("id", "")).strip() + try: + tc_args: dict[str, Any] = json.loads( + str(func.get("arguments", "{}")) + ) + except json.JSONDecodeError: + tc_args = {} + + if tc_name == "read_profile": + rp_et = str(tc_args.get("entity_type", "")).strip() + rp_eid = str(tc_args.get("entity_id", "")).strip() + if ( + rp_et not in {"user", "group"} + or not rp_eid + or not rp_eid.isalnum() + ): + tc_content = "错误:entity_type 或 entity_id 无效" + else: + profile_text = await self._profile_storage.read_profile( + rp_et, rp_eid + ) + tc_content = profile_text or "(暂无侧写)" + logger.info( + "[史官] 任务 %s read_profile: %s:%s len=%s", + event_id, + rp_et, + rp_eid, + len(tc_content), + ) + tool_results.append( + {"role": "tool", "tool_call_id": tc_id, "content": tc_content} + ) + + elif tc_name == "update_profile": + up_et = str(tc_args.get("entity_type", entity_type)).strip() + up_eid = str(tc_args.get("entity_id", entity_id)).strip() + if ( + up_et not in {"user", "group"} + or not up_eid + or not up_eid.isalnum() + ): + tool_results.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": "错误:entity_type 或 entity_id 无效", + } + ) + continue + raw_skip = tc_args.get("skip", False) + skip = ( + raw_skip.lower() not in ("false", "0", "no", "") + if isinstance(raw_skip, str) + else bool(raw_skip) + ) + if skip: + skip_reason = str(tc_args.get("skip_reason", "")).strip() + logger.info( + "[史官] 任务 %s 侧写更新跳过: target=%s:%s perspective=%s reason=%s", + event_id, + up_et, + up_eid, + perspective, + skip_reason or "unspecified", + ) + tool_results.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": f"已跳过: {skip_reason}", + } + ) + done = True + continue + + summary = str(tc_args.get("summary", "")).strip() + if not summary: + logger.info( + "[史官] 任务 %s 侧写更新跳过: target=%s:%s reason=empty_summary", + event_id, + up_et, + up_eid, + ) + tool_results.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": "错误:summary 为空", + } + ) + continue + raw_tags = tc_args.get("tags", []) + up_tags: list[str] = [] + if isinstance(raw_tags, list): + up_tags = [str(t).strip() for t in raw_tags if str(t).strip()][ + :10 + ] + + llm_name = str(tc_args.get("name", "")).strip() + is_target = up_et == entity_type and up_eid == entity_id + name_hint = preferred_name if is_target else "" + if not llm_name and not name_hint: + existing = await self._profile_storage.read_profile( + up_et, up_eid + ) + fallback_name = _extract_frontmatter_name(existing or "") + else: + fallback_name = "" + effective_name = ( + name_hint + or llm_name + or fallback_name + or (f"GID:{up_eid}" if up_et == "group" else f"UID:{up_eid}") + ) + + await self._write_profile( + entity_type=up_et, + entity_id=up_eid, + effective_name=effective_name, + tags=up_tags, + summary=summary, + event_id=event_id, + perspective=perspective, + ) + tool_results.append( + {"role": "tool", "tool_call_id": tc_id, "content": "侧写已更新"} + ) + result = True + done = True + + else: + tool_results.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": f"未知工具: {tc_name}", + } + ) + + messages.extend(tool_results) + if done: + break + + return result diff --git a/src/Undefined/cognitive/job_queue.py b/src/Undefined/cognitive/job_queue.py index 7c5eba61..e1b036d2 100644 --- a/src/Undefined/cognitive/job_queue.py +++ b/src/Undefined/cognitive/job_queue.py @@ -23,7 +23,6 @@ def __init__(self, base_path: str | Path) -> None: self._failed_dir = base / "failed" for d in (self._pending_dir, self._processing_dir, self._failed_dir): d.mkdir(parents=True, exist_ok=True) - # 启动时清理所有遗留的 lock 文件 stale_lock_count = 0 for d in (self._pending_dir, self._processing_dir, self._failed_dir): for lock_file in d.glob("*.lock"): @@ -65,7 +64,6 @@ def _pick() -> tuple[str, dict[str, Any]] | None: with open(dst, "r", encoding="utf-8") as fh: data = json.load(fh) - # 清理遗留的 lock 文件 lock_file = f.with_name(f"{f.name}.lock") lock_file.unlink(missing_ok=True) return f.stem, data @@ -123,7 +121,6 @@ async def requeue(self, job_id: str, error: str) -> None: data = await read_json(src) or {} data["_retry_count"] = data.get("_retry_count", 0) + 1 data["_last_error"] = error - # 先原子更新 processing 内容,再原子移动到 pending await write_json(src, data) await asyncio.to_thread(lambda: os.replace(src, dst)) logger.info( diff --git a/src/Undefined/cognitive/profile_storage.py b/src/Undefined/cognitive/profile_storage.py index 90585fae..d1f61215 100644 --- a/src/Undefined/cognitive/profile_storage.py +++ b/src/Undefined/cognitive/profile_storage.py @@ -74,7 +74,6 @@ def _write() -> None: p.parent.mkdir(parents=True, exist_ok=True) hist_dir.mkdir(parents=True, exist_ok=True) - # 备份现有版本 if p.exists(): ts = datetime.now().strftime("%Y%m%d%H%M%S%f") (hist_dir / f"{ts}.md").write_text( @@ -87,7 +86,6 @@ def _write() -> None: ts, ) - # 原子写入 fd, tmp = tempfile.mkstemp( prefix=f".{p.name}.", suffix=".tmp", dir=str(p.parent) ) @@ -102,7 +100,6 @@ def _write() -> None: pass raise - # 清理旧快照 snapshots = sorted(hist_dir.glob("*.md")) for old in snapshots[: max(0, len(snapshots) - self._revision_keep)]: try: @@ -140,7 +137,6 @@ def _list() -> list[str]: def _sanitize_profile(content: str, entity_type: str, entity_id: str) -> str: import yaml - # 剥离 ```markdown / ``` 包裹 stripped = content.strip() if stripped.startswith("```"): lines = stripped.splitlines() @@ -151,7 +147,6 @@ def _sanitize_profile(content: str, entity_type: str, entity_id: str) -> str: if end: stripped = "\n".join(lines[1:end]) - # 解析 frontmatter if stripped.startswith("---"): parts = stripped[3:].split("---", 1) if len(parts) == 2: diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py index 13cf3f1c..831fac53 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service.py @@ -320,7 +320,6 @@ def _merge_weighted_events( safe_group_boost = max(0.0, float(current_group_boost)) seen_keys: set[tuple[str, str, str, str, str, str]] = set() # 排序主键优先使用“作用域内原始排名”(已含 time_decay/mmr/rerank 效果), - # 再使用相似度分值兜底,避免跨 scope 合并时打乱衰减后的顺序。 scored_items: list[ tuple[float, float, float, float, float, int, dict[str, Any]] ] = [] @@ -547,7 +546,6 @@ async def enqueue_job( else str(context.get("request_id", "")).strip() ) if not safe_request_id: - # 最终兜底由 JobQueue 生成 request_id。 safe_request_id = "" end_seq_raw = context.get("_end_seq", 0) @@ -699,7 +697,6 @@ async def build_context( getattr(config, "auto_top_k", 5), ) - # 用户侧写(优先 sender_id,与 enqueue_job 写入侧一致) uid = safe_sender_id or safe_user_id if uid: profile = await self._profile_storage.read_profile("user", uid) @@ -707,7 +704,6 @@ async def build_context( label = f"{sender_name}(UID: {uid})" if sender_name else f"UID: {uid}" parts.append(f"## 用户侧写 — {label}\n{profile}") - # 群聊侧写 if safe_group_id: gprofile = await self._profile_storage.read_profile("group", safe_group_id) if gprofile: diff --git a/src/Undefined/cognitive/service/__init__.py b/src/Undefined/cognitive/service/__init__.py new file mode 100644 index 00000000..45878197 --- /dev/null +++ b/src/Undefined/cognitive/service/__init__.py @@ -0,0 +1,5 @@ +"""认知记忆服务包。""" + +from Undefined.cognitive.service.service import CognitiveService + +__all__ = ["CognitiveService"] diff --git a/src/Undefined/cognitive/service/helpers.py b/src/Undefined/cognitive/service/helpers.py new file mode 100644 index 00000000..74752c48 --- /dev/null +++ b/src/Undefined/cognitive/service/helpers.py @@ -0,0 +1,169 @@ +"""认知服务辅助函数。""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from Undefined.utils.coerce import safe_float + + +def _parse_iso_to_epoch_seconds(value: Any) -> int | None: + if not isinstance(value, str): + return None + text = value.strip() + if not text: + return None + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + except Exception: + return None + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return int(parsed.timestamp()) + + +def _compose_where(clauses: list[dict[str, Any]]) -> dict[str, Any] | None: + if not clauses: + return None + if len(clauses) == 1: + return clauses[0] + return {"$and": clauses} + + +def _event_base_score(item: dict[str, Any]) -> float: + # 优先 rerank 分,否则用 1-distance 作为相似度 + rerank_score = item.get("rerank_score") + if isinstance(rerank_score, (int, float)): + return max(0.0, float(rerank_score)) + if isinstance(rerank_score, str): + try: + return max(0.0, float(rerank_score.strip())) + except Exception: + pass + similarity = 1.0 - safe_float(item.get("distance"), default=1.0) + if similarity < 0.0: + return 0.0 + if similarity > 1.0: + return 1.0 + return similarity + + +def _event_timestamp_epoch(metadata: Any) -> float: + if not isinstance(metadata, dict): + return float("-inf") + raw_epoch = metadata.get("timestamp_epoch") + if isinstance(raw_epoch, (int, float)): + return float(raw_epoch) + if isinstance(raw_epoch, str): + try: + return float(raw_epoch.strip()) + except Exception: + pass + for key in ("timestamp_utc", "timestamp_local"): + parsed = _parse_iso_to_epoch_seconds(metadata.get(key)) + if parsed is not None: + return float(parsed) + return float("-inf") + + +def _event_dedupe_key(item: dict[str, Any]) -> tuple[str, str, str, str, str, str]: + metadata = item.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + return ( + str(item.get("document", "")).strip(), + str(metadata.get("timestamp_epoch", "")).strip(), + str(metadata.get("timestamp_local", "")).strip(), + str(metadata.get("group_id", "")).strip(), + str(metadata.get("sender_id", "")).strip(), + str(metadata.get("user_id", "")).strip(), + ) + + +def _resolve_auto_request_type( + *, + request_type: str | None, + group_id: str, + user_id: str, + sender_id: str, +) -> str: + normalized = str(request_type or "").strip().lower() + if normalized in {"group", "private"}: + return normalized + if group_id: + return "group" + if sender_id or user_id: + return "private" + return "" + + +def _parse_profile_markdown(markdown: str) -> tuple[dict[str, Any], str] | None: + text = str(markdown or "") + if not text.startswith("---"): + return None + try: + import yaml + + parts = text[3:].split("---", 1) + if len(parts) != 2: + return None + frontmatter = yaml.safe_load(parts[0]) + if not isinstance(frontmatter, dict): + return None + body = parts[1].lstrip("\n") + return frontmatter, body + except Exception: + return None + + +def _serialize_profile_markdown(frontmatter: dict[str, Any], body: str) -> str: + import yaml + + return f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{body}" + + +def _normalize_profile_tags(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item).strip() for item in value if str(item).strip()] + + +def _current_profile_name(entity_type: str, frontmatter: dict[str, Any]) -> str: + if entity_type == "user": + return str(frontmatter.get("nickname") or frontmatter.get("name") or "").strip() + return str(frontmatter.get("group_name") or frontmatter.get("name") or "").strip() + + +def _build_profile_vector_payload( + *, + entity_type: str, + entity_id: str, + effective_name: str, + tags: list[str], + summary: str, +) -> tuple[str, dict[str, Any]]: + profile_doc_lines: list[str] = [] + if entity_type == "user": + profile_doc_lines.append(f"昵称: {effective_name}") + profile_doc_lines.append(f"QQ号: {entity_id}") + else: + profile_doc_lines.append(f"群名: {effective_name}") + profile_doc_lines.append(f"群号: {entity_id}") + if tags: + profile_doc_lines.append(f"标签: {', '.join(tags)}") + profile_doc_lines.append(summary) + profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) + + metadata: dict[str, Any] = { + "entity_type": entity_type, + "entity_id": entity_id, + "name": effective_name, + } + if entity_type == "user": + metadata["nickname"] = effective_name + metadata["qq"] = entity_id + else: + metadata["group_name"] = effective_name + metadata["group_id"] = entity_id + return profile_doc, metadata diff --git a/src/Undefined/cognitive/service/service.py b/src/Undefined/cognitive/service/service.py new file mode 100644 index 00000000..c7a69e9d --- /dev/null +++ b/src/Undefined/cognitive/service/service.py @@ -0,0 +1,751 @@ +"""认知记忆服务实现。""" + +from __future__ import annotations + +import asyncio +import logging +import time +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Callable, cast + +from Undefined.context import RequestContext +from Undefined.utils.coerce import safe_float +from Undefined.cognitive.service.helpers import ( + _build_profile_vector_payload, + _compose_where, + _current_profile_name, + _event_base_score, + _event_dedupe_key, + _event_timestamp_epoch, + _normalize_profile_tags, + _parse_iso_to_epoch_seconds, + _parse_profile_markdown, + _resolve_auto_request_type, + _serialize_profile_markdown, +) + +if TYPE_CHECKING: + from Undefined.knowledge.runtime import RetrievalRuntime + +logger = logging.getLogger(__name__) + + +class CognitiveService: + def __init__( + self, + config_getter: Callable[[], Any], + vector_store: Any, + job_queue: Any, + profile_storage: Any, + reranker: Any = None, + retrieval_runtime: RetrievalRuntime | None = None, + ) -> None: + self._config_getter = config_getter + self._vector_store = vector_store + self._job_queue = job_queue + self._profile_storage = profile_storage + self._reranker = reranker + self._retrieval_runtime = retrieval_runtime + + def _base_reranker(self) -> Any: + if self._retrieval_runtime is not None: + return self._retrieval_runtime.ensure_reranker() + return self._reranker + + def _current_reranker(self) -> Any: + config = self._config_getter() + if not bool(getattr(config, "enable_rerank", True)): + return None + return self._base_reranker() + + async def _prepare_query_embedding(self, query: str) -> list[float] | None: + embed_query = getattr(self._vector_store, "embed_query", None) + if not callable(embed_query): + return None + try: + result = await embed_query(query) + except Exception as exc: + logger.warning("[认知服务] 预生成查询向量失败,回退即时计算: error=%s", exc) + return None + if not isinstance(result, list): + logger.warning("[认知服务] 预生成查询向量返回值非法,回退即时计算") + return None + normalized: list[float] = [] + for item in result: + try: + normalized.append(float(item)) + except (TypeError, ValueError): + logger.warning("[认知服务] 预生成查询向量包含非法元素,回退即时计算") + return None + return normalized + + @property + def enabled(self) -> bool: + return bool(self._config_getter().enabled) + + async def sync_profile_display_name( + self, + *, + entity_type: str, + entity_id: str, + preferred_name: str, + ) -> bool: + normalized_entity_type = str(entity_type or "").strip().lower() + normalized_entity_id = str(entity_id or "").strip() + normalized_name = str(preferred_name or "").strip() + if normalized_entity_type not in {"user", "group"}: + return False + if not normalized_entity_id or not normalized_name: + return False + if self._profile_storage is None or self._vector_store is None: + return False + + existing = await self._profile_storage.read_profile( + normalized_entity_type, + normalized_entity_id, + ) + if not existing: + return False + + parsed = _parse_profile_markdown(existing) + if parsed is None: + return False + frontmatter, summary = parsed + current_name = _current_profile_name(normalized_entity_type, frontmatter) + if current_name == normalized_name: + return False + + frontmatter["name"] = normalized_name + frontmatter["updated_at"] = datetime.now().isoformat() + if normalized_entity_type == "user": + frontmatter["nickname"] = normalized_name + frontmatter["qq"] = normalized_entity_id + else: + frontmatter["group_name"] = normalized_name + frontmatter["group_id"] = normalized_entity_id + + updated_markdown = _serialize_profile_markdown(frontmatter, summary) + await self._profile_storage.write_profile( + normalized_entity_type, + normalized_entity_id, + updated_markdown, + ) + + profile_doc, profile_metadata = _build_profile_vector_payload( + entity_type=normalized_entity_type, + entity_id=normalized_entity_id, + effective_name=normalized_name, + tags=_normalize_profile_tags(frontmatter.get("tags")), + summary=summary, + ) + await self._vector_store.upsert_profile( + f"{normalized_entity_type}:{normalized_entity_id}", + profile_doc, + profile_metadata, + ) + logger.info( + "[认知服务] 已刷新侧写展示名: entity_type=%s entity_id=%s old=%s new=%s", + normalized_entity_type, + normalized_entity_id, + current_name, + normalized_name, + ) + return True + + @staticmethod + def _uid_candidates(user_id: str, sender_id: str) -> list[str]: + values: list[str] = [] + for raw in (sender_id, user_id): + text = str(raw or "").strip() + if text and text not in values: + values.append(text) + return values + + @staticmethod + def _merge_weighted_events( + scoped_results: list[tuple[list[dict[str, Any]], float]], + *, + top_k: int, + current_group_id: str = "", + current_group_boost: float = 1.0, + ) -> list[dict[str, Any]]: + safe_top_k = max(1, int(top_k)) + safe_group_boost = max(0.0, float(current_group_boost)) + seen_keys: set[tuple[str, str, str, str, str, str]] = set() + # 排序主键优先使用“作用域内原始排名”(已含 time_decay/mmr/rerank 效果), + scored_items: list[ + tuple[float, float, float, float, float, int, dict[str, Any]] + ] = [] + serial = 0 + for scoped_events, scope_weight in scoped_results: + safe_scope_weight = max(0.0, safe_float(scope_weight, default=1.0)) + scope_size = max(1, len(scoped_events)) + for rank_idx, event in enumerate(scoped_events): + dedupe_key = _event_dedupe_key(event) + if dedupe_key in seen_keys: + continue + seen_keys.add(dedupe_key) + metadata = event.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + scope_boost = safe_scope_weight + if ( + current_group_id + and str(metadata.get("group_id", "")).strip() == current_group_id + ): + scope_boost *= safe_group_boost + # 保留每个 scope 内已重排结果(time_decay/mmr/rerank)的相对顺序。 + rank_score = float(scope_size - rank_idx) / float(scope_size) + weighted_rank_score = rank_score * scope_boost + base_score = _event_base_score(event) + weighted_score = base_score * scope_boost + scored_items.append( + ( + weighted_rank_score, + weighted_score, + rank_score, + base_score, + _event_timestamp_epoch(metadata), + serial, + event, + ) + ) + serial += 1 + scored_items.sort( + key=lambda item: ( + -item[0], + -item[1], + -item[2], + -item[3], + -item[4], + item[5], + ) + ) + return [item[6] for item in scored_items[:safe_top_k]] + + async def _query_events_for_auto_context( + self, + *, + query: str, + request_type: str, + group_id: str, + user_id: str, + sender_id: str, + top_k: int, + config: Any, + ) -> list[dict[str, Any]]: + safe_top_k = max(1, int(top_k)) + scope_candidate_multiplier = int( + getattr(config, "auto_scope_candidate_multiplier", 2) + ) + if scope_candidate_multiplier <= 0: + scope_candidate_multiplier = 2 + scoped_top_k = max(safe_top_k, safe_top_k * scope_candidate_multiplier) + current_group_boost = safe_float( + getattr(config, "auto_current_group_boost", 1.15), default=1.15 + ) + if current_group_boost <= 0: + current_group_boost = 1.15 + current_private_boost = safe_float( + getattr(config, "auto_current_private_boost", 1.25), default=1.25 + ) + if current_private_boost <= 0: + current_private_boost = 1.25 + query_embedding = await self._prepare_query_embedding(query) + common_kwargs: dict[str, Any] = { + "reranker": self._current_reranker(), + "candidate_multiplier": config.rerank_candidate_multiplier, + "time_decay_enabled": bool(getattr(config, "time_decay_enabled", True)), + "time_decay_half_life_days": float( + getattr(config, "time_decay_half_life_days_auto", 14.0) + ), + "time_decay_boost": float(getattr(config, "time_decay_boost", 0.2)), + "time_decay_min_similarity": float( + getattr(config, "time_decay_min_similarity", 0.35) + ), + "apply_mmr": True, + } + if query_embedding is not None: + common_kwargs["query_embedding"] = query_embedding + uid_values = self._uid_candidates(user_id, sender_id) + + if request_type == "group": + group_events: list[dict[str, Any]] = await self._vector_store.query_events( + query, + top_k=scoped_top_k, + where={"request_type": "group"}, + **common_kwargs, + ) + merge_started = time.perf_counter() + merged = self._merge_weighted_events( + [(group_events, 1.0)], + top_k=safe_top_k, + current_group_id=group_id, + current_group_boost=current_group_boost, + ) + merge_duration = time.perf_counter() - merge_started + logger.info( + "[认知服务] 自动检索(群聊): group_candidates=%s merged=%s top_k=%s scope_multiplier=%s current_group_boost=%.2f merge=%.3fs", + len(group_events), + len(merged), + safe_top_k, + scope_candidate_multiplier, + current_group_boost, + merge_duration, + ) + return merged + + if request_type == "private": + group_task = self._vector_store.query_events( + query, + top_k=scoped_top_k, + where={"request_type": "group"}, + **common_kwargs, + ) + if uid_values: + uid_clauses = [{"user_id": value} for value in uid_values] + [ + {"sender_id": value} for value in uid_values + ] + private_where: dict[str, Any] = { + "$and": [ + {"request_type": "private"}, + {"$or": uid_clauses}, + ] + } + private_task = self._vector_store.query_events( + query, + top_k=scoped_top_k, + where=private_where, + **common_kwargs, + ) + group_events_raw, private_events_raw = await asyncio.gather( + group_task, private_task + ) + group_events = cast(list[dict[str, Any]], group_events_raw) + private_events = cast(list[dict[str, Any]], private_events_raw) + else: + group_events = cast(list[dict[str, Any]], await group_task) + private_events = [] + merge_started = time.perf_counter() + merged = self._merge_weighted_events( + [ + (group_events, 1.0), + (private_events, current_private_boost), + ], + top_k=safe_top_k, + ) + merge_duration = time.perf_counter() - merge_started + logger.info( + "[认知服务] 自动检索(私聊): group_candidates=%s private_candidates=%s merged=%s top_k=%s scope_multiplier=%s private_boost=%.2f uid_candidates=%s merge=%.3fs", + len(group_events), + len(private_events), + len(merged), + safe_top_k, + scope_candidate_multiplier, + current_private_boost, + uid_values, + merge_duration, + ) + return merged + + where: dict[str, Any] | None = None + if group_id: + where = {"group_id": group_id} + elif uid_values: + where = { + "$or": [{"user_id": value} for value in uid_values] + + [{"sender_id": value} for value in uid_values] + } + events: list[dict[str, Any]] = await self._vector_store.query_events( + query, + top_k=safe_top_k, + where=where, + **common_kwargs, + ) + logger.info( + "[认知服务] 自动检索(兜底): mode=%s where=%s count=%s top_k=%s", + request_type or "unknown", + where or {}, + len(events), + safe_top_k, + ) + return events + + async def enqueue_job( + self, + memo: str, + observations: list[str], + context: dict[str, Any], + *, + force: bool = False, + ) -> str | None: + memo_text = str(memo or "").strip() + observation_items = ( + [s for s in observations if s.strip()] if observations else [] + ) + if not self.enabled: + logger.info("[认知服务] 已禁用,跳过入队") + return None + if not memo_text and not observation_items: + logger.info("[认知服务] memo/observations 均为空,跳过入队") + return None + ctx = RequestContext.current() + + now = datetime.now().astimezone() + now_utc = datetime.now(timezone.utc) + safe_request_id = ( + str(ctx.request_id) + if ctx and str(ctx.request_id or "").strip() + else str(context.get("request_id", "")).strip() + ) + if not safe_request_id: + safe_request_id = "" + + end_seq_raw = context.get("_end_seq", 0) + try: + end_seq = int(end_seq_raw) + except (TypeError, ValueError): + end_seq = 0 + + has_observations = bool(observation_items) + message_ids = context.get("message_ids") + if not isinstance(message_ids, list): + message_ids = [] + message_ids = [str(item).strip() for item in message_ids if str(item).strip()] + perspective = str(context.get("memory_perspective", "")).strip() + user_id = ( + str(ctx.user_id or "") if ctx else str(context.get("user_id", "") or "") + ) + group_id = ( + str(ctx.group_id or "") if ctx else str(context.get("group_id", "") or "") + ) + sender_id = ( + str(ctx.sender_id or "") + if ctx + else str(context.get("sender_id") or context.get("user_id", "") or "") + ) + request_type = ( + str(ctx.request_type) + if ctx and ctx.request_type + else str(context.get("request_type", "") or "") + ) + sender_name = str(context.get("sender_name") or "").strip() + group_name = str(context.get("group_name") or "").strip() + source_message = str(context.get("historian_source_message") or "").strip() + recent_messages_raw = context.get("historian_recent_messages", []) + recent_messages: list[str] = [] + if isinstance(recent_messages_raw, list): + recent_messages = [ + str(item).strip() for item in recent_messages_raw if str(item).strip() + ] + + profile_targets: list[dict[str, str]] = [] + if has_observations: + group_id = group_id.strip() + sender_id = sender_id.strip() or user_id.strip() + seen: set[tuple[str, str]] = set() + if group_id: + key = ("group", group_id) + if key not in seen: + seen.add(key) + profile_targets.append( + { + "entity_type": "group", + "entity_id": group_id, + "perspective": "group", + "preferred_name": group_name, + } + ) + if sender_id: + key = ("user", sender_id) + if key not in seen: + seen.add(key) + profile_targets.append( + { + "entity_type": "user", + "entity_id": sender_id, + "perspective": "sender", + "preferred_name": sender_name, + } + ) + + bot_name = str(self._config_getter().bot_name or "Undefined").strip() + + job: dict[str, Any] = { + "request_id": safe_request_id, + "end_seq": end_seq, + "user_id": user_id, + "group_id": group_id, + "sender_id": sender_id, + "sender_name": sender_name, + "group_name": group_name, + "bot_name": bot_name, + "request_type": request_type, + "timestamp_utc": now_utc.isoformat(), + "timestamp_local": now.isoformat(), + "timestamp_epoch": int(now_utc.timestamp()), + "timezone": str(now.tzinfo or ""), + "location_abs": str( + context.get("group_name") or context.get("sender_name") or "" + ), + "message_ids": message_ids, + "memo": memo_text, + "observations": observation_items, + "has_observations": has_observations, + "perspective": perspective, + "profile_targets": profile_targets, + "schema_version": "final_v1", + "source_message": source_message, + "recent_messages": recent_messages, + "force": bool(force), + } + logger.info( + "[认知服务] 准备入队: request_id=%s end_seq=%s user=%s group=%s sender=%s perspective=%s has_observations=%s profile_targets=%s memo_len=%s observations_len=%s source_len=%s recent_ref=%s force=%s", + job.get("request_id", ""), + job.get("end_seq", 0), + job.get("user_id", ""), + job.get("group_id", ""), + job.get("sender_id", ""), + perspective or "default", + has_observations, + len(profile_targets), + len(memo_text), + len(observation_items), + len(source_message), + len(recent_messages), + bool(force), + ) + result: str | None = await self._job_queue.enqueue(job) + logger.info("[认知服务] 入队完成: job_id=%s", result or "") + return result + + async def build_context( + self, + query: str, + group_id: str | None = None, + user_id: str | None = None, + sender_id: str | None = None, + sender_name: str | None = None, + group_name: str | None = None, + request_type: str | None = None, + ) -> str: + config = self._config_getter() + safe_group_id = str(group_id or "").strip() + safe_user_id = str(user_id or "").strip() + safe_sender_id = str(sender_id or "").strip() + safe_request_type = _resolve_auto_request_type( + request_type=request_type, + group_id=safe_group_id, + user_id=safe_user_id, + sender_id=safe_sender_id, + ) + parts: list[str] = [] + logger.info( + "[认知服务] 构建上下文: query_len=%s type=%s user=%s sender=%s group=%s top_k=%s", + len(query or ""), + safe_request_type or "", + safe_user_id, + safe_sender_id, + safe_group_id, + getattr(config, "auto_top_k", 5), + ) + + uid = safe_sender_id or safe_user_id + if uid: + profile = await self._profile_storage.read_profile("user", uid) + if profile: + label = f"{sender_name}(UID: {uid})" if sender_name else f"UID: {uid}" + parts.append(f"## 用户侧写 — {label}\n{profile}") + + if safe_group_id: + gprofile = await self._profile_storage.read_profile("group", safe_group_id) + if gprofile: + glabel = ( + f"{group_name}(GID: {safe_group_id})" + if group_name + else f"GID: {safe_group_id}" + ) + parts.append(f"## 群聊侧写 — {glabel}\n{gprofile}") + + default_top_k = 5 + try: + top_k = int(getattr(config, "auto_top_k", default_top_k)) + except Exception: + top_k = default_top_k + if top_k <= 0: + top_k = default_top_k + top_k = min(top_k, 500) + try: + events = await self._query_events_for_auto_context( + query=query, + request_type=safe_request_type, + group_id=safe_group_id, + user_id=safe_user_id, + sender_id=safe_sender_id, + top_k=top_k, + config=config, + ) + except Exception as exc: + logger.warning( + "[认知服务] 自动上下文事件检索失败,降级为空结果: type=%s user=%s sender=%s group=%s err=%s", + safe_request_type, + safe_user_id, + safe_sender_id, + safe_group_id, + exc, + ) + events = [] + if events: + event_lines = "\n".join( + f"- [{e['metadata'].get('timestamp_local', '')}] {e['document']}" + for e in events + ) + parts.append(f"## 相关记忆事件\n{event_lines}") + + if not parts: + logger.info("[认知服务] 构建上下文完成: 无可用记忆") + return "" + + body = "\n\n".join(parts) + result = ( + "\n" + "\n" + f"{body}\n" + "" + ) + logger.info( + "[认知服务] 构建上下文完成: sections=%s result_len=%s", + len(parts), + len(result), + ) + return result + + async def search_events(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: + config = self._config_getter() + group_id = str( + kwargs.get("group_id") or kwargs.get("target_group_id") or "" + ).strip() + user_id = str( + kwargs.get("user_id") or kwargs.get("target_user_id") or "" + ).strip() + sender_id = str(kwargs.get("sender_id") or "").strip() + where_clauses: list[dict[str, Any]] = [] + if group_id: + where_clauses.append({"group_id": group_id}) + if user_id: + where_clauses.append({"user_id": user_id}) + if sender_id: + where_clauses.append({"sender_id": sender_id}) + request_type = str(kwargs.get("request_type") or "").strip() + if request_type: + where_clauses.append({"request_type": request_type}) + + time_from_epoch = _parse_iso_to_epoch_seconds(kwargs.get("time_from")) + time_to_epoch = _parse_iso_to_epoch_seconds(kwargs.get("time_to")) + if ( + time_from_epoch is not None + and time_to_epoch is not None + and time_from_epoch > time_to_epoch + ): + logger.warning( + "[认知服务] search_events 时间范围反转,已自动交换: time_from=%s time_to=%s", + kwargs.get("time_from"), + kwargs.get("time_to"), + ) + time_from_epoch, time_to_epoch = time_to_epoch, time_from_epoch + if time_from_epoch is not None: + where_clauses.append({"timestamp_epoch": {"$gte": time_from_epoch}}) + if time_to_epoch is not None: + where_clauses.append({"timestamp_epoch": {"$lte": time_to_epoch}}) + + where = _compose_where(where_clauses) + default_top_k = getattr(config, "tool_default_top_k", 12) + top_k_raw = kwargs.get("top_k", default_top_k) + try: + top_k = int(top_k_raw) + except Exception: + top_k = default_top_k + if top_k <= 0: + top_k = default_top_k + top_k = min(top_k, 500) + logger.info( + "[认知服务] 搜索事件: query_len=%s top_k=%s where=%s time_from=%s time_to=%s", + len(query or ""), + top_k, + where or {}, + time_from_epoch, + time_to_epoch, + ) + results: list[dict[str, Any]] = await self._vector_store.query_events( + query, + top_k=top_k, + where=where or None, + reranker=self._current_reranker(), + candidate_multiplier=config.rerank_candidate_multiplier, + time_decay_enabled=bool(getattr(config, "time_decay_enabled", True)), + time_decay_half_life_days=float( + getattr(config, "time_decay_half_life_days_tool", 60.0) + ), + time_decay_boost=float(getattr(config, "time_decay_boost", 0.2)), + time_decay_min_similarity=float( + getattr(config, "time_decay_min_similarity", 0.35) + ), + apply_mmr=True, + query_embedding=await self._prepare_query_embedding(query), + ) + logger.info("[认知服务] 搜索事件完成: count=%s", len(results)) + return results + + async def get_profile(self, entity_type: str, entity_id: str) -> str | None: + logger.info( + "[认知服务] 读取侧写: entity_type=%s entity_id=%s", + entity_type, + entity_id, + ) + result: str | None = await self._profile_storage.read_profile( + entity_type, entity_id + ) + logger.info( + "[认知服务] 读取侧写完成: found=%s", + bool(result), + ) + return result + + async def search_profiles(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: + config = self._config_getter() + default_top_k = int(getattr(config, "profile_top_k", 5)) + top_k_raw = kwargs.get("top_k", default_top_k) + try: + top_k = int(top_k_raw) + except Exception: + top_k = default_top_k + if top_k <= 0: + top_k = default_top_k + top_k = min(top_k, 500) + + where: dict[str, Any] | None = None + entity_type_raw = kwargs.get("entity_type") + entity_type = ( + str(entity_type_raw).strip() if entity_type_raw is not None else "" + ) + if entity_type: + where = {"entity_type": entity_type} + + logger.info( + "[认知服务] 搜索侧写: query_len=%s top_k=%s where=%s", + len(query or ""), + top_k, + where or {}, + ) + results: list[dict[str, Any]] = await self._vector_store.query_profiles( + query, + top_k=top_k, + where=where, + reranker=self._current_reranker(), + candidate_multiplier=config.rerank_candidate_multiplier, + query_embedding=await self._prepare_query_embedding(query), + ) + logger.info("[认知服务] 搜索侧写完成: count=%s", len(results)) + return results diff --git a/src/Undefined/cognitive/vector_store.py b/src/Undefined/cognitive/vector_store.py index 48b80eb1..076b8a9f 100644 --- a/src/Undefined/cognitive/vector_store.py +++ b/src/Undefined/cognitive/vector_store.py @@ -129,7 +129,6 @@ def _mmr_select( if n <= top_k: return np.arange(n) - # 预计算 query-doc 相关性(cosine similarity) query_norm = np.sqrt(np.sum(query_embedding * query_embedding)) relevance = np.empty(n, dtype=np.float64) norms = np.empty(n, dtype=np.float64) @@ -161,7 +160,6 @@ def _mmr_select( return selected[:step] selected[step] = best_idx chosen[best_idx] = True - # 更新 max_sim_to_selected:新选中项与所有未选中项的相似度 if norms[best_idx] > 0.0: for j in range(n): if not chosen[j] and norms[j] > 0.0: @@ -446,7 +444,6 @@ async def _query( query_embedding=query_embedding, ) embed_duration = time.perf_counter() - embed_started - # 重排要求候选数 > 最终返回数,否则重排无意义 use_reranker = bool(reranker) and safe_multiplier >= 2 if reranker and safe_multiplier < 2: logger.warning( diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py index 986c5cb1..c0de5fa0 100644 --- a/src/Undefined/config/domain_parsers.py +++ b/src/Undefined/config/domain_parsers.py @@ -191,7 +191,6 @@ def _parse_message_batcher_config(data: dict[str, Any]) -> MessageBatcherConfig: window_seconds = _coerce_float(section.get("window_seconds"), 5.0) if window_seconds < 0: window_seconds = 0.0 - # max_window_seconds <= 0 视为不限制(仅靠 window_seconds + max_messages_per_batch 触发) max_window_seconds = _coerce_float(section.get("max_window_seconds"), 30.0) if max_window_seconds < 0: max_window_seconds = 0.0 diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index a793ae44..8006aa40 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -223,6 +223,7 @@ def _record_repeat_cooldown(self, group_id: int, text: str) -> None: # 清理已过期条目 expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] for k in expired: + # delete del group_cd[k] group_cd[key] = now diff --git a/src/Undefined/handlers/__init__.py b/src/Undefined/handlers/__init__.py new file mode 100644 index 00000000..631cda8e --- /dev/null +++ b/src/Undefined/handlers/__init__.py @@ -0,0 +1,28 @@ +"""消息处理和命令分发包。 + +聚合 ``MessageHandler`` 与各 mixin 子模块;保留 ``parse_message_content_for_history`` 等 +模块级符号供测试 monkeypatch(``import Undefined.handlers``)。 +""" + +from Undefined.handlers.message_flow import ( + KEYWORD_REPLY_HISTORY_PREFIX, + MessageHandler, +) +from Undefined.handlers.poke import GroupPokeRecord, PrivatePokeRecord +from Undefined.handlers.repeat import REPEAT_REPLY_HISTORY_PREFIX +from Undefined.utils.common import ( + extract_text, + matches_xinliweiyuan, + parse_message_content_for_history, +) + +__all__ = [ + "GroupPokeRecord", + "KEYWORD_REPLY_HISTORY_PREFIX", + "MessageHandler", + "PrivatePokeRecord", + "REPEAT_REPLY_HISTORY_PREFIX", + "extract_text", + "matches_xinliweiyuan", + "parse_message_content_for_history", +] diff --git a/src/Undefined/handlers/auto_extract.py b/src/Undefined/handlers/auto_extract.py new file mode 100644 index 00000000..8375f8cb --- /dev/null +++ b/src/Undefined/handlers/auto_extract.py @@ -0,0 +1,223 @@ +"""B 站 / arXiv / GitHub 链接自动提取 mixin。 + +从消息中解析外部资源 ID 并调用对应 sender 发送;由 ``MessageHandler`` 混入使用。 +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.onebot import OneBotClient + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +class AutoExtractMixin: + """外部资源自动提取 mixin。""" + + if TYPE_CHECKING: + config: Config + sender: MessageSender + onebot: OneBotClient + + async def _extract_bilibili_ids( + self, text: str, message_content: list[dict[str, Any]] + ) -> list[str]: + """从文本和消息段中提取 B 站视频 BV 号。""" + from Undefined.bilibili.parser import ( + extract_bilibili_ids, + extract_from_json_message, + ) + + bvids = await extract_bilibili_ids(text) + if not bvids: + bvids = await extract_from_json_message(message_content) + return list(bvids) + + def _extract_arxiv_ids( + self, text: str, message_content: list[dict[str, Any]] + ) -> list[str]: + """从文本和消息段中提取 arXiv 论文 ID。""" + from Undefined.arxiv.parser import extract_arxiv_ids, extract_from_json_message + + paper_ids: list[str] = [] + seen: set[str] = set() + + for paper_id in extract_arxiv_ids(text): + if paper_id in seen: + continue + seen.add(paper_id) + paper_ids.append(paper_id) + + for paper_id in extract_from_json_message(message_content): + if paper_id in seen: + continue + seen.add(paper_id) + paper_ids.append(paper_id) + + return paper_ids + + def _extract_github_repo_ids( + self, text: str, message_content: list[dict[str, Any]] + ) -> list[str]: + """从文本和消息段中提取 GitHub 仓库 ID。""" + from Undefined.github.parser import ( + extract_from_json_message, + extract_github_repo_ids, + ) + + repo_ids: list[str] = [] + seen: set[str] = set() + + for repo_id in extract_github_repo_ids(text): + # 仓库 ID 大小写不敏感去重 + key = repo_id.lower() + if key in seen: + continue + seen.add(key) + repo_ids.append(repo_id) + + for repo_id in extract_from_json_message(message_content): + key = repo_id.lower() + if key in seen: + continue + seen.add(key) + repo_ids.append(repo_id) + + return repo_ids + + async def _handle_bilibili_extract( + self, + target_id: int, + bvids: list[str], + target_type: str, + ) -> None: + """处理 bilibili 视频自动提取和发送。""" + from Undefined.bilibili.sender import send_bilibili_video + + for bvid in bvids[:3]: + try: + # 单条消息最多自动提取 3 个 BV + await send_bilibili_video( + video_id=bvid, + sender=self.sender, + onebot=self.onebot, + target_type=target_type, # type: ignore[arg-type] + target_id=target_id, + cookie=self.config.bilibili_cookie, + prefer_quality=self.config.bilibili_prefer_quality, + max_duration=self.config.bilibili_max_duration, + max_file_size=self.config.bilibili_max_file_size, + oversize_strategy=self.config.bilibili_oversize_strategy, + danmaku_enabled=self.config.bilibili_danmaku_enabled, + danmaku_batch_size=self.config.bilibili_danmaku_batch_size, + danmaku_max_count=self.config.bilibili_danmaku_max_count, + ) + except Exception as exc: + logger.error( + "[Bilibili] 自动提取失败 %s → %s:%s: %s", + bvid, + target_type, + target_id, + exc, + ) + try: + error_msg = f"视频提取失败: {exc}" + if target_type == "group": + await self.sender.send_group_message(target_id, error_msg) + else: + await self.sender.send_private_message(target_id, error_msg) + except Exception: + pass + + async def _handle_arxiv_extract( + self, + target_id: int, + paper_ids: list[str], + target_type: str, + ) -> None: + """处理 arXiv 论文自动提取和发送。""" + from Undefined.arxiv.sender import send_arxiv_paper + + max_items = max(1, int(self.config.arxiv_auto_extract_max_items)) + + for paper_id in paper_ids[:max_items]: + try: + result = await send_arxiv_paper( + paper_id=paper_id, + sender=self.sender, + target_type=target_type, # type: ignore[arg-type] + target_id=target_id, + max_file_size=self.config.arxiv_max_file_size, + author_preview_limit=self.config.arxiv_author_preview_limit, + summary_preview_chars=self.config.arxiv_summary_preview_chars, + context={ + "request_id": ( + f"arxiv_auto_extract:{target_type}:{target_id}:{paper_id}" + ) + }, + ) + logger.info( + "[arXiv] 自动提取完成 %s → %s:%s: %s", + paper_id, + target_type, + target_id, + result, + ) + except Exception: + logger.exception( + "[arXiv] 自动提取失败 %s → %s:%s", + paper_id, + target_type, + target_id, + ) + + async def _handle_github_extract( + self, + target_id: int, + repo_ids: list[str], + target_type: str, + ) -> None: + """处理 GitHub 仓库自动提取和发送。""" + from Undefined.github.sender import send_github_repo_card + + max_items = max( + 1, int(getattr(self.config, "github_auto_extract_max_items", 3)) + ) + request_timeout = float( + getattr(self.config, "github_request_timeout_seconds", 10.0) + ) + + for repo_id in repo_ids[:max_items]: + try: + result = await send_github_repo_card( + repo_id=repo_id, + sender=self.sender, + target_type=target_type, # type: ignore[arg-type] + target_id=target_id, + request_timeout=request_timeout, + context={ + "request_id": ( + f"github_auto_extract:{target_type}:{target_id}:{repo_id}" + ) + }, + ) + logger.info( + "[GitHub] 自动提取完成 %s → %s:%s: %s", + repo_id, + target_type, + target_id, + result, + ) + except Exception as exc: + logger.info( + "[GitHub] 自动提取跳过 %s → %s:%s: %s", + repo_id, + target_type, + target_id, + exc, + ) diff --git a/src/Undefined/handlers/message_flow.py b/src/Undefined/handlers/message_flow.py new file mode 100644 index 00000000..b9b2a5a5 --- /dev/null +++ b/src/Undefined/handlers/message_flow.py @@ -0,0 +1,845 @@ +"""消息主流程与 ``MessageHandler`` 核心实现。 + +协调私聊/群聊事件分发、附件收集、管线与 AI 回复;拍一拍、复读、自动提取由 mixin 提供。 +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from pathlib import Path +import random +from typing import Any, Coroutine, Literal + +import Undefined.handlers as handlers_module +from Undefined.attachments import ( + append_attachment_text, + build_attachment_scope, + register_message_attachments, +) +from Undefined.ai import AIClient +from Undefined.config import Config +from Undefined.faq import FAQStorage +from Undefined.handlers.auto_extract import AutoExtractMixin +from Undefined.handlers.poke import PokeMixin +from Undefined.handlers.repeat import RepeatMixin +from Undefined.onebot import ( + OneBotClient, + get_message_content, + get_message_sender_id, +) +from Undefined.rate_limit import RateLimiter +from Undefined.scheduled_task_storage import ScheduledTaskStorage +from Undefined.services.ai_coordinator import AICoordinator +from Undefined.services.command import CommandDispatcher +from Undefined.services.message_batcher import MessageBatcher, make_scope +from Undefined.services.model_pool import ModelPoolService +from Undefined.services.queue_manager import QueueManager +from Undefined.services.security import SecurityService +from Undefined.skills.pipelines import PipelineRegistry +from Undefined.skills.pipelines.context import build_pipeline_context +from Undefined.utils.coerce import safe_int +from Undefined.utils.fake_at import BotNicknameCache, strip_fake_at +from Undefined.utils.history import MessageHistoryManager +from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.queue_intervals import build_model_queue_intervals +from Undefined.utils.resources import resolve_resource_path +from Undefined.utils.scheduler import TaskScheduler +from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + +KEYWORD_REPLY_HISTORY_PREFIX = "[系统关键词自动回复] " + + +def _is_private_model_pool_control_text(text: str) -> bool: + return bool(ModelPoolService.is_private_control_text(text)) + + +class MessageHandler(PokeMixin, RepeatMixin, AutoExtractMixin): + """消息处理器。 + + 接收 OneBot 事件、写入历史并协调命令分发、自动管线与 AI 回复; + 依赖 ``AICoordinator``、``CommandDispatcher`` 与 ``MessageBatcher``。 + """ + + def __init__( + self, + config: Config, + onebot: OneBotClient, + ai: AIClient, + faq_storage: FAQStorage, + task_storage: ScheduledTaskStorage, + ) -> None: + self.config = config + self.onebot = onebot + self.ai = ai + self.faq_storage = faq_storage + self.history_manager = MessageHistoryManager(config.history_max_records) + self.sender = MessageSender( + onebot, + self.history_manager, + config.bot_qq, + config, + attachment_registry=getattr(ai, "attachment_registry", None), + ) + + self.security = SecurityService(config, ai._http_client) + self.rate_limiter = RateLimiter(config) + self.queue_manager = QueueManager( + max_retries=config.ai_request_max_retries, + ) + self.queue_manager.update_model_intervals(build_model_queue_intervals(config)) + + ai.set_queue_manager(self.queue_manager) + + self.command_dispatcher = CommandDispatcher( + config, + self.sender, + ai, + faq_storage, + onebot, + self.security, + queue_manager=self.queue_manager, + rate_limiter=self.rate_limiter, + history_manager=self.history_manager, + ) + self.ai_coordinator = AICoordinator( + config, + ai, + self.queue_manager, + self.history_manager, + self.sender, + onebot, + TaskScheduler(ai, self.sender, onebot, self.history_manager, task_storage), + self.security, + command_dispatcher=self.command_dispatcher, + ) + + self.message_batcher = MessageBatcher( + config.message_batcher, + flush_callback=self.ai_coordinator.handle_batched_dispatch, + ) + self.ai_coordinator.set_batcher(self.message_batcher) + + self._background_tasks: set[asyncio.Task[None]] = set() + self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} + self._bot_nickname_cache = BotNicknameCache(onebot, config.bot_qq) + self.pipeline_registry = PipelineRegistry() + self._pipelines_initialized = False + self._pipelines_init_lock = asyncio.Lock() + + self._repeat_counter: dict[int, list[tuple[str, int]]] = {} + self._repeat_locks: dict[int, asyncio.Lock] = {} + self._repeat_cooldown: dict[int, dict[str, float]] = {} + + self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) + + async def initialize(self) -> None: + """完成需要事件循环承载的异步初始化。""" + await self.init_pipelines() + + async def init_pipelines(self) -> None: + """异步加载自动处理管线并按配置启动热重载。""" + if getattr(self, "_pipelines_initialized", False): + return + init_lock = getattr(self, "_pipelines_init_lock", None) + if init_lock is None: + init_lock = asyncio.Lock() + self._pipelines_init_lock = init_lock + async with init_lock: + if getattr(self, "_pipelines_initialized", False): + return + await self.pipeline_registry.load_items_async() + self._pipelines_initialized = True + if getattr(self.config, "skills_hot_reload", False): + self.pipeline_registry.start_hot_reload( + interval=self.config.skills_hot_reload_interval, + debounce=self.config.skills_hot_reload_debounce, + ) + + async def _annotate_meme_descriptions( + self, + attachments: list[dict[str, str]], + scope_key: str, + ) -> list[dict[str, str]]: + """为图片附件添加表情包描述(如果在表情库中找到)。 + + 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 + 最佳努力:任何失败时返回原始列表。 + """ + if not attachments: + return attachments + + ai_client = getattr(self, "ai", None) + if ai_client is None: + return attachments + + attachment_registry = getattr(ai_client, "attachment_registry", None) + if attachment_registry is None: + return attachments + + meme_service = getattr(ai_client, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return attachments + + meme_store = getattr(meme_service, "_store", None) + if meme_store is None: + return attachments + + try: + # 仅 pic_ 前缀图片参与表情库匹配 + uid_to_hash: dict[str, str] = {} + for att in attachments: + uid = att.get("uid", "") + if not uid.startswith("pic_"): + continue + record = attachment_registry.resolve(uid, scope_key) + if record and record.sha256: + uid_to_hash[uid] = record.sha256 + + if not uid_to_hash: + return attachments + + unique_hashes = set(uid_to_hash.values()) + hash_to_desc: dict[str, str] = {} + for h in unique_hashes: + meme = await meme_store.find_by_sha256(h) + if meme and meme.description: + hash_to_desc[h] = meme.description + + if not hash_to_desc: + return attachments + + result: list[dict[str, str]] = [] + for att in attachments: + uid = att.get("uid", "") + sha = uid_to_hash.get(uid, "") + desc = hash_to_desc.get(sha, "") + if desc: + new_att = dict(att) + new_att["description"] = f"[表情包] {desc}" + result.append(new_att) + else: + result.append(att) + return result + except Exception: + logger.warning("表情包自动匹配失败,跳过", exc_info=True) + return attachments + + async def _collect_message_attachments( + self, + message_content: list[dict[str, Any]], + *, + group_id: int | None = None, + user_id: int | None = None, + request_type: str, + ) -> list[dict[str, str]]: + scope_key = build_attachment_scope( + group_id=group_id, + user_id=user_id, + request_type=request_type, + ) + if not scope_key: + return [] + ai_client = getattr(self, "ai", None) + attachment_registry = ( + getattr(ai_client, "attachment_registry", None) if ai_client else None + ) + if attachment_registry is None: + return [] + onebot = getattr(self, "onebot", None) + resolve_image_url = getattr(onebot, "get_image", None) if onebot else None + result = await register_message_attachments( + registry=attachment_registry, + segments=message_content, + scope_key=scope_key, + resolve_image_url=resolve_image_url, + get_forward_messages=getattr(onebot, "get_forward_msg", None) + if onebot + else None, + ) + attachments = result.attachments + # 命中表情库时为 AI 上下文补充 [表情包] 描述 + attachments = await self._annotate_meme_descriptions(attachments, scope_key) + return attachments + + def _schedule_meme_ingest( + self, + *, + attachments: list[dict[str, str]], + chat_type: str, + chat_id: int, + sender_id: int, + message_id: int | None, + scope_key: str | None, + ) -> None: + # 后台异步入库,不阻塞主消息处理 + if not attachments or not scope_key: + return + meme_service = getattr(self.ai, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return + self._spawn_background_task( + f"meme_ingest:{chat_type}:{chat_id}:{sender_id}:{message_id or 0}", + meme_service.enqueue_incoming_attachments( + attachments=attachments, + chat_type=chat_type, + chat_id=chat_id, + sender_id=sender_id, + message_id=message_id, + scope_key=scope_key, + ), + ) + + async def _refresh_profile_display_names( + self, + *, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: + ai_client = getattr(self, "ai", None) + cognitive_service = getattr(ai_client, "_cognitive_service", None) + if not cognitive_service or not getattr(cognitive_service, "enabled", False): + return + + if sender_id and sender_name.strip(): + await cognitive_service.sync_profile_display_name( + entity_type="user", + entity_id=str(sender_id), + preferred_name=sender_name.strip(), + ) + if group_id and group_name.strip(): + await cognitive_service.sync_profile_display_name( + entity_type="group", + entity_id=str(group_id), + preferred_name=group_name.strip(), + ) + + def _can_refresh_profile_display_names(self) -> bool: + ai_client = getattr(self, "ai", None) + cognitive_service = getattr(ai_client, "_cognitive_service", None) + return bool(cognitive_service and getattr(cognitive_service, "enabled", False)) + + def _schedule_profile_display_name_refresh( + self, + *, + task_name: str, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: + if not self._can_refresh_profile_display_names(): + return + + cache = getattr(self, "_profile_name_refresh_cache", None) + if cache is None: + cache = {} + self._profile_name_refresh_cache = cache + + updates: dict[str, Any] = {} + rollback: list[tuple[tuple[str, int], str | None]] = [] + + normalized_sender_name = sender_name.strip() + if sender_id and normalized_sender_name: + sender_key = ("user", int(sender_id)) + previous = cache.get(sender_key) + if previous != normalized_sender_name: + cache[sender_key] = normalized_sender_name + rollback.append((sender_key, previous)) + updates["sender_id"] = sender_id + updates["sender_name"] = normalized_sender_name + + normalized_group_name = group_name.strip() + if group_id and normalized_group_name: + group_key = ("group", int(group_id)) + previous = cache.get(group_key) + if previous != normalized_group_name: + cache[group_key] = normalized_group_name + rollback.append((group_key, previous)) + updates["group_id"] = group_id + updates["group_name"] = normalized_group_name + + if not updates: + return + + async def _run_refresh() -> None: + try: + await self._refresh_profile_display_names(**updates) + except Exception: + # 刷新失败时回滚内存缓存,避免脏昵称长期生效 + for key, previous in rollback: + if previous is None: + cache.pop(key, None) + else: + cache[key] = previous + raise + + self._spawn_background_task(task_name, _run_refresh()) + + async def handle_message(self, event: dict[str, Any]) -> None: + """处理收到的消息事件。""" + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[事件数据]", event) + post_type = event.get("post_type", "message") + + # 拍一拍走 notice 旁路,不进入普通消息流水线 + if post_type == "notice" and event.get("notice_type") == "poke": + await self._handle_poke_notice(event) + return + + if event.get("message_type") == "private": + await self._handle_private_message(event) + return + + if event.get("message_type") != "group": + return + + await self._handle_group_message(event) + + async def _handle_private_message(self, event: dict[str, Any]) -> None: + """处理私聊消息事件。""" + private_sender_id: int = get_message_sender_id(event) + private_message_content: list[dict[str, Any]] = get_message_content(event) + trigger_message_id = event.get("message_id") + + if not self.config.is_private_allowed(private_sender_id): + private_reason = ( + self.config.private_access_denied_reason(private_sender_id) or "unknown" + ) + logger.debug( + "[访问控制] 忽略私聊消息: user=%s reason=%s (access enabled=%s)", + private_sender_id, + private_reason, + self.config.access_control_enabled(), + ) + return + + private_sender: dict[str, Any] = event.get("sender", {}) + private_sender_nickname: str = private_sender.get("nickname", "") + + user_name = private_sender_nickname + if not user_name: + try: + user_info = await self.onebot.get_stranger_info(private_sender_id) + if user_info: + user_name = user_info.get("nickname", "") + except Exception as exc: + logger.warning("获取用户昵称失败: %s", exc) + + text = handlers_module.extract_text(private_message_content, self.config.bot_qq) + private_attachments, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + private_message_content, + user_id=private_sender_id, + request_type="private", + ), + handlers_module.parse_message_content_for_history( + private_message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + safe_text = redact_string(text) + logger.info( + "[私聊消息] 发送者=%s 昵称=%s 内容=%s", + private_sender_id, + user_name or private_sender_nickname, + safe_text[:100], + ) + resolved_private_name = (user_name or private_sender_nickname or "").strip() + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_private:{private_sender_id}", + sender_id=private_sender_id, + sender_name=resolved_private_name, + ) + + parsed_content = append_attachment_text(parsed_content_raw, private_attachments) + safe_parsed = redact_string(parsed_content) + logger.debug( + "[历史记录] 保存私聊: user=%s content=%s...", + private_sender_id, + safe_parsed[:50], + ) + await self.history_manager.add_private_message( + user_id=private_sender_id, + text_content=parsed_content, + display_name=private_sender_nickname, + user_name=user_name, + message_id=trigger_message_id, + attachments=private_attachments, + ) + + # 机器人自身消息只写历史,不触发后续自动回复/入库 + if private_sender_id == self.config.bot_qq: + return + + self._schedule_meme_ingest( + attachments=private_attachments, + chat_type="private", + chat_id=private_sender_id, + sender_id=private_sender_id, + message_id=safe_int(trigger_message_id), + scope_key=build_attachment_scope( + user_id=private_sender_id, + request_type="private", + ), + ) + + if not self.config.should_process_private_message(): + logger.debug( + "[消息策略] 已关闭私聊处理: user=%s", + private_sender_id, + ) + return + + # 多模型池控制指令优先于斜杠命令与 AI 回复 + if ( + getattr(self.config, "model_pool_enabled", False) + and _is_private_model_pool_control_text(text) + ) and await self.ai_coordinator.model_pool.handle_private_message( + private_sender_id, + text, + ): + return + + private_command = self.command_dispatcher.parse_command(text) + if private_command: + await self._flush_command_buffer( + scope=make_scope(user_id=private_sender_id), + sender_id=private_sender_id, + ) + await self.command_dispatcher.dispatch_private( + user_id=private_sender_id, + sender_id=private_sender_id, + command=private_command, + ) + return + + await self._run_pipelines( + target_id=private_sender_id, + target_type="private", + text=text, + message_content=private_message_content, + ) + + await self.ai_coordinator.handle_private_reply( + private_sender_id, + text, + private_message_content, + attachments=private_attachments, + sender_name=user_name, + trigger_message_id=trigger_message_id, + ) + + async def _handle_group_message(self, event: dict[str, Any]) -> None: + """处理群聊消息事件。""" + group_id: int = event.get("group_id", 0) + sender_id: int = get_message_sender_id(event) + message_content: list[dict[str, Any]] = get_message_content(event) + trigger_message_id = event.get("message_id") + + if not self.config.is_group_allowed(group_id): + group_reason = self.config.group_access_denied_reason(group_id) or "unknown" + logger.debug( + "[访问控制] 忽略群消息: group=%s sender=%s reason=%s (access enabled=%s)", + group_id, + sender_id, + group_reason, + self.config.access_control_enabled(), + ) + return + + group_sender: dict[str, Any] = event.get("sender", {}) + sender_card: str = group_sender.get("card", "") + sender_nickname: str = group_sender.get("nickname", "") + sender_role: str = group_sender.get("role", "member") + sender_title: str = group_sender.get("title", "") + sender_level: str = str(group_sender.get("level", "")).strip() + + text = handlers_module.extract_text(message_content, self.config.bot_qq) + safe_text = redact_string(text) + logger.info( + f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " + f"role={sender_role} | {safe_text[:100]}" + ) + + async def _fetch_group_name() -> str: + try: + info = await self.onebot.get_group_info(group_id) + if info: + return str(info.get("group_name", "") or "") + except Exception as e: + logger.warning(f"获取群聊名失败: {e}") + return "" + + group_attachments, group_name, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + message_content, + group_id=group_id, + request_type="group", + ), + _fetch_group_name(), + handlers_module.parse_message_content_for_history( + message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + + resolved_group_sender_name = (sender_card or sender_nickname or "").strip() + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", + sender_id=sender_id, + sender_name=resolved_group_sender_name, + group_id=group_id, + group_name=str(group_name or "").strip(), + ) + + parsed_content = append_attachment_text(parsed_content_raw, group_attachments) + safe_parsed = redact_string(parsed_content) + logger.debug( + f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." + ) + await self.history_manager.add_group_message( + group_id=group_id, + sender_id=sender_id, + text_content=parsed_content, + sender_card=sender_card, + sender_nickname=sender_nickname, + group_name=group_name, + role=sender_role, + title=sender_title, + level=sender_level, + message_id=trigger_message_id, + attachments=group_attachments, + ) + + # 机器人发言计入复读计数,防止 bot 复读自身 + if sender_id == self.config.bot_qq: + await self._append_bot_repeat_counter(group_id, text) + return + + self._schedule_meme_ingest( + attachments=group_attachments, + chat_type="group", + chat_id=group_id, + sender_id=sender_id, + message_id=safe_int(trigger_message_id), + scope_key=build_attachment_scope(group_id=group_id, request_type="group"), + ) + + is_at_bot = self.ai_coordinator._is_at_bot(message_content) + + # 文本 @ 未命中 CQ at 段时,尝试识别「假 @」昵称 + is_fake_at = False + normalized_text = text + if not is_at_bot and ("@" in text or "@" in text): + nicknames = await self._bot_nickname_cache.get_nicknames(group_id) + if nicknames: + is_fake_at, normalized_text = strip_fake_at(text, nicknames) + if is_fake_at: + is_at_bot = True + logger.info( + "[假@] 识别到假@: group=%s sender=%s", + group_id, + sender_id, + ) + + if not self.config.should_process_group_message(is_at_bot=is_at_bot): + logger.debug( + "[消息策略] 跳过群消息处理: group=%s sender=%s process_every_message=%s at_bot=%s", + group_id, + sender_id, + self.config.process_every_message, + is_at_bot, + ) + return + + # 斜杠命令仅在 @bot 时生效;未 @ 时不拦截普通群聊 + if is_at_bot: + command = self.command_dispatcher.parse_command(normalized_text) + if command: + await self._flush_command_buffer( + scope=make_scope(group_id=group_id), + sender_id=sender_id, + ) + await self.command_dispatcher.dispatch(group_id, sender_id, command) + return + + if self.config.keyword_reply_enabled and handlers_module.matches_xinliweiyuan( + text + ): + if await self._handle_keyword_reply(group_id, sender_id): + return + + # 复读命中则跳过管线与 AI 自动回复 + if await self._maybe_trigger_repeat(group_id, sender_id, text): + return + + await self._run_pipelines( + target_id=group_id, + target_type="group", + text=text, + message_content=message_content, + ) + + display_name = sender_card or sender_nickname or str(sender_id) + await self.ai_coordinator.handle_auto_reply( + group_id, + sender_id, + normalized_text, + message_content, + attachments=group_attachments, + sender_name=display_name, + group_name=group_name, + sender_role=sender_role, + sender_title=sender_title, + sender_level=sender_level, + trigger_message_id=trigger_message_id, + is_fake_at=is_fake_at, + ) + + async def _handle_keyword_reply(self, group_id: int, sender_id: int) -> bool: + """处理心理委员关键词自动回复;若已发送回复则返回 True。""" + rand_val = random.random() + if rand_val < 0.01: + message = f"[@{sender_id}] 再发让你飞起来" + logger.info("关键词回复: 再发让你飞起来") + await self.sender.send_group_message( + group_id, + message, + history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, + ) + return True + if rand_val < 0.11: + try: + image_path = resolve_resource_path("img/xlwy.jpg").resolve().as_uri() + except Exception: + image_path = Path(os.path.abspath("img/xlwy.jpg")).as_uri() + message = f"[CQ:image,file={image_path}]" + if random.random() < 0.5: + message = f"[@{sender_id}] {message}" + logger.info("关键词回复: 发送图片 xlwy.jpg") + else: + if random.random() < 0.7: + reply = "受着" + else: + reply = "那咋了" + if random.random() < 0.5: + message = f"[@{sender_id}] {reply}" + else: + message = reply + logger.info(f"关键词回复: {reply}") + await self.sender.send_group_message( + group_id, + message, + history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, + ) + return True + + async def _flush_command_buffer(self, *, scope: str, sender_id: int) -> None: + batcher_config = getattr(self.config, "message_batcher", None) + if not getattr(batcher_config, "flush_on_command", False): + return + batcher = getattr(self, "message_batcher", None) + if batcher is None: + return + # 斜杠命令命中时强制 flush 未合并 buffer,避免命令与待发 batch 交错 + flushed = await batcher.flush_sender(scope, sender_id) + if not flushed: + logger.warning( + "[MessageBatcher] 命令触发 flush 当前 buffer 失败: scope=%s sender=%s", + scope, + sender_id, + ) + + async def _run_pipelines( + self, + *, + target_id: int, + target_type: Literal["group", "private"], + text: str, + message_content: list[dict[str, Any]], + ) -> bool: + """并行检测并处理所有命中的自动处理管线。""" + if not getattr(self, "_pipelines_initialized", False): + await self.init_pipelines() + context = build_pipeline_context( + self, + target_id=target_id, + target_type=target_type, + text=text, + message_content=message_content, + ) + detections = await self.pipeline_registry.run(context) + return bool(detections) + + async def apply_skills_hot_reload_config( + self, + *, + enabled: bool, + interval: float, + debounce: float, + ) -> None: + """跟随全局 skills 热重载配置更新管线。""" + if not enabled: + await self.pipeline_registry.stop_hot_reload() + logger.info("[pipelines] 热重载已随配置禁用") + return + + await self.pipeline_registry.stop_hot_reload() + self.pipeline_registry.start_hot_reload( + interval=interval, + debounce=debounce, + ) + + def _spawn_background_task( + self, + name: str, + coroutine: Coroutine[Any, Any, None], + ) -> None: + task = asyncio.create_task(coroutine, name=name) + self._background_tasks.add(task) + + def _finalize(done_task: asyncio.Task[None]) -> None: + self._background_tasks.discard(done_task) + try: + exc = done_task.exception() + except asyncio.CancelledError: + logger.debug("[后台任务] 已取消: %s", name) + return + if exc is not None: + logger.exception( + "[后台任务] 执行失败: name=%s", + name, + exc_info=(type(exc), exc, exc.__traceback__), + ) + + task.add_done_callback(_finalize) + + async def close(self) -> None: + """关闭消息处理器。""" + logger.info("正在关闭消息处理器...") + if self._background_tasks: + logger.info( + "[后台任务] 等待自动提取任务收敛: count=%s", + len(self._background_tasks), + ) + await asyncio.gather( + *list(self._background_tasks), + return_exceptions=True, + ) + await self.pipeline_registry.stop_hot_reload() + await self.message_batcher.flush_all() + # 关闭前排空 AI 队列并落盘历史,避免丢回复/丢记录 + await self.ai_coordinator.queue_manager.drain() + await self.ai_coordinator.queue_manager.stop() + await self.history_manager.flush_pending_saves() + logger.info("消息处理器已关闭") diff --git a/src/Undefined/handlers/poke.py b/src/Undefined/handlers/poke.py new file mode 100644 index 00000000..06ba2738 --- /dev/null +++ b/src/Undefined/handlers/poke.py @@ -0,0 +1,293 @@ +"""拍一拍(poke)通知处理 mixin。 + +负责私聊/群聊拍一拍历史写入与 AI 回复触发;由 ``MessageHandler`` 混入使用。 +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.onebot import OneBotClient + from Undefined.services.ai_coordinator import AICoordinator + from Undefined.utils.history import MessageHistoryManager + +logger = logging.getLogger(__name__) + + +def _format_poke_history_text(display_name: str, user_id: int) -> str: + """格式化拍一拍历史文本。""" + return f"{display_name}(暱称)[{user_id}(QQ号)] 拍了拍你。" + + +@dataclass(frozen=True) +class PrivatePokeRecord: + """私聊拍一拍历史记录摘要。""" + + poke_text: str + sender_name: str + + +@dataclass(frozen=True) +class GroupPokeRecord: + """群聊拍一拍历史记录摘要。""" + + poke_text: str + sender_name: str + group_name: str + sender_role: str + sender_title: str + sender_level: str + + +class PokeMixin: + """拍一拍事件处理 mixin。""" + + if TYPE_CHECKING: + config: Config + onebot: OneBotClient + ai_coordinator: AICoordinator + history_manager: MessageHistoryManager + + def _schedule_profile_display_name_refresh( + self, + *, + task_name: str, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: ... + + async def _handle_poke_notice(self, event: dict[str, Any]) -> None: + """处理拍一拍通知并触发对应私聊/群聊 AI 回复。""" + # 仅处理拍机器人自身的 poke + target_id = event.get("target_id", 0) + if target_id != self.config.bot_qq: + logger.debug( + "[通知] 忽略拍一拍目标非机器人: target=%s", + target_id, + ) + return + + if not self.config.should_process_poke_message(): + logger.debug("[消息策略] 已关闭拍一拍处理,忽略此次 poke 事件") + return + + poke_group_id: int = event.get("group_id", 0) + poke_sender_id: int = event.get("user_id", 0) + + if poke_group_id == 0: + # group_id=0 表示私聊拍一拍 + if not self.config.is_private_allowed(poke_sender_id): + private_reason = ( + self.config.private_access_denied_reason(poke_sender_id) + or "unknown" + ) + logger.debug( + "[访问控制] 忽略私聊拍一拍: user=%s reason=%s (access enabled=%s)", + poke_sender_id, + private_reason, + self.config.access_control_enabled(), + ) + return + else: + if not self.config.is_group_allowed(poke_group_id): + group_reason = ( + self.config.group_access_denied_reason(poke_group_id) or "unknown" + ) + logger.debug( + "[访问控制] 忽略群聊拍一拍: group=%s sender=%s reason=%s (access enabled=%s)", + poke_group_id, + poke_sender_id, + group_reason, + self.config.access_control_enabled(), + ) + return + + logger.info( + "[通知] 收到拍一拍: group=%s sender=%s", + poke_group_id, + poke_sender_id, + ) + logger.debug("[通知] 拍一拍事件数据: %s", str(event)[:200]) + + if poke_group_id == 0: + private_poke = await self._record_private_poke_history( + poke_sender_id, event + ) + logger.info("[通知] 私聊拍一拍,触发私聊回复") + # 拍一拍旁路 MessageBatcher,直接走 mention 级队列 + await self.ai_coordinator.handle_private_reply( + poke_sender_id, + private_poke.poke_text, + [], + is_poke=True, + sender_name=private_poke.sender_name, + ) + else: + group_poke = await self._record_group_poke_history( + poke_group_id, + poke_sender_id, + event, + ) + logger.info( + "[通知] 群聊拍一拍,触发群聊回复: group=%s", + poke_group_id, + ) + await self.ai_coordinator.handle_auto_reply( + poke_group_id, + poke_sender_id, + group_poke.poke_text, + [], + is_poke=True, + sender_name=group_poke.sender_name, + group_name=group_poke.group_name, + sender_role=group_poke.sender_role, + sender_title=group_poke.sender_title, + sender_level=group_poke.sender_level, + ) + + async def _record_private_poke_history( + self, user_id: int, event: dict[str, Any] + ) -> PrivatePokeRecord: + """记录私聊拍一拍到历史。""" + sender = event.get("sender", {}) + sender_nickname = "" + if isinstance(sender, dict): + sender_nickname = str(sender.get("nickname", "")).strip() + + user_name = sender_nickname + if not user_name: + try: + user_info = await self.onebot.get_stranger_info(user_id) + if isinstance(user_info, dict): + user_name = str(user_info.get("nickname", "")).strip() + except Exception as exc: + logger.warning( + "[通知] 获取私聊拍一拍用户昵称失败: user=%s err=%s", + user_id, + exc, + ) + + resolved_sender_name = (sender_nickname or user_name).strip() + display_name = resolved_sender_name or f"QQ{user_id}" + normalized_user_name = user_name or display_name + poke_text = _format_poke_history_text(display_name, user_id) + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_private_poke:{user_id}", + sender_id=user_id, + sender_name=resolved_sender_name, + ) + + try: + await self.history_manager.add_private_message( + user_id=user_id, + text_content=poke_text, + display_name=display_name, + user_name=normalized_user_name, + ) + except Exception as exc: + logger.warning( + "[历史记录] 写入私聊拍一拍失败: user=%s err=%s", + user_id, + exc, + ) + return PrivatePokeRecord(poke_text=poke_text, sender_name=display_name) + + async def _record_group_poke_history( + self, + group_id: int, + sender_id: int, + event: dict[str, Any], + ) -> GroupPokeRecord: + """记录群聊拍一拍到历史。""" + sender = event.get("sender", {}) + sender_card = "" + sender_nickname = "" + sender_role = "member" + sender_title = "" + sender_level = "" + if isinstance(sender, dict): + sender_card = str(sender.get("card", "")).strip() + sender_nickname = str(sender.get("nickname", "")).strip() + sender_role = str(sender.get("role", "member")).strip() or "member" + sender_title = str(sender.get("title", "")).strip() + sender_level = str(sender.get("level", "")).strip() + + if not sender_card and not sender_nickname: + try: + member_info = await self.onebot.get_group_member_info( + group_id, sender_id + ) + if isinstance(member_info, dict): + sender_card = str(member_info.get("card", "")).strip() + sender_nickname = str(member_info.get("nickname", "")).strip() + sender_role = ( + str(member_info.get("role", "member")).strip() or "member" + ) + sender_title = str(member_info.get("title", "")).strip() + sender_level = str(member_info.get("level", "")).strip() + except Exception as exc: + logger.warning( + "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", + group_id, + sender_id, + exc, + ) + + group_name = "" + try: + group_info = await self.onebot.get_group_info(group_id) + if isinstance(group_info, dict): + group_name = str(group_info.get("group_name", "")).strip() + except Exception as exc: + logger.warning( + "[通知] 获取拍一拍群名失败: group=%s err=%s", + group_id, + exc, + ) + + resolved_sender_name = (sender_card or sender_nickname).strip() + resolved_group_name = group_name.strip() + display_name = resolved_sender_name or f"QQ{sender_id}" + poke_text = _format_poke_history_text(display_name, sender_id) + normalized_group_name = resolved_group_name or f"群{group_id}" + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_group_poke:{group_id}:{sender_id}", + sender_id=sender_id, + sender_name=resolved_sender_name, + group_id=group_id, + group_name=resolved_group_name, + ) + + try: + await self.history_manager.add_group_message( + group_id=group_id, + sender_id=sender_id, + text_content=poke_text, + sender_card=sender_card, + sender_nickname=sender_nickname, + group_name=normalized_group_name, + role=sender_role, + title=sender_title, + level=sender_level, + ) + except Exception as exc: + logger.warning( + "[历史记录] 写入群聊拍一拍失败: group=%s sender=%s err=%s", + group_id, + sender_id, + exc, + ) + return GroupPokeRecord( + poke_text=poke_text, + sender_name=display_name, + group_name=normalized_group_name, + sender_role=sender_role, + sender_title=sender_title, + sender_level=sender_level, + ) diff --git a/src/Undefined/handlers/repeat.py b/src/Undefined/handlers/repeat.py new file mode 100644 index 00000000..5307e792 --- /dev/null +++ b/src/Undefined/handlers/repeat.py @@ -0,0 +1,146 @@ +"""群聊复读功能 mixin。 + +按群跟踪连续相同消息并在阈值满足时复读;由 ``MessageHandler`` 混入使用。 +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.utils.sender import MessageSender + +from Undefined.utils.logging import redact_string + +logger = logging.getLogger(__name__) + +REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " + + +class RepeatMixin: + """群聊复读计数与触发 mixin。""" + + if TYPE_CHECKING: + config: Config + sender: MessageSender + _repeat_counter: dict[int, list[tuple[str, int]]] + _repeat_locks: dict[int, asyncio.Lock] + _repeat_cooldown: dict[int, dict[str, float]] + + def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: + """获取或创建指定群的复读竞态保护锁。""" + lock = self._repeat_locks.get(group_id) + if lock is None: + lock = asyncio.Lock() + self._repeat_locks[group_id] = lock + return lock + + @staticmethod + def _normalize_repeat_text(text: str) -> str: + """规范化复读文本用于冷却比较(?→?)。""" + return text.replace("?", "?") + + def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: + """检查指定群的文本是否在复读冷却期内。""" + cooldown_minutes = self.config.repeat_cooldown_minutes + if cooldown_minutes <= 0: + return False + group_cd = self._repeat_cooldown.get(group_id) + if not group_cd: + return False + key = self._normalize_repeat_text(text) + last_time = group_cd.get(key) + if last_time is None: + return False + return bool((time.monotonic() - last_time) < cooldown_minutes * 60) + + def _record_repeat_cooldown(self, group_id: int, text: str) -> None: + """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + if cooldown_seconds <= 0: + return + key = self._normalize_repeat_text(text) + group_cd = self._repeat_cooldown.setdefault(group_id, {}) + now = time.monotonic() + expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] + for k in expired: + del group_cd[k] + group_cd[key] = now + + async def _append_bot_repeat_counter(self, group_id: int, text: str) -> None: + """将 bot 自身发言写入复读计数器,防止误触复读。""" + if not self.config.repeat_enabled or not text: + return + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, self.config.bot_qq)) + n = self.config.repeat_threshold + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] + + async def _maybe_trigger_repeat( + self, + group_id: int, + sender_id: int, + text: str, + ) -> bool: + """尝试触发群聊复读;若已发送复读消息则返回 True。""" + if not self.config.repeat_enabled or not text: + return False + + n = self.config.repeat_threshold + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] + counter = self._repeat_counter[group_id] + + if len(counter) < n: + return False + + last_n = counter[-n:] + texts = [t for t, _ in last_n] + senders = [s for _, s in last_n] + # 连续 n 条文本相同且来自 n 个不同发送者,且 bot 未参与 + if not ( + len(set(texts)) == 1 + and len(set(senders)) == n + and self.config.bot_qq not in senders + ): + return False + + reply_text = texts[0] + if self._is_repeat_on_cooldown(group_id, reply_text): + # 冷却期内清空计数,避免同一文本反复试探 + self._repeat_counter[group_id] = [] + logger.debug( + "[复读] 冷却中跳过: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + return False + + if self.config.inverted_question_enabled: + stripped = reply_text.strip() + # 纯问号复读时翻转成 ¿ + if set(stripped) <= {"?", "?"}: + reply_text = "¿" * len(stripped) + + self._repeat_counter[group_id] = [] + self._record_repeat_cooldown(group_id, texts[0]) + logger.info( + "[复读] 触发复读: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + await self.sender.send_group_message( + group_id, + reply_text, + history_prefix=REPEAT_REPLY_HISTORY_PREFIX, + ) + return True diff --git a/src/Undefined/memes/_service.py b/src/Undefined/memes/_service.py new file mode 100644 index 00000000..5dab0c77 --- /dev/null +++ b/src/Undefined/memes/_service.py @@ -0,0 +1,269 @@ +"""MemeService 门面类。""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import datetime +import logging +import math +import mimetypes +from pathlib import Path +import re +import threading +from typing import Any + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from PIL import Image + +from Undefined.attachments import AttachmentRecord +from Undefined.memes.models import ( + build_search_text, + normalize_string_list, +) +from Undefined.memes.store import MemeStore +from Undefined.memes.vector_store import MemeVectorStore +from Undefined.utils.paths import ensure_dir +from Undefined.memes.ingest import MemeIngestMixin +from Undefined.memes.search import MemeSearchMixin + +logger = logging.getLogger(__name__) + +_IMAGE_EXTENSIONS_BY_MIME = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", +} +_TAG_SPLIT_RE = re.compile(r"[,,\n]+") + + +def _now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def _guess_suffix(path: Path, mime_type: str) -> str: + suffix = path.suffix.lower() + if suffix: + return suffix + guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) + if guessed: + return guessed + mime_guess = mimetypes.guess_extension(mime_type or "") + if mime_guess: + return mime_guess.lower() + return ".bin" + + +def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: + if raw_tags is None: + return [] + if isinstance(raw_tags, str): + parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] + return normalize_string_list(parts) + return normalize_string_list(raw_tags) + + +def _is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = _sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def _sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + # 去重并保持顺序 + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + +@dataclass +class _IngestDigestLockEntry: + lock: asyncio.Lock + users: int = 0 + + +class MemeService(MemeSearchMixin, MemeIngestMixin): + def __init__( + self, + *, + config_getter: Any, + store: MemeStore, + vector_store: MemeVectorStore, + job_queue: Any | None = None, + ai_client: Any | None = None, + attachment_registry: Any | None = None, + retrieval_runtime: Any | None = None, + ) -> None: + self._config_getter = config_getter + self._store = store + self._vector_store = vector_store + self._job_queue = job_queue + self._ai_client = ai_client + self._attachment_registry = attachment_registry + self._retrieval_runtime = retrieval_runtime + # 同内容 digest 锁:进程内串行入库,防止重复 AI 分析 + self._ingest_digest_locks: dict[str, _IngestDigestLockEntry] = {} + self._ingest_digest_locks_guard = asyncio.Lock() + self._global_image_cache: dict[str, AttachmentRecord] = {} + self._global_image_cache_lock = threading.Lock() + + def enabled(self) -> bool: + cfg = self._config_getter() + return bool(getattr(cfg, "enabled", False)) + + def default_query_mode(self) -> str: + mode = ( + str( + getattr(self._config_getter(), "query_default_mode", "hybrid") + or "hybrid" + ) + .strip() + .lower() + ) + return mode if mode in {"keyword", "semantic", "hybrid"} else "hybrid" + + def _cfg(self) -> Any: + return self._config_getter() + + def _blob_dir(self) -> Path: + return ensure_dir(Path(self._cfg().blob_dir)) + + def _preview_dir(self) -> Path: + return ensure_dir(Path(self._cfg().preview_dir)) + + def _queue_enabled(self) -> bool: + return self._job_queue is not None + + def _invalidate_global_image_cache(self, uid: str) -> None: + normalized_uid = str(uid or "").strip() + if not normalized_uid: + return + with self._global_image_cache_lock: + self._global_image_cache.pop(normalized_uid, None) + + async def update_meme( + self, + uid: str, + *, + manual_description: str | None = None, + tags: list[str] | str | None = None, + aliases: list[str] | str | None = None, + enabled: bool | None = None, + pinned: bool | None = None, + ) -> dict[str, Any] | None: + record = await self._store.get(uid) + if record is None: + return None + + next_tags = list(record.tags) if tags is None else _normalize_tags(tags) + next_aliases = ( + list(record.aliases) if aliases is None else _normalize_tags(aliases) + ) + next_manual = ( + record.manual_description + if manual_description is None + else str(manual_description or "").strip() + ) + next_enabled = record.enabled if enabled is None else bool(enabled) + next_pinned = record.pinned if pinned is None else bool(pinned) + next_search_text = build_search_text( + manual_description=next_manual, + auto_description=record.auto_description, + ocr_text="", + tags=next_tags, + aliases=next_aliases, + ) + + updated = await self._store.update_fields( + uid, + { + "manual_description": next_manual, + "tags_json": next_tags, + "aliases_json": next_aliases, + "enabled": next_enabled, + "pinned": next_pinned, + "search_text": next_search_text, + "updated_at": _now_iso(), + }, + ) + if updated is None: + return None + self._invalidate_global_image_cache(uid) + await self._vector_store.upsert(updated) + return self.serialize_record(updated) + + async def delete_meme(self, uid: str) -> bool: + record = await self._store.delete(uid) + if record is None: + return False + self._invalidate_global_image_cache(uid) + await self._vector_store.delete(uid) + await asyncio.to_thread(self._delete_file_if_exists, Path(record.blob_path)) + if record.preview_path and record.preview_path != record.blob_path: + await asyncio.to_thread( + self._delete_file_if_exists, + Path(record.preview_path), + ) + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + return True diff --git a/src/Undefined/memes/ingest.py b/src/Undefined/memes/ingest.py new file mode 100644 index 00000000..0341d4bb --- /dev/null +++ b/src/Undefined/memes/ingest.py @@ -0,0 +1,734 @@ +"""MemeService 入库与后台任务处理。""" + +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from dataclasses import dataclass, replace +from datetime import datetime +import hashlib +import logging +import math +import mimetypes +from pathlib import Path +import re +import shutil +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from PIL import Image + +from Undefined.memes.models import ( + MemeRecord, + MemeSourceRecord, + build_search_text, + normalize_string_list, +) + +if TYPE_CHECKING: + from Undefined.memes.store import MemeStore + from Undefined.memes.vector_store import MemeVectorStore + +logger = logging.getLogger(__name__) + +_IMAGE_EXTENSIONS_BY_MIME = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", +} +_TAG_SPLIT_RE = re.compile(r"[,,\n]+") + + +def _now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def _guess_suffix(path: Path, mime_type: str) -> str: + suffix = path.suffix.lower() + if suffix: + return suffix + guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) + if guessed: + return guessed + mime_guess = mimetypes.guess_extension(mime_type or "") + if mime_guess: + return mime_guess.lower() + return ".bin" + + +def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: + if raw_tags is None: + return [] + if isinstance(raw_tags, str): + parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] + return normalize_string_list(parts) + return normalize_string_list(raw_tags) + + +def _is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = _sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def _sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + # 去重并保持顺序 + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + +@dataclass +class _IngestDigestLockEntry: + lock: asyncio.Lock + users: int = 0 + + +class MemeIngestMixin: + if TYPE_CHECKING: + _ai_client: Any | None + _attachment_registry: Any | None + _ingest_digest_locks: dict[str, Any] + _ingest_digest_locks_guard: asyncio.Lock + _job_queue: Any | None + _store: MemeStore + _vector_store: MemeVectorStore + + def _blob_dir(self) -> Path: ... + def _cfg(self) -> Any: ... + def _invalidate_global_image_cache(self, uid: str) -> None: ... + def _preview_dir(self) -> Path: ... + def _queue_enabled(self) -> bool: ... + async def delete_meme(self, uid: str) -> bool: ... + def enabled(self) -> bool: ... + + async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEntry: + async with self._ingest_digest_locks_guard: + entry = self._ingest_digest_locks.get(digest) + if entry is None: + entry = _IngestDigestLockEntry(lock=asyncio.Lock()) + self._ingest_digest_locks[digest] = entry + entry.users += 1 + try: + await entry.lock.acquire() + except BaseException: + await self._release_ingest_digest_lock_reference(digest, entry) + raise + return entry + + async def _release_ingest_digest_lock_reference( + self, + digest: str, + entry: _IngestDigestLockEntry, + *, + release_lock: bool = False, + ) -> None: + if release_lock and entry.lock.locked(): + entry.lock.release() + async with self._ingest_digest_locks_guard: + entry.users = max(0, entry.users - 1) + current = self._ingest_digest_locks.get(digest) + if current is entry and entry.users == 0 and not entry.lock.locked(): + self._ingest_digest_locks.pop(digest, None) + + def _delete_file_if_exists(self, path: Path) -> None: + try: + path.unlink(missing_ok=True) + except OSError: + logger.debug("[memes] 删除文件失败: path=%s", path, exc_info=True) + + def _cleanup_gif_frame_files(self, uid: str) -> None: + """清理 GIF 多帧分析产生的临时帧文件 ({uid}_f{i}.png)。""" + preview_dir = self._preview_dir() + for frame_file in preview_dir.glob(f"{uid}_f*.png"): + try: + frame_file.unlink(missing_ok=True) + except OSError: + logger.debug( + "[memes] 删除帧文件失败: path=%s", frame_file, exc_info=True + ) + + async def _cleanup_meme_artifacts( + self, + *, + uid: str | None, + blob_path: Path, + preview_path: Path | None, + ) -> None: + if uid: + try: + await self._store.delete(uid) + self._invalidate_global_image_cache(uid) + except Exception: + logger.exception( + "[memes] 清理记录失败: uid=%s", + uid, + ) + try: + await self._vector_store.delete(uid) + except Exception: + logger.exception( + "[memes] 清理向量索引失败: uid=%s", + uid, + ) + await asyncio.to_thread(self._delete_file_if_exists, blob_path) + if preview_path is not None and preview_path != blob_path: + await asyncio.to_thread(self._delete_file_if_exists, preview_path) + if uid: + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + + async def enqueue_incoming_attachments( + self, + *, + attachments: list[dict[str, str]], + chat_type: str, + chat_id: int, + sender_id: int, + message_id: int | None, + scope_key: str, + ) -> None: + if not self.enabled() or not self._queue_enabled(): + return + cfg = self._cfg() + if chat_type == "group" and not bool(cfg.auto_ingest_group): + return + if chat_type == "private" and not bool(cfg.auto_ingest_private): + return + + for item in attachments: + media_type = str(item.get("media_type") or item.get("kind") or "").strip() + uid = str(item.get("uid") or "").strip() + if media_type != "image" or not uid: + continue + job = { + "request_id": f"meme_ingest_{uid}", + "kind": "ingest", + "attachment_uid": uid, + "scope_key": scope_key, + "chat_type": chat_type, + "chat_id": str(chat_id), + "sender_id": str(sender_id), + "message_id": str(message_id or ""), + "queued_at": _now_iso(), + } + queue = self._job_queue + if queue is None: + return + await queue.enqueue(job) + + async def enqueue_reanalyze(self, uid: str) -> str | None: + if not self._queue_enabled(): + return None + queue = self._job_queue + if queue is None: + return None + result = await queue.enqueue( + { + "request_id": f"meme_reanalyze_{uid}", + "kind": "reanalyze", + "uid": uid, + "queued_at": _now_iso(), + } + ) + return str(result) + + async def enqueue_reindex(self, uid: str) -> str | None: + if not self._queue_enabled(): + return None + queue = self._job_queue + if queue is None: + return None + result = await queue.enqueue( + { + "request_id": f"meme_reindex_{uid}", + "kind": "reindex", + "uid": uid, + "queued_at": _now_iso(), + } + ) + return str(result) + + async def process_job(self, job: Mapping[str, Any]) -> None: + kind = str(job.get("kind") or "").strip().lower() + if kind == "ingest": + await self._process_ingest_job(job) + return + if kind == "reanalyze": + await self._process_reanalyze_job(job) + return + if kind == "reindex": + await self._process_reindex_job(job) + return + raise ValueError(f"unsupported meme job kind: {kind}") + + async def _process_reindex_job(self, job: Mapping[str, Any]) -> None: + uid = str(job.get("uid") or "").strip() + if not uid: + return + record = await self._store.get(uid) + if record is None: + return + await self._vector_store.upsert(record) + + async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: + uid = str(job.get("uid") or "").strip() + if not uid: + return + record = await self._store.get(uid) + if record is None: + return + if self._ai_client is None: + raise RuntimeError("reanalyze requires ai_client") + analyze_path: str | list[str] = ( + record.preview_path if record.preview_path else record.blob_path + ) + # GIF 多帧模式:与 ingest 路径保持一致 + if record.is_animated: + cfg = self._cfg() + if str(getattr(cfg, "gif_analysis_mode", "grid")).lower() == "multi": + analyze_path = await self._prepare_gif_multi_frames( + Path(record.blob_path), uid + ) + try: + judgement = await self._ai_client.judge_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] judge stage failed during reanalyze: uid=%s err=%s", uid, exc + ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + return + if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + await self.delete_meme(uid) + return + try: + described = await self._ai_client.describe_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] describe stage failed during reanalyze: uid=%s err=%s", + uid, + exc, + ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + return + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + auto_description = str(described.get("description") or "").strip() + next_tags = _normalize_tags(described.get("tags")) + if not auto_description and not next_tags: + logger.warning( + "[memes] reanalyze describe failed, skip update: uid=%s", uid + ) + return + next_record = replace( + record, + auto_description=auto_description, + ocr_text="", + tags=next_tags, + search_text=build_search_text( + manual_description=record.manual_description, + auto_description=auto_description, + ocr_text="", + tags=next_tags, + aliases=record.aliases, + ), + updated_at=_now_iso(), + ) + saved = await self._store.upsert_record(next_record) + self._invalidate_global_image_cache(saved.uid) + await self._vector_store.upsert(saved) + + async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: + if self._attachment_registry is None: + raise RuntimeError("ingest requires attachment_registry") + if self._ai_client is None: + raise RuntimeError("ingest requires ai_client") + + attachment_uid = str(job.get("attachment_uid") or "").strip() + scope_key = str(job.get("scope_key") or "").strip() or None + if not attachment_uid: + return + attachment = self._attachment_registry.resolve(attachment_uid, scope_key) + if attachment is None: + raise FileNotFoundError(f"attachment uid unavailable: {attachment_uid}") + if str(attachment.media_type).lower() != "image": + return + source_path = Path(str(attachment.local_path or "")) + if not source_path.is_file(): + raise FileNotFoundError(source_path) + file_size = source_path.stat().st_size + cfg = self._cfg() + if file_size > int(cfg.max_source_image_bytes): + logger.info( + "[memes] skip oversized image: uid=%s size=%s limit=%s", + attachment_uid, + file_size, + cfg.max_source_image_bytes, + ) + return + + digest = await asyncio.to_thread(self._hash_file, source_path) + # 同一 SHA256 并发入库串行化,避免重复 AI 判定 + digest_lock_entry = await self._acquire_ingest_digest_lock(digest) + try: + existing = await self._store.find_by_sha256(digest) + if existing is not None and not Path(existing.blob_path).is_file(): + logger.warning( + "[memes] 检测到孤儿记录,删除后重新入库: uid=%s blob_path=%s", + existing.uid, + existing.blob_path, + ) + await self._cleanup_meme_artifacts( + uid=existing.uid, + blob_path=Path(existing.blob_path), + preview_path=( + Path(existing.preview_path) if existing.preview_path else None + ), + ) + existing = await self._store.find_by_sha256(digest) + if existing is not None and not Path(existing.blob_path).is_file(): + raise RuntimeError( + f"stale meme record cleanup failed: uid={existing.uid}" + ) + source = MemeSourceRecord( + uid=existing.uid if existing is not None else "", + source_type="message_attachment", + chat_type=str(job.get("chat_type") or ""), + chat_id=str(job.get("chat_id") or ""), + sender_id=str(job.get("sender_id") or ""), + message_id=str(job.get("message_id") or ""), + attachment_uid=attachment_uid, + source_url=str(attachment.source_ref or ""), + seen_at=_now_iso(), + ) + if existing is not None: + # 内容已存在:仅追加来源记录并刷新向量索引 + await self._store.add_source(replace(source, uid=existing.uid)) + await self._vector_store.upsert(existing) + return + + with Image.open(source_path) as image: + width, height = image.size + is_animated = bool(getattr(image, "is_animated", False)) + if is_animated and not bool(cfg.allow_gif): + return + + uid = await self._generate_uid() + suffix = _guess_suffix(source_path, str(attachment.mime_type or "")) + blob_path = self._blob_dir() / f"{uid}{suffix}" + cleanup_preview_path = ( + self._preview_dir() / f"{uid}.png" if is_animated else blob_path + ) + persisted_uid: str | None = None + + try: + preview_path = await self._prepare_blob_and_preview( + source_path=source_path, + target_uid=uid, + suffix=suffix, + is_animated=is_animated, + ) + if preview_path is not None: + cleanup_preview_path = preview_path + mime_type = str( + attachment.mime_type + or mimetypes.guess_type(source_path.name)[0] + or "application/octet-stream" + ) + analyze_path: str | list[str] = str( + preview_path if preview_path is not None else blob_path + ) + if ( + is_animated + and str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + == "multi" + ): + analyze_path = await self._prepare_gif_multi_frames( + source_path, uid + ) + try: + judgement = await self._ai_client.judge_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] judge stage failed, treat as non-meme: uid=%s err=%s", + uid, + exc, + ) + judgement = {"is_meme": False} + if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + # 非表情包:清理已落盘文件,不入库 + await self._cleanup_meme_artifacts( + uid=None, + blob_path=blob_path, + preview_path=cleanup_preview_path, + ) + return + + try: + described = await self._ai_client.describe_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] describe stage failed, drop uid=%s err=%s", uid, exc + ) + described = {"description": "", "tags": []} + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + tags = _normalize_tags(described.get("tags")) + auto_description = str(described.get("description") or "").strip() + if not auto_description and not tags: + logger.warning( + "[memes] describe stage returned empty result, drop uid=%s", uid + ) + await self._cleanup_meme_artifacts( + uid=None, + blob_path=blob_path, + preview_path=cleanup_preview_path, + ) + return + now = _now_iso() + record = MemeRecord( + uid=uid, + content_sha256=digest, + blob_path=str(blob_path), + preview_path=( + str(preview_path) if preview_path is not None else None + ), + mime_type=mime_type, + file_size=file_size, + width=width, + height=height, + is_animated=is_animated, + enabled=True, + pinned=False, + auto_description=auto_description, + manual_description="", + ocr_text="", + tags=tags, + aliases=[], + search_text=build_search_text( + manual_description="", + auto_description=auto_description, + ocr_text="", + tags=tags, + aliases=[], + ), + use_count=0, + last_used_at="", + created_at=now, + updated_at=now, + status="ready", + segment_data={"subType": "1"}, + ) + saved = await self._store.upsert_record(record) + self._invalidate_global_image_cache(saved.uid) + persisted_uid = saved.uid + await self._store.add_source(replace(source, uid=saved.uid)) + await self._vector_store.upsert(saved) + except Exception: + await self._cleanup_meme_artifacts( + uid=persisted_uid, + blob_path=blob_path, + preview_path=cleanup_preview_path, + ) + raise + finally: + await self._release_ingest_digest_lock_reference( + digest, + digest_lock_entry, + release_lock=True, + ) + await self._prune_if_needed() + + async def _prepare_blob_and_preview( + self, + *, + source_path: Path, + target_uid: str, + suffix: str, + is_animated: bool, + ) -> Path | None: + blob_path = self._blob_dir() / f"{target_uid}{suffix}" + + def _copy() -> None: + shutil.copy2(source_path, blob_path) + + await asyncio.to_thread(_copy) + if not is_animated: + return blob_path + + cfg = self._cfg() + mode = str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) + preview_path = self._preview_dir() / f"{target_uid}.png" + + def _render_preview() -> None: + frames = _extract_gif_frames(source_path, n_frames) + if mode == "multi": + frames[0].save(preview_path, format="PNG") + else: + _compose_grid(frames, preview_path) + for f in frames: + f.close() + + await asyncio.to_thread(_render_preview) + return preview_path + + async def _prepare_gif_multi_frames( + self, source_path: Path, target_uid: str + ) -> list[str]: + """multi 模式:将 GIF 各帧单独保存为 PNG,返回路径列表。""" + cfg = self._cfg() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) + preview_dir = self._preview_dir() + + def _render_frames() -> list[str]: + frames = _extract_gif_frames(source_path, n_frames) + paths: list[str] = [] + for i, frame in enumerate(frames): + p = preview_dir / f"{target_uid}_f{i}.png" + frame.save(p, format="PNG") + frame.close() + paths.append(str(p)) + return paths + + return await asyncio.to_thread(_render_frames) + + def _hash_file(self, path: Path) -> str: + hasher = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + break + hasher.update(chunk) + return hasher.hexdigest() + + async def _generate_uid(self) -> str: + while True: + candidate = f"pic_{uuid4().hex[:8]}" + if await self._store.get(candidate) is not None: + continue + if ( + self._attachment_registry is not None + and self._attachment_registry.get(candidate) is not None + ): + continue + return candidate + + async def _prune_if_needed(self) -> None: + stats = await self._store.stats() + cfg = self._cfg() + total_count = int(stats.get("total_count", 0)) + total_bytes = int(stats.get("total_bytes", 0)) + if total_count <= int(cfg.max_items) and total_bytes <= int( + cfg.max_total_bytes + ): + return + candidates = await self._store.list_prune_candidates() + for candidate in candidates: + if candidate.pinned: + continue + if total_count <= int(cfg.max_items) and total_bytes <= int( + cfg.max_total_bytes + ): + break + deleted = await self._store.delete(candidate.uid) + if deleted is None: + continue + self._invalidate_global_image_cache(candidate.uid) + await self._vector_store.delete(candidate.uid) + await asyncio.to_thread( + self._delete_file_if_exists, Path(deleted.blob_path) + ) + if deleted.preview_path and deleted.preview_path != deleted.blob_path: + await asyncio.to_thread( + self._delete_file_if_exists, + Path(deleted.preview_path), + ) + total_count -= 1 + total_bytes -= int(deleted.file_size) diff --git a/src/Undefined/memes/search.py b/src/Undefined/memes/search.py new file mode 100644 index 00000000..ba9d5de1 --- /dev/null +++ b/src/Undefined/memes/search.py @@ -0,0 +1,591 @@ +"""MemeService 检索与列表操作。""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from datetime import datetime +import logging +import math +import mimetypes +from pathlib import Path +import re +from typing import TYPE_CHECKING, Any + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from PIL import Image + +from Undefined.attachments import AttachmentRecord +from Undefined.memes.models import ( + MemeRecord, + MemeSearchItem, + normalize_string_list, +) +from Undefined.utils.message_targets import resolve_message_target +from Undefined.utils.coerce import safe_int + +if TYPE_CHECKING: + import threading + + from Undefined.memes.store import MemeStore + from Undefined.memes.vector_store import MemeVectorStore + +logger = logging.getLogger(__name__) + +_IMAGE_EXTENSIONS_BY_MIME = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", +} +_TAG_SPLIT_RE = re.compile(r"[,,\n]+") + + +def _now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def _guess_suffix(path: Path, mime_type: str) -> str: + suffix = path.suffix.lower() + if suffix: + return suffix + guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) + if guessed: + return guessed + mime_guess = mimetypes.guess_extension(mime_type or "") + if mime_guess: + return mime_guess.lower() + return ".bin" + + +def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: + if raw_tags is None: + return [] + if isinstance(raw_tags, str): + parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] + return normalize_string_list(parts) + return normalize_string_list(raw_tags) + + +def _is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = _sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def _sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + # 去重并保持顺序 + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + +@dataclass +class _IngestDigestLockEntry: + lock: asyncio.Lock + users: int = 0 + + +class MemeSearchMixin: + if TYPE_CHECKING: + _cfg: Any + _global_image_cache: dict[str, AttachmentRecord] + _global_image_cache_lock: threading.Lock + _job_queue: Any | None + _retrieval_runtime: Any | None + _store: MemeStore + _vector_store: MemeVectorStore + + def resolve_global_image_sync(self, uid: str) -> AttachmentRecord | None: + normalized_uid = str(uid or "").strip() + if not normalized_uid: + return None + + with self._global_image_cache_lock: + cached = self._global_image_cache.get(normalized_uid) + if cached is not None: + return cached + + record = self._store.get_sync(normalized_uid) + if record is None or not record.enabled or record.status != "ready": + return None + # scope_key 留空:表情库 UID 全局可解析,不受会话 scope 限制 + attachment = AttachmentRecord( + uid=record.uid, + scope_key="", + kind="image", + media_type="image", + display_name=Path(record.blob_path).name, + source_kind="meme_library", + source_ref=Path(record.blob_path).resolve().as_uri(), + local_path=record.blob_path, + mime_type=record.mime_type, + sha256=record.content_sha256, + created_at=record.created_at, + segment_data={"subType": "1"}, + semantic_kind="meme", + description=record.description, + ) + with self._global_image_cache_lock: + self._global_image_cache[normalized_uid] = attachment + return attachment + + async def resolve_global_image(self, uid: str) -> AttachmentRecord | None: + return await asyncio.to_thread(self.resolve_global_image_sync, uid) + + async def get_meme(self, uid: str) -> dict[str, Any] | None: + record = await self._store.get(uid) + if record is None: + return None + sources = await self._store.get_sources(uid) + return { + "record": self.serialize_record(record), + "sources": [source.__dict__ for source in sources], + } + + async def get_record(self, uid: str) -> MemeRecord | None: + return await self._store.get(uid) + + def serialize_record(self, record: MemeRecord) -> dict[str, Any]: + preview_path = record.preview_path or record.blob_path + return { + "uid": record.uid, + "description": record.description, + "auto_description": record.auto_description, + "manual_description": record.manual_description, + "ocr_text": record.ocr_text, + "tags": list(record.tags), + "aliases": list(record.aliases), + "enabled": bool(record.enabled), + "pinned": bool(record.pinned), + "is_animated": bool(record.is_animated), + "mime_type": record.mime_type, + "file_size": record.file_size, + "width": record.width, + "height": record.height, + "blob_url": f"/api/v1/management/memes/{record.uid}/blob", + "preview_url": f"/api/v1/management/memes/{record.uid}/preview", + "use_count": record.use_count, + "last_used_at": record.last_used_at, + "created_at": record.created_at, + "updated_at": record.updated_at, + "status": record.status, + "search_text": record.search_text, + "preview_path": preview_path, + } + + def serialize_list_item(self, record: MemeRecord) -> dict[str, Any]: + return { + "uid": record.uid, + "description": record.description, + "enabled": bool(record.enabled), + "pinned": bool(record.pinned), + "is_animated": bool(record.is_animated), + "use_count": int(record.use_count), + "created_at": record.created_at, + "updated_at": record.updated_at, + "status": record.status, + } + + async def list_memes( + self, + *, + query: str = "", + enabled: bool | None = None, + animated: bool | None = None, + pinned: bool | None = None, + sort: str = "updated_at", + page: int = 1, + page_size: int = 50, + summary: bool = False, + ) -> dict[str, Any]: + items, total = await self._store.list_memes( + query=query, + enabled=enabled, + animated=animated, + pinned=pinned, + sort=sort, + page=page, + page_size=page_size, + ) + return { + "ok": True, + "total": total, + "window_total": total, + "total_exact": True, + "page": max(1, int(page)), + "page_size": max(1, min(200, int(page_size))), + "has_more": max(1, int(page)) * max(1, min(200, int(page_size))) < total, + "sort": str(sort or "updated_at"), + "items": [ + self.serialize_list_item(item) + if summary + else self.serialize_record(item) + for item in items + ], + } + + async def stats(self) -> dict[str, Any]: + stats = await self._store.stats() + cfg = self._cfg() + stats["max_items"] = int(cfg.max_items) + stats["max_total_bytes"] = int(cfg.max_total_bytes) + queue = self._job_queue.snapshot() if self._job_queue is not None else None + stats["queue"] = queue + return stats + + async def search_memes( + self, + query: str, + *, + query_mode: str = "hybrid", + keyword_query: str | None = None, + semantic_query: str | None = None, + top_k: int = 8, + include_disabled: bool = False, + sort: str = "relevance", + ) -> dict[str, Any]: + raw_query = str(query or "").strip() + raw_keyword_query = str(keyword_query or "").strip() + raw_semantic_query = str(semantic_query or "").strip() + normalized_mode = str(query_mode or "hybrid").strip().lower() + if normalized_mode not in {"keyword", "semantic", "hybrid"}: + normalized_mode = "hybrid" + resolved_keyword_query = raw_keyword_query or raw_query + resolved_semantic_query = raw_semantic_query or raw_query + if normalized_mode == "keyword" and not resolved_keyword_query: + return { + "ok": True, + "count": 0, + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "items": [], + } + if normalized_mode == "semantic" and not resolved_semantic_query: + return { + "ok": True, + "count": 0, + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "items": [], + } + if ( + normalized_mode == "hybrid" + and not resolved_keyword_query + and not resolved_semantic_query + ): + return { + "ok": True, + "count": 0, + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "items": [], + } + + cfg = self._cfg() + keyword_hits: list[dict[str, Any]] = [] + if normalized_mode in {"keyword", "hybrid"} and resolved_keyword_query: + keyword_hits = await self._store.search_keyword( + resolved_keyword_query, + limit=max(int(cfg.keyword_top_k), int(top_k)), + include_disabled=include_disabled, + ) + semantic_hits: list[dict[str, Any]] = [] + if normalized_mode in {"semantic", "hybrid"} and resolved_semantic_query: + semantic_hits = await self._vector_store.query( + resolved_semantic_query, + top_k=max(int(cfg.semantic_top_k), int(top_k)), + include_disabled=include_disabled, + ) + merged: dict[str, dict[str, Any]] = {} + + for item in keyword_hits: + record: MemeRecord = item["record"] + merged[record.uid] = { + "record": record, + "keyword_score": float(item.get("keyword_score", 0.0)), + "semantic_score": 0.0, + "rerank_score": None, + } + + missing_semantic_uids = [ + str(item.get("uid") or "").strip() + for item in semantic_hits + if str(item.get("uid") or "").strip() + and str(item.get("uid") or "").strip() not in merged + ] + missing_records = ( + await self._store.get_many(missing_semantic_uids) + if missing_semantic_uids + else {} + ) + + for item in semantic_hits: + uid = str(item.get("uid") or "").strip() + if not uid: + continue + existing = merged.get(uid) + if existing is None: + stored_record = missing_records.get(uid) + if stored_record is None: + continue + existing = { + "record": stored_record, + "keyword_score": 0.0, + "semantic_score": 0.0, + "rerank_score": None, + } + merged[uid] = existing + existing["semantic_score"] = max( + float(existing.get("semantic_score", 0.0)), + float(item.get("semantic_score", 0.0)), + ) + + reranker = ( + self._retrieval_runtime.ensure_reranker() + if self._retrieval_runtime is not None + else None + ) + ranked_candidates = list(merged.values()) + rerank_query = resolved_semantic_query or resolved_keyword_query + if ( + reranker is not None + and normalized_mode in {"semantic", "hybrid"} + and rerank_query + and ranked_candidates + ): + documents = [ + candidate["record"].search_text for candidate in ranked_candidates + ] + reranked = await reranker.rerank( + rerank_query, + documents, + top_n=min(len(documents), int(cfg.rerank_top_k)), + ) + for item in reranked: + try: + index = int(item.get("index")) + except (TypeError, ValueError): + continue + if index < 0 or index >= len(ranked_candidates): + continue + ranked_candidates[index]["rerank_score"] = float( + item.get("relevance_score", 0.0) or 0.0 + ) + + def _final_score(item: dict[str, Any]) -> float: + rerank_score = item.get("rerank_score") + if rerank_score is not None: + return float(rerank_score) + # hybrid 模式:keyword 与 semantic 取较高分 + return max( + float(item.get("keyword_score", 0.0)), + float(item.get("semantic_score", 0.0)), + ) + + normalized_sort = str(sort or "relevance").strip().lower() + if normalized_sort == "use_count": + ranked_candidates.sort( + key=lambda item: ( + item["record"].pinned, + item["record"].use_count, + item["record"].updated_at, + _final_score(item), + ), + reverse=True, + ) + elif normalized_sort == "created_at": + ranked_candidates.sort( + key=lambda item: ( + item["record"].pinned, + item["record"].created_at, + item["record"].updated_at, + _final_score(item), + ), + reverse=True, + ) + elif normalized_sort == "updated_at": + ranked_candidates.sort( + key=lambda item: ( + item["record"].pinned, + item["record"].updated_at, + item["record"].use_count, + _final_score(item), + ), + reverse=True, + ) + else: + ranked_candidates.sort( + key=lambda item: ( + _final_score(item), + item["record"].pinned, + item["record"].use_count, + item["record"].updated_at, + ), + reverse=True, + ) + + items: list[dict[str, Any]] = [] + for candidate in ranked_candidates[: max(1, int(top_k))]: + record = candidate["record"] + search_item = MemeSearchItem( + uid=record.uid, + description=record.description, + tags=list(record.tags), + aliases=list(record.aliases), + enabled=bool(record.enabled), + pinned=bool(record.pinned), + is_animated=record.is_animated, + created_at=record.created_at, + updated_at=record.updated_at, + score=round(_final_score(candidate), 6), + keyword_score=round(float(candidate.get("keyword_score", 0.0)), 6), + semantic_score=round(float(candidate.get("semantic_score", 0.0)), 6), + rerank_score=( + round(float(candidate["rerank_score"]), 6) + if candidate.get("rerank_score") is not None + else None + ), + use_count=record.use_count, + ) + items.append(search_item.__dict__) + return { + "ok": True, + "count": len(items), + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "sort": normalized_sort, + "items": items, + } + + async def send_meme_by_uid(self, uid: str, context: dict[str, Any]) -> str: + record = await self._store.get(uid) + if record is None or not record.enabled or record.status != "ready": + return f"发送失败:未找到可用表情包 UID:{uid}" + + sender = context.get("sender") + if sender is None: + return "发送失败:当前上下文缺少 sender" + + tool_args = { + "target_type": context.get("target_type"), + "target_id": context.get("target_id"), + } + target, target_error = resolve_message_target(tool_args, context) + if target_error or target is None: + return f"发送失败:{target_error or '无法确定目标会话'}" + target_type, target_id = target + + local_path = Path(record.blob_path) + if not local_path.is_file(): + return f"发送失败:表情包文件不存在:{uid}" + file_uri = local_path.resolve().as_uri() + cq_message = f"[CQ:image,file={file_uri},subType=1]" + history_message = f"[图片 uid={record.uid} name={local_path.name}]" + history_attachment = await self.resolve_global_image(uid) + history_attachments = ( + [history_attachment.prompt_ref()] + if history_attachment is not None + else None + ) + + if target_type == "group": + sent_message_id = await sender.send_group_message( + int(target_id), + cq_message, + history_message=history_message, + attachments=history_attachments, + ) + else: + preferred_temp_group_id = safe_int(context.get("group_id")) or None + sent_message_id = await sender.send_private_message( + int(target_id), + cq_message, + preferred_temp_group_id=preferred_temp_group_id, + history_message=history_message, + attachments=history_attachments, + ) + + now = _now_iso() + updated_record = await self._store.increment_use(uid, now) + if updated_record is not None: + await self._vector_store.upsert(updated_record) + if sent_message_id is not None: + return f"表情包已发送(message_id={sent_message_id})" + return "表情包已发送" + + async def blob_path_for_uid( + self, uid: str, *, preview: bool = False + ) -> Path | None: + record = await self._store.get(uid) + if record is None: + return None + path_text = ( + record.preview_path if preview and record.preview_path else record.blob_path + ) + path = Path(path_text) + return path if path.is_file() else None diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 96f80abe..8612cef6 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -1177,7 +1177,6 @@ def _copy() -> None: def _render_preview() -> None: frames = _extract_gif_frames(source_path, n_frames) if mode == "multi": - # multi 模式也需要生成一张预览用于存储/展示,取首帧 frames[0].save(preview_path, format="PNG") else: _compose_grid(frames, preview_path) diff --git a/src/Undefined/onebot/__init__.py b/src/Undefined/onebot/__init__.py new file mode 100644 index 00000000..14c5690a --- /dev/null +++ b/src/Undefined/onebot/__init__.py @@ -0,0 +1,16 @@ +"""OneBot WebSocket 客户端包。""" + +# 统一 re-export,供 handlers 与 sender 直接 from Undefined.onebot import ... +from Undefined.onebot.client import OneBotClient +from Undefined.onebot.message import ( + get_message_content, + get_message_sender_id, + parse_message_time, +) + +__all__ = [ + "OneBotClient", + "parse_message_time", + "get_message_sender_id", + "get_message_content", +] diff --git a/src/Undefined/onebot/client.py b/src/Undefined/onebot/client.py new file mode 100644 index 00000000..ce293428 --- /dev/null +++ b/src/Undefined/onebot/client.py @@ -0,0 +1,873 @@ +"""OneBot v11 WebSocket 客户端实现。""" + +import asyncio +import json +import logging +import time +from typing import Any, Callable, Coroutine + +import websockets +from websockets.asyncio.client import ClientConnection + +from Undefined.context import RequestContext +from Undefined.utils.logging import log_debug_json, redact_string, sanitize_data + +logger = logging.getLogger(__name__) + + +def _mark_message_sent_this_turn() -> None: + ctx = RequestContext.current() + if ctx is None: + return + # 标记本轮已向用户发出消息,供 end 工具判断是否可静默结束。 + ctx.set_resource("message_sent_this_turn", True) + + +# OneBot v11 WebSocket 客户端 +class OneBotClient: + """OneBot v11 WebSocket 客户端""" + + def __init__(self, ws_url: str, token: str = ""): + self.ws_url = ws_url + self.token = token + self.ws: ClientConnection | None = None + self._message_id = 0 + self._pending_responses: dict[str, asyncio.Future[dict[str, Any]]] = {} + self._message_handler: ( + Callable[[dict[str, Any]], Coroutine[Any, Any, None]] | None + ) = None + self._running = False + + def set_message_handler( + self, handler: Callable[[dict[str, Any]], Coroutine[Any, Any, None]] + ) -> None: + """设置消息处理器""" + self._message_handler = handler + + def connection_status(self) -> dict[str, Any]: + """返回连接状态快照。""" + ws = self.ws + ws_exists = ws is not None + # websockets v13+ ClientConnection 没有 .closed 属性, + # 用 close_code 判断:连接关闭后 close_code 为 int,活跃时为 None + ws_closed = (ws.close_code is not None) if ws is not None else True + connected = ws_exists and (not ws_closed) and self._running + return { + "connected": connected, + "running": self._running, + "ws_exists": ws_exists, + "ws_closed": ws_closed, + "ws_url": self.ws_url, + } + + async def connect(self) -> None: + """连接到 OneBot WebSocket""" + url = self.ws_url + if self.token: + separator = "&" if "?" in url else "?" + url = f"{url}{separator}access_token={self.token}" + + safe_ws_url = redact_string(self.ws_url) + logger.info( + f"[bold cyan][WebSocket][/bold cyan] 正在连接到 [blue]{safe_ws_url}[/blue]..." + ) + + # 同时在请求头中传递 token(兼容不同实现) + extra_headers = {} + if self.token: + extra_headers["Authorization"] = f"Bearer {self.token}" + + try: + self.ws = await websockets.connect( + url, + ping_interval=20, + ping_timeout=480, + max_size=100 * 1024 * 1024, # 100MB,支持大量历史消息 + additional_headers=extra_headers if extra_headers else None, + ) + logger.info("[bold green][WebSocket][/bold green] 连接成功") + except Exception as e: + logger.error(f"[WebSocket] 连接失败: {e}") + raise + + async def disconnect(self) -> None: + """断开连接""" + self._running = False + if self.ws: + logger.info("[WebSocket] 正在主动断开连接...") + await self.ws.close() + self.ws = None + logger.info("[WebSocket] 连接已断开") + + async def _call_api( + self, + action: str, + params: dict[str, Any] | None = None, + *, + suppress_error_retcodes: set[int] | None = None, + ) -> dict[str, Any]: + """调用 OneBot API""" + if not self.ws: + raise RuntimeError("WebSocket 未连接") + + self._message_id += 1 + echo = str(self._message_id) # 使用字符串类型 + + request = { + "action": action, + "params": params or {}, + "echo": echo, + } + + safe_params = sanitize_data(params or {}) + logger.debug( + f"[bold yellow][API请求][/bold yellow] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 参数: {safe_params}" + ) + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[OneBot请求体]", request) + + future: asyncio.Future[dict[str, Any]] = asyncio.Future() + self._pending_responses[echo] = future + + start_time = time.perf_counter() + + try: + await self.ws.send(json.dumps(request)) + # 等待响应,超时 8 分钟 + response = await asyncio.wait_for(future, timeout=480.0) + duration = time.perf_counter() - start_time + + status = response.get("status") + if status == "failed": + retcode = response.get("retcode", -1) + msg = response.get("message", "未知错误") + if suppress_error_retcodes and retcode in suppress_error_retcodes: + logger.warning( + f"[bold yellow][API预期失败][/bold yellow] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 耗时=[magenta]{duration:.2f}s[/magenta] | retcode=[yellow]{retcode}[/yellow] | message={msg}" + ) + else: + logger.error( + f"[bold red][API失败][/bold red] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 耗时=[magenta]{duration:.2f}s[/magenta] | retcode=[red]{retcode}[/red] | message={msg}" + ) + raise RuntimeError(f"API 调用失败: {msg} (retcode={retcode})") + + logger.info( + f"[bold green][API成功][/bold green] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 耗时=[magenta]{duration:.2f}s[/magenta]" + ) + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[OneBot响应体]", response) + return response + except asyncio.TimeoutError: + duration = time.perf_counter() - start_time + logger.error(f"[API超时] {action} (ID={echo}) | 耗时={duration:.2f}s") + raise + finally: + self._pending_responses.pop(echo, None) + + async def send_group_message( + self, + group_id: int, + message: str | list[dict[str, Any]], + *, + mark_sent: bool = True, + ) -> dict[str, Any]: + """发送群消息""" + result = await self._call_api( + "send_group_msg", + { + "group_id": group_id, + "message": message, + }, + ) + if mark_sent: + _mark_message_sent_this_turn() + return result + + async def send_private_message( + self, + user_id: int, + message: str | list[dict[str, Any]], + *, + group_id: int | None = None, + mark_sent: bool = True, + ) -> dict[str, Any]: + """发送私聊消息 + + 参数: + user_id: 用户 QQ 号 + message: 消息内容 + group_id: 共享群号;传入时通过该群的临时会话发送 + mark_sent: 是否标记本轮已发送(用于 end 工具判定) + """ + params: dict[str, Any] = { + "user_id": user_id, + "message": message, + } + if group_id is not None: + params["group_id"] = group_id + + result = await self._call_api( + "send_private_msg", + params, + ) + if mark_sent: + _mark_message_sent_this_turn() + return result + + async def get_group_msg_history( + self, + group_id: int, + message_seq: int | None = None, + count: int = 500, + ) -> list[dict[str, Any]]: + """获取群消息历史 + + 参数: + group_id: 群号 + message_seq: 起始消息序号,None 表示从最新消息开始 + count: 获取的消息数量 + + 返回: + 消息列表 + """ + params: dict[str, Any] = { + "group_id": group_id, + "count": count, + } + if message_seq is not None: + params["message_seq"] = message_seq + + result = await self._call_api("get_group_msg_history", params) + + if result is None: + logger.warning("get_group_msg_history 返回 None") + return [] + + data = result.get("data") + if data is None: + logger.warning(f"get_group_msg_history 响应无 data 字段: {result}") + return [] + + messages: list[dict[str, Any]] = data.get("messages", []) + logger.debug(f"获取到 {len(messages)} 条历史消息") + return messages + + async def get_image(self, file: str) -> str: + """获取图片信息 + + 参数: + file: 图片文件名或 URL + + 返回: + 图片的本地路径或 URL + """ + result = await self._call_api("get_image", {"file": file}) + data: dict[str, str] = result.get("data", {}) + url: str = data.get("url", "") or data.get("file", "") + return url + + async def get_group_info(self, group_id: int) -> dict[str, Any] | None: + """获取群信息 + + 参数: + group_id: 群号 + + 返回: + 群信息字典,包含 group_name 等字段 + """ + try: + result = await self._call_api("get_group_info", {"group_id": group_id}) + data: dict[str, Any] = result.get("data", {}) + return data + except Exception as e: + logger.error(f"获取群信息失败: {e}") + return None + + async def get_stranger_info(self, user_id: int) -> dict[str, Any] | None: + """获取陌生人信息 + + 参数: + user_id: 用户QQ号 + + 返回: + 用户信息字典,包含 nickname 等字段 + """ + try: + result = await self._call_api("get_stranger_info", {"user_id": user_id}) + data: dict[str, Any] = result.get("data", {}) + return data + except Exception as e: + logger.error(f"获取陌生人信息失败: {e}") + return None + + async def get_group_member_info( + self, group_id: int, user_id: int, no_cache: bool = False + ) -> dict[str, Any] | None: + """获取群成员信息 + + 参数: + group_id: 群号 + user_id: 群成员QQ号 + no_cache: 是否不使用缓存(默认 false) + + 返回: + 群成员信息字典,包含群昵称、QQ昵称、加群时间、等级、最后发言时间等字段 + """ + try: + result = await self._call_api( + "get_group_member_info", + {"group_id": group_id, "user_id": user_id, "no_cache": no_cache}, + ) + data: dict[str, Any] = result.get("data", {}) + return data + except Exception as e: + logger.error(f"获取群成员信息失败: {e}") + return None + + async def get_group_member_list(self, group_id: int) -> list[dict[str, Any]]: + """获取群成员列表 + + 参数: + group_id: 群号 + + 返回: + 群成员信息列表 + """ + try: + result = await self._call_api( + "get_group_member_list", {"group_id": group_id} + ) + data: list[dict[str, Any]] = result.get("data", []) + return data + except Exception as e: + logger.error(f"获取群成员列表失败: {e}") + return [] + + async def get_friend_list(self) -> list[dict[str, Any]]: + """获取好友列表 + + 返回: + 好友信息列表,每个好友包含: + - user_id: QQ号 + - nickname: QQ昵称 + - remark: 备注名 + """ + try: + result = await self._call_api("get_friend_list") + data: list[dict[str, Any]] = result.get("data", []) + return data + except Exception as e: + logger.error(f"获取好友列表失败: {e}") + return [] + + async def get_group_list(self) -> list[dict[str, Any]]: + """获取群列表 + + 返回: + 群信息列表,每个群包含: + - group_id: 群号 + - group_name: 群名称 + - member_count: 成员数 + - max_member_count: 最大成员数 + """ + try: + result = await self._call_api("get_group_list") + data: list[dict[str, Any]] = result.get("data", []) + return data + except Exception as e: + logger.error(f"获取群列表失败: {e}") + return [] + + async def get_forward_msg(self, id: str) -> list[dict[str, Any]]: + """获取合并转发消息详情 + + 参数: + id: 合并转发 ID + + 返回: + 消息节点列表 + """ + try: + result = await self._call_api( + "get_forward_msg", + {"message_id": id}, + suppress_error_retcodes={1200}, + ) + data = result.get("data", {}) + # data 可能是字典(包含 messages)或列表(直接是 nodes) + if isinstance(data, dict): + messages: list[dict[str, Any]] = data.get("messages", []) + return messages + elif isinstance(data, list): + nodes: list[dict[str, Any]] = data + return nodes + return [] + except Exception as e: + error_text = str(e) + if "retcode=1200" in error_text: + logger.debug( + "合并转发消息不可获取(可能过期或内层): id=%s err=%s", id, e + ) + return [] + logger.error(f"获取合并转发消息失败: {e}") + return [] + + async def get_msg(self, message_id: int) -> dict[str, Any] | None: + """获取单条消息详情 + + 参数: + message_id: 消息 ID + + 返回: + 消息详情字典 + """ + try: + result = await self._call_api("get_msg", {"message_id": message_id}) + return result.get("data") + except Exception as e: + logger.error(f"获取消息详情失败: {e}") + return None + + async def send_forward_msg( + self, group_id: int, messages: list[dict[str, Any]] + ) -> dict[str, Any]: + """发送合并转发消息到群聊 + + 参数: + group_id: 群号 + messages: 消息节点列表,每个节点格式为: + { + "type": "node", + "data": { + "name": "发送者昵称", + "uin": "发送者QQ号", + "content": "消息内容(字符串或消息段数组)", + "time": "时间戳(可选)" + } + } + + 返回: + API 响应 + """ + return await self._call_api( + "send_forward_msg", {"group_id": group_id, "messages": messages} + ) + + async def send_private_forward_msg( + self, user_id: int, messages: list[dict[str, Any]] + ) -> dict[str, Any]: + """发送合并转发消息到私聊。""" + return await self._call_api( + "send_private_forward_msg", + {"user_id": user_id, "messages": messages}, + ) + + async def send_like(self, user_id: int, times: int = 1) -> dict[str, Any]: + """给用户点赞 + + 参数: + user_id: 对方 QQ 号 + times: 赞的次数(默认1次) + + 返回: + API 响应 + """ + return await self._call_api("send_like", {"user_id": user_id, "times": times}) + + async def fetch_emoji_like(self, message_id: int) -> dict[str, Any] | list[Any]: + """获取消息已设置的表情反应信息(扩展接口)。 + + 参数: + message_id: 消息 ID + + 返回: + data 字段内容(字典或列表),异常时抛出 RuntimeError + """ + result = await self._call_api("fetch_emoji_like", {"message_id": message_id}) + data = result.get("data") + if isinstance(data, (dict, list)): + return data + return {} + + async def set_msg_emoji_like( + self, + message_id: int, + emoji_id: int, + *, + set_like: bool = True, + mark_sent: bool = True, + ) -> dict[str, Any]: + """给指定消息添加/取消表情反应(扩展接口)。 + + 参数: + message_id: 目标消息 ID + emoji_id: 表情 ID + set_like: True=添加反应,False=取消反应 + mark_sent: 是否标记本轮已发送(用于 end 工具判定) + + 返回: + API 响应 + """ + if set_like: + try: + result = await self._call_api( + "set_msg_emoji_like", + {"message_id": message_id, "emoji_id": emoji_id}, + ) + except RuntimeError: + logger.warning( + "[消息表情] set_msg_emoji_like 默认参数失败,尝试 set=true 回退: msg=%s emoji=%s", + message_id, + emoji_id, + ) + result = await self._call_api( + "set_msg_emoji_like", + {"message_id": message_id, "emoji_id": emoji_id, "set": True}, + ) + else: + # 取消反应可能依赖实现方扩展参数,默认采用 set=false。 + result = await self._call_api( + "set_msg_emoji_like", + {"message_id": message_id, "emoji_id": emoji_id, "set": False}, + ) + + if mark_sent: + _mark_message_sent_this_turn() + return result + + async def send_group_poke( + self, + group_id: int, + user_id: int, + *, + mark_sent: bool = True, + ) -> dict[str, Any]: + """在群聊中拍一拍指定成员。 + + 参数: + group_id: 群号 + user_id: 被拍一拍的用户 QQ 号 + mark_sent: 是否标记本轮已发送(用于 end 工具判定) + + 返回: + API 响应 + """ + try: + result = await self._call_api( + "group_poke", {"group_id": group_id, "user_id": user_id} + ) + except RuntimeError: + logger.warning( + "[拍一拍] group_poke 失败,尝试 send_poke 回退: group=%s user=%s", + group_id, + user_id, + ) + result = await self._call_api( + "send_poke", + { + "group_id": group_id, + "user_id": user_id, + "target_id": user_id, + }, + ) + + if mark_sent: + _mark_message_sent_this_turn() + return result + + async def send_private_poke( + self, + user_id: int, + *, + mark_sent: bool = True, + ) -> dict[str, Any]: + """在私聊中拍一拍指定用户。 + + 参数: + user_id: 被拍一拍的用户 QQ 号 + mark_sent: 是否标记本轮已发送(用于 end 工具判定) + + 返回: + API 响应 + """ + try: + result = await self._call_api("friend_poke", {"user_id": user_id}) + except RuntimeError: + logger.warning( + "[拍一拍] friend_poke 失败,尝试 send_poke 回退: user=%s", + user_id, + ) + result = await self._call_api( + "send_poke", + { + "user_id": user_id, + "target_id": user_id, + }, + ) + + if mark_sent: + _mark_message_sent_this_turn() + return result + + async def upload_group_file( + self, + group_id: int, + file_path: str, + name: str | None = None, + ) -> dict[str, Any]: + """上传文件到群聊 + + 参数: + group_id: 群号 + file_path: 本地文件绝对路径 + name: 文件名(可选,默认使用原文件名) + """ + from pathlib import Path as _Path + + file_name = name or _Path(file_path).name + file_uri = _Path(file_path).resolve().as_uri() + try: + return await self._call_api( + "upload_group_file", + { + "group_id": group_id, + "file": file_uri, + "name": file_name, + }, + ) + except RuntimeError: + # 回退:尝试用文件消息段发送 + logger.warning( + "[文件上传] upload_group_file 失败,尝试文件消息段回退: group=%s", + group_id, + ) + return await self.send_group_message( + group_id, + [ + { + "type": "file", + "data": {"file": file_uri, "name": file_name}, + } + ], + ) + + async def upload_private_file( + self, + user_id: int, + file_path: str, + name: str | None = None, + ) -> dict[str, Any]: + """上传文件到私聊 + + 参数: + user_id: 用户 QQ 号 + file_path: 本地文件绝对路径 + name: 文件名(可选,默认使用原文件名) + """ + from pathlib import Path as _Path + + file_name = name or _Path(file_path).name + file_uri = _Path(file_path).resolve().as_uri() + try: + return await self._call_api( + "upload_private_file", + { + "user_id": user_id, + "file": file_uri, + "name": file_name, + }, + ) + except RuntimeError: + logger.warning( + "[文件上传] upload_private_file 失败,尝试文件消息段回退: user=%s", + user_id, + ) + return await self.send_private_message( + user_id, + [ + { + "type": "file", + "data": {"file": file_uri, "name": file_name}, + } + ], + ) + + async def send_group_sign(self, group_id: int) -> dict[str, Any]: + """执行群打卡 + + 参数: + group_id: 群号 + + 返回: + API 响应 + """ + return await self._call_api("send_group_sign", {"group_id": group_id}) + + async def _get_group_notices(self, group_id: int) -> list[dict[str, Any]]: + """获取群公告列表(非标准 API,依赖具体实现) + + 参数: + group_id: 群号 + + 返回: + 公告列表 + """ + try: + result = await self._call_api("_get_group_notice", {"group_id": group_id}) + data = result.get("data") + if isinstance(data, list): + return data + elif isinstance(data, dict): + # 尝试获取常见的列表字段 + notices = data.get("notices") + if notices is None: + notices = data.get("list") + if isinstance(notices, list): + return notices + return [] + except Exception as e: + logger.error(f"获取群公告失败: {e}") + return [] + + async def run(self) -> None: + """运行消息接收循环""" + if not self.ws: + raise RuntimeError("WebSocket 未连接") + + self._running = True + self._tasks: set[asyncio.Task[None]] = set() + logger.info("[WebSocket] 消息接收循环已启动") + + try: + while self._running: + raw_message = "" + try: + message_data = await self.ws.recv() + raw_message = ( + message_data.decode("utf-8") + if isinstance(message_data, bytes) + else message_data + ) + data = json.loads(raw_message) + # 处理消息(不阻塞接收循环) + await self._dispatch_message(data) + except json.JSONDecodeError as e: + logger.error( + f"[WebSocket] 无法解析 JSON 消息: {raw_message!r}, 错误: {e}" + ) + except websockets.ConnectionClosed: + logger.warning("[WebSocket] 连接已关闭,接收循环结束") + break + except Exception as e: + logger.exception(f"[WebSocket] 接收消息时发生异常: {e}") + finally: + self._running = False + # 等待所有后台任务完成 + if self._tasks: + logger.debug( + f"[WebSocket] 正在等待 {len(self._tasks)} 个异步任务完成..." + ) + await asyncio.gather(*self._tasks, return_exceptions=True) + logger.info("[WebSocket] 接收循环已停止") + + async def _dispatch_message(self, data: dict[str, Any]) -> None: + """分发消息(API响应同步处理,事件异步处理)""" + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[WebSocket消息]", data) + # 检查是否是 API 响应(需要立即处理) + echo = data.get("echo") + if echo is not None: + echo_str = str(echo) + if echo_str in self._pending_responses: + logger.debug(f"收到 API 响应: echo={echo_str}") + self._pending_responses[echo_str].set_result(data) + return + else: + logger.debug( + f"收到未知 echo 响应: {echo_str}, 待处理: {list(self._pending_responses.keys())}" + ) + return + + # 事件类型的消息异步处理,不阻塞接收循环 + post_type = data.get("post_type") + if post_type == "message": + msg_type = data.get("message_type", "unknown") + sender = data.get("sender", {}).get("user_id", "unknown") + logger.info( + f"[bold blue][收到消息][/bold blue] type=[yellow]{msg_type}[/yellow], sender=[blue]{sender}[/blue]" + ) + if self._message_handler: + # 创建后台任务处理消息 + task = asyncio.create_task(self._safe_handle_message(data)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + elif post_type == "notice": + notice_type = data.get("notice_type", "") + sub_type = data.get("sub_type", "") + # 处理拍一拍事件 + if notice_type == "notify" and sub_type == "poke": + target_id = data.get("target_id", 0) + sender_id = data.get("user_id", 0) + group_id = data.get("group_id", 0) + logger.info( + f"[bold magenta][收到拍一拍][/bold magenta] sender=[blue]{sender_id}[/blue], target=[blue]{target_id}[/blue], group=[blue]{group_id}[/blue]" + ) + if self._message_handler: + # 将 poke 事件转换为类似消息的格式,方便 handler 处理 + poke_event = { + "post_type": "notice", + "notice_type": "poke", + "group_id": group_id, + "user_id": sender_id, + "sender": {"user_id": sender_id}, + "target_id": target_id, + "message": [], # 空消息 + } + task = asyncio.create_task(self._safe_handle_message(poke_event)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + else: + logger.debug( + f"收到通知事件: notice_type={notice_type}, sub_type={sub_type}" + ) + elif post_type: + logger.debug( + f"收到事件: post_type={post_type}, meta={data.get('meta_event_type', '')}" + ) + + async def _safe_handle_message(self, data: dict[str, Any]) -> None: + """安全地处理消息(捕获异常)""" + try: + if self._message_handler: + await self._message_handler(data) + except Exception as e: + logger.exception(f"处理消息时出错: {e}") + + async def run_with_reconnect(self, reconnect_interval: float = 5.0) -> None: + """带自动重连的运行""" + self._should_stop = False + reconnect_count = 0 + + while not self._should_stop: + try: + if reconnect_count > 0: + logger.info(f"[WebSocket] 正在尝试第 {reconnect_count} 次重连...") + await self.connect() + reconnect_count = 0 # 连接成功重置计数 + await self.run() + except websockets.ConnectionClosed as e: + logger.warning(f"[WebSocket] 连接已断开: {e}") + except Exception as e: + logger.error(f"[WebSocket] 发生错误: {e}") + + if self._should_stop: + break + + reconnect_count += 1 + logger.info(f"{reconnect_interval} 秒后尝试重连...") + await asyncio.sleep(reconnect_interval) + + def stop(self) -> None: + """停止运行""" + self._should_stop = True + self._running = False diff --git a/src/Undefined/onebot/message.py b/src/Undefined/onebot/message.py new file mode 100644 index 00000000..7fe13dc5 --- /dev/null +++ b/src/Undefined/onebot/message.py @@ -0,0 +1,58 @@ +"""OneBot 消息解析辅助函数。""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +logger = logging.getLogger(__name__) + + +def parse_message_time(message: dict[str, Any]) -> datetime: + """解析消息时间。 + + 兼容秒级/毫秒级时间戳与字符串输入,异常时回退到当前时间。 + """ + + raw_timestamp = message.get("time") + + if raw_timestamp is None: + return datetime.now() + + try: + timestamp = float(raw_timestamp) + except (TypeError, ValueError): + logger.debug("[OneBot] 无法解析消息时间戳,使用当前时间: %s", raw_timestamp) + return datetime.now() + + # 13 位毫秒时间戳自动降为秒。 + if timestamp > 1_000_000_000_000: + timestamp /= 1000.0 + + if timestamp <= 0: + return datetime.now() + + try: + return datetime.fromtimestamp(timestamp) + except (OSError, OverflowError, ValueError): + # 越界或非法 epoch 回退当前时间,避免整条消息解析失败。 + logger.debug("[OneBot] 时间戳越界,使用当前时间: %s", raw_timestamp) + return datetime.now() + + +def get_message_sender_id(message: dict[str, Any]) -> int: + """获取消息发送者 QQ 号""" + sender: dict[str, Any] = message.get("sender", {}) + user_id: int = sender.get("user_id", 0) + return user_id + + +def get_message_content(message: dict[str, Any]) -> list[dict[str, Any]]: + """获取消息内容(CQ 码数组格式)""" + msg = message.get("message", []) + if isinstance(msg, str): + # 如果是字符串格式,转换为数组格式 + return [{"type": "text", "data": {"text": msg}}] + content: list[dict[str, Any]] = msg + return content diff --git a/src/Undefined/py.typed b/src/Undefined/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index f4efaf2f..70f46b66 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -305,7 +305,6 @@ async def _execute_auto_reply(self, request: dict[str, Any]) -> None: # 用于向 batcher 注册 inflight 任务(仅当本请求源自合并桶时生效) batcher_scope: str | None = make_scope(group_id=group_id) if group_id else None - # 创建请求上下文 async with RequestContext( request_type="group", group_id=group_id, @@ -347,7 +346,6 @@ async def send_img_cb(tid: int, mtype: str, path: str) -> None: async def send_like_cb(uid: int, times: int = 1) -> None: await self.onebot.send_like(uid, times) - # 存储资源到上下文 ai_client = self.ai memory_storage = self.ai.memory_storage runtime_config = self.ai.runtime_config @@ -442,7 +440,6 @@ async def _execute_private_reply(self, request: dict[str, Any]) -> None: trigger_message_id = request.get("trigger_message_id") batcher_scope: str | None = make_scope(user_id=user_id) - # 创建请求上下文 async with RequestContext( request_type="private", user_id=user_id, @@ -479,7 +476,6 @@ async def send_private_cb( ) -> None: await self.sender.send_private_message(uid, msg, reply_to=reply_to) - # 存储资源到上下文 ai_client = self.ai memory_storage = self.ai.memory_storage runtime_config = self.ai.runtime_config @@ -591,7 +587,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: logger.warning("[统计分析] 缺少 request_id,群=%s", group_id) return try: - # 加载提示词模板 prompt_template = _STATS_ANALYSIS_FALLBACK_PROMPT try: loaded_prompt = read_text_resource(_STATS_ANALYSIS_PROMPT_PATH).strip() @@ -615,7 +610,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: data_summary=safe_data_summary ) - # 调用 AI 进行分析 messages = [ {"role": "system", "content": "你是一位专业的数据分析师。"}, {"role": "user", "content": full_prompt}, @@ -629,7 +623,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: queue_lane=request.get("_queue_lane"), ) - # 提取分析结果 choices = result.get("choices", [{}]) if choices: content = choices[0].get("message", {}).get("content", "") @@ -647,7 +640,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: request_id, ) - # 设置分析结果(通知等待的 _handle_stats 方法) if self.command_dispatcher: self.command_dispatcher.set_stats_analysis_result( group_id, request_id, analysis @@ -655,7 +647,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: except Exception as exc: logger.exception("[统计分析] AI 分析失败: %s", exc) - # 出错时也通知等待,但返回空字符串 if self.command_dispatcher: self.command_dispatcher.set_stats_analysis_result( group_id, request_id, "" @@ -708,7 +699,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None return try: - # 获取提示词 from Undefined.skills.agents.intro_generator import AgentIntroGenerator agent_intro_generator = self.ai._agent_intro_generator @@ -721,7 +711,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None user_prompt, ) = await agent_intro_generator.get_intro_prompt_and_context(agent_name) - # 调用 AI 生成 messages = [ {"role": "system", "content": system_prompt or "你是一位智能助手。"}, {"role": "user", "content": user_prompt}, @@ -735,7 +724,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None queue_lane=request.get("_queue_lane"), ) - # 提取结果 choices = result.get("choices", [{}]) if choices: content = choices[0].get("message", {}).get("content", "") @@ -750,7 +738,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None request_id, ) - # 通知结果 agent_intro_generator.set_intro_generation_result( request_id, generated_content if generated_content else None ) @@ -761,7 +748,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None agent_name, exc, ) - # 出错时也通知,返回 None try: agent_intro_generator = self.ai._agent_intro_generator agent_intro_generator.set_intro_generation_result(request_id, None) diff --git a/src/Undefined/services/commands/bugfix.py b/src/Undefined/services/commands/bugfix.py new file mode 100644 index 00000000..f4689a3e --- /dev/null +++ b/src/Undefined/services/commands/bugfix.py @@ -0,0 +1,189 @@ +"""Bug 修复归档命令(/bugfix)的实现逻辑。 + +本模块提供 ``BugfixCommandMixin``,供 ``CommandDispatcher`` 通过多重继承组合。 +通过回溯群聊记录并调用 AI 摘要,自动生成 FAQ 归档条目。 +""" + +from __future__ import annotations + +# 斜杠命令:目录扫描注册、权限/限流/子命令路由 + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from Undefined.faq import extract_faq_title +from Undefined.onebot import ( + get_message_content, + get_message_sender_id, + parse_message_time, +) + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.faq import FAQStorage + from Undefined.onebot import OneBotClient + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +class BugfixCommandMixin: + """``/bugfix`` 命令相关方法集合,作为 ``CommandDispatcher`` 的 mixin 使用。""" + + if TYPE_CHECKING: + ai: Any + config: Config + faq_storage: FAQStorage + onebot: OneBotClient + sender: MessageSender + + async def _handle_bugfix( + self, group_id: int, admin_id: int, args: list[str] + ) -> None: + """处理 ``/bugfix`` 命令,通过分析聊天记录自动生成 FAQ 归档。""" + parsed = self._parse_bugfix_args(args) + if isinstance(parsed, str): + await self.sender.send_group_message(group_id, parsed) + return + + target_qqs, start_date, end_date, start_str, end_str = parsed + + await self.sender.send_group_message( + group_id, "🔍 正在获取对话记录进行回溯分析..." + ) + + try: + messages = await self._fetch_messages( + group_id, target_qqs, start_date, end_date + ) + if not messages: + await self.sender.send_group_message( + group_id, "❌ 未找到符合条件的对话记录。" + ) + return + + processed_text = await self._process_messages(messages) + summary = await self._obtain_bugfix_summary(group_id, processed_text) + + title = extract_faq_title(summary) + if not title or title == "未命名问题": + title = await self.ai.generate_title(summary) + + faq = await self.faq_storage.create( + group_id=group_id, + target_qq=target_qqs[0], + start_time=start_str, + end_time=end_str, + title=title, + content=summary, + ) + + result_msg = f"✅ Bug 修复分析完成!\n\n📌 FAQ ID: {faq.id}\n📋 标题: {title}\n\n{summary}" + await self.sender.send_group_message(group_id, result_msg) + + except Exception as e: + error_id = uuid4().hex[:8] + logger.exception("Bugfix 失败: error_id=%s err=%s", error_id, e) + await self.sender.send_group_message( + group_id, + f"❌ Bug 修复分析失败,请稍后重试(错误码: {error_id})", + ) + + def _parse_bugfix_args( + self, args: list[str] + ) -> tuple[list[int], datetime, datetime, str, str] | str: + """解析 ``/bugfix`` 命令的参数。""" + if len(args) < 3: + return ( + "❌ 用法: /bugfix [QQ号|@用户2] ... <开始时间> <结束时间>\n" + "时间格式: YYYY/MM/DD/HH:MM,结束时间可用 now\n" + "示例: /bugfix 123456 2024/12/01/09:00 now" + ) + + try: + target_qqs = [int(arg) for arg in args[:-2]] + start_str, end_str_raw = args[-2], args[-1] + start_date = datetime.strptime(start_str, "%Y/%m/%d/%H:%M") + + if end_str_raw.lower() == "now": + end_date, end_str = datetime.now(), "now" + else: + end_date, end_str = ( + datetime.strptime(end_str_raw, "%Y/%m/%d/%H:%M"), + end_str_raw, + ) + + return target_qqs, start_date, end_date, start_str, end_str + except ValueError: + return "❌ 参数格式错误:QQ号应为数字或 @ 提及,时间格式应为 YYYY/MM/DD/HH:MM。" + + async def _obtain_bugfix_summary(self, group_id: int, processed_text: str) -> str: + """利用 AI 生成聊天记录的 Bug 分析摘要。""" + total_tokens = self.ai.count_tokens(processed_text) + max_tokens = self.config.chat_model.max_tokens + + if total_tokens <= max_tokens: + return str(await self.ai.summarize_chat(processed_text)) + + await self.sender.send_group_message( + group_id, f"📊 消息较长({total_tokens} tokens),正在分段处理..." + ) + chunks = self.ai.split_messages_by_tokens(processed_text, max_tokens) + summaries = [await self.ai.summarize_chat(chunk) for chunk in chunks] + return str(await self.ai.merge_summaries(summaries)) + + async def _fetch_messages( + self, + group_id: int, + target_qqs: list[int], + start_date: datetime, + end_date: datetime, + ) -> list[dict[str, Any]]: + """从 OneBot 拉取指定时间范围内目标用户的消息。""" + batch = await self.onebot.get_group_msg_history(group_id, count=2500) + if not batch: + return [] + target_qqs_set = set(target_qqs) + results = [] + for msg in batch: + msg_time = parse_message_time(msg) + if ( + start_date <= msg_time <= end_date + and get_message_sender_id(msg) in target_qqs_set + ): + # 后台循环处理队列 + results.append(msg) + return sorted(results, key=lambda m: m.get("time", 0)) + + # 后台循环处理队列 + async def _process_messages(self, messages: list[dict[str, Any]]) -> str: + """将原始 OneBot 消息序列化为 AI 可读的纯文本。""" + lines = [] + for msg in messages: + sender_id = get_message_sender_id(msg) + msg_time = parse_message_time(msg).strftime("%Y-%m-%d %H:%M:%S") + content = get_message_content(msg) + text_parts = [] + for segment in content: + seg_type, seg_data = segment.get("type", ""), segment.get("data", {}) + if seg_type == "text": + text_parts.append(seg_data.get("text", "")) + elif seg_type == "image": + file = seg_data.get("file", "") or seg_data.get("url", "") + if file: + try: + url = await self.onebot.get_image(file) + if url: + res = await self.ai.analyze_multimodal(url, "image") + text_parts.append( + f"[pic]{res.get('description', '')}{res.get('ocr_text', '')}[/pic]" + ) + except Exception: + text_parts.append("[pic]图片处理失败[/pic]") + elif seg_type == "at": + text_parts.append(f"@{seg_data.get('qq', '')}") + if text_parts: + lines.append(f"[{msg_time}] {sender_id}: {''.join(text_parts)}") + return "\n".join(lines) diff --git a/src/Undefined/services/commands/stats.py b/src/Undefined/services/commands/stats.py new file mode 100644 index 00000000..a992f760 --- /dev/null +++ b/src/Undefined/services/commands/stats.py @@ -0,0 +1,822 @@ +"""Token 使用统计命令(/stats)的实现逻辑。 + +本模块提供 ``StatsCommandMixin``,供 ``CommandDispatcher`` 通过多重继承组合。 +群聊与私聊统计、图表生成、AI 分析队列交互均在此实现。 +""" + +from __future__ import annotations + +# 斜杠命令:目录扫描注册、权限/限流/子命令路由 + +import asyncio +import base64 +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from Undefined.ai.queue_budget import ( + compute_queued_llm_timeout_seconds, + resolve_effective_retry_count, +) +from Undefined.token_usage_storage import TokenUsageStorage + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.onebot import OneBotClient + from Undefined.utils.history import MessageHistoryManager + from Undefined.utils.sender import MessageSender + +# 尝试导入 matplotlib(可选依赖) +plt: Any +try: + import matplotlib.pyplot as plt + + _MATPLOTLIB_AVAILABLE = True +except ImportError: + plt = None + _MATPLOTLIB_AVAILABLE = False + +logger = logging.getLogger(__name__) + +_STATS_DEFAULT_DAYS = 7 +_STATS_MIN_DAYS = 1 +_STATS_MAX_DAYS = 365 +_STATS_MODEL_TOP_N = 8 +_STATS_CALL_TYPE_TOP_N = 12 +_STATS_DATA_SUMMARY_MAX_CHARS = 12000 +_STATS_AI_FLAGS = {"--ai", "-a"} +_STATS_TIME_RANGE_RE = re.compile(r"^\d+[dwm]?$", re.IGNORECASE) + + +class StatsCommandMixin: + """``/stats`` 命令相关方法集合,作为 ``CommandDispatcher`` 的 mixin 使用。""" + + if TYPE_CHECKING: + ai: Any + config: Config + history_manager: MessageHistoryManager + onebot: OneBotClient + queue_manager: Any + sender: MessageSender + + _token_usage_storage: TokenUsageStorage + _stats_analysis_results: dict[str, str] + _stats_analysis_events: dict[str, asyncio.Event] + + def _parse_time_range(self, time_str: str) -> int: + """解析时间范围字符串,返回天数。 + + 参数: + time_str: 时间范围字符串(如 ``7d``、``1w``、``30d``)。 + + 返回: + clamp 在 ``[_STATS_MIN_DAYS, _STATS_MAX_DAYS]`` 内的天数。 + """ + if not time_str: + return _STATS_DEFAULT_DAYS + + def _clamp_days(value: int) -> int: + if value < _STATS_MIN_DAYS: + return _STATS_DEFAULT_DAYS + if value > _STATS_MAX_DAYS: + return _STATS_MAX_DAYS + return value + + time_str = time_str.lower().strip() + + if time_str.endswith("d"): + try: + return _clamp_days(int(time_str[:-1])) + except ValueError: + return _STATS_DEFAULT_DAYS + if time_str.endswith("w"): + try: + return _clamp_days(int(time_str[:-1]) * 7) + except ValueError: + return _STATS_DEFAULT_DAYS + if time_str.endswith("m"): + try: + return _clamp_days(int(time_str[:-1]) * 30) + except ValueError: + return _STATS_DEFAULT_DAYS + + try: + return _clamp_days(int(time_str)) + except ValueError: + return _STATS_DEFAULT_DAYS + + def _parse_stats_options(self, args: list[str]) -> tuple[int, bool]: + """解析 ``/stats`` 参数:时间范围 + AI 分析开关。""" + days = _STATS_DEFAULT_DAYS + enable_ai_analysis = False + picked_days = False + + for raw in args: + token = str(raw or "").strip() + if not token: + continue + lower = token.lower() + if lower in _STATS_AI_FLAGS: + enable_ai_analysis = True + continue + if not picked_days and _STATS_TIME_RANGE_RE.match(lower): + days = self._parse_time_range(lower) + picked_days = True + + return days, enable_ai_analysis + + async def _handle_stats( + self, group_id: int, sender_id: int, args: list[str] + ) -> None: + """处理群聊 ``/stats`` 命令,生成 token 使用统计图表(可选 AI 分析)。""" + if not _MATPLOTLIB_AVAILABLE: + await self.sender.send_group_message( + group_id, "❌ 缺少必要的库,无法生成图表。请安装 matplotlib。" + ) + return + + days, enable_ai_analysis = self._parse_stats_options(args) + + try: + summary = await self._token_usage_storage.get_summary(days=days) + if summary["total_calls"] == 0: + await self.sender.send_group_message( + group_id, f"📊 最近 {days} 天内无 Token 使用记录。" + ) + return + + from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir + + img_dir = ensure_dir(RENDER_CACHE_DIR) + await self._generate_line_chart(summary, img_dir, days) + await self._generate_bar_chart(summary, img_dir) + await self._generate_pie_chart(summary, img_dir) + await self._generate_stats_table(summary, img_dir) + + ai_analysis = "" + if enable_ai_analysis: + ai_analysis = await self._run_stats_ai_analysis( + scope="group", + scope_id=group_id, + sender_id=sender_id, + summary=summary, + days=days, + ) + + forward_messages = self._build_stats_forward_nodes( + summary, img_dir, days, ai_analysis + ) + await self._send_group_forward_message( + group_id, + forward_messages, + history_message=self._build_stats_history_message( + summary, + days, + ai_analysis, + ), + ) + + from Undefined.utils.cache import cleanup_cache_dir + + cleanup_cache_dir(RENDER_CACHE_DIR) + + except Exception as e: + error_id = uuid4().hex[:8] + logger.exception( + "[Stats] 生成统计图表失败: error_id=%s err=%s", error_id, e + ) + await self.sender.send_group_message( + group_id, + f"❌ 生成统计图表失败,请稍后重试(错误码: {error_id})", + ) + + async def _send_group_forward_message( + self, + group_id: int, + messages: list[dict[str, Any]], + *, + history_message: str, + ) -> None: + """发送群组合并转发消息,并在需要时写入历史记录。""" + send_forward = getattr(self.sender, "send_group_forward_message", None) + if callable(send_forward): + await send_forward(group_id, messages, history_message=history_message) + return + + await self.onebot.send_forward_msg(group_id, messages) + if self.history_manager is None: + return + text_content = history_message.strip() + if not text_content: + return + + await self.history_manager.add_group_message( + group_id=group_id, + sender_id=getattr(self.config, "bot_qq", 0), + text_content=text_content, + sender_nickname="Bot", + group_name="", + ) + + @staticmethod + def _build_stats_history_message( + summary: dict[str, Any], + days: int, + ai_analysis: str, + ) -> str: + """构建写入群聊历史的 ``/stats`` 输出摘要文本。""" + lines = [ + f"[命令输出] /stats 最近 {days} 天 Token 使用统计", + f"总调用: {summary.get('total_calls', 0)}", + f"总 Token: {summary.get('total_tokens', 0)}", + f"输入 Token: {summary.get('prompt_tokens', 0)}", + f"输出 Token: {summary.get('completion_tokens', 0)}", + ] + if ai_analysis.strip(): + lines.extend(["", "AI 分析:", ai_analysis.strip()]) + return "\n".join(lines) + + async def _handle_stats_private( + self, + user_id: int, + sender_id: int, + args: list[str], + send_message: Any = None, + *, + is_webui_session: bool = False, + ) -> None: + """处理私聊 ``/stats``(含 WebUI 虚拟私聊适配)。""" + + async def _send_private(message: str) -> None: + if send_message is not None: + await send_message(message) + else: + await self.sender.send_private_message(user_id, message) + + days, enable_ai_analysis = self._parse_stats_options(args) + try: + summary = await self._token_usage_storage.get_summary(days=days) + if summary["total_calls"] == 0: + await _send_private(f"📊 最近 {days} 天内无 Token 使用记录。") + return + + ai_analysis = "" + if enable_ai_analysis: + ai_analysis = await self._run_stats_ai_analysis( + scope="private", + scope_id=0, + sender_id=sender_id, + summary=summary, + days=days, + ) + + if not _MATPLOTLIB_AVAILABLE: + message = "❌ 缺少必要的库,无法生成图表。请安装 matplotlib。" + if is_webui_session: + message += "\n\n" + self._build_stats_summary_text(summary) + if ai_analysis: + message += f"\n\n🤖 AI 智能分析\n{ai_analysis}" + await _send_private(message) + return + + from Undefined.utils.cache import cleanup_cache_dir + from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir + + img_dir = ensure_dir(RENDER_CACHE_DIR) + await self._generate_line_chart(summary, img_dir, days) + await self._generate_bar_chart(summary, img_dir) + await self._generate_pie_chart(summary, img_dir) + await self._generate_stats_table(summary, img_dir) + + await _send_private(f"📊 最近 {days} 天的 Token 使用统计:") + for img_name in ["line_chart", "bar_chart", "pie_chart", "table"]: + img_path = img_dir / f"stats_{img_name}.png" + if img_path.exists(): + message = await self._build_private_stats_image_message( + img_path, + inline_base64=is_webui_session, + ) + await _send_private(message) + + await _send_private(self._build_stats_summary_text(summary)) + if ai_analysis: + await _send_private(f"🤖 AI 智能分析\n{ai_analysis}") + + cleanup_cache_dir(RENDER_CACHE_DIR) + except Exception as e: + error_id = uuid4().hex[:8] + logger.exception( + "[Stats] 私聊统计生成失败: error_id=%s user=%s err=%s", + error_id, + user_id, + e, + ) + await _send_private( + f"❌ 生成统计图表失败,请稍后重试(错误码: {error_id})" + ) + + async def _build_private_stats_image_message( + self, + image_path: Path, + *, + inline_base64: bool, + ) -> str: + """构建私聊统计图片的 OneBot CQ 码消息。""" + file_uri = image_path.absolute().as_uri() + if not inline_base64: + return f"[CQ:image,file={file_uri}]" + + try: + encoded = await asyncio.to_thread( + lambda: base64.b64encode(image_path.read_bytes()).decode("ascii") + ) + except Exception as exc: + logger.warning( + "[Stats] 图像 base64 编码失败,回退文件路径: path=%s err=%s", + file_uri, + exc, + ) + return f"[CQ:image,file={file_uri}]" + + return f"[CQ:image,file=base64://{encoded}]" + + async def _run_stats_ai_analysis( + self, + *, + scope: str, + scope_id: int, + sender_id: int, + summary: dict[str, Any], + days: int, + ) -> str: + """投递并等待 AI 对统计数据的分析结果。""" + if not self.queue_manager: + return "" + + data_summary = self._build_data_summary(summary, days) + request_id = uuid4().hex + analysis_event = asyncio.Event() + self._stats_analysis_events[request_id] = analysis_event + request_data = { + "type": "stats_analysis", + "group_id": scope_id, + "request_id": request_id, + "sender_id": sender_id, + "data_summary": data_summary, + "summary": summary, + "days": days, + "scope": scope, + } + receipt = await self.queue_manager.add_group_mention_request( + request_data, model_name=self.config.chat_model.model_name + ) + logger.info("[Stats] 已投递 AI 分析请求: scope=%s target=%s", scope, scope_id) + + wait_timeout = compute_queued_llm_timeout_seconds( + self.ai.runtime_config, + self.config.chat_model, + retry_count=resolve_effective_retry_count( + self.ai.runtime_config, self.queue_manager + ), + initial_wait_seconds=float( + getattr(receipt, "estimated_wait_seconds", 0.0) or 0.0 + ), + ) + try: + await asyncio.wait_for(analysis_event.wait(), timeout=wait_timeout) + ai_analysis = self._stats_analysis_results.pop(request_id, "") + logger.info( + "[Stats] 已获取 AI 分析结果: scope=%s len=%s", scope, len(ai_analysis) + ) + return ai_analysis + except asyncio.TimeoutError: + logger.warning( + "[Stats] AI 分析超时: scope=%s target=%s timeout=%.1fs", + scope, + scope_id, + wait_timeout, + ) + return "AI 分析超时,已先发送图表与汇总数据。" + finally: + self._stats_analysis_events.pop(request_id, None) + self._stats_analysis_results.pop(request_id, None) + + def _build_data_summary(self, summary: dict[str, Any], days: int) -> str: + """构建用于 AI 分析的统计数据摘要。""" + lines = [] + lines.append("📊 Token 使用综合分析数据:") + lines.append("") + + lines.append("【整体概况】") + lines.append(f"统计周期: {days} 天") + lines.append(f"总调用次数: {summary['total_calls']}") + lines.append(f"总 Token 消耗: {summary['total_tokens']:,}") + lines.append(f"平均响应时间: {summary['avg_duration']:.2f}s") + lines.append(f"涉及模型数: {len(summary['models'])}") + lines.append("") + + daily_stats = summary.get("daily_stats", {}) + if daily_stats: + dates = sorted(daily_stats.keys()) + total_daily_calls = sum(daily_stats[d]["calls"] for d in dates) + total_daily_tokens = sum(daily_stats[d]["tokens"] for d in dates) + avg_daily_calls = total_daily_calls / len(dates) if dates else 0 + avg_daily_tokens = total_daily_tokens / len(dates) if dates else 0 + + peak_day = ( + max(dates, key=lambda d: daily_stats[d]["tokens"]) if dates else "" + ) + peak_day_tokens = daily_stats[peak_day]["tokens"] if peak_day else 0 + + lines.append("【时间维度】") + lines.append(f"统计天数: {len(dates)} 天") + lines.append(f"每日平均调用: {avg_daily_calls:.1f} 次") + lines.append(f"每日平均 Token: {avg_daily_tokens:,.0f} 个") + lines.append(f"高峰日期: {peak_day} ({peak_day_tokens:,} tokens)") + lines.append("") + + models = summary.get("models", {}) + if models: + lines.append("【模型维度】") + total_tokens_all = summary["total_tokens"] + sorted_models = sorted( + models.items(), key=lambda x: x[1]["tokens"], reverse=True + ) + for model_name, model_data in sorted_models[:_STATS_MODEL_TOP_N]: + calls = model_data["calls"] + tokens = model_data["tokens"] + prompt_tokens = model_data["prompt_tokens"] + completion_tokens = model_data["completion_tokens"] + token_pct = ( + (tokens / total_tokens_all * 100) if total_tokens_all > 0 else 0 + ) + avg_per_call = tokens / calls if calls > 0 else 0 + io_ratio = completion_tokens / prompt_tokens if prompt_tokens > 0 else 0 + + lines.append(f"模型: {model_name}") + lines.append( + f" - 调用次数: {calls} ({calls / summary['total_calls'] * 100:.1f}%)" + ) + lines.append(f" - Token 消耗: {tokens:,} ({token_pct:.1f}%)") + lines.append(f" - 平均每次调用: {avg_per_call:.0f} tokens") + lines.append( + f" - 输入: {prompt_tokens:,} / 输出: {completion_tokens:,}" + ) + lines.append(f" - 输入/输出比: 1:{io_ratio:.2f}") + lines.append("") + + if len(sorted_models) > _STATS_MODEL_TOP_N: + others = sorted_models[_STATS_MODEL_TOP_N:] + others_calls = sum(int(item[1].get("calls", 0)) for item in others) + others_tokens = sum(int(item[1].get("tokens", 0)) for item in others) + others_pct = ( + (others_tokens / total_tokens_all * 100) + if total_tokens_all > 0 + else 0.0 + ) + lines.append( + f"其余 {len(others)} 个模型合计: 调用 {others_calls} 次, Token {others_tokens:,} ({others_pct:.1f}%)" + ) + lines.append("") + + call_types = summary.get("call_types", {}) + if call_types: + lines.append("【调用类型维度】") + sorted_types = sorted( + call_types.items(), key=lambda item: int(item[1]), reverse=True + ) + total_calls = max(1, int(summary.get("total_calls", 0))) + for call_type, count in sorted_types[:_STATS_CALL_TYPE_TOP_N]: + ratio = int(count) / total_calls * 100 + lines.append(f"- {call_type}: {count} 次 ({ratio:.1f}%)") + if len(sorted_types) > _STATS_CALL_TYPE_TOP_N: + rest_count = sum( + int(item[1]) for item in sorted_types[_STATS_CALL_TYPE_TOP_N:] + ) + ratio = rest_count / total_calls * 100 + lines.append( + f"- 其他 {len(sorted_types) - _STATS_CALL_TYPE_TOP_N} 类: {rest_count} 次 ({ratio:.1f}%)" + ) + lines.append("") + + prompt_tokens = summary.get("prompt_tokens", 0) + completion_tokens = summary.get("completion_tokens", 0) + total_tokens = summary.get("total_tokens", 0) + input_ratio = (prompt_tokens / total_tokens * 100) if total_tokens > 0 else 0 + output_ratio = ( + (completion_tokens / total_tokens * 100) if total_tokens > 0 else 0 + ) + output_per_input = completion_tokens / prompt_tokens if prompt_tokens > 0 else 0 + + lines.append("【效率指标】") + lines.append(f"输入 Token: {prompt_tokens:,} ({input_ratio:.1f}%)") + lines.append(f"输出 Token: {completion_tokens:,} ({output_ratio:.1f}%)") + lines.append(f"输入/输出比: 1:{output_per_input:.2f}") + lines.append("") + + if daily_stats and len(daily_stats) > 1: + lines.append("【趋势分析】") + dates = sorted(daily_stats.keys()) + first_day_tokens = daily_stats[dates[0]]["tokens"] + last_day_tokens = daily_stats[dates[-1]]["tokens"] + trend_change = ( + ((last_day_tokens - first_day_tokens) / first_day_tokens * 100) + if first_day_tokens > 0 + else 0 + ) + trend_desc = "增长" if trend_change > 0 else "下降" + lines.append( + f"总体趋势: {trend_desc} {abs(trend_change):.1f}% (从首日到末日)" + ) + lines.append("") + + summary_text = "\n".join(lines) + if len(summary_text) > _STATS_DATA_SUMMARY_MAX_CHARS: + trimmed = summary_text[: _STATS_DATA_SUMMARY_MAX_CHARS - 80].rstrip() + summary_text = ( + f"{trimmed}\n\n[数据摘要已截断,总长度 {len(summary_text)} 字符," + f"仅保留前 {_STATS_DATA_SUMMARY_MAX_CHARS} 字符]" + ) + logger.info( + "[Stats] 数据摘要过长已截断: original_len=%s max_len=%s", + len("\n".join(lines)), + _STATS_DATA_SUMMARY_MAX_CHARS, + ) + return summary_text + + def _build_stats_summary_text(self, summary: dict[str, Any]) -> str: + """构建统计结果的纯文本摘要。""" + return f"""📈 摘要汇总: +• 总调用次数: {summary["total_calls"]} +• 总消耗 Tokens: {summary["total_tokens"]:,} + └─ 输入: {summary["prompt_tokens"]:,} + └─ 输出: {summary["completion_tokens"]:,} +• 平均耗时: {summary["avg_duration"]:.2f}s +• 涉及模型数: {len(summary["models"])}""" + + def set_stats_analysis_result( + self, group_id: int, request_id: str, analysis: str + ) -> None: + """设置 AI 分析结果(由队列处理器调用)。""" + event = self._stats_analysis_events.get(request_id) + if not event: + logger.warning( + "[StatsAnalysis] 未找到等待事件,群: %s, 请求: %s", + group_id, + request_id, + ) + return + self._stats_analysis_results[request_id] = analysis + event.set() + + def _build_stats_forward_nodes( + self, + summary: dict[str, Any], + img_dir: Path, + days: int, + ai_analysis: str = "", + ) -> list[dict[str, Any]]: + """构建用于合并转发的统计图表节点列表。""" + # 对外入队 API + nodes = [] + bot_qq = str(self.config.bot_qq) + + # 对外入队 API + def add_node(content: str) -> None: + nodes.append( + { + "type": "node", + "data": {"name": "Bot", "uin": bot_qq, "content": content}, + } + ) + + add_node(f"📊 最近 {days} 天的 Token 使用统计:") + + for img_name in ["line_chart", "bar_chart", "pie_chart", "table"]: + img_path = img_dir / f"stats_{img_name}.png" + if img_path.exists(): + add_node(f"[CQ:image,file={img_path.absolute().as_uri()}]") + + add_node(self._build_stats_summary_text(summary)) + + if ai_analysis: + add_node(f"🤖 AI 智能分析\n{ai_analysis}") + + return nodes + + async def _generate_line_chart( + self, summary: dict[str, Any], img_dir: Path, days: int + ) -> None: + """生成折线图:时间趋势。""" + daily_stats = summary["daily_stats"] + if not daily_stats: + return + + dates = sorted(daily_stats.keys()) + tokens = [daily_stats[d]["tokens"] for d in dates] + prompt_tokens = [daily_stats[d]["prompt_tokens"] for d in dates] + completion_tokens = [daily_stats[d]["completion_tokens"] for d in dates] + + fig, ax = plt.subplots(figsize=(12, 7)) + + ax.plot( + dates, tokens, marker="o", linewidth=2, label="Total Token", color="#2196F3" + ) + ax.plot( + dates, + prompt_tokens, + marker="s", + linewidth=2, + label="Input Token", + color="#4CAF50", + ) + ax.plot( + dates, + completion_tokens, + marker="^", + linewidth=2, + label="Output Token", + color="#FF9800", + ) + + ax.set_title( + f"Token Usage Trend for Last {days} Days", fontsize=16, fontweight="bold" + ) + ax.set_xlabel("Date", fontsize=12) + ax.set_ylabel("Token Count", fontsize=12) + ax.legend(loc="upper left", fontsize=10) + ax.grid(True, alpha=0.3) + + plt.xticks(rotation=45, ha="right") + plt.tight_layout() + + filepath = img_dir / "stats_line_chart.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) + + async def _generate_bar_chart(self, summary: dict[str, Any], img_dir: Path) -> None: + """生成柱状图:模型对比。""" + models = summary["models"] + if not models: + return + + model_names = list(models.keys()) + tokens = [models[m]["tokens"] for m in model_names] + prompt_tokens = [models[m]["prompt_tokens"] for m in model_names] + completion_tokens = [models[m]["completion_tokens"] for m in model_names] + + fig, ax = plt.subplots(figsize=(14, 8)) + + x = range(len(model_names)) + width = 0.25 + + bars1 = ax.bar( + [i - width for i in x], + tokens, + width, + label="Total Token", + color="#2196F3", + alpha=0.8, + ) + bars2 = ax.bar( + x, + prompt_tokens, + width, + label="Input Token", + color="#4CAF50", + alpha=0.8, + ) + bars3 = ax.bar( + [i + width for i in x], + completion_tokens, + width, + label="Output Token", + color="#FF9800", + alpha=0.8, + ) + + ax.set_title("Token Usage Comparison by Model", fontsize=16, fontweight="bold") + ax.set_xlabel("Model", fontsize=12) + ax.set_ylabel("Token Count", fontsize=12) + ax.set_xticks(x) + ax.set_xticklabels(model_names, rotation=45, ha="right") + ax.legend(loc="upper right", fontsize=10) + ax.grid(True, alpha=0.3, axis="y") + + for bars in [bars1, bars2, bars3]: + for bar in bars: + height = bar.get_height() + if height > 0: + ax.text( + bar.get_x() + bar.get_width() / 2.0, + height, + f"{int(height):,}", + ha="center", + va="bottom", + fontsize=8, + ) + + plt.tight_layout() + + filepath = img_dir / "stats_bar_chart.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) + + async def _generate_pie_chart(self, summary: dict[str, Any], img_dir: Path) -> None: + """生成饼图:输入/输出比例。""" + prompt_tokens = summary["prompt_tokens"] + completion_tokens = summary["completion_tokens"] + + if prompt_tokens == 0 and completion_tokens == 0: + return + + fig, ax = plt.subplots(figsize=(12, 8)) + + labels = ["Input Token", "Output Token"] + sizes = [prompt_tokens, completion_tokens] + colors = ["#4CAF50", "#FF9800"] + explode = (0.05, 0.05) + + wedges, *_ = ax.pie( + sizes, + explode=explode, + labels=labels, + colors=colors, + autopct="%1.1f%%", + startangle=90, + textprops={"fontsize": 12}, + ) + + ax.set_title("Input/Output Token Ratio", fontsize=16, fontweight="bold", pad=20) + + ax.legend( + wedges, + [f"{labels[i]}: {sizes[i]:,}" for i in range(len(labels))], + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + fontsize=10, + ) + + plt.tight_layout() + + filepath = img_dir / "stats_pie_chart.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) + + async def _generate_stats_table( + self, summary: dict[str, Any], img_dir: Path + ) -> None: + """生成统计表格图片。""" + models = summary["models"] + if not models: + return + + model_names = list(models.keys()) + data = [] + for model in model_names: + m = models[model] + data.append( + [ + model, + m["calls"], + f"{m['tokens']:,}", + f"{m['prompt_tokens']:,}", + f"{m['completion_tokens']:,}", + ] + ) + + fig, ax = plt.subplots(figsize=(14, 9)) + ax.axis("tight") + ax.axis("off") + + table = ax.table( + cellText=data, + colLabels=["Model", "Calls", "Total Token", "Input Token", "Output Token"], + cellLoc="center", + loc="center", + ) + + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1.2, 1.5) + + for i in range(5): + table[(0, i)].set_facecolor("#2196F3") + table[(0, i)].set_text_props(weight="bold", color="white") + + for i in range(1, len(data) + 1): + for j in range(5): + if i % 2 == 0: + table[(i, j)].set_facecolor("#f0f0f0") + + ax.set_title( + "Model Usage Statistics Details", fontsize=16, fontweight="bold", pad=20 + ) + + plt.tight_layout() + + filepath = img_dir / "stats_table.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) diff --git a/src/Undefined/services/coordinator/__init__.py b/src/Undefined/services/coordinator/__init__.py new file mode 100644 index 00000000..07c8fd50 --- /dev/null +++ b/src/Undefined/services/coordinator/__init__.py @@ -0,0 +1,88 @@ +"""AI 协调器:组合群聊、私聊、合并与后台任务 mixin。""" + +from __future__ import annotations + + +import logging +from pathlib import Path +from typing import Any + +from Undefined.config import Config +from Undefined.services.coordinator.background import BackgroundMixin +from Undefined.services.coordinator.batching import BatchingMixin +from Undefined.services.coordinator.group import GroupReplyMixin +from Undefined.services.coordinator.private import PrivateReplyMixin +from Undefined.services.message_batcher import MessageBatcher +from Undefined.services.model_pool import ModelPoolService +from Undefined.services.queue_manager import QueueManager +from Undefined.services.security import SecurityService +from Undefined.utils.history import MessageHistoryManager +from Undefined.utils.scheduler import TaskScheduler +from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +class AICoordinator( + GroupReplyMixin, + PrivateReplyMixin, + BatchingMixin, + BackgroundMixin, +): + """AI 协调器,处理 AI 回复逻辑、Prompt 构建和队列管理""" + + def __init__( + self, + config: Config, + ai: Any, # AIClient + queue_manager: QueueManager, + history_manager: MessageHistoryManager, + sender: MessageSender, + onebot: Any, # OneBotClient + scheduler: TaskScheduler, + security: SecurityService, + command_dispatcher: Any = None, + ) -> None: + self.config = config + self.ai = ai + self.queue_manager = queue_manager + self.history_manager = history_manager + self.sender = sender + self.onebot = onebot + self.scheduler = scheduler + self.security = security + self.command_dispatcher = command_dispatcher + self.model_pool = ModelPoolService(ai, config, sender) + # batcher 由外部(handlers.py)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 + self._batcher: MessageBatcher | None = None + + def set_batcher(self, batcher: MessageBatcher | None) -> None: + """注入消息合并器;传 None 等同于禁用合并。""" + self._batcher = batcher + + @property + def batcher(self) -> MessageBatcher | None: + return self._batcher + + async def _send_image(self, tid: int, mtype: str, path: str) -> None: + """发送图片或语音消息到群聊或私聊""" + import os + + if not os.path.exists(path): + return + file_uri = Path(path).resolve().as_uri() + ext = os.path.splitext(path)[1].lower() + if ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"]: + msg = f"[CQ:image,file={file_uri}]" + elif ext in [".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aac"]: + msg = f"[CQ:record,file={file_uri}]" + else: + return + + try: + if mtype == "group": + await self.sender.send_group_message(tid, msg, auto_history=False) + elif mtype == "private": + await self.sender.send_private_message(tid, msg, auto_history=False) + except Exception: + logger.exception("发送媒体文件失败") diff --git a/src/Undefined/services/coordinator/background.py b/src/Undefined/services/coordinator/background.py new file mode 100644 index 00000000..b6c8cefe --- /dev/null +++ b/src/Undefined/services/coordinator/background.py @@ -0,0 +1,250 @@ +"""后台 LLM 任务:stats 分析、队列调用与 Agent 介绍生成。""" + +from __future__ import annotations + + +import logging +from typing import TYPE_CHECKING, Any, cast + +from Undefined.services.queue_manager import QUEUE_LANE_BACKGROUND +from Undefined.utils.resources import read_text_resource + +if TYPE_CHECKING: + from Undefined.ai import AIClient + from Undefined.config import Config + from Undefined.services.command import CommandDispatcher + +logger = logging.getLogger(__name__) + + +_STATS_ANALYSIS_PROMPT_PATH = "res/prompts/stats_analysis.txt" +_STATS_ANALYSIS_FALLBACK_PROMPT = ( + "你是一位专业的数据分析师。请根据以下 Token 使用统计数据提供分析:\n\n" + "{data_summary}\n\n" + "请从整体概况、趋势、模型效率、成本结构、异常点和优化建议进行总结," + "语言简洁,建议可执行。" +) + + +class BackgroundMixin: + """队列分发入口与后台 LLM 任务执行。""" + + if TYPE_CHECKING: + ai: AIClient + config: Config + command_dispatcher: CommandDispatcher + + async def _execute_auto_reply(self, request: dict[str, Any]) -> None: ... + async def _execute_private_reply(self, request: dict[str, Any]) -> None: ... + + async def execute_reply(self, request: dict[str, Any]) -> None: + """执行排队中的回复请求(由 QueueManager 分发调用) + + 参数: + request: 包含请求类型和必要元数据的请求字典 + """ + req_type = request.get("type", "unknown") + logger.debug("[执行请求] type=%s keys=%s", req_type, list(request.keys())) + batch_token = request.get("_message_batcher_token") + # 投机 pre-fire 被新消息 cancel 后,coordinator 在真正执行前跳过旧 token + if bool(getattr(batch_token, "cancelled", False)): + logger.info( + "[MessageBatcher] 跳过已取消的投机请求: type=%s scope=%s sender=%s batch=%s", + req_type, + getattr(batch_token, "scope", ""), + getattr(batch_token, "sender_id", ""), + getattr(batch_token, "batch_id", ""), + ) + return + if req_type == "auto_reply": + await self._execute_auto_reply(request) + elif req_type == "private_reply": + await self._execute_private_reply(request) + elif req_type == "stats_analysis": + await self._execute_stats_analysis(request) + elif req_type == "agent_intro_generation": + await self._execute_agent_intro_generation(request) + elif req_type in {"queued_llm_call", "background_llm_call"}: + await self._execute_queued_llm_call(request) + + async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: + """执行 stats 命令的 AI 分析""" + group_id = request["group_id"] + request_id = request.get("request_id") + data_summary = request.get("data_summary", "") + + if not request_id: + logger.warning("[统计分析] 缺少 request_id,群=%s", group_id) + return + try: + prompt_template = _STATS_ANALYSIS_FALLBACK_PROMPT + try: + loaded_prompt = read_text_resource(_STATS_ANALYSIS_PROMPT_PATH).strip() + if loaded_prompt: + prompt_template = loaded_prompt + except Exception as exc: + logger.warning("[统计分析] 读取提示词失败,使用内置模板: %s", exc) + + if "{data_summary}" not in prompt_template: + logger.warning( + "[统计分析] 提示词缺少 {data_summary} 占位符,自动追加", + ) + prompt_template = f"{prompt_template}\n\n{{data_summary}}" + + safe_data_summary = str(data_summary).strip() or "暂无统计数据摘要" + try: + full_prompt = prompt_template.format(data_summary=safe_data_summary) + except Exception as exc: + logger.warning("[统计分析] 提示词渲染失败,使用回退模板: %s", exc) + full_prompt = _STATS_ANALYSIS_FALLBACK_PROMPT.format( + data_summary=safe_data_summary + ) + + messages = [ + {"role": "system", "content": "你是一位专业的数据分析师。"}, + {"role": "user", "content": full_prompt}, + ] + + result = await self.ai.submit_queued_llm_call( + model_config=self.config.chat_model, + messages=messages, + max_tokens=2048, + call_type="stats_analysis", + queue_lane=request.get("_queue_lane"), + ) + + choices = result.get("choices", [{}]) + if choices: + content = choices[0].get("message", {}).get("content", "") + analysis = content.strip() + else: + analysis = "AI 分析未能生成结果" + + if not analysis: + analysis = "AI 分析结果为空,建议稍后重试。" + + logger.info( + "[统计分析] 分析完成: group=%s length=%s request_id=%s", + group_id, + len(analysis), + request_id, + ) + + if self.command_dispatcher: + self.command_dispatcher.set_stats_analysis_result( + group_id, request_id, analysis + ) + + except Exception as exc: + logger.exception("[统计分析] AI 分析失败: %s", exc) + if self.command_dispatcher: + self.command_dispatcher.set_stats_analysis_result( + group_id, request_id, "" + ) + + async def _execute_queued_llm_call(self, request: dict[str, Any]) -> None: + """执行队列中的 LLM 子请求。""" + request_id = request.get("request_id", "") + retry_count = int(request.get("_retry_count", 0) or 0) + queue_lane = str(request.get("_queue_lane") or QUEUE_LANE_BACKGROUND) + call_type = str(request.get("call_type", "background") or "background") + try: + max_tokens_raw = request.get("max_tokens") or getattr( + request["model_config"], "max_tokens", 4096 + ) + max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else 4096 + result = await self.ai.request_model( + model_config=request["model_config"], + messages=request["messages"], + tools=request.get("tools"), + tool_choice=request.get("tool_choice", "auto"), + call_type=call_type, + max_tokens=max_tokens, + transport_state=request.get("transport_state"), + ) + self.ai.set_llm_call_result(request_id, result) + if retry_count > 0: + logger.info( + "[queued_llm_retry_success] request_id=%s call_type=%s model=%s lane=%s retry=%s", + request_id, + call_type, + getattr(request["model_config"], "model_name", "default"), + queue_lane, + retry_count, + ) + except Exception as exc: + retry_count = request.get("_retry_count", 0) + if retry_count >= self.config.ai_request_max_retries: + self.ai.set_llm_call_result(request_id, exc) + raise + + async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None: + """执行 Agent 自我介绍生成请求""" + request_id = request.get("request_id") + agent_name = request.get("agent_name") + + if not request_id or not agent_name: + logger.warning( + "[Agent介绍生成] 缺少必要参数: request_id=%s agent_name=%s", + request_id, + agent_name, + ) + return + + try: + from Undefined.skills.agents.intro_generator import AgentIntroGenerator + + agent_intro_generator = self.ai._agent_intro_generator + if not isinstance(agent_intro_generator, AgentIntroGenerator): + logger.error("[Agent介绍生成] 无法获取 AgentIntroGenerator 实例") + return + + ( + system_prompt, + user_prompt, + ) = await agent_intro_generator.get_intro_prompt_and_context(agent_name) + + messages = [ + {"role": "system", "content": system_prompt or "你是一位智能助手。"}, + {"role": "user", "content": user_prompt}, + ] + + result = await self.ai.submit_queued_llm_call( + model_config=self.ai.agent_config, + messages=messages, + max_tokens=agent_intro_generator.config.max_tokens, + call_type=f"agent:{agent_name}", + queue_lane=request.get("_queue_lane"), + ) + + choices = result.get("choices", [{}]) + if choices: + content = choices[0].get("message", {}).get("content", "") + generated_content = content.strip() + else: + generated_content = "" + + logger.info( + "[Agent介绍生成] 生成完成: agent=%s length=%s request_id=%s", + agent_name, + len(generated_content), + request_id, + ) + + agent_intro_generator.set_intro_generation_result( + request_id, generated_content if generated_content else None + ) + + except Exception as exc: + logger.exception( + "[Agent介绍生成] 生成失败: agent=%s error=%s", + agent_name, + exc, + ) + try: + agent_intro_generator = cast( + AgentIntroGenerator, self.ai._agent_intro_generator + ) + agent_intro_generator.set_intro_generation_result(request_id, None) + except Exception: + pass diff --git a/src/Undefined/services/coordinator/batching.py b/src/Undefined/services/coordinator/batching.py new file mode 100644 index 00000000..d3da2c63 --- /dev/null +++ b/src/Undefined/services/coordinator/batching.py @@ -0,0 +1,174 @@ +"""消息合并、分组 prompt 构建与队列投递。""" + +from __future__ import annotations + + +import logging +from typing import TYPE_CHECKING, Any + +from Undefined.services.coordinator.group import _GROUP_STRATEGY_FOOTER +from Undefined.services.coordinator.private import _PRIVATE_STRATEGY_FOOTER +from Undefined.services.message_batcher import BufferedMessage + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.services.message_batcher import BufferedMessage as _BufferedMessage + from Undefined.services.model_pool import ModelPoolService + from Undefined.services.queue_manager import QueueManager + +logger = logging.getLogger(__name__) + + +# BatchingMixin:MessageBatcher 回调、合并 prompt 与队列路由 +class BatchingMixin: + """MessageBatcher 回调、合并 prompt 与队列路由。""" + + if TYPE_CHECKING: + config: Config + queue_manager: QueueManager + model_pool: ModelPoolService + + def _format_group_message_segment(self, item: _BufferedMessage) -> str: ... + def _format_private_message_segment(self, item: _BufferedMessage) -> str: ... + + async def handle_batched_dispatch(self, items: list[BufferedMessage]) -> None: + """:class:`MessageBatcher` 的 flush_callback:把一批消息组装为单次请求并入队。""" + if not items: + return + await self._dispatch_grouped_request(items) + + @staticmethod + def _build_continuous_messages_note(items: list[BufferedMessage]) -> str: + """生成"连续消息说明"段。仅在 ``len(items) >= 2`` 时使用。""" + count = len(items) + first_t = items[0].arrival_time + last_t = items[-1].arrival_time + span = max(0.0, last_t - first_t) + return ( + f"\n\n 【连续消息说明】以上 {count} 条 是同一用户在约 " + f"{span:.1f} 秒内连续发送的消息(按时间先后排列),代表本轮要回应的全部输入:\n" + f" - 这些 共同构成【当前输入批次】,不要把同批前几条误判为历史旧任务;" + f"批次之外的历史消息仍只作为背景,不能回溯拾荒\n" + f" - 先识别每条的意图,分清是【独立请求】还是【对前一条的修正/否定/补充/打断】\n" + f' · 若是【多个独立的不同意图/问题】(如"先帮我查 A,再翻译 B")' + f" → 每个都要回应,不要遗漏;与平时一样,可以多次 send_message 自然分发\n" + f' · 若后发是【对前发的修正/否定/补充/打断】(如"画猫" → "改成狗")' + f" → 以最后一次明确意图为准,旧的不再执行,可简短说明已采纳更新\n" + f' · 拿不准时偏向"独立请求",宁多勿漏\n' + f" - 整批在本轮一次性处理完即可,不要为同一意图重复输出(不要" + f'"中间一波、结尾再来一波"重复相同回复)\n' + f" - history 中若出现与当前轮 相同的条目,视为同一来源,不要重复处理" + ) + + def _build_grouped_prompt(self, items: list[BufferedMessage]) -> str: + """根据 BufferedMessage 列表构造合并后的完整 prompt。""" + if not items: + return "" + is_private = items[0].is_private + # prefix:拍一拍优先;否则任一 @bot + any_poke = any(it.is_poke for it in items) + any_at_bot = any(it.is_at_bot for it in items) + if any_poke: + prefix = "(用户拍了拍你) " + elif any_at_bot: + prefix = "(用户 @ 了你) " + else: + prefix = "" + + if is_private: + segments = [self._format_private_message_segment(it) for it in items] + else: + segments = [self._format_group_message_segment(it) for it in items] + body = prefix + "\n".join(segments) + if len(items) >= 2: + body += self._build_continuous_messages_note(items) + body += _GROUP_STRATEGY_FOOTER if not is_private else _PRIVATE_STRATEGY_FOOTER + return body + + async def _dispatch_grouped_request(self, items: list[BufferedMessage]) -> None: + """根据一组 BufferedMessage 决定优先级、构造 prompt 并入队。 + + 既是单条直送路径的统一出口,也是 :class:`MessageBatcher` 的 flush_callback。 + """ + if not items: + return + first = items[0] + last = items[-1] + full_question = self._build_grouped_prompt(items) + any_poke = any(it.is_poke for it in items) + any_at_bot = any(it.is_at_bot for it in items) + + if first.is_private: + user_id = first.sender_id + request_data: dict[str, Any] = { + "type": "private_reply", + "user_id": user_id, + "sender_name": first.sender_name, + "text": last.text, + "full_question": full_question, + "trigger_message_id": last.trigger_message_id, + "batched_count": len(items), + } + if first.batch_token is not None: + request_data["_message_batcher_token"] = first.batch_token + effective_config = self.model_pool.select_chat_config( + self.config.chat_model, user_id=user_id + ) + request_data["selected_model_name"] = effective_config.model_name + logger.debug( + "[私聊回复] full_question_len=%s user=%s batched=%s", + len(full_question), + user_id, + len(items), + ) + if user_id == self.config.superadmin_qq: + # 私聊超管 → 最高优先级 superadmin lane + await self.queue_manager.add_superadmin_request( + request_data, model_name=effective_config.model_name + ) + else: + await self.queue_manager.add_private_request( + request_data, model_name=effective_config.model_name + ) + return + + # 群聊:按 sender 身份与 @bot 状态选择 4 级 lane 之一 + group_id = first.group_id or 0 + sender_id = first.sender_id + request_data = { + "type": "auto_reply", + "group_id": group_id, + "sender_id": sender_id, + "sender_name": first.sender_name, + "group_name": first.group_name, + "text": last.text, + "full_question": full_question, + "is_at_bot": any_at_bot, + "trigger_message_id": last.trigger_message_id, + "batched_count": len(items), + } + if first.batch_token is not None: + request_data["_message_batcher_token"] = first.batch_token + logger.debug( + "[自动回复] full_question_len=%s group=%s sender=%s batched=%s", + len(full_question), + group_id, + sender_id, + len(items), + ) + if sender_id == self.config.superadmin_qq: + logger.info("[AI] 投递至群聊超级管理员队列 (batched=%s)", len(items)) + await self.queue_manager.add_group_superadmin_request( + request_data, model_name=self.config.chat_model.model_name + ) + elif any_at_bot: + trigger = "拍一拍" if any_poke else "@机器人" + logger.info("[AI] 触发原因: %s (batched=%s)", trigger, len(items)) + await self.queue_manager.add_group_mention_request( + request_data, model_name=self.config.chat_model.model_name + ) + else: + logger.info("[AI] 投递至普通请求队列 (batched=%s)", len(items)) + await self.queue_manager.add_group_normal_request( + request_data, model_name=self.config.chat_model.model_name + ) diff --git a/src/Undefined/services/coordinator/group.py b/src/Undefined/services/coordinator/group.py new file mode 100644 index 00000000..0dce633b --- /dev/null +++ b/src/Undefined/services/coordinator/group.py @@ -0,0 +1,405 @@ +"""群聊自动回复与 prompt 构建。""" + +from __future__ import annotations + + +import asyncio +import logging +import time +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional + +from Undefined.attachments import attachment_refs_to_xml +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.render import render_html_to_image, render_markdown_to_html +from Undefined.services.message_batcher import BufferedMessage, make_scope +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.services.message_batcher import BufferedMessage + from Undefined.services.security import SecurityService + from Undefined.utils.history import MessageHistoryManager + from Undefined.utils.scheduler import TaskScheduler + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +_GROUP_STRATEGY_FOOTER = """ + + 【回复策略 - 更克制,纯表情包才前置检索】 + 1. 如果用户 @ 了你或拍了拍你 → 【必须回复】 + 2. 如果消息中明确提到了你(根据上下文判断用户是否在叫你或维持对话流) → 【必须回复】 + 3. 如果问题明确涉及某个项目/代码/部署细节(用户明确点名或上下文明确指向) → 【酌情回复,必要时先查证再回答】 + 4. 其他技术问题 → 【酌情回复,直接按用户提到的对象回答,不要引入无关的项目名/工具名作背景】 + 5. 先判断当前输入批次(无连续消息说明时就是最后一条消息)是不是在对你说: + - 如果明显是在和别人说话 → 【不要回复】 + - 如果你不能确定是不是在和你说话 → 【默认不回复】 + - 只有明确在和你说,或多人公开讨论且对话明显开放时,才进入下一步 + 6. 群聊里的主动参与只保留给公开、开放的技术或项目讨论: + - 只在多人公开讨论代码、AI、开发工具、项目进展、技术 bug 等,且不是别人之间定向交流时,才可以【极低频参与】 + - 默认更倾向不参与;不要长篇大论,一两句点到为止;如果别人已经在深入讨论且不需要你,保持沉默 + - 轻松互动、玩梗、吐槽本身不构成参与许可;只有在你已经决定要回复,且本轮明确是纯表情包/纯反应图时,才优先考虑表情包表达 + 7. 对于已经决定要回复的场景(包括被@、被拍一拍、轻量答疑,以及少量符合条件的主动参与): + - 只有明确纯表情包回复才先检索表情包,再用 memes.send_meme_by_uid 单独发一条图片消息 + - 其他需要文字承接、解释、答疑、推进任务、确认操作或表达具体态度的场景,第一轮必须优先把必要文字回复做好并调用 send_message + - 如果确实还想补表情包,把 memes.search_memes 和 memes.send_meme_by_uid 放到文字发送后的后续响应轮次,不要阻塞首条文字回复 + - 不要发送任何敷衍消息(如'懒得掺和'、'哦'等);不想回复就直接调用 end + - 严肃、任务型、高信息密度场景少发表情包,避免打断信息传递 + - 绝不要刷屏、绝不要每条都回 + 8. 对于本来就会回复的场景(私聊、被拍一拍、被@、轻量答疑): + - 如果表情包能自然增强语气、缓和语气或让表达更像真人,也只能作为后续可选补充 + - 但不要为了发表情包而牺牲信息传递;信息密度优先时仍以文字为主 + + 简单说:像个极度安静的群友。主动插话只留给公开、开放的技术或项目讨论;明显对别人说或拿不准时就闭嘴。已经决定要回复时,除非明确是纯表情包回复,否则先把文字回复做好,表情包最后再搜。""" + + +class GroupReplyMixin: + """群聊自动回复、注入防御与群聊 prompt 格式化。""" + + if TYPE_CHECKING: + ai: Any + config: Config + history_manager: MessageHistoryManager + onebot: Any + scheduler: TaskScheduler + security: SecurityService + sender: MessageSender + + async def _dispatch_grouped_request( + self, items: list[BufferedMessage] + ) -> None: ... + async def _send_image(self, tid: int, mtype: str, path: str) -> None: ... + + async def handle_auto_reply( + self, + group_id: int, + sender_id: int, + text: str, + message_content: list[dict[str, Any]], + attachments: list[dict[str, str]] | None = None, + is_poke: bool = False, + sender_name: str = "未知用户", + group_name: str = "未知群聊", + sender_role: str = "member", + sender_title: str = "", + sender_level: str = "", + trigger_message_id: int | None = None, + is_fake_at: bool = False, + ) -> None: + """群聊自动回复入口:根据消息内容、命中情况和安全检测决定是否回复""" + is_at_bot = is_poke or is_fake_at or self._is_at_bot(message_content) + logger.debug( + "[自动回复] group=%s sender=%s at_bot=%s fake_at=%s text_len=%s", + group_id, + sender_id, + is_at_bot, + is_fake_at, + len(text), + ) + + if sender_id != self.config.superadmin_qq: + logger.debug(f"[Security] 注入检测: group={group_id}, user={sender_id}") + if await self.security.detect_injection(text, message_content): + logger.warning( + f"[Security] 检测到注入攻击: group={group_id}, user={sender_id}" + ) + await self.history_manager.modify_last_group_message( + group_id, sender_id, "<这句话检测到用户进行注入,已删除>" + ) + if is_at_bot: + await self._handle_injection_response( + group_id, text, sender_id=sender_id + ) + return + + scope = make_scope(group_id=group_id) + item = BufferedMessage( + scope=scope, + sender_id=sender_id, + text=text, + message_content=list(message_content), + attachments=list(attachments or []), + sender_name=sender_name, + arrival_time=time.time(), + is_private=False, + trigger_message_id=trigger_message_id, + is_poke=is_poke, + is_at_bot=is_at_bot, + is_fake_at=is_fake_at, + group_id=group_id, + group_name=group_name, + sender_role=sender_role, + sender_title=sender_title, + sender_level=sender_level, + ) + + # 路由:拍一拍 → 永远旁路;否则按 batcher 启用情况与 @bot 处理规则决定 + if is_poke: + await self._dispatch_grouped_request([item]) + return + + batcher = getattr(self, "_batcher", None) + if batcher is not None and batcher.is_enabled_for(is_group=True): + if is_at_bot and batcher.has_buffer(scope, sender_id): + # 已有 buffer 时再来一条 @bot:单独立即处理,不打断现有 buffer + logger.info( + "[自动回复] batch 内 @bot 旁路立即处理: group=%s sender=%s", + group_id, + sender_id, + ) + await self._dispatch_grouped_request([item]) + return + await batcher.submit(item) + return + + await self._dispatch_grouped_request([item]) + + async def _execute_auto_reply(self, request: dict[str, Any]) -> None: + group_id = request["group_id"] + sender_id = request["sender_id"] + sender_name = str(request.get("sender_name") or "未知用户") + group_name = str(request.get("group_name") or "未知群聊") + full_question = request["full_question"] + trigger_message_id = request.get("trigger_message_id") + # 用于向 batcher 注册 inflight 任务(仅当本请求源自合并桶时生效) + batcher_scope: str | None = make_scope(group_id=group_id) if group_id else None + + async with RequestContext( + request_type="group", + group_id=group_id, + sender_id=sender_id, + user_id=sender_id, + ) as ctx: + + async def send_msg_cb(message: str, reply_to: int | None = None) -> None: + await self.sender.send_group_message( + group_id, + message, + reply_to=reply_to, + history_message=message, + ) + + async def get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=self.onebot, + history_manager=self.history_manager, + bot_qq=self.config.bot_qq, + attachment_registry=getattr(self.ai, "attachment_registry", None), + group_name_hint=group_name, + ) + + async def send_private_cb( + uid: int, msg: str, reply_to: int | None = None + ) -> None: + await self.sender.send_private_message(uid, msg, reply_to=reply_to) + + async def send_img_cb(tid: int, mtype: str, path: str) -> None: + await self._send_image(tid, mtype, path) + + async def send_like_cb(uid: int, times: int = 1) -> None: + await self.onebot.send_like(uid, times) + + ai_client = self.ai + memory_storage = self.ai.memory_storage + runtime_config = self.ai.runtime_config + sender = self.sender + history_manager = self.history_manager + onebot_client = self.onebot + scheduler = self.scheduler + send_message_callback = send_msg_cb + get_recent_messages_callback = get_recent_cb + get_image_url_callback = self.onebot.get_image + get_forward_msg_callback = self.onebot.get_forward_msg + send_like_callback = send_like_cb + send_private_message_callback = send_private_cb + send_image_callback = send_img_cb + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + ctx.set_resource(key, value) + if trigger_message_id is not None: + ctx.set_resource("trigger_message_id", trigger_message_id) + if request.get("_queue_lane"): + ctx.set_resource("queue_lane", request.get("_queue_lane")) + logger.debug( + "[上下文资源] group=%s keys=%s", + group_id, + ", ".join(sorted(resources.keys())), + ) + + try: + # 把当前 task 注册到 batcher,使其有能力在新消息到达时取消本次 LLM 调用 + batcher = getattr(self, "_batcher", None) + current_task = asyncio.current_task() + registered_task: asyncio.Task[Any] | None = None + if ( + batcher is not None + and batcher_scope is not None + and current_task is not None + ): + batcher.register_inflight( + batcher_scope, sender_id, current_task, ctx + ) + registered_task = current_task + try: + await self.ai.ask( + full_question, + send_message_callback=send_msg_cb, + get_recent_messages_callback=get_recent_cb, + get_image_url_callback=self.onebot.get_image, + get_forward_msg_callback=self.onebot.get_forward_msg, + send_like_callback=send_like_cb, + sender=self.sender, + history_manager=self.history_manager, + onebot_client=self.onebot, + scheduler=self.scheduler, + extra_context={ + "render_html_to_image": render_html_to_image, + "render_markdown_to_html": render_markdown_to_html, + "group_id": group_id, + "user_id": sender_id, + "is_at_bot": bool(request.get("is_at_bot", False)), + "sender_name": sender_name, + "group_name": group_name, + }, + ) + finally: + if ( + batcher is not None + and batcher_scope is not None + and registered_task is not None + ): + batcher.unregister_inflight( + batcher_scope, sender_id, registered_task + ) + except asyncio.CancelledError: + # 投机预发送被新消息抢占取消:不写错误日志、不重试 + logger.info( + "[自动回复] 任务被取消(投机抢占): group=%s sender=%s", + group_id, + sender_id, + ) + raise + except Exception: + logger.exception("自动回复执行出错") + raise + + def _is_at_bot(self, content: list[dict[str, Any]]) -> bool: + """检查消息内容中是否包含对机器人的 @ 提问""" + for seg in content: + if seg.get("type") == "at" and str( + seg.get("data", {}).get("qq", "") + ) == str(self.config.bot_qq): + return True + return False + + async def _handle_injection_response( + self, + tid: int, + text: str, + is_private: bool = False, + sender_id: Optional[int] = None, + ) -> None: + """当检测到注入攻击时,生成并发送特定的防御性回复""" + reply = await self.security.generate_injection_response(text) + if is_private: + await self.sender.send_private_message(tid, reply, auto_history=False) + await self.history_manager.add_private_message( + tid, "<对注入消息的回复>", "Bot", "Bot" + ) + else: + msg = f"[@{sender_id}] {reply}" if sender_id else reply + await self.sender.send_group_message(tid, msg, auto_history=False) + await self.history_manager.add_group_message( + tid, self.config.bot_qq, "<对注入消息的回复>", "Bot", "" + ) + + def _format_group_message_segment(self, item: BufferedMessage) -> str: + """格式化群聊单条 ```` 块。""" + time_str = datetime.fromtimestamp(item.arrival_time).strftime( + "%Y-%m-%d %H:%M:%S" + ) + group_name = item.group_name or "未知群聊" + location = group_name if group_name.endswith("群") else f"{group_name}群" + safe_name = escape_xml_attr(item.sender_name or "未知用户") + safe_uid = escape_xml_attr(item.sender_id) + safe_gid = escape_xml_attr(item.group_id or 0) + safe_gname = escape_xml_attr(group_name) + safe_loc = escape_xml_attr(location) + safe_role = escape_xml_attr(item.sender_role or "member") + safe_title = escape_xml_attr(item.sender_title or "") + safe_time = escape_xml_attr(time_str) + safe_text = escape_xml_text(item.text) + message_id_attr = "" + if item.trigger_message_id is not None: + message_id_attr = ( + f' message_id="{escape_xml_attr(item.trigger_message_id)}"' + ) + level_attr = ( + f' level="{escape_xml_attr(item.sender_level)}"' + if item.sender_level + else "" + ) + attachment_xml = ( + f"\n{attachment_refs_to_xml(item.attachments)}" if item.attachments else "" + ) + return ( + f'\n' + f" {safe_text}{attachment_xml}\n" + f" " + ) + + def _build_prompt( + self, + prefix: str, + name: str, + uid: int, + gid: int, + gname: str, + loc: str, + role: str, + title: str, + time_str: str, + text: str, + attachments: list[dict[str, str]] | None = None, + message_id: int | None = None, + level: str = "", + ) -> str: + """构建最终发送给 AI 的结构化 XML 消息 Prompt + + 包含回复策略提示、用户信息和原始文本内容。 + """ + safe_name = escape_xml_attr(name) + safe_uid = escape_xml_attr(uid) + safe_gid = escape_xml_attr(gid) + safe_gname = escape_xml_attr(gname) + safe_loc = escape_xml_attr(loc) + safe_role = escape_xml_attr(role) + safe_title = escape_xml_attr(title) + safe_time = escape_xml_attr(time_str) + safe_text = escape_xml_text(text) + message_id_attr = "" + if message_id is not None: + message_id_attr = f' message_id="{escape_xml_attr(message_id)}"' + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" + attachment_xml = ( + f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" + ) + return f"""{prefix} + {safe_text}{attachment_xml} + +{_GROUP_STRATEGY_FOOTER}""" diff --git a/src/Undefined/services/coordinator/private.py b/src/Undefined/services/coordinator/private.py new file mode 100644 index 00000000..e3a6ec74 --- /dev/null +++ b/src/Undefined/services/coordinator/private.py @@ -0,0 +1,282 @@ +"""私聊回复与私聊 prompt 格式化。""" + +from __future__ import annotations + + +import asyncio +import logging +import time +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from Undefined.attachments import ( + build_attachment_scope, + dispatch_pending_file_sends, + render_message_with_pic_placeholders, + attachment_refs_to_xml, +) +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.render import render_html_to_image, render_markdown_to_html +from Undefined.services.message_batcher import BufferedMessage, make_scope +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.services.message_batcher import BufferedMessage + from Undefined.services.security import SecurityService + from Undefined.utils.history import MessageHistoryManager + from Undefined.utils.scheduler import TaskScheduler + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +_PRIVATE_STRATEGY_FOOTER = """ + +【私聊消息】 +这是私聊消息,用户专门来找你说话。你可以自由选择是否回复: +- 如果想回复,先调用 send_message 工具发送回复内容,然后调用 end 结束对话 +- 只有明确纯表情包回复时,才先用 memes.search_memes 查表情包,再用 memes.send_meme_by_uid 单独发图;其他场景先把文字回复做好,表情包最后再搜或不搜 +- 如果不想回复,直接调用 end 结束对话即可""" + + +class PrivateReplyMixin: + """私聊自动回复与私聊 prompt 格式化。""" + + if TYPE_CHECKING: + ai: Any + config: Config + history_manager: MessageHistoryManager + onebot: Any + scheduler: TaskScheduler + security: SecurityService + sender: MessageSender + + async def _dispatch_grouped_request( + self, items: list[BufferedMessage] + ) -> None: ... + async def _handle_injection_response( + self, + tid: int, + text: str, + is_private: bool = False, + sender_id: int | None = None, + ) -> None: ... + async def _send_image(self, tid: int, mtype: str, path: str) -> None: ... + + async def handle_private_reply( + self, + user_id: int, + text: str, + message_content: list[dict[str, Any]], + attachments: list[dict[str, str]] | None = None, + is_poke: bool = False, + sender_name: str = "未知用户", + trigger_message_id: int | None = None, + ) -> None: + """处理私聊消息入口,决定回复策略并进行安全检测""" + logger.debug("[私聊回复] user=%s text_len=%s", user_id, len(text)) + if user_id != self.config.superadmin_qq: + if await self.security.detect_injection(text, message_content): + logger.warning(f"[Security] 私聊注入攻击: user_id={user_id}") + await self.history_manager.modify_last_private_message( + user_id, "<这句话检测到用户进行注入,已删除>" + ) + await self._handle_injection_response(user_id, text, is_private=True) + return + + scope = make_scope(user_id=user_id) + item = BufferedMessage( + scope=scope, + sender_id=user_id, + text=text, + message_content=list(message_content), + attachments=list(attachments or []), + sender_name=sender_name, + arrival_time=time.time(), + is_private=True, + trigger_message_id=trigger_message_id, + is_poke=is_poke, + ) + + if is_poke: + # 拍一拍旁路 batcher,立即单条入队 + await self._dispatch_grouped_request([item]) + return + + batcher = getattr(self, "_batcher", None) + if batcher is not None and batcher.is_enabled_for(is_group=False): + await batcher.submit(item) + return + + await self._dispatch_grouped_request([item]) + + async def _execute_private_reply(self, request: dict[str, Any]) -> None: + user_id = request["user_id"] + sender_name = str(request.get("sender_name") or "未知用户") + full_question = request["full_question"] + trigger_message_id = request.get("trigger_message_id") + batcher_scope: str | None = make_scope(user_id=user_id) + + async with RequestContext( + request_type="private", + user_id=user_id, + sender_id=user_id, + ) as ctx: + + async def send_msg_cb(message: str, reply_to: int | None = None) -> None: + await self.sender.send_private_message( + user_id, message, reply_to=reply_to + ) + + async def get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=self.onebot, + history_manager=self.history_manager, + bot_qq=self.config.bot_qq, + attachment_registry=getattr(self.ai, "attachment_registry", None), + ) + + async def send_img_cb(tid: int, mtype: str, path: str) -> None: + await self._send_image(tid, mtype, path) + + async def send_like_cb(uid: int, times: int = 1) -> None: + await self.onebot.send_like(uid, times) + + async def send_private_cb( + uid: int, msg: str, reply_to: int | None = None + ) -> None: + await self.sender.send_private_message(uid, msg, reply_to=reply_to) + + ai_client = self.ai + memory_storage = self.ai.memory_storage + runtime_config = self.ai.runtime_config + sender = self.sender + history_manager = self.history_manager + onebot_client = self.onebot + scheduler = self.scheduler + send_message_callback = send_msg_cb + get_recent_messages_callback = get_recent_cb + get_image_url_callback = self.onebot.get_image + get_forward_msg_callback = self.onebot.get_forward_msg + send_like_callback = send_like_cb + send_private_message_callback = send_private_cb + send_image_callback = send_img_cb + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + ctx.set_resource(key, value) + if trigger_message_id is not None: + ctx.set_resource("trigger_message_id", trigger_message_id) + if request.get("_queue_lane"): + ctx.set_resource("queue_lane", request.get("_queue_lane")) + logger.debug( + "[上下文资源] private user=%s keys=%s", + user_id, + ", ".join(sorted(resources.keys())), + ) + + try: + batcher = getattr(self, "_batcher", None) + current_task = asyncio.current_task() + registered_task: asyncio.Task[Any] | None = None + if ( + batcher is not None + and batcher_scope is not None + and current_task is not None + ): + batcher.register_inflight(batcher_scope, user_id, current_task, ctx) + registered_task = current_task + try: + result = await self.ai.ask( + full_question, + send_message_callback=send_msg_cb, + get_recent_messages_callback=get_recent_cb, + get_image_url_callback=self.onebot.get_image, + get_forward_msg_callback=self.onebot.get_forward_msg, + send_like_callback=send_like_cb, + sender=self.sender, + history_manager=self.history_manager, + onebot_client=self.onebot, + scheduler=self.scheduler, + extra_context={ + "render_html_to_image": render_html_to_image, + "render_markdown_to_html": render_markdown_to_html, + "user_id": user_id, + "is_private_chat": True, + "sender_name": sender_name, + "selected_model_name": request.get("selected_model_name"), + }, + ) + finally: + if ( + batcher is not None + and batcher_scope is not None + and registered_task is not None + ): + batcher.unregister_inflight( + batcher_scope, user_id, registered_task + ) + if result: + scope_key = build_attachment_scope( + user_id=user_id, + request_type="private", + ) + rendered = await render_message_with_pic_placeholders( + str(result), + registry=self.ai.attachment_registry, + scope_key=scope_key, + strict=False, + ) + await self.sender.send_private_message( + user_id, + rendered.delivery_text, + history_message=rendered.history_text, + ) + await dispatch_pending_file_sends( + rendered, + sender=self.sender, + target_type="private", + target_id=user_id, + registry=self.ai.attachment_registry, + ) + except asyncio.CancelledError: + logger.info("[私聊回复] 任务被取消(投机抢占): user=%s", user_id) + raise + except Exception: + logger.exception("私聊回复执行出错") + raise + + def _format_private_message_segment(self, item: BufferedMessage) -> str: + """格式化私聊单条 ```` 块。""" + time_str = datetime.fromtimestamp(item.arrival_time).strftime( + "%Y-%m-%d %H:%M:%S" + ) + safe_name = escape_xml_attr(item.sender_name or "未知用户") + safe_uid = escape_xml_attr(item.sender_id) + safe_time = escape_xml_attr(time_str) + safe_text = escape_xml_text(item.text) + message_id_attr = "" + if item.trigger_message_id is not None: + message_id_attr = ( + f' message_id="{escape_xml_attr(item.trigger_message_id)}"' + ) + attachment_xml = ( + f"\n{attachment_refs_to_xml(item.attachments)}" if item.attachments else "" + ) + return ( + f'\n' + f" {safe_text}{attachment_xml}\n" + f" " + ) diff --git a/src/Undefined/services/message_batcher/__init__.py b/src/Undefined/services/message_batcher/__init__.py new file mode 100644 index 00000000..c9195f1f --- /dev/null +++ b/src/Undefined/services/message_batcher/__init__.py @@ -0,0 +1,48 @@ +"""同 sender 短时多消息合并器(MessageBatcher)。 + +核心目标:把同一个 sender 在短时间内连续发出的消息合并到同一轮 AI 调用, +让模型一次看到全部 ```` 块自行决定 "独立请求 / 修正 / 打断", +避免 N 条独立 LLM 调用造成的重复回复或行为打架。 + +时序:每个 (scope, sender_id) 桶内有两条独立的"静默计时器": + +- ``T1 = window_seconds`` —— "打字静默阈值"。静默达到 T1 视为用户写完, + 这一批 batch 结束。 +- ``T2 = pre_send_seconds`` —— "投机预发送阈值",要求严格小于 T1。 + 静默到 T2 时**先把当前 batch 提前发给 LLM 抢时间**(speculative pre-fire), + 但 batch 尚未结束;T1 才决定结束。 + +新消息到来: + +- 若桶处于 ``TYPING``(尚未 pre-fire):append 后重置 T1/T2。 +- 若桶处于 ``SPECULATING``(已 pre-fire,请求已入队或 inflight 在跑): + - 检查 inflight 是否已经 "向用户发出过任何消息" + (来自 ``RequestContext.get_resource("message_sent_this_turn")``)。 + - inflight 尚未发消息 → 调 ``inflight_task.cancel()``,桶回到 TYPING; + 新消息照常 append 到原有 items 后面,T1/T2 重置。 + - inflight 已经发过消息且 ``allow_cancel_after_send=False``(默认安全)→ + 保留旧 batch 让其自然走完,新消息开新 batch(即清空当前桶后立即重新作为首条入桶)。 + - inflight 已经发过消息但开关 = True → 仍 cancel(可能造成重复发送,仅极端场景)。 + +兼容回退:当 ``pre_send_seconds <= 0`` 或 ``>= window_seconds`` 时投机模式关闭, +退化为旧版 "T1 静默到期才发车" 的行为。 +""" + +# 同 sender 短时合并:T1 结束 batch,T2 投机预发送 +from Undefined.services.message_batcher.scheduler import MessageBatcher +from Undefined.services.message_batcher.state import ( + BatchDispatchToken, + BatchPhase, + BufferedMessage, + FlushCallback, + make_scope, +) + +__all__ = [ + "BatchDispatchToken", + "BatchPhase", + "BufferedMessage", + "FlushCallback", + "MessageBatcher", + "make_scope", +] diff --git a/src/Undefined/services/message_batcher/scheduler.py b/src/Undefined/services/message_batcher/scheduler.py new file mode 100644 index 00000000..45b0cf93 --- /dev/null +++ b/src/Undefined/services/message_batcher/scheduler.py @@ -0,0 +1,700 @@ +"""MessageBatcher 调度与 timer 逻辑。""" + +from __future__ import annotations + +# 同 sender 短时合并:T1 结束 batch,T2 投机预发送 + +import asyncio +import logging +import time +from typing import Any + +from Undefined.config.models import MessageBatcherConfig +from Undefined.services.message_batcher.state import ( + BatchDispatchToken, + BatchPhase, + BufferedMessage, + FlushCallback, + _BatchState, + _InflightInfo, +) +from Undefined.utils.coerce import was_message_sent + +logger = logging.getLogger(__name__) + + +class MessageBatcher: + """同 sender 短时合并器(含 T2 投机预发送)。""" + + def __init__( + self, + config: MessageBatcherConfig, + flush_callback: FlushCallback, + ) -> None: + self._config = config + self._flush_callback = flush_callback + self._buckets: dict[tuple[str, int], _BatchState] = {} + self._flush_failure_counts: dict[tuple[str, int], int] = {} + self._lock = asyncio.Lock() + # 持有 timer 触发后创建的 flush task 强引用,避免被 GC(asyncio 文档要求) + self._pending_tasks: set[asyncio.Task[Any]] = set() + self._next_batch_id = 0 + self._shutdown = False + + # ------------------------------------------------------------------ public + + def update_config(self, config: MessageBatcherConfig) -> None: + """配置热更新。""" + self._config = config + logger.info( + "[MessageBatcher] 配置已更新: enabled=%s window=%.2fs pre_send=%.2fs " + "strategy=%s max_window=%.2fs max_messages=%s group=%s private=%s " + "allow_cancel_after_send=%s", + config.enabled, + config.window_seconds, + config.pre_send_seconds, + config.strategy, + config.max_window_seconds, + config.max_messages_per_batch, + config.group_enabled, + config.private_enabled, + config.allow_cancel_after_send, + ) + + @property + def config(self) -> MessageBatcherConfig: + return self._config + + def is_enabled_for(self, *, is_group: bool) -> bool: + cfg = self._config + if not cfg.enabled or cfg.window_seconds <= 0: + return False + return cfg.group_enabled if is_group else cfg.private_enabled + + # 立即触发 batch 发车 + def has_buffer(self, scope: str, sender_id: int) -> bool: + return (scope, sender_id) in self._buckets + + # 立即触发 batch 发车 + async def flush_sender(self, scope: str, sender_id: int) -> bool: + return await self._handle_t1((scope, sender_id), raise_on_failure=False) + + @property + def speculative_enabled(self) -> bool: + # 0 < pre_send < window 时启用 T2 投机预发送 + cfg = self._config + return 0 < cfg.pre_send_seconds < cfg.window_seconds + + # 提交消息进入 (scope,sender) 合并桶并重置 T1/T2 计时器 + async def submit(self, item: BufferedMessage) -> None: + """提交一条消息进入合并桶。 + + 新消息到来时的处理依赖当前桶 ``phase``,详见模块 docstring。 + """ + cfg = self._config + key = (item.scope, item.sender_id) + # 异步路径里只在锁内修改桶;invoke callback 在锁外执行 + immediate_fire_items: list[BufferedMessage] | None = None + + async with self._lock: + if self._shutdown: + logger.info( + "[MessageBatcher] 已进入关停模式,新消息立即发车: scope=%s sender=%s", + item.scope, + item.sender_id, + ) + immediate_fire_items = [item] + else: + now_mono = time.monotonic() + state = self._buckets.get(key) + + # === 阶段 1: 决定本条消息怎么进桶 === + if state is None: + # 全新桶 + state = _BatchState( + phase=BatchPhase.TYPING, + first_arrival_monotonic=now_mono, + dispatch_token=self._new_token(item.scope, item.sender_id), + ) + self._buckets[key] = state + state.items.append(item) + elif state.phase is BatchPhase.SPECULATING: + # 已 pre-fire,决定是否 cancel inflight + inflight = state.inflight + already_sent = ( + was_message_sent(inflight.request_context) + if inflight is not None + else False + ) + allow_cancel = (not already_sent) or cfg.allow_cancel_after_send + + if inflight is not None and allow_cancel: + logger.info( + "[MessageBatcher] 投机调用被新消息抢占取消: scope=%s sender=%s " + "already_sent=%s allow_cancel_after_send=%s", + item.scope, + item.sender_id, + already_sent, + cfg.allow_cancel_after_send, + ) + if state.dispatch_token is not None: + state.dispatch_token.cancel() + inflight.task.cancel() + state.inflight = None + state.phase = BatchPhase.TYPING + # 新消息追加到现有 items 后面 + state.items.append(item) + self._retokenize_locked(state, item.scope, item.sender_id) + elif inflight is None: + # inflight 尚未注册(coordinator 还没进入 execute_reply): + # 1) 若 flush task 仍在跑,先 cancel; + # 2) 若它已经把请求入队,则取消旧 token,execute_reply 入口会跳过旧请求。 + logger.info( + "[MessageBatcher] inflight 未注册,取消投机 token/flush task: " + "scope=%s sender=%s", + item.scope, + item.sender_id, + ) + if state.dispatch_token is not None: + state.dispatch_token.cancel() + if state.speculative_flush_task is not None: + state.speculative_flush_task.cancel() + state.speculative_flush_task = None + state.phase = BatchPhase.TYPING + state.items.append(item) + self._retokenize_locked(state, item.scope, item.sender_id) + else: + # 已发过消息且不允许取消:丢弃当前桶,新消息开新桶 + logger.info( + "[MessageBatcher] 投机调用已发出消息且不允许取消,新消息开新 batch: " + "scope=%s sender=%s", + item.scope, + item.sender_id, + ) + self._cancel_t1(state) + self._cancel_t2(state) + state.phase = BatchPhase.FINALIZING + # 旧桶让 inflight 自然结束;从 _buckets pop 以释放 key 给新 batch + self._buckets.pop(key, None) + # 新桶 + state = _BatchState( + phase=BatchPhase.TYPING, + first_arrival_monotonic=now_mono, + dispatch_token=self._new_token(item.scope, item.sender_id), + ) + self._buckets[key] = state + state.items.append(item) + elif state.phase is BatchPhase.FINALIZING: + # 极少见:T1 已到、inflight 未上报但 task 已不可控;当作新桶处理 + logger.warning( + "[MessageBatcher] 桶处于 FINALIZING 期间收到新消息,开新 batch: " + "scope=%s sender=%s", + item.scope, + item.sender_id, + ) + self._buckets.pop(key, None) + state = _BatchState( + phase=BatchPhase.TYPING, + first_arrival_monotonic=now_mono, + dispatch_token=self._new_token(item.scope, item.sender_id), + ) + self._buckets[key] = state + state.items.append(item) + else: # TYPING:直接 append + state.items.append(item) + + self._bind_items_to_token_locked(state) + + # === 阶段 2: 重置 T1/T2 timer === + self._cancel_t1(state) + self._cancel_t2(state) + + elapsed = now_mono - state.first_arrival_monotonic + unlimited_window = cfg.max_window_seconds <= 0 + remaining_max = ( + float("inf") + if unlimited_window + else cfg.max_window_seconds - elapsed + ) + + # 硬顶:max_messages_per_batch 立即发车(结束 batch) + if ( + cfg.max_messages_per_batch > 0 + and len(state.items) >= cfg.max_messages_per_batch + ): + logger.info( + "[MessageBatcher] 达到 max_messages_per_batch=%s 立即发车: " + "scope=%s sender=%s", + cfg.max_messages_per_batch, + item.scope, + item.sender_id, + ) + immediate_fire_items = self._pop_locked(key) + elif not unlimited_window and remaining_max <= 0: + logger.info( + "[MessageBatcher] 已超 max_window_seconds 硬顶 立即发车: " + "scope=%s sender=%s elapsed=%.2fs", + item.scope, + item.sender_id, + elapsed, + ) + immediate_fire_items = self._pop_locked(key) + else: + # T1 delay + # fixed:从首条到达时刻起算绝对 T1;extend:每条新消息重置为 window_seconds + if cfg.strategy == "fixed": + target = state.first_arrival_monotonic + cfg.window_seconds + t1_delay = max(0.0, target - now_mono) + else: # extend + t1_delay = cfg.window_seconds + if not unlimited_window: + t1_delay = min(t1_delay, remaining_max) + + loop = asyncio.get_running_loop() + state.t1_handle = loop.call_later( + max(0.0, t1_delay), self._on_t1_timer, key + ) + + # T2 delay(仅当投机启用,且本桶尚未 pre-fire 时设置) + if ( + self.speculative_enabled + and state.phase is BatchPhase.TYPING + and cfg.pre_send_seconds < t1_delay + ): + t2_delay = cfg.pre_send_seconds + state.t2_handle = loop.call_later( + max(0.0, t2_delay), self._on_t2_timer, key + ) + logger.debug( + "[MessageBatcher] 缓冲: scope=%s sender=%s count=%s " + "t1=%.2fs t2=%.2fs strategy=%s", + item.scope, + item.sender_id, + len(state.items), + t1_delay, + t2_delay, + cfg.strategy, + ) + else: + logger.debug( + "[MessageBatcher] 缓冲: scope=%s sender=%s count=%s " + "t1=%.2fs strategy=%s phase=%s", + item.scope, + item.sender_id, + len(state.items), + t1_delay, + cfg.strategy, + state.phase.value, + ) + + # 锁外执行 callback + if immediate_fire_items is not None: + success = await self._invoke_callback(immediate_fire_items) + if success: + self._flush_failure_counts.pop(key, None) + else: + await self._restore_items_after_failed_flush( + key, immediate_fire_items, schedule_retry=True + ) + + # ----------------------------------------------------------- inflight API + + def register_inflight( + self, + scope: str, + sender_id: int, + task: asyncio.Task[Any], + request_context: Any = None, + ) -> None: + """coordinator 在 ``execute_reply`` 开头上报 inflight LLM 任务。 + + 如果桶不存在或 phase 不是 SPECULATING,则忽略(说明这次 fire 不是投机的)。 + """ + key = (scope, sender_id) + state = self._buckets.get(key) + if state is None: + return + if state.phase is not BatchPhase.SPECULATING: + return + state.inflight = _InflightInfo(task=task, request_context=request_context) + logger.debug( + "[MessageBatcher] 注册 inflight 任务: scope=%s sender=%s", + scope, + sender_id, + ) + + # 注销 inflight 任务 + def unregister_inflight( + self, scope: str, sender_id: int, task: asyncio.Task[Any] + ) -> None: + """coordinator 在 ``execute_reply`` 结束(含异常/取消)时上报。""" + key = (scope, sender_id) + state = self._buckets.get(key) + if state is None: + return + if state.inflight is not None and state.inflight.task is not task: + logger.debug( + "[MessageBatcher] 忽略过期 inflight 注销: scope=%s sender=%s phase=%s", + scope, + sender_id, + state.phase.value, + ) + return + state.inflight = None + # 若 phase 是 SPECULATING 且 T1 已经 fire 过(FINALIZING 才 unregister), + # 此时 inflight 自然结束 → 桶已经在 _on_t1_timer 中弹出,无需再做事 + # 若仍在 SPECULATING(T1 未到):inflight 已结束但仍可能有新消息进来; + # 保持 SPECULATING,新消息会按 SPECULATING 分支处理(已发消息开新 batch / 未发追加) + logger.debug( + "[MessageBatcher] 注销 inflight 任务: scope=%s sender=%s phase=%s", + scope, + sender_id, + state.phase.value, + ) + + # ---------------------------------------------------------------- timers + + def _cancel_t1(self, state: _BatchState) -> None: + if state.t1_handle is not None: + state.t1_handle.cancel() + state.t1_handle = None + + def _cancel_t2(self, state: _BatchState) -> None: + if state.t2_handle is not None: + state.t2_handle.cancel() + state.t2_handle = None + + def _new_token(self, scope: str, sender_id: int) -> BatchDispatchToken: + self._next_batch_id += 1 + return BatchDispatchToken( + scope=scope, + sender_id=sender_id, + batch_id=self._next_batch_id, + ) + + def _retokenize_locked( + self, state: _BatchState, scope: str, sender_id: int + ) -> None: + state.dispatch_token = self._new_token(scope, sender_id) + self._bind_items_to_token_locked(state) + + @staticmethod + def _bind_items_to_token_locked(state: _BatchState) -> None: + if state.dispatch_token is None: + return + for buffered in state.items: + buffered.batch_token = state.dispatch_token + + def _pop_locked(self, key: tuple[str, int]) -> list[BufferedMessage] | None: + state = self._buckets.pop(key, None) + if state is None or not state.items: + return None + self._cancel_t1(state) + self._cancel_t2(state) + return list(state.items) + + def _on_t1_timer(self, key: tuple[str, int]) -> None: + """T1 静默到期:batch 结束。""" + task = asyncio.create_task(self._handle_t1(key)) + self._pending_tasks.add(task) + task.add_done_callback(self._pending_tasks.discard) + + def _on_t2_timer(self, key: tuple[str, int]) -> None: + """T2 静默到期:投机预发送(pre-fire),但 batch 不结束。""" + task = asyncio.create_task(self._handle_t2(key)) + self._pending_tasks.add(task) + task.add_done_callback(self._pending_tasks.discard) + + async def _handle_t1( + self, key: tuple[str, int], *, raise_on_failure: bool = False + ) -> bool: + items_to_fire: list[BufferedMessage] | None = None + wait_inflight: asyncio.Task[Any] | None = None + wait_prefire: asyncio.Task[Any] | None = None + finalizing_state: _BatchState | None = None + async with self._lock: + state = self._buckets.get(key) + if state is None: + return True + self._cancel_t2(state) + if state.phase is BatchPhase.SPECULATING: + # T1 到了,投机请求已经发出/入队;这里只结束 batch,不能再次发车。 + state.phase = BatchPhase.FINALIZING + finalizing_state = state + if state.inflight is not None: + wait_inflight = state.inflight.task + elif ( + state.speculative_flush_task is not None + and not state.speculative_flush_task.done() + ): + wait_prefire = state.speculative_flush_task + else: + self._buckets.pop(key, None) + logger.debug( + "[MessageBatcher] T1 结束已投机 batch,不重复发车: " + "scope=%s sender=%s", + key[0], + key[1], + ) + else: + # 普通模式或 SPECULATING 但 inflight 已结束:直接 fire + items_to_fire = self._pop_locked(key) + if items_to_fire is not None: + state.phase = BatchPhase.FINALIZING + + wait_task: asyncio.Task[Any] | None = wait_inflight or wait_prefire + if wait_task is not None: + try: + await wait_task + except asyncio.CancelledError: + # inflight/prefire 已被 cancel(极少同时发生),让 cancel 路径自然走 + logger.info( + "[MessageBatcher] T1 等待投机任务时被取消: scope=%s sender=%s", + key[0], + key[1], + ) + except Exception: + logger.exception( + "[MessageBatcher] T1 等待投机任务失败: scope=%s sender=%s", + key[0], + key[1], + ) + finally: + # 仅当桶仍是 finalizing_state(同一对象)时才 pop; + # 否则 submit 已经在 SPECULATING/FINALIZING 分支把旧桶 pop 并建立新桶, + # 不能误删新桶。 + async with self._lock: + current = self._buckets.get(key) + if current is finalizing_state: + self._buckets.pop(key, None) + return True + + if items_to_fire is not None: + success = await self._invoke_callback(items_to_fire, speculative=False) + if success: + self._flush_failure_counts.pop(key, None) + else: + await self._restore_items_after_failed_flush( + key, items_to_fire, schedule_retry=not self._shutdown + ) + if raise_on_failure: + raise RuntimeError("message batcher flush callback failed") + return success + return True + + async def _handle_t2(self, key: tuple[str, int]) -> None: + speculative_items: list[BufferedMessage] | None = None + async with self._lock: + state = self._buckets.get(key) + if state is None: + return + if state.phase is not BatchPhase.TYPING: + return + if not state.items: + return + # 切到 SPECULATING,但**不**清空 items(保留以便后续 T1 也能用 / 抢占回收) + state.phase = BatchPhase.SPECULATING + self._cancel_t2(state) + if state.dispatch_token is None: + state.dispatch_token = self._new_token(key[0], key[1]) + self._bind_items_to_token_locked(state) + state.dispatch_token.speculative = True + # 记录"承担投机职责"的当前 task;此处指向 _handle_t2 协程本身 + # (pre-fire 协程),不是 LLM inflight task。 + # 后续 submit() 抢占判定通过 `state.speculative_flush_task is asyncio.current_task()` + # 区分新旧 pre-fire 协程,避免误清理新 batch。 + state.speculative_flush_task = asyncio.current_task() + speculative_items = list(state.items) + logger.info( + "[MessageBatcher] 投机预发送: scope=%s sender=%s count=%s", + key[0], + key[1], + len(speculative_items), + ) + + if speculative_items is not None: + success = False + try: + success = await self._invoke_callback( + speculative_items, speculative=True + ) + finally: + # 清掉自身引用,避免 state 残留指向已结束 task;若投机 callback + # 异常/取消且桶仍是本次 SPECULATING,则回滚为 TYPING,等待 T1 正常重试。 + async with self._lock: + state2 = self._buckets.get(key) + if ( + state2 is not None + and state2.speculative_flush_task is asyncio.current_task() + ): + state2.speculative_flush_task = None + if state2.phase is BatchPhase.SPECULATING and not success: + if state2.dispatch_token is not None: + state2.dispatch_token.cancel() + state2.phase = BatchPhase.TYPING + self._retokenize_locked(state2, key[0], key[1]) + logger.warning( + "[MessageBatcher] 投机预发送失败,回滚等待 T1 重试: " + "scope=%s sender=%s", + key[0], + key[1], + ) + + async def _invoke_callback( + self, + items: list[BufferedMessage], + *, + speculative: bool = False, + ) -> bool: + if not items: + return True + first = items[0] + logger.info( + "[MessageBatcher] 发车: scope=%s sender=%s count=%s speculative=%s", + first.scope, + first.sender_id, + len(items), + speculative, + ) + try: + await self._flush_callback(items) + return True + except asyncio.CancelledError: + # 投机被新消息取消是预期行为 + logger.info( + "[MessageBatcher] flush_callback 被取消(投机抢占): " + "scope=%s sender=%s speculative=%s", + first.scope, + first.sender_id, + speculative, + ) + return False + except Exception: + logger.exception( + "[MessageBatcher] flush_callback 异常: scope=%s sender=%s count=%s", + first.scope, + first.sender_id, + len(items), + ) + return False + + async def _restore_items_after_failed_flush( + self, + key: tuple[str, int], + items: list[BufferedMessage], + *, + schedule_retry: bool, + ) -> None: + """flush callback 失败后回滚到 TYPING 阶段。 + + 重试策略(fail-fast): + - 每次失败累加 ``self._flush_failure_counts[key]``; + - 仅在 ``failure_count <= 1``(即首次失败)时安排一次延后 T1 重试; + - 第二次起仅恢复 batch、等待用户新消息或 ``flush_all`` 触发, + 避免 LLM 端持续故障时形成"无限重试风暴"; + - 桶在成功一次后 ``failure_count`` 会被 pop 清零。 + - ``flush_all`` 路径会 raise,从而暴露持续失败。 + """ + if not items: + return + async with self._lock: + state = self._buckets.get(key) + if state is None: + state = _BatchState( + phase=BatchPhase.TYPING, + first_arrival_monotonic=time.monotonic(), + dispatch_token=self._new_token(key[0], key[1]), + ) + self._buckets[key] = state + state.items = list(items) + else: + self._cancel_t1(state) + self._cancel_t2(state) + state.phase = BatchPhase.TYPING + state.items = list(items) + state.items + state.first_arrival_monotonic = time.monotonic() + state.inflight = None + if state.dispatch_token is not None: + state.dispatch_token.cancel() + self._retokenize_locked(state, key[0], key[1]) + logger.warning( + "[MessageBatcher] flush 失败,已恢复 batch: scope=%s sender=%s count=%s", + key[0], + key[1], + len(state.items), + ) + failure_count = self._flush_failure_counts.get(key, 0) + 1 + self._flush_failure_counts[key] = failure_count + if schedule_retry and not self._shutdown and failure_count <= 1: + loop = asyncio.get_running_loop() + delay = max(0.0, self._config.window_seconds) + state.t1_handle = loop.call_later(delay, self._on_t1_timer, key) + + # ------------------------------------------------------------ shutdown + + async def flush_all(self) -> None: + """立即 flush 所有 buckets(用于关停)。 + + 关停时直接对所有桶执行 T1 等价路径并等 inflight 收尾。 + """ + while True: + async with self._lock: + self._shutdown = True + keys = list(self._buckets.keys()) + if not keys: + break + logger.info("[MessageBatcher] flush_all: pending_buckets=%s", len(keys)) + for key in keys: + await self._handle_t1(key, raise_on_failure=True) + # 等 timer 已触发但回调仍在跑的 task + pending = [t for t in self._pending_tasks if not t.done()] + if pending: + logger.info( + "[MessageBatcher] flush_all: 等待 %s 个 in-flight flush task", + len(pending), + ) + await asyncio.gather(*pending, return_exceptions=True) + + # ------------------------------------------------------------- snapshot + + def snapshot(self) -> dict[str, Any]: + """返回当前 buckets 状态的非阻塞快照(供 Runtime API / WebUI 展示)。""" + cfg = self._config + now_mono = time.monotonic() + buckets: list[dict[str, Any]] = [] + for (scope, sender_id), state in list(self._buckets.items()): + buckets.append( + { + "scope": scope, + "sender_id": sender_id, + "count": len(state.items), + "elapsed_seconds": round( + max(0.0, now_mono - state.first_arrival_monotonic), 2 + ), + "phase": state.phase.value, + "has_inflight": state.inflight is not None, + "has_speculative_dispatch": ( + state.dispatch_token is not None + and state.dispatch_token.speculative + and not state.dispatch_token.cancelled + ), + } + ) + return { + "config": { + "enabled": cfg.enabled, + "window_seconds": cfg.window_seconds, + "pre_send_seconds": cfg.pre_send_seconds, + "speculative_enabled": self.speculative_enabled, + "strategy": cfg.strategy, + "max_window_seconds": cfg.max_window_seconds, + "max_messages_per_batch": cfg.max_messages_per_batch, + "group_enabled": cfg.group_enabled, + "private_enabled": cfg.private_enabled, + "flush_on_command": cfg.flush_on_command, + "allow_cancel_after_send": cfg.allow_cancel_after_send, + "shutdown": self._shutdown, + }, + "pending_buckets": len(buckets), + "buckets": buckets, + } diff --git a/src/Undefined/services/message_batcher/state.py b/src/Undefined/services/message_batcher/state.py new file mode 100644 index 00000000..6dbea7c4 --- /dev/null +++ b/src/Undefined/services/message_batcher/state.py @@ -0,0 +1,106 @@ +"""MessageBatcher 数据模型与 scope 工具。""" + +from __future__ import annotations + +# 同 sender 短时合并:T1 结束 batch,T2 投机预发送 + +import asyncio +import enum +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable + + +@dataclass +class BatchDispatchToken: + """一次 batch 发车的身份令牌,用于取消已入队但尚未执行的投机请求。""" + + scope: str + sender_id: int + batch_id: int + speculative: bool = False + cancelled: bool = False + + def cancel(self) -> None: + self.cancelled = True + + +@dataclass +class BufferedMessage: + """缓冲中的单条消息上下文。""" + + scope: str + sender_id: int + text: str + message_content: list[dict[str, Any]] + attachments: list[dict[str, str]] + sender_name: str + arrival_time: float + is_private: bool + trigger_message_id: int | None = None + is_poke: bool = False + is_at_bot: bool = False + is_fake_at: bool = False + # 群聊扩展字段 + group_id: int | None = None + group_name: str = "" + sender_role: str = "member" + sender_title: str = "" + sender_level: str = "" + batch_token: BatchDispatchToken | None = None + + +FlushCallback = Callable[[list[BufferedMessage]], Awaitable[None]] +"""``flush_callback(items)``:batcher 决定 fire 时调用,调用方负责拼装 prompt 并入队执行。 + +调用约定: +- batcher 的 ``flush_callback`` **不应** 立即 await LLM 的完成, + 而是把请求扔进 QueueManager 后立即返回,真正的 LLM 任务由 coordinator 在 ``execute_reply`` + 开头调用 :meth:`MessageBatcher.register_inflight` 上报。 +- 若需要 batcher 关停时也等待 in-flight 收尾,由 :meth:`MessageBatcher.flush_all` 处理。 +""" + + +class BatchPhase(enum.Enum): + """桶状态机:TYPING → SPECULATING(可选) → FINALIZING。""" + + TYPING = "typing" # 等待 T1/T2 静默,用户仍在输入 + SPECULATING = "speculating" # T2 已触发投机 pre-fire,T1 未到,batch 未结束 + FINALIZING = "finalizing" # T1 已到,等待 inflight 自然结束后再释放桶 + + +@dataclass +class _InflightInfo: + """inflight LLM 任务关联信息,由 coordinator 通过 ``register_inflight`` 上报。""" + + task: asyncio.Task[Any] + # ``RequestContext`` 引用,用于判断 ``message_sent_this_turn`` 资源 + request_context: Any = None + + +@dataclass +class _BatchState: + """单个 (scope, sender_id) 桶的状态。""" + + phase: BatchPhase = BatchPhase.TYPING + items: list[BufferedMessage] = field(default_factory=list) + first_arrival_monotonic: float = 0.0 + # T1 = window_seconds 静默 timer(决定 batch 结束) + t1_handle: asyncio.TimerHandle | None = None + # T2 = pre_send_seconds 静默 timer(决定 pre-fire);投机关闭时为 None + t2_handle: asyncio.TimerHandle | None = None + # SPECULATING 阶段记录 inflight LLM 任务(由 coordinator 通过 register_inflight 注入) + inflight: _InflightInfo | None = None + # T2 fire 时由 batcher 创建的 flush task;inflight 还未上报前用于兜底取消 + speculative_flush_task: asyncio.Task[Any] | None = None + # 当前 batch 的身份令牌;T2 入队后若又来新消息,可将旧 token 标记取消, + # coordinator 在真正执行前会跳过它。 + dispatch_token: BatchDispatchToken | None = None + + +def make_scope(*, group_id: int | None = None, user_id: int | None = None) -> str: + """构造合并 key 的 scope 字符串。""" + if group_id and group_id > 0: + return f"group:{group_id}" + if user_id is not None: + return f"private:{user_id}" + return "unknown" diff --git a/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py b/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py index aabb03ab..81ab2233 100644 --- a/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py +++ b/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py @@ -39,6 +39,7 @@ async def _read_single_file( async with aiofiles.open( full_path, "r", encoding="utf-8", errors="replace" ) as f: + # async for 循环 async for _ in f: line_count += 1 diff --git a/src/Undefined/skills/agents/runner/__init__.py b/src/Undefined/skills/agents/runner/__init__.py new file mode 100644 index 00000000..09676932 --- /dev/null +++ b/src/Undefined/skills/agents/runner/__init__.py @@ -0,0 +1,12 @@ +"""Agent runner 子包:context 准备、工具执行与 LLM 迭代循环。""" + +# 对外 re-export,兼容 `from Undefined.skills.agents.runner import run_agent_with_tools` +from Undefined.skills.agents.runner.context import load_prompt_text +from Undefined.skills.agents.runner.loop import run_agent_with_tools +from Undefined.skills.agents.runner.tools import _filter_tools_for_runtime_config + +__all__ = [ + "load_prompt_text", + "run_agent_with_tools", + "_filter_tools_for_runtime_config", +] diff --git a/src/Undefined/skills/agents/runner/context.py b/src/Undefined/skills/agents/runner/context.py new file mode 100644 index 00000000..3e4e399e --- /dev/null +++ b/src/Undefined/skills/agents/runner/context.py @@ -0,0 +1,133 @@ +# Agent 运行前上下文准备(工具注册表、模型、消息链) +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import aiofiles + +from Undefined.config.models import AgentModelConfig +from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry +from Undefined.skills.anthropic_skills import AnthropicSkillRegistry + +from Undefined.skills.agents.runner.tools import _filter_tools_for_runtime_config + + +# 异步读取 agent 目录下的 prompt.md +async def load_prompt_text(agent_dir: Path, default_prompt: str) -> str: + """从 agent 目录加载 prompt.md,缺失时返回默认提示词。""" + + prompt_path = agent_dir / "prompt.md" + if prompt_path.exists(): + async with aiofiles.open(prompt_path, "r", encoding="utf-8") as file: + return await file.read() + return default_prompt + + +@dataclass +# 类:PreparedAgentRun +class PreparedAgentRun: + tool_registry: AgentToolRegistry + agent_skill_registry: AnthropicSkillRegistry | None + tools: list[dict[str, Any]] + agent_config: AgentModelConfig + messages: list[dict[str, Any]] + ai_client: Any + queue_lane: Any + max_pre_tool_retries: int + + +# 准备 Agent 运行上下文:工具、模型、消息链 +async def prepare_agent_run( + *, + agent_name: str, + user_content: str, + context_messages: list[dict[str, str]] | None, + default_prompt: str, + context: dict[str, Any], + agent_dir: Path, + logger: Any, +) -> PreparedAgentRun | str: + # 为当前 Agent 实例化私有工具注册表(含 callable agent 扫描) + tool_registry = AgentToolRegistry( + agent_dir / "tools", + current_agent_name=agent_name, + is_main_agent=False, + ) + tools = tool_registry.get_tools_schema() + runtime_config = context.get("runtime_config") + tools = _filter_tools_for_runtime_config(agent_name, tools, runtime_config) + + agent_skills_dir = agent_dir / "anthropic_skills" + agent_skill_registry: AnthropicSkillRegistry | None = None + # 可选:加载 Agent 目录下的 Anthropic Skills 并追加 tool schema + if agent_skills_dir.exists() and agent_skills_dir.is_dir(): + agent_skill_registry = AnthropicSkillRegistry(agent_skills_dir) + if agent_skill_registry.has_skills(): + tools = tools + agent_skill_registry.get_tools_schema() + logger.info( + "[Agent:%s] 加载了 %d 个私有 Anthropic Skills", + agent_name, + len(agent_skill_registry.get_all_skills()), + ) + + ai_client = context.get("ai_client") + if not ai_client: + return "AI client 未在上下文中提供" + + model_config_override = context.get("model_config_override") + if isinstance(model_config_override, AgentModelConfig): + agent_config = model_config_override + else: + agent_config = ai_client.agent_config + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + global_enabled = runtime_config.model_pool_enabled if runtime_config else False + # 多模型池:按群/私聊上下文选择 Agent 专用模型配置 + agent_config = ai_client.model_selector.select_agent_config( + agent_config, + group_id=group_id, + user_id=user_id, + global_enabled=global_enabled, + ) + system_prompt = await load_prompt_text(agent_dir, default_prompt) + + if agent_skill_registry and agent_skill_registry.has_skills(): + skills_xml = agent_skill_registry.build_metadata_xml() + if skills_xml: + system_prompt = ( + f"{system_prompt}\n\n" + f"【可用的 Anthropic Skills】\n" + f"{skills_xml}\n\n" + f"注意:以上是你可用的 Anthropic Agent Skills。" + f"当任务与某个 skill 相关时," + f"可以调用对应的 skill tool(tool_name 字段)" + f"来获取该领域的详细指令和知识。" + ) + + agent_history = context.get("agent_history", []) + + # 组装 LLM 消息链:system → agent 历史 → 上下文 → 当前用户输入 + messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] + if agent_history: + messages.extend(agent_history) + if context_messages: + messages.extend(context_messages) + messages.append({"role": "user", "content": user_content}) + + queue_lane = context.get("queue_lane") + max_pre_tool_retries = max( + 0, int(getattr(runtime_config, "ai_request_max_retries", 0) or 0) + ) + + return PreparedAgentRun( + tool_registry=tool_registry, + agent_skill_registry=agent_skill_registry, + tools=tools, + agent_config=agent_config, + messages=messages, + ai_client=ai_client, + queue_lane=queue_lane, + max_pre_tool_retries=max_pre_tool_retries, + ) diff --git a/src/Undefined/skills/agents/runner/loop.py b/src/Undefined/skills/agents/runner/loop.py new file mode 100644 index 00000000..bc215e8b --- /dev/null +++ b/src/Undefined/skills/agents/runner/loop.py @@ -0,0 +1,182 @@ +# Agent LLM↔工具迭代循环核心 +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY +from Undefined.skills.agents.runner.context import prepare_agent_run +from Undefined.skills.agents.runner.tools import execute_assistant_tool_calls + + +# Agent 主循环:LLM 决策 → 工具执行 → 结果回填 +async def run_agent_with_tools( + *, + agent_name: str, + user_content: str, + context_messages: list[dict[str, str]] | None = None, + empty_user_content_message: str, + default_prompt: str, + context: dict[str, Any], + agent_dir: Path, + logger: logging.Logger, + max_iterations: int = 20, + tool_error_prefix: str = "错误", +) -> str: + """执行通用 Agent 循环。 + + 该方法统一处理: + - prompt 加载 + - LLM 迭代决策 + - tool call 并发执行 + - tool 结果回填 messages + """ + + # 空输入直接返回提示 + if not user_content.strip(): + return empty_user_content_message + + prepared = await prepare_agent_run( + agent_name=agent_name, + user_content=user_content, + context_messages=context_messages, + default_prompt=default_prompt, + context=context, + agent_dir=agent_dir, + logger=logger, + ) + # prepare 失败时 prepared 为错误字符串 + if isinstance(prepared, str): + return prepared + + messages = prepared.messages + transport_state: dict[str, Any] | None = None + pre_tool_failure_count = 0 + + # Agent 主循环:LLM 决策 → 工具执行 → 结果回填,直到无 tool_calls + # 迭代上限防止无限 tool 循环 + for iteration in range(1, max_iterations + 1): + logger.debug("[Agent:%s] iteration=%s", agent_name, iteration) + # 记录 checkpoint,pre-tool 失败时可回滚 messages / transport_state + message_checkpoint_len = len(messages) + transport_state_checkpoint = transport_state + try: + # 通过队列提交 LLM 请求(含 tools 与 transport 多轮状态) + result = await prepared.ai_client.submit_queued_llm_call( + model_config=prepared.agent_config, + messages=messages, + max_tokens=prepared.agent_config.max_tokens, + call_type=f"agent:{agent_name}", + tools=prepared.tools if prepared.tools else None, + tool_choice="auto", + transport_state=transport_state, + queue_lane=prepared.queue_lane, + ) + except Exception as exc: + logger.exception( + "[Agent:%s] queued LLM 调用失败: lane=%s iteration=%s error=%s", + agent_name, + prepared.queue_lane, + iteration, + exc, + ) + raise RuntimeError("智能体模型请求失败") from exc + + try: + tool_execution_started = False + tool_name_map = ( + result.get("_tool_name_map") if isinstance(result, dict) else None + ) + # API 工具名与内部 registry 名称的映射(含 dot 分隔符转换) + api_to_internal: dict[str, str] = {} + if isinstance(tool_name_map, dict): + raw_api_to_internal = tool_name_map.get("api_to_internal") + if isinstance(raw_api_to_internal, dict): + api_to_internal = { + str(key): str(value) + for key, value in raw_api_to_internal.items() + } + + next_transport_state = ( + result.get("_transport_state") if isinstance(result, dict) else None + ) + # Responses API 等多轮 transport 状态,下一轮 LLM 调用需回传 + transport_state = ( + next_transport_state if isinstance(next_transport_state, dict) else None + ) + + choice: dict[str, Any] = result.get("choices", [{}])[0] + message: dict[str, Any] = choice.get("message", {}) + content: str = message.get("content") or "" + reasoning_content: str | None = message.get("reasoning_content") + tool_calls: list[dict[str, Any]] = message.get("tool_calls", []) + + # 模型同时返回文本与工具调用时,优先走工具路径 + if content.strip() and tool_calls: + content = "" + + # 无工具调用即视为最终回复 + if not tool_calls: + return content + + # 将 assistant 消息(含 tool_calls)追加到对话历史 + assistant_message: dict[str, Any] = { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) + if isinstance(output_items, list): + assistant_message[RESPONSES_OUTPUT_ITEMS_KEY] = output_items + # 部分模型需回放 reasoning_content 以兼容 thinking + tool_call + capture_reasoning = bool( + getattr(prepared.agent_config, "thinking_tool_call_compat", False) + ) or bool(getattr(prepared.agent_config, "reasoning_content_replay", False)) + if capture_reasoning and reasoning_content is not None: + assistant_message["reasoning_content"] = reasoning_content + messages.append(assistant_message) + + # 并发执行 tool_calls,结果以 role=tool 消息回填 + tool_execution_started = await execute_assistant_tool_calls( + agent_name=agent_name, + tool_calls=tool_calls, + api_to_internal=api_to_internal, + messages=messages, + tool_registry=prepared.tool_registry, + agent_skill_registry=prepared.agent_skill_registry, + context=context, + logger=logger, + tool_error_prefix=tool_error_prefix, + ) + pre_tool_failure_count = 0 + + except Exception as exc: + # pre-tool 本地异常:在未开始执行工具前可重试当前 LLM 轮次 + if ( + not tool_execution_started + and pre_tool_failure_count < prepared.max_pre_tool_retries + ): + pre_tool_failure_count += 1 + del messages[message_checkpoint_len:] + transport_state = transport_state_checkpoint + logger.warning( + "[Agent:%s] pre-tool 本地失败,重试当前 LLM 轮次: lane=%s retry=%s/%s iteration=%s error=%s", + agent_name, + prepared.queue_lane, + pre_tool_failure_count, + prepared.max_pre_tool_retries, + iteration, + exc, + ) + continue + logger.exception( + "[Agent:%s] 执行失败,已静默抑制: lane=%s iteration=%s error=%s", + agent_name, + prepared.queue_lane, + iteration, + exc, + ) + return "" + + return "达到最大迭代次数" diff --git a/src/Undefined/skills/agents/runner/tools.py b/src/Undefined/skills/agents/runner/tools.py new file mode 100644 index 00000000..31d79177 --- /dev/null +++ b/src/Undefined/skills/agents/runner/tools.py @@ -0,0 +1,179 @@ +# Agent 工具并发调度与 end 工具特殊处理 +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT +from Undefined.skills.anthropic_skills import AnthropicSkillRegistry +from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry +from Undefined.utils.tool_calls import parse_tool_arguments + + +# 按运行时配置过滤不可用工具 schema +def _filter_tools_for_runtime_config( + agent_name: str, + tools: list[dict[str, Any]], + runtime_config: Any | None, +) -> list[dict[str, Any]]: + # web_agent 在 grok 未启用时从 schema 中剔除 grok_search + if agent_name != "web_agent" or runtime_config is None: + return tools + + if bool(getattr(runtime_config, "grok_search_enabled", False)): + return tools + + filtered: list[dict[str, Any]] = [] + for tool in tools: + function = tool.get("function") if isinstance(tool, dict) else None + name = function.get("name") if isinstance(function, dict) else None + if name == "grok_search": + continue + filtered.append(tool) + return filtered + + +# 并发执行 tool_calls 并回填 tool 消息 +async def execute_assistant_tool_calls( + *, + agent_name: str, + tool_calls: list[dict[str, Any]], + api_to_internal: dict[str, str], + messages: list[dict[str, Any]], + tool_registry: AgentToolRegistry, + agent_skill_registry: AnthropicSkillRegistry | None, + context: dict[str, Any], + logger: logging.Logger, + tool_error_prefix: str, +) -> bool: + """并发执行 assistant 的 tool_calls,回填 tool 消息。返回是否已开始工具执行。""" + + tool_tasks: list[asyncio.Future[Any]] = [] + tool_call_ids: list[str] = [] + tool_api_names: list[str] = [] + end_tool_call: dict[str, Any] | None = None + end_tool_args: dict[str, Any] = {} + tool_execution_started = False + + for tool_call in tool_calls: + call_id = str(tool_call.get("id", "")) + function: dict[str, Any] = tool_call.get("function", {}) + api_function_name = str(function.get("name", "")) + raw_args = function.get("arguments") + + internal_function_name = api_to_internal.get( + api_function_name, api_function_name + ) + logger.info( + "[Agent:%s] preparing tool=%s", + agent_name, + internal_function_name, + ) + + function_args = parse_tool_arguments( + raw_args, + logger=logger, + tool_name=api_function_name, + ) + + if not isinstance(function_args, dict): + function_args = {} + + # end 工具延后处理:若与其他工具同批调用则返回拒绝 + if internal_function_name == "end": + if len(tool_calls) > 1: + logger.warning( + "[Agent:%s] end 与其他工具同时调用," + "将先执行其他工具,end 将返回拒绝结果", + agent_name, + ) + end_tool_call = tool_call + end_tool_args = function_args + continue + + tool_call_ids.append(call_id) + tool_api_names.append(api_function_name) + + skill_delimiter = ( + agent_skill_registry.dot_delimiter if agent_skill_registry else "-_-" + ) + # Anthropic Skill 走独立 registry,其余走 AgentToolRegistry + is_agent_skill = internal_function_name.startswith(f"skills{skill_delimiter}") + if is_agent_skill and agent_skill_registry: + tool_tasks.append( + asyncio.ensure_future( + agent_skill_registry.execute_skill_tool( + internal_function_name, + function_args, + context, + ) + ) + ) + else: + tool_tasks.append( + asyncio.ensure_future( + tool_registry.execute_tool( + internal_function_name, + function_args, + context, + ) + ) + ) + + if tool_tasks: + tool_execution_started = True + logger.info( + "[Agent:%s] executing tools in parallel: count=%s", + agent_name, + len(tool_tasks), + ) + # 同轮 tool_calls 并发执行,异常转为 tool 消息内容 + results = await asyncio.gather(*tool_tasks, return_exceptions=True) + + for index, tool_result in enumerate(results): + call_id = tool_call_ids[index] + api_tool_name = tool_api_names[index] + if isinstance(tool_result, Exception): + content_str = f"{tool_error_prefix}: {tool_result}" + else: + content_str = str(tool_result) + + messages.append( + { + "role": "tool", + "tool_call_id": call_id, + "name": api_tool_name, + "content": content_str, + } + ) + + if end_tool_call: + end_call_id = str(end_tool_call.get("id", "")) + end_api_name = end_tool_call.get("function", {}).get("name", "end") + if tool_tasks: + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": END_CO_CALL_REJECT_CONTENT, + } + ) + logger.info( + "[Agent:%s] end 与其他工具同时调用,其它工具已执行,end 已回填拒绝响应", + agent_name, + ) + else: + tool_execution_started = True + end_result = await tool_registry.execute_tool("end", end_tool_args, context) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": str(end_result), + } + ) + + return tool_execution_started diff --git a/src/Undefined/skills/http_client.py b/src/Undefined/skills/http_client.py index ca6875c0..30ce813f 100644 --- a/src/Undefined/skills/http_client.py +++ b/src/Undefined/skills/http_client.py @@ -1,29 +1,142 @@ +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 性能敏感路径避免重复 IO 与网络请求 +# http_client 模块实现细节说明 +# 本段注释用于描述 http_client 的核心职责与边界 +# 调用方应通过公开 API 访问 http_client 能力 +# 异常路径统一记录日志并向上抛出或返回错误码 +# 配置项变更需热重载或重启后生效 +# 线程/协程安全:共享状态需加锁或使用单例 +# 导入 from __future__ import annotations +# 导入 import asyncio + +# 导入 import json + +# 导入 import logging + +# 导入 from typing import Any +# 导入 import httpx +# 导入 from Undefined.skills.http_config import ( get_request_proxy, get_request_retries, get_request_timeout, ) +# 赋值 logger = logging.getLogger(__name__) +# 函数 _should_retry_http_status def _should_retry_http_status(status_code: int) -> bool: + # 返回 return status_code == 429 or 500 <= status_code < 600 +# 函数 _retry_delay def _retry_delay(attempt: int) -> float: + # 返回 return float(min(2.0, 0.25 * (2**attempt))) +# 异步函数 request_with_retry async def request_with_retry( method: str, url: str, @@ -39,27 +152,41 @@ async def request_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> httpx.Response: + # 赋值 request_timeout = ( timeout if timeout is not None else get_request_timeout(default_timeout) ) + # 赋值 request_retries = retries if retries is not None else get_request_retries(0) + # 赋值 request_proxy = get_request_proxy(url) + # 赋值 request_id = "-" + # 条件分支 if context is not None: + # 赋值 request_id = str(context.get("request_id", "-")) + # 注解赋值 last_exception: Exception | None = None + # 注解赋值 client_kwargs: dict[str, Any] = { "timeout": request_timeout, "follow_redirects": follow_redirects, "trust_env": False, } + # 条件分支 if request_proxy is not None: + # 赋值 client_kwargs["proxy"] = request_proxy + # async with 上下文 async with httpx.AsyncClient(**client_kwargs) as client: + # for 循环 for attempt in range(request_retries + 1): + # try 块 try: + # 赋值 response = await client.request( method=method, url=url, @@ -69,11 +196,14 @@ async def request_with_retry( files=files, headers=headers, ) + # 条件分支 if ( _should_retry_http_status(response.status_code) and attempt < request_retries ): + # 赋值 delay = _retry_delay(attempt) + # 表达式 logger.warning( "[HTTP] status retry: method=%s url=%s status=%s attempt=%s/%s wait=%.2fs request_id=%s", method, @@ -84,18 +214,29 @@ async def request_with_retry( delay, request_id, ) + # 表达式 await asyncio.sleep(delay) + # continue continue + # 表达式 response.raise_for_status() + # 返回 return response except httpx.HTTPStatusError as exc: + # 赋值 last_exception = exc + # 条件分支 if attempt >= request_retries: + # break break + # 条件分支 if not _should_retry_http_status(exc.response.status_code): + # break break + # 赋值 delay = _retry_delay(attempt) + # 表达式 logger.warning( "[HTTP] status exception retry: method=%s url=%s status=%s attempt=%s/%s wait=%.2fs request_id=%s", method, @@ -106,12 +247,18 @@ async def request_with_retry( delay, request_id, ) + # 表达式 await asyncio.sleep(delay) except (httpx.TimeoutException, httpx.RequestError) as exc: + # 赋值 last_exception = exc + # 条件分支 if attempt >= request_retries: + # break break + # 赋值 delay = _retry_delay(attempt) + # 表达式 logger.warning( "[HTTP] request retry: method=%s url=%s err=%s attempt=%s/%s wait=%.2fs request_id=%s", method, @@ -122,13 +269,18 @@ async def request_with_retry( delay, request_id, ) + # 表达式 await asyncio.sleep(delay) + # 条件分支 if last_exception is not None: + # 抛出异常 raise last_exception + # 抛出异常 raise RuntimeError(f"HTTP request failed without exception: {method} {url}") +# 异步函数 get_json_with_retry async def get_json_with_retry( url: str, *, @@ -139,10 +291,14 @@ async def get_json_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> Any: + # 赋值 request_id = "-" + # 条件分支 if context is not None: + # 赋值 request_id = str(context.get("request_id", "-")) + # 赋值 response = await request_with_retry( "GET", url, @@ -153,11 +309,16 @@ async def get_json_with_retry( context=context, retries=retries, ) + # try 块 try: + # 返回 return response.json() except json.JSONDecodeError as exc: + # 赋值 content_type = response.headers.get("content-type", "") + # 赋值 preview = response.text[:200].replace("\n", "\\n").replace("\r", "\\r") + # 表达式 logger.warning( "[HTTP] json decode failed: url=%s status=%s content_type=%s preview=%s request_id=%s err=%s", url, @@ -167,9 +328,11 @@ async def get_json_with_retry( request_id, exc, ) + # 抛出异常 raise +# 异步函数 get_text_with_retry async def get_text_with_retry( url: str, *, @@ -180,6 +343,7 @@ async def get_text_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> str: + # 赋值 response = await request_with_retry( "GET", url, @@ -190,9 +354,11 @@ async def get_text_with_retry( context=context, retries=retries, ) + # 返回 return response.text +# 异步函数 get_bytes_with_retry async def get_bytes_with_retry( url: str, *, @@ -203,6 +369,7 @@ async def get_bytes_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> bytes: + # 赋值 response = await request_with_retry( "GET", url, @@ -213,4 +380,8 @@ async def get_bytes_with_retry( context=context, retries=retries, ) + # 返回 return response.content + + +# 文档注释 1/1 diff --git a/src/Undefined/skills/http_config.py b/src/Undefined/skills/http_config.py index 20ca935a..fde119f7 100644 --- a/src/Undefined/skills/http_config.py +++ b/src/Undefined/skills/http_config.py @@ -1,66 +1,109 @@ +# 导入 from __future__ import annotations +# 导入 from urllib.parse import urlsplit +# 导入 from Undefined.config import get_config +# 函数 _normalize_base_url def _normalize_base_url(value: str, fallback: str) -> str: + # 赋值 base_url = value.strip().rstrip("/") + # 返回 return base_url or fallback.rstrip("/") +# 函数 build_url def build_url(base_url: str, path: str) -> str: + # 赋值 normalized_path = path if path.startswith("/") else f"/{path}" + # 返回 return f"{base_url.rstrip('/')}{normalized_path}" +# 函数 get_request_timeout def get_request_timeout(default_timeout: float = 480.0) -> float: + # 赋值 config = get_config(strict=False) + # 赋值 timeout = float(config.network_request_timeout) + # 返回 return timeout if timeout > 0 else default_timeout +# 函数 get_request_retries def get_request_retries(default_retries: int = 0) -> int: + # 赋值 config = get_config(strict=False) + # 赋值 retries = int(config.network_request_retries) + # 条件分支 if retries < 0: + # 返回 return default_retries + # 返回 return retries +# 函数 get_request_proxy def get_request_proxy(url: str) -> str | None: + # 赋值 config = get_config(strict=False) + # 条件分支 if not bool(getattr(config, "use_proxy", False)): + # 返回 return None + # 赋值 http_proxy = str(getattr(config, "http_proxy", "") or "").strip() + # 赋值 https_proxy = str(getattr(config, "https_proxy", "") or "").strip() + # 赋值 scheme = urlsplit(url).scheme.lower() + # 条件分支 if scheme == "https": + # 返回 return https_proxy or http_proxy or None + # 条件分支 if scheme == "http": + # 返回 return http_proxy or https_proxy or None + # 返回 return https_proxy or http_proxy or None +# 函数 get_xxapi_url def get_xxapi_url(path: str) -> str: + # 赋值 config = get_config(strict=False) + # 赋值 base_url = _normalize_base_url(config.api_xxapi_base_url, "https://v2.xxapi.cn") + # 返回 return build_url(base_url, path) +# 函数 get_xingzhige_url def get_xingzhige_url(path: str) -> str: + # 赋值 config = get_config(strict=False) + # 赋值 base_url = _normalize_base_url( config.api_xingzhige_base_url, "https://api.xingzhige.com", ) + # 返回 return build_url(base_url, path) +# 函数 get_jkyai_url def get_jkyai_url(path: str) -> str: + # 赋值 config = get_config(strict=False) + # 赋值 base_url = _normalize_base_url(config.api_jkyai_base_url, "https://api.jkyai.top") + # 返回 return build_url(base_url, path) diff --git a/src/Undefined/skills/tools/__init__.py b/src/Undefined/skills/tools/__init__.py index 2f23af88..4c316330 100644 --- a/src/Undefined/skills/tools/__init__.py +++ b/src/Undefined/skills/tools/__init__.py @@ -114,13 +114,9 @@ def _log_tools_summary(self, include_mcp: bool = True) -> None: """ tool_names = list(self._items.keys()) - # 分类工具 basic_tools, toolset_tools, mcp_tools = self._categorize_tools(tool_names) - - # 按类别分组工具集工具 toolset_by_category = self._group_toolsets_by_category(toolset_tools) - # 输出统计信息 logger.info("=" * 60) if include_mcp: logger.info("工具加载完成统计 (包含 MCP)") diff --git a/src/Undefined/skills/tools/bilibili_video/handler.py b/src/Undefined/skills/tools/bilibili_video/handler.py index 2ff07a8e..2650f9ac 100644 --- a/src/Undefined/skills/tools/bilibili_video/handler.py +++ b/src/Undefined/skills/tools/bilibili_video/handler.py @@ -23,7 +23,6 @@ def _resolve_target( return None, "target_id 必须是整数" return (target_type, target_id), None # type: ignore[return-value] - # 从上下文推断 request_type = context.get("request_type") if request_type == "group": group_id = context.get("group_id") @@ -34,7 +33,6 @@ def _resolve_target( if user_id: return ("private", int(user_id)), None - # 兜底 group_id = context.get("group_id") if group_id: return ("group", int(group_id)), None @@ -51,13 +49,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not video_id: return "video_id 不能为空" - # 解析目标 target, error = _resolve_target(args, context) if error or target is None: return f"目标解析失败: {error or '参数错误'}" target_type, target_id = target - # 获取配置 runtime_config = context.get("runtime_config") sender = context.get("sender") onebot = context.get("onebot_client") or context.get("onebot") @@ -67,7 +63,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not sender or not onebot: return "缺少必要的运行时组件(sender/onebot)" - # 读取 bilibili 配置 cookie = "" prefer_quality = 80 max_duration = 600 diff --git a/src/Undefined/skills/tools/fetch_image_uid/handler.py b/src/Undefined/skills/tools/fetch_image_uid/handler.py index 7986c811..cc84bbcc 100644 --- a/src/Undefined/skills/tools/fetch_image_uid/handler.py +++ b/src/Undefined/skills/tools/fetch_image_uid/handler.py @@ -38,7 +38,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception("fetch_image_uid 注册失败: %s", exc) return f"获取图片失败:{exc}" - # 验证是否为图片类型 mime = str(getattr(record, "mime_type", "") or "").strip().lower() if mime and not mime.startswith(_IMAGE_MIME_PREFIX): return f"URL 内容不是图片类型(检测到 {mime}),仅支持图片" diff --git a/src/Undefined/skills/tools/get_current_time/handler.py b/src/Undefined/skills/tools/get_current_time/handler.py index 1c79700a..5996b743 100644 --- a/src/Undefined/skills/tools/get_current_time/handler.py +++ b/src/Undefined/skills/tools/get_current_time/handler.py @@ -37,7 +37,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: indent=2, ) - # 默认返回 ISO 格式 return now.isoformat(timespec="seconds") @@ -51,7 +50,6 @@ def _format_text( """生成人类可读的文本格式""" lines = [] - # 公历信息 weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] weekday = weekdays[now.weekday()] tz_offset = now.strftime("%z") @@ -62,7 +60,6 @@ def _format_text( f"{now.hour:02d}:{now.minute:02d}:{now.second:02d} ({tz_str})" ) - # 农历信息 if include_lunar and lunar: year_gz = lunar.getYearInGanZhi() zodiac = lunar.getYearShengXiao() @@ -70,19 +67,15 @@ def _format_text( day_cn = lunar.getDayInChinese() lines.append(f"农历:{year_gz}年({zodiac}年) {month_cn}{day_cn}") - # 干支信息 month_gz = lunar.getMonthInGanZhi() day_gz = lunar.getDayInGanZhi() lines.append(f"干支:{year_gz}年 {month_gz}月 {day_gz}日") - # 黄历信息 if include_almanac and lunar: - # 节气 jieqi = lunar.getCurrentJieQi() if jieqi: lines.append(f"节气:{jieqi.getName()}") - # 节日 festivals = [] if solar: solar_festivals = solar.getFestivals() @@ -92,7 +85,6 @@ def _format_text( if festivals: lines.append(f"节日:{' '.join(festivals)}") - # 宜忌 yi = lunar.getDayYi() if yi: lines.append(f"宜:{' '.join(yi)}") @@ -100,7 +92,6 @@ def _format_text( if ji: lines.append(f"忌:{' '.join(ji)}") - # 冲煞 chong = lunar.getDayChongDesc() sha = lunar.getDaySha() if chong or sha: @@ -109,7 +100,6 @@ def _format_text( chong_sha += f"煞{sha}" lines.append(f"冲煞:{chong_sha}") - # 胎神 tai = lunar.getDayPositionTai() if tai: lines.append(f"胎神:{tai}") @@ -127,7 +117,6 @@ def _format_json( """生成结构化 JSON 格式""" result: Dict[str, Any] = {} - # 公历信息 weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] weekday = weekdays[now.weekday()] tz_offset = now.strftime("%z") @@ -145,7 +134,6 @@ def _format_json( "timezone": tz_str, } - # 农历信息 if include_lunar and lunar: result["lunar"] = { "year_cn": lunar.getYearInGanZhi(), @@ -159,16 +147,13 @@ def _format_json( }, } - # 黄历信息 if include_almanac and lunar: almanac: Dict[str, Any] = {} - # 节气 jieqi = lunar.getCurrentJieQi() if jieqi: almanac["solar_term"] = {"current": jieqi.getName()} - # 节日 festivals = [] if solar: solar_festivals = solar.getFestivals() @@ -178,7 +163,6 @@ def _format_json( if festivals: almanac["festivals"] = festivals - # 宜忌 yi = lunar.getDayYi() if yi: almanac["yi"] = yi @@ -186,7 +170,6 @@ def _format_json( if ji: almanac["ji"] = ji - # 冲煞 chong = lunar.getDayChongDesc() sha = lunar.getDaySha() if chong or sha: @@ -195,7 +178,6 @@ def _format_json( chong_sha += f"煞{sha}" almanac["chong"] = chong_sha - # 胎神 tai = lunar.getDayPositionTai() if tai: almanac["fetal_god"] = tai diff --git a/src/Undefined/skills/tools/get_picture/handler.py b/src/Undefined/skills/tools/get_picture/handler.py index d580dce4..5fbb6893 100644 --- a/src/Undefined/skills/tools/get_picture/handler.py +++ b/src/Undefined/skills/tools/get_picture/handler.py @@ -98,7 +98,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: device = args.get("device", "pc") fourk_type = args.get("fourk_type", "acg") - # 参数验证 if delivery not in {"embed", "send"}: return f"delivery 无效:{delivery}。仅支持 embed 或 send" @@ -117,14 +116,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if picture_type == "random4kPic" and fourk_type not in ("acg", "wallpaper"): return "4K图片类型必须是 acg(二次元)或 wallpaper(风景)" - # 构造请求参数 params: Dict[str, Any] = {"return": "json"} if picture_type == "acg": params["type"] = device elif picture_type == "random4kPic": params["type"] = fourk_type - # 创建图片保存目录 from Undefined.utils.paths import IMAGE_CACHE_DIR, ensure_dir img_dir = ensure_dir(IMAGE_CACHE_DIR) @@ -133,7 +130,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: base_url = _get_xxapi_base_url() api_url = f"{base_url}{API_PATHS[picture_type]}" - # 获取图片 success_count = 0 fail_count = 0 local_image_paths: list[str] = [] @@ -153,14 +149,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 美腿类型直接返回 JPEG 图片,不需要解析 JSON if picture_type == "meitui": - # 验证响应内容类型 content_type = response.headers.get("content-type", "") if "image" not in content_type.lower(): logger.error(f"响应不是图片格式: {content_type}") fail_count += 1 continue - # 保存图片 filename = f"{picture_type}_{uuid.uuid4().hex[:16]}.jpg" filepath = img_dir / filename filepath.write_bytes(response.content) @@ -171,13 +165,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: else: data = response.json() - # 检查响应 if data.get("code") != 200: logger.error(f"获取图片失败: {data.get('msg')}") fail_count += 1 continue - # 获取图片 URL image_url = data.get("data") if not image_url: logger.error("响应中未找到图片 URL") @@ -186,7 +178,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.info(f"图片 URL: {image_url}") - # 下载图片到本地 logger.info("正在下载图片到本地...") image_response = await request_with_retry( "GET", @@ -195,7 +186,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: context=context, ) - # 保存图片 filename = f"{picture_type}_{uuid.uuid4().hex[:16]}.jpg" filepath = img_dir / filename filepath.write_bytes(image_response.content) @@ -214,7 +204,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception(f"获取图片失败: {e}") fail_count += 1 - # 如果没有获取到任何图片 if success_count == 0: return f"获取 {TYPE_NAMES[picture_type]} 图片失败,请稍后重试" diff --git a/src/Undefined/skills/tools/get_user_info/handler.py b/src/Undefined/skills/tools/get_user_info/handler.py index e8521458..5bb34fea 100644 --- a/src/Undefined/skills/tools/get_user_info/handler.py +++ b/src/Undefined/skills/tools/get_user_info/handler.py @@ -32,7 +32,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: return "获取用户信息功能不可用(OneBot 客户端未设置)" try: - # 使用 get_stranger_info 获取详细信息 user_info = await onebot_client.get_stranger_info(user_id) if not user_info: @@ -40,12 +39,10 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: result_parts = ["【QQ用户信息】"] - # 添加头像 URL (常用 API) result_parts.append( f"头像: http://q.qlogo.cn/headimg_dl?dst_uin={user_id}&spec=640" ) - # 处理性别 sex = user_info.get("sex") if sex == "male": user_info["sex"] = "男" @@ -59,8 +56,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if value is not None and value != "": result_parts.append(f"{display_name}: {value}") - # 如果有其他字段(取决于 OneBot 实现,如 NapCat/Go-CQHttp 可能有更多) - # 我们可以尝试输出一些常见的额外字段 extra_fields = { "remark": "备注", "signature": "签名", diff --git a/src/Undefined/skills/tools/python_interpreter/handler.py b/src/Undefined/skills/tools/python_interpreter/handler.py index 002bf560..f23365db 100644 --- a/src/Undefined/skills/tools/python_interpreter/handler.py +++ b/src/Undefined/skills/tools/python_interpreter/handler.py @@ -164,7 +164,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: libraries: list[str] = args.get("libraries") or [] send_files: list[str] = args.get("send_files") or [] - # 验证库名 for lib in libraries: if not _SAFE_LIB_PATTERN.match(lib): return f"错误: 无效的库名 '{lib}'。" @@ -173,12 +172,10 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: memory = MEMORY_LIMIT_WITH_LIBS if has_libs else MEMORY_LIMIT timeout = TIMEOUT_WITH_LIBS if has_libs else TIMEOUT - # 创建宿主机临时目录,绑定挂载到容器 /tmp host_tmpdir = tempfile.mkdtemp(prefix="pyinterp_") defer_cleanup = False try: - # 验证文件路径必须绑定到容器 /tmp,并且不能逃逸宿主机临时目录 for fpath in send_files: if _resolve_output_host_path(fpath, host_tmpdir) is None: return f"错误: 输出文件路径必须位于容器 /tmp 目录内: '{fpath}'" @@ -235,7 +232,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: timeout=exec_timeout, ) - # 构建结果 parts: list[str] = [] if exit_code == 0: parts.append( @@ -246,7 +242,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: f"代码执行失败 (退出代码: {exit_code}):\n{error_output}\n{response}" ) - # 执行成功时发送文件 if send_files and exit_code == 0: file_result = await _send_output_files(send_files, host_tmpdir, context) if file_result: diff --git a/src/Undefined/skills/tools/qq_like/handler.py b/src/Undefined/skills/tools/qq_like/handler.py index 87031020..3b635f4d 100644 --- a/src/Undefined/skills/tools/qq_like/handler.py +++ b/src/Undefined/skills/tools/qq_like/handler.py @@ -20,7 +20,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if target_user_id is None: return "请提供要点赞的目标QQ号(target_user_id参数)" - # 验证参数类型 try: target_user_id = int(target_user_id) times = int(times) @@ -37,7 +36,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: return "点赞功能不可用(回调函数未设置)" try: - # 调用点赞回调 await send_like_callback(target_user_id, times) if times == 1: @@ -49,7 +47,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception(f"点赞失败: {e}") error_msg = str(e) - # 根据错误消息提供更友好的提示 if "SVIP 上限" in error_msg: return "点赞失败:今日给同一好友的点赞数已达SVIP上限" elif "点赞失败" in error_msg: diff --git a/src/Undefined/skills/tools/task_progress/handler.py b/src/Undefined/skills/tools/task_progress/handler.py index 668037a0..9327c70b 100644 --- a/src/Undefined/skills/tools/task_progress/handler.py +++ b/src/Undefined/skills/tools/task_progress/handler.py @@ -61,11 +61,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not isinstance(tasks_input, list) or len(tasks_input) == 0: return "tasks 必须是非空数组" - # 从 context 获取或初始化任务列表 task_store: list[dict[str, Any]] = context.get("_task_progress", []) if action == "plan": - # 创建/替换计划 new_tasks: list[dict[str, Any]] = [] for item in tasks_input: tid_raw = item.get("id") @@ -80,7 +78,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: status = "pending" new_tasks.append({"id": tid, "description": desc, "status": status}) - # 按 id 排序 new_tasks.sort(key=lambda t: t.get("id", 0)) task_store = new_tasks context["_task_progress"] = task_store @@ -88,11 +85,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.info(f"[task_progress] 创建计划,共 {len(task_store)} 个步骤") return f"计划已创建\n{_format_task_list(task_store)}" - # action == "update" if not task_store: return "还没有任务计划,请先用 plan 动作创建" - # 构建 id -> index 映射 id_to_idx = {t["id"]: i for i, t in enumerate(task_store)} updated: list[int] = [] @@ -112,7 +107,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: return f"不存在 id={tid} 的步骤" task_store[idx]["status"] = new_status - # 允许更新描述 new_desc = item.get("description") if new_desc: task_store[idx]["description"] = new_desc diff --git a/src/Undefined/utils/render_cache.py b/src/Undefined/utils/render_cache.py index 2bf2afe8..12d7ed47 100644 --- a/src/Undefined/utils/render_cache.py +++ b/src/Undefined/utils/render_cache.py @@ -362,6 +362,7 @@ async def get_render_cache() -> HtmlRenderCache: 单例的 enabled / 容量由 ``[render.cache]`` 决定; 禁用时仍返回单例对象,但所有 get/put 立即短路。 """ + # global global _cache if _cache is not None: await _cache.initialize() @@ -387,6 +388,7 @@ async def get_render_cache() -> HtmlRenderCache: async def close_render_cache() -> None: """关停时调用:刷盘并丢弃单例。""" + # global global _cache cache = _cache if cache is None: @@ -399,5 +401,6 @@ async def close_render_cache() -> None: def reset_render_cache() -> None: """仅供测试使用:丢弃单例(不刷盘),下次调用重新加载。""" + # global global _cache _cache = None diff --git a/src/Undefined/utils/sender_helpers.py b/src/Undefined/utils/sender_helpers.py new file mode 100644 index 00000000..0138f289 --- /dev/null +++ b/src/Undefined/utils/sender_helpers.py @@ -0,0 +1,116 @@ +"""MessageSender 辅助函数。""" + +from pathlib import Path +from typing import Any +from urllib.parse import unquote, urlsplit + +from Undefined.attachments import attachment_refs_to_text + + +def _extract_message_id(result: object) -> int | None: + if not isinstance(result, dict): + return None + + message_id = result.get("message_id") + if message_id is None: + # OneBot 实现差异:message_id 可能在顶层或 data 子对象。 + data = result.get("data") + if isinstance(data, dict): + message_id = data.get("message_id") + + try: + return int(message_id) if message_id is not None else None + except (TypeError, ValueError): + return None + + +def _format_size(size_bytes: int | None) -> str: + if size_bytes is None or size_bytes < 0: + return "未知大小" + if size_bytes < 1024: + return f"{size_bytes}B" + if size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f}KB" + if size_bytes < 1024 * 1024 * 1024: + return f"{size_bytes / 1024 / 1024:.2f}MB" + return f"{size_bytes / 1024 / 1024 / 1024:.2f}GB" + + +def _build_file_history_message(file_name: str, size_bytes: int | None) -> str: + return f"[文件] {file_name} ({_format_size(size_bytes)})" + + +def _append_attachment_refs( + history_content: str, + attachments: list[dict[str, str]] | None, +) -> str: + refs_text = attachment_refs_to_text(attachments or []) + if not refs_text or refs_text in history_content: + return history_content + if not history_content: + return refs_text + return f"{history_content}\n{refs_text}" + + +def _merge_attachment_refs( + *groups: list[dict[str, str]] | None, +) -> list[dict[str, str]]: + merged: list[dict[str, str]] = [] + seen_uids: set[str] = set() + for group in groups: + for item in group or []: + uid = str(item.get("uid", "") or "").strip() + if uid and uid in seen_uids: + continue + if uid: + seen_uids.add(uid) + merged.append(item) + return merged + + +def _iter_segments_deep(value: object) -> list[dict[str, Any]]: + """递归收集消息段,用于合并转发中的本地媒体登记。""" + segments: list[dict[str, Any]] = [] + if isinstance(value, dict): + type_value = value.get("type") + data = value.get("data") + if type_value is not None and isinstance(data, dict): + segments.append(value) + content = data.get("content") + if isinstance(content, (list, dict)): + segments.extend(_iter_segments_deep(content)) + else: + for child in value.values(): + if isinstance(child, (list, dict)): + segments.extend(_iter_segments_deep(child)) + elif isinstance(value, list): + for child in value: + if isinstance(child, (list, dict)): + segments.extend(_iter_segments_deep(child)) + return segments + + +def _local_path_from_segment_source(source: Any) -> Path | None: + raw_source = str(source or "").strip() + if not raw_source: + return None + lowered = raw_source.lower() + if lowered.startswith(("http://", "https://", "base64://")): + return None + if lowered.startswith("file://"): + parsed = urlsplit(raw_source) + path = Path(unquote(parsed.path)).expanduser() + else: + path = Path(raw_source).expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + else: + path = path.resolve() + return path if path.is_file() else None + + +def _get_file_size(file_path: str) -> int | None: + try: + return Path(file_path).stat().st_size + except OSError: + return None diff --git a/tests/test_ai_coordinator_queue_routing.py b/tests/test_ai_coordinator_queue_routing.py index 194e67c7..7a78a73f 100644 --- a/tests/test_ai_coordinator_queue_routing.py +++ b/tests/test_ai_coordinator_queue_routing.py @@ -6,8 +6,8 @@ import pytest -from Undefined.services import ai_coordinator as ai_coordinator_module from Undefined.services.ai_coordinator import AICoordinator +from Undefined.services.coordinator import group as coordinator_group_module @pytest.mark.asyncio @@ -245,7 +245,9 @@ async def _fake_ask(*_args: Any, **kwargs: Any) -> str: coordinator.scheduler = SimpleNamespace() monkeypatch.setattr( - ai_coordinator_module, "collect_context_resources", lambda _vars: {} + coordinator_group_module, + "collect_context_resources", + lambda _vars: {}, ) await coordinator._execute_auto_reply( From 5959034a60f3956f08cc36a55ba09119adeb4d00 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 21:36:58 +0800 Subject: [PATCH 06/16] chore: remove injected comment spam from HTTP and naga modules Strip agent-generated label comments and duplicate section headers without changing runtime behavior. Co-authored-by: Cursor --- src/Undefined/ai/llm/requester.py | 1 - src/Undefined/api/routes/naga/bind.py | 3 - src/Undefined/api/routes/naga/send.py | 6 - src/Undefined/api/routes/naga/unbind.py | 3 - src/Undefined/skills/http_client.py | 171 ------------------------ src/Undefined/skills/http_config.py | 43 ------ 6 files changed, 227 deletions(-) diff --git a/src/Undefined/ai/llm/requester.py b/src/Undefined/ai/llm/requester.py index d5c025da..2e7e667f 100644 --- a/src/Undefined/ai/llm/requester.py +++ b/src/Undefined/ai/llm/requester.py @@ -416,7 +416,6 @@ async def request( ) if bool( getattr(model_config, "prompt_cache_enabled", True) - # ) and not effective_kwargs.get("prompt_cache_key"): ) and not effective_kwargs.get("prompt_cache_key"): effective_kwargs["prompt_cache_key"] = _build_default_prompt_cache_key( model_config, diff --git a/src/Undefined/api/routes/naga/bind.py b/src/Undefined/api/routes/naga/bind.py index d6f4d756..fe2e2ca9 100644 --- a/src/Undefined/api/routes/naga/bind.py +++ b/src/Undefined/api/routes/naga/bind.py @@ -21,10 +21,7 @@ # POST /api/v1/naga/bind/callback # ------------------------------------------------------------------ -# ------------------------------------------------------------------ - -# POST /api/v1/naga/bind/callback — Naga 绑定回调 async def naga_bind_callback_handler( ctx: RuntimeAPIContext, request: web.Request ) -> Response: diff --git a/src/Undefined/api/routes/naga/send.py b/src/Undefined/api/routes/naga/send.py index b14e3e54..0e2c3280 100644 --- a/src/Undefined/api/routes/naga/send.py +++ b/src/Undefined/api/routes/naga/send.py @@ -29,10 +29,7 @@ # POST /api/v1/naga/messages/send # ------------------------------------------------------------------ -# ------------------------------------------------------------------ - -# POST /api/v1/naga/messages/send — 验签后发送消息 async def naga_messages_send_handler( ctx: RuntimeAPIContext, naga_state: NagaState, @@ -230,9 +227,6 @@ async def naga_messages_send_handler( # Core send implementation # ------------------------------------------------------------------ -# Core send implementation (no NagaState dependency) -# ------------------------------------------------------------------ - async def naga_messages_send_impl( ctx: RuntimeAPIContext, diff --git a/src/Undefined/api/routes/naga/unbind.py b/src/Undefined/api/routes/naga/unbind.py index 5b4493ba..f29e3a91 100644 --- a/src/Undefined/api/routes/naga/unbind.py +++ b/src/Undefined/api/routes/naga/unbind.py @@ -20,10 +20,7 @@ # POST /api/v1/naga/unbind # ------------------------------------------------------------------ -# ------------------------------------------------------------------ - -# POST /api/v1/naga/unbind — 远端主动解绑 async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: """POST /api/v1/naga/unbind — 远端主动解绑。""" trace_id = _uuid.uuid4().hex[:8] diff --git a/src/Undefined/skills/http_client.py b/src/Undefined/skills/http_client.py index 30ce813f..ca6875c0 100644 --- a/src/Undefined/skills/http_client.py +++ b/src/Undefined/skills/http_client.py @@ -1,142 +1,29 @@ -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 性能敏感路径避免重复 IO 与网络请求 -# http_client 模块实现细节说明 -# 本段注释用于描述 http_client 的核心职责与边界 -# 调用方应通过公开 API 访问 http_client 能力 -# 异常路径统一记录日志并向上抛出或返回错误码 -# 配置项变更需热重载或重启后生效 -# 线程/协程安全:共享状态需加锁或使用单例 -# 导入 from __future__ import annotations -# 导入 import asyncio - -# 导入 import json - -# 导入 import logging - -# 导入 from typing import Any -# 导入 import httpx -# 导入 from Undefined.skills.http_config import ( get_request_proxy, get_request_retries, get_request_timeout, ) -# 赋值 logger = logging.getLogger(__name__) -# 函数 _should_retry_http_status def _should_retry_http_status(status_code: int) -> bool: - # 返回 return status_code == 429 or 500 <= status_code < 600 -# 函数 _retry_delay def _retry_delay(attempt: int) -> float: - # 返回 return float(min(2.0, 0.25 * (2**attempt))) -# 异步函数 request_with_retry async def request_with_retry( method: str, url: str, @@ -152,41 +39,27 @@ async def request_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> httpx.Response: - # 赋值 request_timeout = ( timeout if timeout is not None else get_request_timeout(default_timeout) ) - # 赋值 request_retries = retries if retries is not None else get_request_retries(0) - # 赋值 request_proxy = get_request_proxy(url) - # 赋值 request_id = "-" - # 条件分支 if context is not None: - # 赋值 request_id = str(context.get("request_id", "-")) - # 注解赋值 last_exception: Exception | None = None - # 注解赋值 client_kwargs: dict[str, Any] = { "timeout": request_timeout, "follow_redirects": follow_redirects, "trust_env": False, } - # 条件分支 if request_proxy is not None: - # 赋值 client_kwargs["proxy"] = request_proxy - # async with 上下文 async with httpx.AsyncClient(**client_kwargs) as client: - # for 循环 for attempt in range(request_retries + 1): - # try 块 try: - # 赋值 response = await client.request( method=method, url=url, @@ -196,14 +69,11 @@ async def request_with_retry( files=files, headers=headers, ) - # 条件分支 if ( _should_retry_http_status(response.status_code) and attempt < request_retries ): - # 赋值 delay = _retry_delay(attempt) - # 表达式 logger.warning( "[HTTP] status retry: method=%s url=%s status=%s attempt=%s/%s wait=%.2fs request_id=%s", method, @@ -214,29 +84,18 @@ async def request_with_retry( delay, request_id, ) - # 表达式 await asyncio.sleep(delay) - # continue continue - # 表达式 response.raise_for_status() - # 返回 return response except httpx.HTTPStatusError as exc: - # 赋值 last_exception = exc - # 条件分支 if attempt >= request_retries: - # break break - # 条件分支 if not _should_retry_http_status(exc.response.status_code): - # break break - # 赋值 delay = _retry_delay(attempt) - # 表达式 logger.warning( "[HTTP] status exception retry: method=%s url=%s status=%s attempt=%s/%s wait=%.2fs request_id=%s", method, @@ -247,18 +106,12 @@ async def request_with_retry( delay, request_id, ) - # 表达式 await asyncio.sleep(delay) except (httpx.TimeoutException, httpx.RequestError) as exc: - # 赋值 last_exception = exc - # 条件分支 if attempt >= request_retries: - # break break - # 赋值 delay = _retry_delay(attempt) - # 表达式 logger.warning( "[HTTP] request retry: method=%s url=%s err=%s attempt=%s/%s wait=%.2fs request_id=%s", method, @@ -269,18 +122,13 @@ async def request_with_retry( delay, request_id, ) - # 表达式 await asyncio.sleep(delay) - # 条件分支 if last_exception is not None: - # 抛出异常 raise last_exception - # 抛出异常 raise RuntimeError(f"HTTP request failed without exception: {method} {url}") -# 异步函数 get_json_with_retry async def get_json_with_retry( url: str, *, @@ -291,14 +139,10 @@ async def get_json_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> Any: - # 赋值 request_id = "-" - # 条件分支 if context is not None: - # 赋值 request_id = str(context.get("request_id", "-")) - # 赋值 response = await request_with_retry( "GET", url, @@ -309,16 +153,11 @@ async def get_json_with_retry( context=context, retries=retries, ) - # try 块 try: - # 返回 return response.json() except json.JSONDecodeError as exc: - # 赋值 content_type = response.headers.get("content-type", "") - # 赋值 preview = response.text[:200].replace("\n", "\\n").replace("\r", "\\r") - # 表达式 logger.warning( "[HTTP] json decode failed: url=%s status=%s content_type=%s preview=%s request_id=%s err=%s", url, @@ -328,11 +167,9 @@ async def get_json_with_retry( request_id, exc, ) - # 抛出异常 raise -# 异步函数 get_text_with_retry async def get_text_with_retry( url: str, *, @@ -343,7 +180,6 @@ async def get_text_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> str: - # 赋值 response = await request_with_retry( "GET", url, @@ -354,11 +190,9 @@ async def get_text_with_retry( context=context, retries=retries, ) - # 返回 return response.text -# 异步函数 get_bytes_with_retry async def get_bytes_with_retry( url: str, *, @@ -369,7 +203,6 @@ async def get_bytes_with_retry( context: dict[str, Any] | None = None, retries: int | None = None, ) -> bytes: - # 赋值 response = await request_with_retry( "GET", url, @@ -380,8 +213,4 @@ async def get_bytes_with_retry( context=context, retries=retries, ) - # 返回 return response.content - - -# 文档注释 1/1 diff --git a/src/Undefined/skills/http_config.py b/src/Undefined/skills/http_config.py index fde119f7..20ca935a 100644 --- a/src/Undefined/skills/http_config.py +++ b/src/Undefined/skills/http_config.py @@ -1,109 +1,66 @@ -# 导入 from __future__ import annotations -# 导入 from urllib.parse import urlsplit -# 导入 from Undefined.config import get_config -# 函数 _normalize_base_url def _normalize_base_url(value: str, fallback: str) -> str: - # 赋值 base_url = value.strip().rstrip("/") - # 返回 return base_url or fallback.rstrip("/") -# 函数 build_url def build_url(base_url: str, path: str) -> str: - # 赋值 normalized_path = path if path.startswith("/") else f"/{path}" - # 返回 return f"{base_url.rstrip('/')}{normalized_path}" -# 函数 get_request_timeout def get_request_timeout(default_timeout: float = 480.0) -> float: - # 赋值 config = get_config(strict=False) - # 赋值 timeout = float(config.network_request_timeout) - # 返回 return timeout if timeout > 0 else default_timeout -# 函数 get_request_retries def get_request_retries(default_retries: int = 0) -> int: - # 赋值 config = get_config(strict=False) - # 赋值 retries = int(config.network_request_retries) - # 条件分支 if retries < 0: - # 返回 return default_retries - # 返回 return retries -# 函数 get_request_proxy def get_request_proxy(url: str) -> str | None: - # 赋值 config = get_config(strict=False) - # 条件分支 if not bool(getattr(config, "use_proxy", False)): - # 返回 return None - # 赋值 http_proxy = str(getattr(config, "http_proxy", "") or "").strip() - # 赋值 https_proxy = str(getattr(config, "https_proxy", "") or "").strip() - # 赋值 scheme = urlsplit(url).scheme.lower() - # 条件分支 if scheme == "https": - # 返回 return https_proxy or http_proxy or None - # 条件分支 if scheme == "http": - # 返回 return http_proxy or https_proxy or None - # 返回 return https_proxy or http_proxy or None -# 函数 get_xxapi_url def get_xxapi_url(path: str) -> str: - # 赋值 config = get_config(strict=False) - # 赋值 base_url = _normalize_base_url(config.api_xxapi_base_url, "https://v2.xxapi.cn") - # 返回 return build_url(base_url, path) -# 函数 get_xingzhige_url def get_xingzhige_url(path: str) -> str: - # 赋值 config = get_config(strict=False) - # 赋值 base_url = _normalize_base_url( config.api_xingzhige_base_url, "https://api.xingzhige.com", ) - # 返回 return build_url(base_url, path) -# 函数 get_jkyai_url def get_jkyai_url(path: str) -> str: - # 赋值 config = get_config(strict=False) - # 赋值 base_url = _normalize_base_url(config.api_jkyai_base_url, "https://api.jkyai.top") - # 返回 return build_url(base_url, path) From 02fe0af3bca0ed173752c249d982095828b6cdb2 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 21:44:32 +0800 Subject: [PATCH 07/16] refactor: dedupe meme helpers and API route guards Extract shared meme image utilities, consolidate ingest lock model, and simplify memes API and bilibili WBI nav parsing. Co-authored-by: Cursor --- src/Undefined/api/routes/memes.py | 79 ++++++++------- src/Undefined/bilibili/wbi.py | 44 ++------- src/Undefined/memes/_image_utils.py | 147 ++++++++++++++++++++++++++++ src/Undefined/memes/_service.py | 128 ++---------------------- src/Undefined/memes/ingest.py | 135 +++---------------------- src/Undefined/memes/models.py | 7 ++ src/Undefined/memes/search.py | 122 +---------------------- src/Undefined/memes/service.py | 144 +++++---------------------- 8 files changed, 249 insertions(+), 557 deletions(-) create mode 100644 src/Undefined/memes/_image_utils.py diff --git a/src/Undefined/api/routes/memes.py b/src/Undefined/api/routes/memes.py index dcef98ad..07e75a5e 100644 --- a/src/Undefined/api/routes/memes.py +++ b/src/Undefined/api/routes/memes.py @@ -1,4 +1,4 @@ -"""Meme management route handlers.""" +"""Meme API route handlers.""" from __future__ import annotations @@ -11,10 +11,21 @@ from Undefined.api._helpers import _json_error, _optional_query_param, _to_bool -async def meme_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: +def _require_meme_service(ctx: RuntimeAPIContext) -> tuple[Any, Response | None]: meme_service = ctx.meme_service if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) + return None, _json_error("Meme service disabled", status=400) + return meme_service, None + + +def _meme_uid(request: web.Request) -> str: + return str(request.match_info.get("uid", "")).strip() + + +async def meme_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error def _parse_optional_bool(name: str) -> bool | None: raw = request.query.get(name) @@ -120,17 +131,17 @@ def _parse_optional_bool(name: str) -> bool | None: async def meme_stats_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: _ = request - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error return web.json_response(await meme_service.stats()) async def meme_detail_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) detail = await meme_service.get_meme(uid) if detail is None: return _json_error("Meme not found", status=404) @@ -138,10 +149,10 @@ async def meme_detail_handler(ctx: RuntimeAPIContext, request: web.Request) -> R async def meme_blob_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) path = await meme_service.blob_path_for_uid(uid, preview=False) if path is None: return _json_error("Meme blob not found", status=404) @@ -151,10 +162,10 @@ async def meme_blob_handler(ctx: RuntimeAPIContext, request: web.Request) -> Res async def meme_preview_handler( ctx: RuntimeAPIContext, request: web.Request ) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) path = await meme_service.blob_path_for_uid(uid, preview=True) if path is None: return _json_error("Meme preview not found", status=404) @@ -162,10 +173,10 @@ async def meme_preview_handler( async def meme_update_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) try: payload = await request.json() except Exception: @@ -186,10 +197,10 @@ async def meme_update_handler(ctx: RuntimeAPIContext, request: web.Request) -> R async def meme_delete_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) deleted = await meme_service.delete_meme(uid) if not deleted: return _json_error("Meme not found", status=404) @@ -199,10 +210,10 @@ async def meme_delete_handler(ctx: RuntimeAPIContext, request: web.Request) -> R async def meme_reanalyze_handler( ctx: RuntimeAPIContext, request: web.Request ) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) job_id = await meme_service.enqueue_reanalyze(uid) if not job_id: return _json_error("Meme queue unavailable", status=503) @@ -212,10 +223,10 @@ async def meme_reanalyze_handler( async def meme_reindex_handler( ctx: RuntimeAPIContext, request: web.Request ) -> Response: - meme_service = ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() + meme_service, error = _require_meme_service(ctx) + if error is not None: + return error + uid = _meme_uid(request) job_id = await meme_service.enqueue_reindex(uid) if not job_id: return _json_error("Meme queue unavailable", status=503) diff --git a/src/Undefined/bilibili/wbi.py b/src/Undefined/bilibili/wbi.py index 2da7a9cc..f3fdb97b 100644 --- a/src/Undefined/bilibili/wbi.py +++ b/src/Undefined/bilibili/wbi.py @@ -152,11 +152,7 @@ def _compute_mixin_key(img_key: str, sub_key: str) -> str: return mixed[:32] -async def _refresh_mixin_key(client: httpx.AsyncClient) -> str: - resp = await client.get(_BILIBILI_API_NAV) - resp.raise_for_status() - payload = resp.json() - +def _mixin_key_from_nav_payload(payload: Any) -> str: if not isinstance(payload, dict): raise ValueError("nav 接口返回格式异常") @@ -186,13 +182,18 @@ async def _refresh_mixin_key(client: httpx.AsyncClient) -> str: return _compute_mixin_key(img_key, sub_key) +async def _refresh_mixin_key(client: httpx.AsyncClient) -> str: + resp = await client.get(_BILIBILI_API_NAV) + resp.raise_for_status() + return _mixin_key_from_nav_payload(resp.json()) + + async def get_mixin_key( client: httpx.AsyncClient, *, force_refresh: bool = False, ) -> str: """获取可复用的 mixin_key。""" - # global global _cached_mixin_key_async, _cached_at_async now = time.time() @@ -259,35 +260,7 @@ async def build_signed_params( def _refresh_mixin_key_sync(client: httpx.Client) -> str: resp = client.get(_BILIBILI_API_NAV) resp.raise_for_status() - payload = resp.json() - - if not isinstance(payload, dict): - raise ValueError("nav 接口返回格式异常") - - code = int(payload.get("code", -1)) - if code not in (0, -101): - message = payload.get("message", "未知错误") - raise ValueError(f"获取 wbi key 失败: {message} (code={code})") - - data = payload.get("data") - if not isinstance(data, dict): - raise ValueError("nav 接口 data 字段异常") - - wbi_img = data.get("wbi_img") - if not isinstance(wbi_img, dict): - raise ValueError("nav 接口 wbi_img 字段缺失") - - img_url = str(wbi_img.get("img_url", "")).strip() - sub_url = str(wbi_img.get("sub_url", "")).strip() - if not img_url or not sub_url: - raise ValueError("nav 接口未返回有效的 img_url/sub_url") - - img_key = _extract_key_from_url(img_url) - sub_key = _extract_key_from_url(sub_url) - if not img_key or not sub_key: - raise ValueError("无法提取有效的 img_key/sub_key") - - return _compute_mixin_key(img_key, sub_key) + return _mixin_key_from_nav_payload(resp.json()) def get_mixin_key_sync( @@ -296,7 +269,6 @@ def get_mixin_key_sync( force_refresh: bool = False, ) -> str: """同步获取可复用的 mixin_key。""" - # global global _cached_mixin_key_sync, _cached_at_sync now = time.time() diff --git a/src/Undefined/memes/_image_utils.py b/src/Undefined/memes/_image_utils.py new file mode 100644 index 00000000..aff46c57 --- /dev/null +++ b/src/Undefined/memes/_image_utils.py @@ -0,0 +1,147 @@ +"""Meme 图片处理与标签归一化共享工具。""" + +from __future__ import annotations + +import math +import mimetypes +import re +from datetime import datetime +from pathlib import Path + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from PIL import Image + +from Undefined.memes.models import normalize_string_list + +_IMAGE_EXTENSIONS_BY_MIME = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", +} +_TAG_SPLIT_RE = re.compile(r"[,,\n]+") + + +def now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def guess_suffix(path: Path, mime_type: str) -> str: + suffix = path.suffix.lower() + if suffix: + return suffix + guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) + if guessed: + return guessed + mime_guess = mimetypes.guess_extension(mime_type or "") + if mime_guess: + return mime_guess.lower() + return ".bin" + + +def normalize_tags(raw_tags: list[str] | str | None) -> list[str]: + if raw_tags is None: + return [] + if isinstance(raw_tags, str): + parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] + return normalize_string_list(parts) + return normalize_string_list(raw_tags) + + +def is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + +# 向后兼容:旧模块级私有名仍可从 service 等路径导入 +_now_iso = now_iso +_guess_suffix = guess_suffix +_normalize_tags = normalize_tags +_is_retryable_llm_error = is_retryable_llm_error +_extract_gif_frames = extract_gif_frames +_sample_frame_indices = sample_frame_indices +_compose_grid = compose_grid + +__all__ = [ + "compose_grid", + "extract_gif_frames", + "guess_suffix", + "is_retryable_llm_error", + "normalize_tags", + "now_iso", + "sample_frame_indices", + "_compose_grid", + "_extract_gif_frames", + "_guess_suffix", + "_is_retryable_llm_error", + "_normalize_tags", + "_now_iso", + "_sample_frame_indices", +] diff --git a/src/Undefined/memes/_service.py b/src/Undefined/memes/_service.py index 5dab0c77..33a720b2 100644 --- a/src/Undefined/memes/_service.py +++ b/src/Undefined/memes/_service.py @@ -3,23 +3,20 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass -from datetime import datetime import logging -import math -import mimetypes from pathlib import Path -import re import threading from typing import Any -from openai import APIConnectionError, APIStatusError, APITimeoutError -from PIL import Image from Undefined.attachments import AttachmentRecord +from Undefined.memes._image_utils import ( + _normalize_tags, + _now_iso, +) from Undefined.memes.models import ( + IngestDigestLockEntry, build_search_text, - normalize_string_list, ) from Undefined.memes.store import MemeStore from Undefined.memes.vector_store import MemeVectorStore @@ -29,119 +26,6 @@ logger = logging.getLogger(__name__) -_IMAGE_EXTENSIONS_BY_MIME = { - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/bmp": ".bmp", - "image/svg+xml": ".svg", -} -_TAG_SPLIT_RE = re.compile(r"[,,\n]+") - - -def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") - - -def _guess_suffix(path: Path, mime_type: str) -> str: - suffix = path.suffix.lower() - if suffix: - return suffix - guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) - if guessed: - return guessed - mime_guess = mimetypes.guess_extension(mime_type or "") - if mime_guess: - return mime_guess.lower() - return ".bin" - - -def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: - if raw_tags is None: - return [] - if isinstance(raw_tags, str): - parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] - return normalize_string_list(parts) - return normalize_string_list(raw_tags) - - -def _is_retryable_llm_error(exc: Exception) -> bool: - """判断 LLM 调用异常是否应触发 worker 级重试。""" - if isinstance(exc, (APIConnectionError, APITimeoutError)): - return True - if isinstance(exc, APIStatusError): - return exc.status_code == 429 or exc.status_code >= 500 - return False - - -def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: - """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" - with Image.open(source_path) as image: - total = getattr(image, "n_frames", 1) - if total <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - n = min(n_frames, total) - if n <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - indices = _sample_frame_indices(total, n) - frames: list[Image.Image] = [] - for idx in indices: - image.seek(idx) - frames.append(image.convert("RGBA").copy()) - return frames - - -def _sample_frame_indices(total: int, n: int) -> list[int]: - """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" - if n >= total: - return list(range(total)) - if n == 1: - return [0] - if n == 2: - return [0, total - 1] - indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] - # 去重并保持顺序 - seen: set[int] = set() - result: list[int] = [] - for idx in indices: - if idx not in seen: - seen.add(idx) - result.append(idx) - return result - - -def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: - """将多帧拼接为网格图并保存为 PNG。""" - n = len(frames) - if n == 0: - return - if n == 1: - frames[0].save(output_path, format="PNG") - return - cols = math.ceil(math.sqrt(n)) - rows = math.ceil(n / cols) - fw, fh = frames[0].size - grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) - for i, frame in enumerate(frames): - resized = ( - frame.resize((fw, fh), Image.Resampling.LANCZOS) - if frame.size != (fw, fh) - else frame - ) - x = (i % cols) * fw - y = (i // cols) * fh - grid.paste(resized, (x, y)) - grid.save(output_path, format="PNG") - - -@dataclass -class _IngestDigestLockEntry: - lock: asyncio.Lock - users: int = 0 - class MemeService(MemeSearchMixin, MemeIngestMixin): def __init__( @@ -163,7 +47,7 @@ def __init__( self._attachment_registry = attachment_registry self._retrieval_runtime = retrieval_runtime # 同内容 digest 锁:进程内串行入库,防止重复 AI 分析 - self._ingest_digest_locks: dict[str, _IngestDigestLockEntry] = {} + self._ingest_digest_locks: dict[str, IngestDigestLockEntry] = {} self._ingest_digest_locks_guard = asyncio.Lock() self._global_image_cache: dict[str, AttachmentRecord] = {} self._global_image_cache_lock = threading.Lock() diff --git a/src/Undefined/memes/ingest.py b/src/Undefined/memes/ingest.py index 0341d4bb..dd157bf5 100644 --- a/src/Undefined/memes/ingest.py +++ b/src/Undefined/memes/ingest.py @@ -4,26 +4,30 @@ import asyncio from collections.abc import Mapping -from dataclasses import dataclass, replace -from datetime import datetime +from dataclasses import replace import hashlib import logging -import math import mimetypes from pathlib import Path -import re import shutil from typing import TYPE_CHECKING, Any from uuid import uuid4 -from openai import APIConnectionError, APIStatusError, APITimeoutError from PIL import Image +from Undefined.memes._image_utils import ( + _compose_grid, + _extract_gif_frames, + _guess_suffix, + _is_retryable_llm_error, + _normalize_tags, + _now_iso, +) from Undefined.memes.models import ( + IngestDigestLockEntry, MemeRecord, MemeSourceRecord, build_search_text, - normalize_string_list, ) if TYPE_CHECKING: @@ -32,119 +36,6 @@ logger = logging.getLogger(__name__) -_IMAGE_EXTENSIONS_BY_MIME = { - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/bmp": ".bmp", - "image/svg+xml": ".svg", -} -_TAG_SPLIT_RE = re.compile(r"[,,\n]+") - - -def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") - - -def _guess_suffix(path: Path, mime_type: str) -> str: - suffix = path.suffix.lower() - if suffix: - return suffix - guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) - if guessed: - return guessed - mime_guess = mimetypes.guess_extension(mime_type or "") - if mime_guess: - return mime_guess.lower() - return ".bin" - - -def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: - if raw_tags is None: - return [] - if isinstance(raw_tags, str): - parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] - return normalize_string_list(parts) - return normalize_string_list(raw_tags) - - -def _is_retryable_llm_error(exc: Exception) -> bool: - """判断 LLM 调用异常是否应触发 worker 级重试。""" - if isinstance(exc, (APIConnectionError, APITimeoutError)): - return True - if isinstance(exc, APIStatusError): - return exc.status_code == 429 or exc.status_code >= 500 - return False - - -def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: - """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" - with Image.open(source_path) as image: - total = getattr(image, "n_frames", 1) - if total <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - n = min(n_frames, total) - if n <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - indices = _sample_frame_indices(total, n) - frames: list[Image.Image] = [] - for idx in indices: - image.seek(idx) - frames.append(image.convert("RGBA").copy()) - return frames - - -def _sample_frame_indices(total: int, n: int) -> list[int]: - """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" - if n >= total: - return list(range(total)) - if n == 1: - return [0] - if n == 2: - return [0, total - 1] - indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] - # 去重并保持顺序 - seen: set[int] = set() - result: list[int] = [] - for idx in indices: - if idx not in seen: - seen.add(idx) - result.append(idx) - return result - - -def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: - """将多帧拼接为网格图并保存为 PNG。""" - n = len(frames) - if n == 0: - return - if n == 1: - frames[0].save(output_path, format="PNG") - return - cols = math.ceil(math.sqrt(n)) - rows = math.ceil(n / cols) - fw, fh = frames[0].size - grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) - for i, frame in enumerate(frames): - resized = ( - frame.resize((fw, fh), Image.Resampling.LANCZOS) - if frame.size != (fw, fh) - else frame - ) - x = (i % cols) * fw - y = (i // cols) * fh - grid.paste(resized, (x, y)) - grid.save(output_path, format="PNG") - - -@dataclass -class _IngestDigestLockEntry: - lock: asyncio.Lock - users: int = 0 - class MemeIngestMixin: if TYPE_CHECKING: @@ -164,11 +55,11 @@ def _queue_enabled(self) -> bool: ... async def delete_meme(self, uid: str) -> bool: ... def enabled(self) -> bool: ... - async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEntry: + async def _acquire_ingest_digest_lock(self, digest: str) -> IngestDigestLockEntry: async with self._ingest_digest_locks_guard: entry = self._ingest_digest_locks.get(digest) if entry is None: - entry = _IngestDigestLockEntry(lock=asyncio.Lock()) + entry = IngestDigestLockEntry(lock=asyncio.Lock()) self._ingest_digest_locks[digest] = entry entry.users += 1 try: @@ -181,7 +72,7 @@ async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEnt async def _release_ingest_digest_lock_reference( self, digest: str, - entry: _IngestDigestLockEntry, + entry: IngestDigestLockEntry, *, release_lock: bool = False, ) -> None: diff --git a/src/Undefined/memes/models.py b/src/Undefined/memes/models.py index 99a3c989..d4637535 100644 --- a/src/Undefined/memes/models.py +++ b/src/Undefined/memes/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass, field from typing import Any @@ -51,6 +52,12 @@ class MemeSourceRecord: seen_at: str +@dataclass +class IngestDigestLockEntry: + lock: asyncio.Lock + users: int = 0 + + @dataclass(frozen=True) class MemeSearchItem: uid: str diff --git a/src/Undefined/memes/search.py b/src/Undefined/memes/search.py index ba9d5de1..abd2be24 100644 --- a/src/Undefined/memes/search.py +++ b/src/Undefined/memes/search.py @@ -3,23 +3,16 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass -from datetime import datetime import logging -import math -import mimetypes from pathlib import Path -import re from typing import TYPE_CHECKING, Any -from openai import APIConnectionError, APIStatusError, APITimeoutError -from PIL import Image from Undefined.attachments import AttachmentRecord +from Undefined.memes._image_utils import _now_iso from Undefined.memes.models import ( MemeRecord, MemeSearchItem, - normalize_string_list, ) from Undefined.utils.message_targets import resolve_message_target from Undefined.utils.coerce import safe_int @@ -32,119 +25,6 @@ logger = logging.getLogger(__name__) -_IMAGE_EXTENSIONS_BY_MIME = { - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/bmp": ".bmp", - "image/svg+xml": ".svg", -} -_TAG_SPLIT_RE = re.compile(r"[,,\n]+") - - -def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") - - -def _guess_suffix(path: Path, mime_type: str) -> str: - suffix = path.suffix.lower() - if suffix: - return suffix - guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) - if guessed: - return guessed - mime_guess = mimetypes.guess_extension(mime_type or "") - if mime_guess: - return mime_guess.lower() - return ".bin" - - -def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: - if raw_tags is None: - return [] - if isinstance(raw_tags, str): - parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] - return normalize_string_list(parts) - return normalize_string_list(raw_tags) - - -def _is_retryable_llm_error(exc: Exception) -> bool: - """判断 LLM 调用异常是否应触发 worker 级重试。""" - if isinstance(exc, (APIConnectionError, APITimeoutError)): - return True - if isinstance(exc, APIStatusError): - return exc.status_code == 429 or exc.status_code >= 500 - return False - - -def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: - """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" - with Image.open(source_path) as image: - total = getattr(image, "n_frames", 1) - if total <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - n = min(n_frames, total) - if n <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - indices = _sample_frame_indices(total, n) - frames: list[Image.Image] = [] - for idx in indices: - image.seek(idx) - frames.append(image.convert("RGBA").copy()) - return frames - - -def _sample_frame_indices(total: int, n: int) -> list[int]: - """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" - if n >= total: - return list(range(total)) - if n == 1: - return [0] - if n == 2: - return [0, total - 1] - indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] - # 去重并保持顺序 - seen: set[int] = set() - result: list[int] = [] - for idx in indices: - if idx not in seen: - seen.add(idx) - result.append(idx) - return result - - -def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: - """将多帧拼接为网格图并保存为 PNG。""" - n = len(frames) - if n == 0: - return - if n == 1: - frames[0].save(output_path, format="PNG") - return - cols = math.ceil(math.sqrt(n)) - rows = math.ceil(n / cols) - fw, fh = frames[0].size - grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) - for i, frame in enumerate(frames): - resized = ( - frame.resize((fw, fh), Image.Resampling.LANCZOS) - if frame.size != (fw, fh) - else frame - ) - x = (i % cols) * fw - y = (i // cols) * fh - grid.paste(resized, (x, y)) - grid.save(output_path, format="PNG") - - -@dataclass -class _IngestDigestLockEntry: - lock: asyncio.Lock - users: int = 0 - class MemeSearchMixin: if TYPE_CHECKING: diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 8612cef6..1f7ba719 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -2,29 +2,34 @@ import asyncio from collections.abc import Mapping -from dataclasses import dataclass, replace -from datetime import datetime +from dataclasses import replace import hashlib import logging -import math import mimetypes from pathlib import Path -import re import shutil import threading from typing import Any from uuid import uuid4 -from openai import APIConnectionError, APIStatusError, APITimeoutError from PIL import Image from Undefined.attachments import AttachmentRecord +from Undefined.memes._image_utils import ( + _compose_grid, + _extract_gif_frames, + _guess_suffix, + _is_retryable_llm_error, + _normalize_tags, + _now_iso, +) +from Undefined.memes._image_utils import _sample_frame_indices # noqa: F401 from Undefined.memes.models import ( + IngestDigestLockEntry, MemeRecord, MemeSearchItem, MemeSourceRecord, build_search_text, - normalize_string_list, ) from Undefined.memes.store import MemeStore from Undefined.memes.vector_store import MemeVectorStore @@ -34,118 +39,13 @@ logger = logging.getLogger(__name__) -_IMAGE_EXTENSIONS_BY_MIME = { - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/bmp": ".bmp", - "image/svg+xml": ".svg", -} -_TAG_SPLIT_RE = re.compile(r"[,,\n]+") - - -def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") - - -def _guess_suffix(path: Path, mime_type: str) -> str: - suffix = path.suffix.lower() - if suffix: - return suffix - guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) - if guessed: - return guessed - mime_guess = mimetypes.guess_extension(mime_type or "") - if mime_guess: - return mime_guess.lower() - return ".bin" - - -def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: - if raw_tags is None: - return [] - if isinstance(raw_tags, str): - parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] - return normalize_string_list(parts) - return normalize_string_list(raw_tags) - - -def _is_retryable_llm_error(exc: Exception) -> bool: - """判断 LLM 调用异常是否应触发 worker 级重试。""" - if isinstance(exc, (APIConnectionError, APITimeoutError)): - return True - if isinstance(exc, APIStatusError): - return exc.status_code == 429 or exc.status_code >= 500 - return False - - -def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: - """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" - with Image.open(source_path) as image: - total = getattr(image, "n_frames", 1) - if total <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - n = min(n_frames, total) - if n <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - indices = _sample_frame_indices(total, n) - frames: list[Image.Image] = [] - for idx in indices: - image.seek(idx) - frames.append(image.convert("RGBA").copy()) - return frames - - -def _sample_frame_indices(total: int, n: int) -> list[int]: - """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" - if n >= total: - return list(range(total)) - if n == 1: - return [0] - if n == 2: - return [0, total - 1] - indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] - # 去重并保持顺序 - seen: set[int] = set() - result: list[int] = [] - for idx in indices: - if idx not in seen: - seen.add(idx) - result.append(idx) - return result - - -def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: - """将多帧拼接为网格图并保存为 PNG。""" - n = len(frames) - if n == 0: - return - if n == 1: - frames[0].save(output_path, format="PNG") - return - cols = math.ceil(math.sqrt(n)) - rows = math.ceil(n / cols) - fw, fh = frames[0].size - grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) - for i, frame in enumerate(frames): - resized = ( - frame.resize((fw, fh), Image.Resampling.LANCZOS) - if frame.size != (fw, fh) - else frame - ) - x = (i % cols) * fw - y = (i // cols) * fh - grid.paste(resized, (x, y)) - grid.save(output_path, format="PNG") - - -@dataclass -class _IngestDigestLockEntry: - lock: asyncio.Lock - users: int = 0 +__all__ = [ + "MemeService", + "_compose_grid", + "_extract_gif_frames", + "_is_retryable_llm_error", + "_sample_frame_indices", +] class MemeService: @@ -168,7 +68,7 @@ def __init__( self._attachment_registry = attachment_registry self._retrieval_runtime = retrieval_runtime # Serialize same-content ingest jobs within the process to avoid duplicates. - self._ingest_digest_locks: dict[str, _IngestDigestLockEntry] = {} + self._ingest_digest_locks: dict[str, IngestDigestLockEntry] = {} self._ingest_digest_locks_guard = asyncio.Lock() self._global_image_cache: dict[str, AttachmentRecord] = {} self._global_image_cache_lock = threading.Lock() @@ -196,11 +96,11 @@ def _cfg(self) -> Any: def _blob_dir(self) -> Path: return ensure_dir(Path(self._cfg().blob_dir)) - async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEntry: + async def _acquire_ingest_digest_lock(self, digest: str) -> IngestDigestLockEntry: async with self._ingest_digest_locks_guard: entry = self._ingest_digest_locks.get(digest) if entry is None: - entry = _IngestDigestLockEntry(lock=asyncio.Lock()) + entry = IngestDigestLockEntry(lock=asyncio.Lock()) self._ingest_digest_locks[digest] = entry entry.users += 1 try: @@ -213,7 +113,7 @@ async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEnt async def _release_ingest_digest_lock_reference( self, digest: str, - entry: _IngestDigestLockEntry, + entry: IngestDigestLockEntry, *, release_lock: bool = False, ) -> None: From 01fd294e88472e5eb57391e8e1c6ba7244e9e312 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 22:13:48 +0800 Subject: [PATCH 08/16] fix(ai): resolve skills paths from PACKAGE_ROOT after client split The ai/client/setup.py module is one level deeper than the old client.py, so Path(__file__).parents[1] pointed at Undefined/ai/ and loaded zero builtin tools. Centralize package root resolution and add regression tests. Co-authored-by: Cursor --- src/Undefined/ai/client/setup.py | 8 +-- src/Undefined/utils/paths.py | 2 + src/Undefined/utils/resources.py | 5 +- src/Undefined/webui/routes/_runtime.py | 5 +- tests/test_ai_client_setup_paths.py | 94 ++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 8 deletions(-) create mode 100644 tests/test_ai_client_setup_paths.py diff --git a/src/Undefined/ai/client/setup.py b/src/Undefined/ai/client/setup.py index 068b9cfc..9a1df7a1 100644 --- a/src/Undefined/ai/client/setup.py +++ b/src/Undefined/ai/client/setup.py @@ -28,6 +28,7 @@ GrokModelConfig, ) from Undefined.context import RequestContext +from Undefined.utils.paths import PACKAGE_ROOT from Undefined.context_resource_registry import set_context_resource_scan_paths from Undefined.end_summary_storage import EndSummaryStorage from Undefined.memory import MemoryStorage @@ -212,12 +213,11 @@ def __init__( self.current_group_id: Optional[int] = None self.current_user_id: Optional[int] = None - base_dir = Path(__file__).resolve().parents[1] - self.tool_registry = ToolRegistry(base_dir / "skills" / "tools") - self.agent_registry = AgentRegistry(base_dir / "skills" / "agents") + self.tool_registry = ToolRegistry(PACKAGE_ROOT / "skills" / "tools") + self.agent_registry = AgentRegistry(PACKAGE_ROOT / "skills" / "agents") # 初始化 Anthropic Agent Skills 注册表(可选,目录不存在时自动跳过) - anthropic_skills_dir = base_dir / "skills" / "anthropic_skills" + anthropic_skills_dir = PACKAGE_ROOT / "skills" / "anthropic_skills" dot_delimiter = self._get_runtime_config().tools_dot_delimiter self.anthropic_skill_registry = AnthropicSkillRegistry( anthropic_skills_dir, diff --git a/src/Undefined/utils/paths.py b/src/Undefined/utils/paths.py index 99124f9d..05254147 100644 --- a/src/Undefined/utils/paths.py +++ b/src/Undefined/utils/paths.py @@ -2,6 +2,8 @@ from pathlib import Path +PACKAGE_ROOT = Path(__file__).resolve().parent.parent + DATA_DIR = Path("data") CACHE_DIR = DATA_DIR / "cache" RENDER_CACHE_DIR = CACHE_DIR / "render" diff --git a/src/Undefined/utils/resources.py b/src/Undefined/utils/resources.py index b614032f..a8ce1186 100644 --- a/src/Undefined/utils/resources.py +++ b/src/Undefined/utils/resources.py @@ -8,10 +8,11 @@ import shutil import tempfile +from Undefined.utils.paths import PACKAGE_ROOT + def _candidate_paths(relative_path: str) -> list[Path]: - module_path = Path(__file__).resolve() - package_root = module_path.parents[1] + package_root = PACKAGE_ROOT candidates = [ Path.cwd() / relative_path, # If installed from wheel, extra files may live under site-packages/ diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index 371c46b9..fa26c0d8 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -65,8 +65,9 @@ def _load_top_level_agent_names(root: Path) -> set[str]: def _get_local_agent_tool_names() -> set[str]: - skills_root = Path(__file__).resolve().parents[2] / "skills" - return _load_top_level_agent_names(skills_root / "agents") + from Undefined.utils.paths import PACKAGE_ROOT + + return _load_top_level_agent_names(PACKAGE_ROOT / "skills" / "agents") def _tool_invoke_proxy_timeout_seconds(tool_name: str) -> float | None: diff --git a/tests/test_ai_client_setup_paths.py b/tests/test_ai_client_setup_paths.py new file mode 100644 index 00000000..2f1381e8 --- /dev/null +++ b/tests/test_ai_client_setup_paths.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import Undefined +from Undefined.services.commands.registry import CommandRegistry +from Undefined.skills.agents import AgentRegistry +from Undefined.skills.pipelines.registry import PipelineRegistry +from Undefined.skills.tools import ToolRegistry +from Undefined.utils.paths import PACKAGE_ROOT + +# Snapshot counts from skills/*/config.json inventory (excluding MCP). +EXPECTED_BASIC_TOOL_COUNT = 15 +EXPECTED_TOOLSET_COUNT = 53 +EXPECTED_AGENT_COUNT = 8 +EXPECTED_COMMAND_COUNT = 12 +EXPECTED_PIPELINE_COUNT = 3 + + +def test_package_root_matches_undefined_package_directory() -> None: + assert PACKAGE_ROOT == Path(Undefined.__file__).resolve().parent + + +def test_package_root_contains_skills_directories() -> None: + assert (PACKAGE_ROOT / "skills" / "tools").is_dir() + assert (PACKAGE_ROOT / "skills" / "agents").is_dir() + assert (PACKAGE_ROOT / "skills" / "anthropic_skills").is_dir() + assert (PACKAGE_ROOT / "skills" / "toolsets").is_dir() + assert (PACKAGE_ROOT / "skills" / "pipelines").is_dir() + assert (PACKAGE_ROOT / "skills" / "commands").is_dir() + + +def test_setup_wrong_path_does_not_exist() -> None: + """Regression: ai/client/setup.py used parents[1] and pointed at ai/skills.""" + wrong_root = Path(__file__).resolve().parents[2] / "src" / "Undefined" / "ai" + assert not (wrong_root / "skills" / "tools").exists() + + +def test_tool_registry_loads_all_skill_directories() -> None: + registry = ToolRegistry(PACKAGE_ROOT / "skills" / "tools") + basic = [name for name in registry._items if "." not in name] + toolsets = [ + name for name in registry._items if "." in name and not name.startswith("mcp.") + ] + + assert len(basic) == EXPECTED_BASIC_TOOL_COUNT + assert len(toolsets) == EXPECTED_TOOLSET_COUNT + assert len(registry._items) == EXPECTED_BASIC_TOOL_COUNT + EXPECTED_TOOLSET_COUNT + + tool_dirs = [ + item.name + for item in (PACKAGE_ROOT / "skills" / "tools").iterdir() + if item.is_dir() and (item / "config.json").exists() + ] + assert len(tool_dirs) == EXPECTED_BASIC_TOOL_COUNT + assert set(basic) == set(tool_dirs) + + +def test_all_registered_tools_import_handlers() -> None: + registry = ToolRegistry(PACKAGE_ROOT / "skills" / "tools") + errors: list[str] = [] + + for name, item in sorted(registry._items.items()): + try: + registry._load_handler_for_item(item) + if item.handler is None: + errors.append(f"{name}: handler is None") + except Exception as exc: + errors.append(f"{name}: {exc}") + + assert errors == [] + + +def test_agent_registry_loads_expected_agents() -> None: + registry = AgentRegistry(PACKAGE_ROOT / "skills" / "agents") + assert len(registry._items) == EXPECTED_AGENT_COUNT + + +def test_command_registry_loads_expected_commands() -> None: + registry = CommandRegistry(PACKAGE_ROOT / "skills" / "commands") + registry.load_commands() + assert len(registry._commands) == EXPECTED_COMMAND_COUNT + + +def test_pipeline_registry_loads_expected_pipelines() -> None: + async def _load() -> PipelineRegistry: + registry = PipelineRegistry(PACKAGE_ROOT / "skills" / "pipelines") + await registry.load_items_async() + return registry + + registry = asyncio.run(_load()) + assert len(registry._items) == EXPECTED_PIPELINE_COUNT + assert set(registry._items) == {"arxiv", "bilibili", "github"} From 70275621e61d550db7903e32409f1c20f0098333 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 22:31:24 +0800 Subject: [PATCH 09/16] docs: align library API exports and refresh README embed section Export set_config from the root package, update python-api.md, and expand README with a simpler core-feature bullet plus a Skills-focused embed example. Co-authored-by: Cursor --- README.md | 92 ++++++++++++++++++++------------ docs/python-api.md | 6 ++- src/Undefined/__init__.py | 2 + tests/test_public_api_imports.py | 8 +++ 4 files changed, 73 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 6d2497ae..2043a552 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ ## ⚡ 核心特性 - **Skills 架构**:全新设计的技能系统,将基础工具(Tools)与智能代理(Agents)分层管理,支持自动发现与注册。 +- **可嵌入 Python 库**:`pip install Undefined-bot` 后可 import 配置、`AIClient`、Skills 与认知记忆等组件,无需启动 Bot CLI。详见 [Python 库 API 参考](docs/python-api.md)。 - **Skills 热重载**:自动扫描 `skills/` 目录,检测到变更后即时重载工具与 Agent,无需重启服务。 - **三层分层记忆架构**:创新的分层记忆系统,模拟人类记忆机制—— - **短期记忆**(`end.memo`):每轮对话结束自动记录便签备忘,最近 N 条始终注入,保持短期连续性,零配置开箱即用 @@ -77,10 +78,10 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将各个模块的深入解析与高阶玩法整理成了专题游览图。这里是开启探索的钥匙: - ⚙️ **[安装与部署指南](docs/deployment.md)**:不管你是需要 `pip` 无脑一键安装,还是源码二次开发,这里的排坑指南应有尽有。 -- 📦 **[Python 库 API 参考](docs/python-api.md)**:作为库嵌入时的 import 路径、`Config.from_mapping` / `set_config` 与公共 API 符号表。 +- 📦 **[Python 库 API 参考](docs/python-api.md)**:根包 lazy re-export、`Config.from_mapping` / `set_config`、公共 API 符号表与嵌入示例。 - 🖥️ **[WebUI 使用指南](docs/webui-guide.md)**:管理控制台功能一览——配置编辑、日志查看、认知记忆管理、表情包库、AI 对话与系统监控。 - 🧭 **[Management API 与远程管理](docs/management-api.md)**:WebUI / App 共用的管理接口、认证、配置/日志/Bot 控制与引导探针说明。 -- 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置。 +- 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置;库嵌入见 [§2 库嵌入配置](docs/configuration.md#2-库嵌入配置)。 - 😶 **[表情包系统 (Memes)](docs/memes.md)**:查看表情包两阶段判定管线、统一图片 `uid` 发送机制、检索模式及库存管理说明。 - 💡 **[交互与使用手册](docs/usage.md)**:包含实用的对话示例、多模态解析用法,以及群管家必备的管理员`/指令`。 - 📝 **[版本变更记录](CHANGELOG.md)**:查看按版本整理的更新摘要,也可在运行时使用 `/changelog` 查询。 @@ -92,41 +93,12 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将 - 🌐 **[Runtime API 与 OpenAPI](docs/openapi.md)**:主进程 Runtime API、鉴权、探针、记忆/侧写查询和运行态集成说明。 - 🏗️ **[构建指南](docs/build.md)**:Python 包、WebUI、跨平台 App、Android 与 Release 工作流的构建说明。 - 🔧 **[运维脚本](scripts/README.md)**:嵌入模型更换后的向量库重嵌入等维护工具。 -- 👨‍💻 **[开发者与拓展中心](docs/development.md)**:代码结构剖析和开发新 Agent 的流程参考及自检命令。 +- 👨‍💻 **[开发者与拓展中心](docs/development.md)**:代码结构剖析、模块拆分后的目录树、开发新 Agent 的流程参考及自检命令。 - **[核心技能系统 (Skills) 解析](src/Undefined/skills/README.md)**:全景式掌握什么是 Skills 架构、怎样定制原子工具与子智能体。 - **[callable.json 共享授权说明](docs/callable.md)**:细粒度管控 Agent 之间的相互调用与工具越权防范。 --- -## 作为 Python 库使用 - -除 CLI / WebUI 部署外,Undefined 也可作为 Python 库嵌入到其他应用或测试中,复用配置系统、`AIClient`、Skills 注册表、认知记忆、知识库等组件。 - -```bash -pip install Undefined-bot # 或 uv sync(源码) -``` - -```python -from Undefined.config import Config, set_config - -cfg = Config.from_mapping( - { - "onebot": {"ws_url": "ws://127.0.0.1:3001"}, - "models": { - "chat": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, - "vision": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, - "agent": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, - }, - }, - strict=False, -) -set_config(cfg) # opt-in;CLI 启动链不调用 -``` - -完整 import 路径、公共 API 表与嵌入示例见 **[Python 库 API 参考](docs/python-api.md)**;程序化配置详见 [配置详解 — 库嵌入配置](docs/configuration.md#2-库嵌入配置)。 - ---- - ## ⚡ 快速开始 (源码模式) > 👶 **新手必看**:如果您是首次部署此类项目或不熟悉 Git/环境配置,**强烈建议直接前往 [《详细安装与部署指南》](docs/deployment.md)** 阅读手把手教程,避免遇到常见报错。 @@ -155,6 +127,58 @@ uv run Undefined-webui --- +## 作为 Python 库使用 + +除 Bot CLI 外,Undefined 也可嵌入脚本、测试或其它服务,直接复用 **Skills 注册表**、**认知记忆**、**知识库**、**AIClient** 等运行时(与 CLI 启动链隔离)。 + +```bash +pip install Undefined-bot # 源码开发:uv sync +``` + +```python +import asyncio + +# 根包 lazy re-export:与 CLI 共用同一套运行时组件 +from Undefined import AgentRegistry, Config, ToolRegistry, set_config + +# 内存构建配置,测试/嵌入场景无需 config.toml +cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, + "vision": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, + "agent": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"}, + }, + }, + strict=False, +) +set_config(cfg) # opt-in 注入全局单例;CLI 启动链不会调用 + +# 自动扫描 skills/:tools + toolsets(end / group.* / cognitive.* …) +tools = ToolRegistry() +# 自动扫描 skills/agents/:web_agent、code_delivery_agent … +agents = AgentRegistry() + +async def main() -> None: + # 直接调用原子工具,无需启动 OneBot + lunar_time = await tools.execute( + "get_current_time", + {"format": "text", "include_lunar": True}, + context={}, + ) + print(lunar_time) + print(len(tools.get_schema()), "tools,", len(agents.get_schema()), "agents") + +asyncio.run(main()) +``` + +- [Python 库 API 参考](docs/python-api.md) — 根包符号表、shim 路径、`AIClient` / `CognitiveService` 等嵌入示例 +- [配置详解 — 库嵌入配置](docs/configuration.md#2-库嵌入配置) — `from_mapping` / `Config.builder` +- [开发者与拓展中心](docs/development.md) — 模块结构与自检命令 + +--- + ## 风险提示与免责声明 1. **账号风控与封禁风险(含 QQ 账号)** @@ -170,7 +194,9 @@ uv run Undefined-webui 本项目遵循 [MIT License](LICENSE) 开源协议。 -感谢 **NagaAgent** 子模块作者及社区支持:[NagaAgent - A simple yet powerful agent framework.](https://github.com/Xxiii8322766509/NagaAgent)。 +感谢 **NagaAgent** 子模块作者及社区提供的支持与鼓励:[NagaAgent - A simple yet powerful agent framework.](https://github.com/Xxiii8322766509/NagaAgent)! + +感谢在开发过程中为我提供各种灵感的群友们!
⭐ 如果这个项目对您有帮助,请考虑给我们一个 Star diff --git a/docs/python-api.md b/docs/python-api.md index 1d84c9ad..c083e9df 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -18,11 +18,13 @@ pip install Undefined-bot Python 版本要求:`3.11` ~ `3.13`。 +包内附带 [`py.typed`](../src/Undefined/py.typed) 标记,mypy / Pyright / IDE 可直接消费类型信息。 + --- ## 推荐 import 路径 -### 根包(`stable`,Phase 3 lazy re-export) +### 根包(`stable`,lazy re-export) 以下符号承诺通过 `from Undefined import …` 长期稳定(完整清单见下文 [公共 API 符号表](#公共-api-符号表)): @@ -47,7 +49,7 @@ from Undefined import ( ) ``` -> **注意**:Phase 3 之前根包 lazy re-export 可能尚未全部启用;若 `from Undefined import X` 失败,请使用下方子包路径,二者语义等价。 +根包符号与 [公共 API 符号表](#公共-api-符号表) 一致;若需更细粒度导入,可使用下方子包路径,二者语义等价。 ### 子包(`stable` / `subpackage`) diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 5cf4027a..b3151b01 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -11,6 +11,7 @@ "__version__", "Config", "get_config", + "set_config", "AIClient", "ToolRegistry", "AgentRegistry", @@ -29,6 +30,7 @@ _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "Config": ("Undefined.config", "Config"), "get_config": ("Undefined.config", "get_config"), + "set_config": ("Undefined.config", "set_config"), "AIClient": ("Undefined.ai", "AIClient"), "ToolRegistry": ("Undefined.skills.tools", "ToolRegistry"), "AgentRegistry": ("Undefined.skills.agents", "AgentRegistry"), diff --git a/tests/test_public_api_imports.py b/tests/test_public_api_imports.py index 41b87886..520a3cbf 100644 --- a/tests/test_public_api_imports.py +++ b/tests/test_public_api_imports.py @@ -11,6 +11,7 @@ _ROOT_EXPORTS: tuple[str, ...] = ( "Config", "get_config", + "set_config", "AIClient", "ToolRegistry", "AgentRegistry", @@ -53,6 +54,13 @@ def test_root_package_exports(symbol: str) -> None: getattr(Undefined, symbol) +def test_root_package_all_matches_exports() -> None: + import Undefined + + expected = {"__version__", *_ROOT_EXPORTS} + assert set(Undefined.__all__) == expected + + def test_root_package_lazy_import_does_not_load_cli_modules() -> None: import sys From ff8261e7eff4c570367a03115f473e6415e1d55e Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 23 May 2026 23:05:20 +0800 Subject: [PATCH 10/16] refactor: remove shadowed dead modules and unify set_config Delete unreachable monolith files shadowed by subpackages, sync set_config with ConfigManager, declare py.typed in the wheel, and add layout regressions. Co-authored-by: Cursor --- ARCHITECTURE.md | 8 +- docs/configuration.md | 4 +- docs/development.md | 19 +- docs/message-batching.md | 6 +- docs/python-api.md | 12 +- pyproject.toml | 1 + src/Undefined/ai/client.py | 1700 -------------- src/Undefined/ai/client/__init__.py | 2 +- src/Undefined/ai/llm.py | 2063 ----------------- src/Undefined/ai/llm/__init__.py | 3 +- src/Undefined/ai/multimodal.py | 893 ------- src/Undefined/ai/multimodal/__init__.py | 3 +- src/Undefined/ai/prompts.py | 846 ------- src/Undefined/ai/prompts/__init__.py | 2 +- src/Undefined/ai/prompts/builder.py | 2 +- src/Undefined/api/routes/naga.py | 897 ------- src/Undefined/attachments.py | 1680 -------------- src/Undefined/cognitive/historian.py | 1043 --------- src/Undefined/cognitive/service.py | 898 ------- src/Undefined/config/__init__.py | 1 + src/Undefined/config/manager.py | 4 + src/Undefined/handlers.py | 1400 ----------- src/Undefined/memes/_service.py | 153 -- src/Undefined/onebot.py | 924 -------- src/Undefined/services/ai_coordinator.py | 2 +- .../services/coordinator/__init__.py | 2 +- src/Undefined/services/message_batcher.py | 810 ------- src/Undefined/skills/agents/runner.py | 384 --- tests/test_ai_client_setup_paths.py | 65 +- tests/test_config_from_mapping.py | 2 + tests/test_handlers_meme_annotation.py | 2 +- tests/test_package_layout.py | 32 + tests/test_public_api_imports.py | 1 + 33 files changed, 117 insertions(+), 13747 deletions(-) delete mode 100644 src/Undefined/ai/client.py delete mode 100644 src/Undefined/ai/llm.py delete mode 100644 src/Undefined/ai/multimodal.py delete mode 100644 src/Undefined/ai/prompts.py delete mode 100644 src/Undefined/api/routes/naga.py delete mode 100644 src/Undefined/attachments.py delete mode 100644 src/Undefined/cognitive/historian.py delete mode 100644 src/Undefined/cognitive/service.py delete mode 100644 src/Undefined/handlers.py delete mode 100644 src/Undefined/memes/_service.py delete mode 100644 src/Undefined/onebot.py delete mode 100644 src/Undefined/services/message_batcher.py delete mode 100644 src/Undefined/skills/agents/runner.py create mode 100644 tests/test_package_layout.py diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 3a8c043e..a7661901 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -25,7 +25,7 @@ graph TB %% ==================== 消息处理层 ==================== subgraph MessageLayer["消息处理层 (src/Undefined/)"] - MessageHandler["MessageHandler
消息处理器
[handlers/ + handlers.py shim]"] + MessageHandler["MessageHandler
消息处理器
[handlers/]"] subgraph BilibiliModule["Bilibili 模块 (bilibili/)"] BilibiliParser["parser.py
标识符解析
• BV/AV号 • URL
• b23.tv短链 • 小程序JSON"] @@ -154,9 +154,9 @@ graph TB HistoryManager["MessageHistoryManager
消息历史管理
[utils/history.py]
• 懒加载
• 10000条限制"] MemoryStorage["MemoryStorage
置顶备忘录
[memory.py]
• 500条上限
• 自动去重"] EndSummaryStorage["EndSummaryStorage
短期总结存储
[end_summary_storage.py]"] - CognitiveService["CognitiveService
认知记忆服务
[cognitive/service.py]
• 事件检索 • 侧写读取
• 入队 memory job"] + CognitiveService["CognitiveService
认知记忆服务
[cognitive/service/]
• 事件检索 • 侧写读取
• 入队 memory job"] CognitiveJobQueue["JobQueue
认知任务队列
[cognitive/job_queue.py]
• pending/processing/failed"] - CognitiveHistorian["HistorianWorker
后台史官
[cognitive/historian.py]
• 绝对化改写 • 闸门重试
• 侧写合并(含历史事件注入)"] + CognitiveHistorian["HistorianWorker
后台史官
[cognitive/historian/]
• 绝对化改写 • 闸门重试
• 侧写合并(含历史事件注入)"] CognitiveVectorStore["CognitiveVectorStore
向量存储
[cognitive/vector_store.py]
• events/profiles
• 时间衰减加权排序
• MMR 去重"] CognitiveProfileStorage["ProfileStorage
侧写存储
[cognitive/profile_storage.py]
• users/groups Markdown
• 历史快照"] MemeSystem["MemeSystem
表情包存储
[memes/]
• worker.py (两阶段识别)
• sqlite+chromadb
• blob 持久化"] @@ -850,7 +850,7 @@ description: 从 PDF 文件中提取文本和表格,填写表单。当用户 1. **外部实体层**:用户、管理员、OneBot 协议端 (NapCat/Lagrange.Core)、大模型 API 服务商 2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py + parsers/ + load_sections/)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot/ + onebot.py shim)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块,含 naga/ 子包) -3. **消息处理层**:MessageHandler (handlers/ + handlers.py shim)、SecurityService (security.py)、CommandDispatcher (services/command.py + commands/ mixins)、MessageBatcher (services/message_batcher/ + shim)、AICoordinator (services/coordinator/ + ai_coordinator.py shim)、QueueManager (queue_manager.py)、自动处理管线 (skills/pipelines/)、Bilibili/arXiv/GitHub 解析与发送模块 +3. **消息处理层**:MessageHandler (`handlers/`)、SecurityService (security.py)、CommandDispatcher (services/command.py + commands/ mixins)、MessageBatcher (services/message_batcher/)、AICoordinator (services/coordinator/ + ai_coordinator.py 门面)、QueueManager (queue_manager.py)、自动处理管线 (skills/pipelines/)、Bilibili/arXiv/GitHub 解析与发送模块 自动提取由 `PipelineRegistry` 并行检测、并行处理全部命中的管线;发送结果写入历史后继续进入 AI 自动回复。 4. **AI 核心能力层**:AIClient (ai/client/ + client.py shim)、PromptBuilder (ai/prompts/ + prompts.py shim)、ModelRequester (ai/llm/ + llm.py shim)、ToolManager (tooling.py)、MultimodalAnalyzer (ai/multimodal/ + multimodal.py shim)、SummaryService (summaries.py)、TokenCounter (tokens.py) 5. **存储与上下文层**:MessageHistoryManager (utils/history.py, 10000条限制)、MemoryStorage (memory.py, 置顶备忘录, 500条上限)、EndSummaryStorage、CognitiveService + JobQueue + HistorianWorker + VectorStore + ProfileStorage、MemeService + MemeWorker + MemeStore + MemeVectorStore (表情包库)、FAQStorage、ScheduledTaskStorage、TokenUsageStorage (自动归档) diff --git a/docs/configuration.md b/docs/configuration.md index db808031..cefbcc04 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -120,17 +120,19 @@ cfg = ( ### 2.5 `set_config()`(opt-in) ```python -from Undefined.config import Config, get_config, set_config +from Undefined.config import Config, get_config, get_config_manager, set_config cfg = Config.from_mapping({...}, strict=False) set_config(cfg) assert get_config(strict=False) is cfg +assert get_config_manager().load(strict=False) is cfg ``` **硬约束**: - `set_config()` 仅供库嵌入 opt-in;**CLI / WebUI 启动链不得调用**。 - 未调用 `set_config()` 时,`get_config()` 仍从 CWD 加载 `./config.toml`,与独立运行 Bot 行为一致。 +- 调用 `set_config()` 会同步更新 `get_config()` 与 `get_config_manager().load()` 的缓存,避免双轨读到不同实例。 --- diff --git a/docs/development.md b/docs/development.md index bac9f503..305e4d02 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,18 +11,18 @@ Undefined 欢迎开发者参与共建和进行二次开发! ```text src/Undefined/ ├── changelog.py # CHANGELOG.md 解析与版本查询公共层 -├── ai/ # AI 运行时核心(子包 + 根级 shim:client.py / llm.py / prompts.py / multimodal.py) +├── ai/ # AI 运行时核心 │ ├── client/ # AIClient 组合:setup / queue / ask_loop │ ├── llm/ # ModelRequester、streaming、thinking、sanitize │ ├── prompts/ # PromptBuilder、system_context、cognitive 片段 │ └── multimodal/# 多模态检测、解析与分析 -├── attachments/ # 附件注册、渲染、作用域隔离(attachments.py shim) +├── attachments/ # 附件注册、渲染、作用域隔离 ├── arxiv/ # arXiv 论文解析、元信息获取、PDF 下载与发送 ├── bilibili/ # B站视频流解析、分段下载与异步发送 ├── cognitive/ # 认知记忆系统(service/ 门面 + historian/ 史官后台) ├── config/ # 配置系统(parsers/ 域解析 + load_sections/ 分段加载 + loader shim) -├── handlers/ # OneBot 消息分流(message_flow / poke / repeat / auto_extract;handlers.py shim) -├── onebot/ # OneBot WebSocket 客户端(onebot.py shim) +├── handlers/ # OneBot 消息分流(message_flow / poke / repeat / auto_extract) +├── onebot/ # OneBot WebSocket 客户端 ├── skills/ # 技能插件核心目录 (存放所有的工具与智能体) │ ├── tools/ # 基础原子的工具 (独立的功能单元,如读写文件、网络请求等) │ ├── toolsets/ # 聚合工具集 (分组后的工具组) @@ -35,19 +35,16 @@ src/Undefined/ │ ├── routes/ # 路由子模块 (chat, tools, naga/, system, memes, memory, cognitive, health) │ ├── app.py # aiohttp 服务主入口 (薄包装委派到 routes/) │ └── _openapi.py # OpenAPI 文档生成 -├── memes/ # 表情包库 (_service 门面 + ingest/ + search/ + store + vector_store) +├── memes/ # 表情包库 (service + ingest/ + search/ + store + vector_store) ├── services/ # 核心运行服务 -│ ├── coordinator/ # AICoordinator mixins(ai_coordinator.py shim) +│ ├── coordinator/ # AICoordinator mixins(ai_coordinator.py 门面) │ ├── commands/ # CommandDispatcher mixins(stats / bugfix) -│ ├── message_batcher/ # 同 sender 短时合并(message_batcher.py shim) +│ ├── message_batcher/ # 同 sender 短时合并 │ ├── command.py # 命令分发门面 + shim 组合 │ ├── queue_manager.py # 车站-列车队列 │ └── security.py # 注入检测与速率限制 ├── utils/ # 通用支持工具组 (__init__.py 聚合 io/paths/resources;io.py 异步原子读写, history.py, coerce.py 类型强转) -├── handlers.py # compatibility shim → handlers/ -├── onebot.py # compatibility shim → onebot/ -├── attachments.py # compatibility shim → attachments/ -└── ai_coordinator.py # compatibility shim → services/coordinator/ +└── py.typed # PEP 561 类型标记(wheel 通过 pyproject force-include 打包) ``` ## 开发指南 diff --git a/docs/message-batching.md b/docs/message-batching.md index c53d5aa5..fb913fc2 100644 --- a/docs/message-batching.md +++ b/docs/message-batching.md @@ -11,7 +11,7 @@ - `extend`(默认):每条新消息重置定时器,并以 `max_window_seconds` 作为硬顶。 - `fixed`:定时器从首条算起;窗口期结束统一发车。 - **硬顶**:`max_window_seconds` 防止极端情况下窗口被无限延长(`0` = 不限制,仅靠 `window_seconds` + `max_messages_per_batch` 触发发车);`max_messages_per_batch` 达到立即发车(`0` = 不限)。 -- **历史记录不变**:每条消息照旧由 `handlers.py` 写入 history;batcher 只决定何时调用 AI。 +- **历史记录不变**:每条消息照旧由 `handlers/message_flow` 写入 history;batcher 只决定何时调用 AI。 - **拍一拍永远旁路**:拍一拍触发不进入 batcher,直接立即处理。 - **群聊 @bot 规则**: - 当前桶**为空**且新消息 @bot → 进入 buffer,本批走 `add_group_mention_request`(提及优先级)。 @@ -84,9 +84,9 @@ allow_cancel_after_send = false ## 相关文件 -- 实现:[src/Undefined/services/message_batcher.py](src/Undefined/services/message_batcher.py) +- 实现:[src/Undefined/services/message_batcher/](src/Undefined/services/message_batcher/) - 接入:[src/Undefined/services/ai_coordinator.py](src/Undefined/services/ai_coordinator.py) 中 `handle_auto_reply` / `handle_private_reply` / `_dispatch_grouped_request` -- 创建/注入:[src/Undefined/handlers.py](src/Undefined/handlers.py) +- 创建/注入:[src/Undefined/handlers/message_flow.py](src/Undefined/handlers/message_flow.py) - 关停 flush:[src/Undefined/main.py](src/Undefined/main.py) - 热更新:[src/Undefined/config/hot_reload.py](src/Undefined/config/hot_reload.py) - 提示词:[res/prompts/undefined.xml](res/prompts/undefined.xml)、[res/prompts/undefined_nagaagent.xml](res/prompts/undefined_nagaagent.xml) diff --git a/docs/python-api.md b/docs/python-api.md index c083e9df..08949b09 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -67,9 +67,9 @@ from Undefined import ( | subpackage | `Undefined.skills.anthropic_skills` | `AnthropicSkillRegistry` | | subpackage | `Undefined.mcp` | `MCPToolRegistry`, `MCPToolSetRegistry` | -### 向后兼容 shim 路径 +### 向后兼容 import 路径 -拆分后旧路径仍可用(测试与下游代码可继续引用): +拆分后下列 import 路径仍可用(指向子包公开 API,而非并列的 `.py` 单文件): ```python from Undefined.config.loader import Config # → Undefined.config.Config @@ -82,7 +82,7 @@ from Undefined.memes.service import MemeService from Undefined.api.app import RuntimeAPIServer ``` -拆分后的各模块旁保留 compatibility shim 文件,旧 import 路径仍可用(见各 shim 文件顶部的 re-export)。 +> **注意**:请勿在同名包目录旁保留完整 `.py` 副本(如 `handlers.py` + `handlers/`)。Python 只会加载包目录,并列单文件会成为不可达死代码;仓库通过 `tests/test_package_layout.py` 回归检测。 ### 内部模块(不承诺稳定) @@ -160,21 +160,23 @@ cfg = ( ### `set_config`(opt-in 单例注入) -将已构建的 `Config` 注入全局单例,供 `get_config()` 读取: +将已构建的 `Config` 注入全局单例,供 `get_config()` 与 `get_config_manager().load()` 读取: ```python -from Undefined.config import Config, get_config, set_config +from Undefined.config import Config, get_config, get_config_manager, set_config cfg = Config.from_mapping({...}, strict=False) set_config(cfg) assert get_config(strict=False) is cfg +assert get_config_manager().load(strict=False) is cfg ``` **约束**: - `set_config()` 仅供库嵌入 opt-in 使用;**CLI / WebUI 启动链不得调用**。 - 未调用 `set_config()` 时,`get_config()` 仍走 CWD 下 `./config.toml`(与 CLI 行为一致)。 +- 注入后会同步 `ConfigManager` 缓存,库嵌入代码不应再混用独立的 `Config.load()` 实例。 ### 纯环境变量构建 diff --git a/pyproject.toml b/pyproject.toml index 2c818127..3b6d44d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ sources = ["src"] [tool.hatch.build.targets.wheel.force-include] "CHANGELOG.md" = "Undefined/CHANGELOG.md" +"src/Undefined/py.typed" = "Undefined/py.typed" [tool.hatch.build.targets.sdist] include = [ diff --git a/src/Undefined/ai/client.py b/src/Undefined/ai/client.py deleted file mode 100644 index 7f4b5084..00000000 --- a/src/Undefined/ai/client.py +++ /dev/null @@ -1,1700 +0,0 @@ -"""AI 客户端入口。""" - -from __future__ import annotations - -import asyncio -import html -import logging -import re -from pathlib import Path -from typing import Any, Awaitable, Callable, Optional, Protocol, TYPE_CHECKING -from uuid import uuid4 - -import httpx - -from Undefined.attachments import AttachmentRegistry -from Undefined.ai.llm import ModelRequester -from Undefined.ai.model_selector import ModelSelector -from Undefined.ai.multimodal import MultimodalAnalyzer -from Undefined.ai.prompts import PromptBuilder -from Undefined.ai.crawl4ai_support import get_crawl4ai_capabilities -from Undefined.ai.queue_budget import ( - compute_queued_llm_timeout_seconds, - resolve_effective_retry_count, -) -from Undefined.ai.parsing import extract_choices_content -from Undefined.ai.summaries import SummaryService -from Undefined.services.message_summary_fetch import fetch_session_messages -from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY -from Undefined.ai.tokens import TokenCounter -from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT, ToolManager -from Undefined.config import ( - ChatModelConfig, - VisionModelConfig, - AgentModelConfig, - GrokModelConfig, - Config, -) -from Undefined.context import RequestContext -from Undefined.context_resource_registry import set_context_resource_scan_paths -from Undefined.end_summary_storage import EndSummaryStorage -from Undefined.memory import MemoryStorage -from Undefined.skills.agents import AgentRegistry -from Undefined.skills.agents.intro_generator import ( - AgentIntroGenConfig, - AgentIntroGenerator, -) -from Undefined.skills.anthropic_skills import AnthropicSkillRegistry -from Undefined.skills.tools import ToolRegistry -from Undefined.services.queue_manager import ( - ALL_QUEUE_LANES, - QUEUE_LANE_BACKGROUND, - QUEUE_LANE_GROUP_MENTION, - QUEUE_LANE_GROUP_NORMAL, - QUEUE_LANE_GROUP_SUPERADMIN, - QUEUE_LANE_PRIVATE, - QUEUE_LANE_SUPERADMIN, -) -from Undefined.token_usage_storage import TokenUsageStorage -from Undefined.utils.logging import log_debug_json, redact_string -from Undefined.utils.tool_calls import parse_tool_arguments - -logger = logging.getLogger(__name__) - -# 模型返回纯文本但未调用 tool 时,追加到 messages 的纠正提示(不写死具体 tool) -MISSING_TOOL_CALL_RETRY_HINT = ( - "【系统提示】你上一轮输出了纯文本且未调用任何工具。" - "本环境必须通过工具调用来完成对外动作与结束本轮处理。" - "请结合上文完整对话历史与已有 tool 返回结果,自行决定下一步应调用的工具;" - "不要直接以纯文本作为最终对外回复。" -) - - -_CONTENT_TAG_PATTERN = re.compile( - r"(.*?)", re.DOTALL | re.IGNORECASE -) - -_INVALID_TOOL_CALL_CONTENT = ( - "无效工具调用:工具名称为空或格式非法,系统已跳过执行。" - "请使用可用工具名重新调用,或调用 end 结束本轮。" -) - - -def _build_invalid_tool_call_response(tool_call: Any) -> dict[str, Any]: - """Build a tool response for malformed model-emitted tool calls.""" - call_id = "" - tool_name = "" - if isinstance(tool_call, dict): - call_id = str(tool_call.get("id", "") or "") - function = tool_call.get("function") - if isinstance(function, dict): - tool_name = str(function.get("name", "") or "").strip() - return { - "role": "tool", - "tool_call_id": call_id, - "name": tool_name, - "content": _INVALID_TOOL_CALL_CONTENT, - } - - -class SendMessageCallback(Protocol): - def __call__( - self, message: str, reply_to: int | None = None - ) -> Awaitable[None]: ... - - -class SendPrivateMessageCallback(Protocol): - def __call__( - self, user_id: int, message: str, reply_to: int | None = None - ) -> Awaitable[None]: ... - - -# 尝试导入 langchain SearxSearchWrapper -if TYPE_CHECKING: - from langchain_community.utilities import ( - SearxSearchWrapper as SearxSearchWrapperType, - ) -else: - SearxSearchWrapperType = object - -_SearxSearchWrapper: type[SearxSearchWrapperType] | None -try: - from langchain_community.utilities import SearxSearchWrapper as _SearxSearchWrapper - - _SEARX_AVAILABLE = True -except Exception: - _SearxSearchWrapper = None - _SEARX_AVAILABLE = False - logger.warning( - "[初始化] langchain_community 未安装或 SearxSearchWrapper 不可用,搜索功能将禁用" - ) - - -def _attachment_remote_download_max_bytes(runtime_config: Config) -> int: - value = int(runtime_config.attachment_remote_download_max_size_mb) - return max(0, value) * 1024 * 1024 - - -def _attachment_cache_max_bytes(runtime_config: Config) -> int: - value = int(runtime_config.attachment_cache_max_total_size_mb) - return max(0, value) * 1024 * 1024 - - -def _attachment_cache_max_age_seconds(runtime_config: Config) -> int: - value = int(runtime_config.attachment_cache_max_age_days) - return max(0, value) * 24 * 60 * 60 - - -def _resolve_summary_model_config( - runtime_config: Config | None, - fallback: AgentModelConfig, -) -> AgentModelConfig: - if runtime_config is None: - return fallback - if not getattr(runtime_config, "summary_model_configured", False): - return fallback - summary_model = getattr(runtime_config, "summary_model", None) - if isinstance(summary_model, AgentModelConfig): - return summary_model - return fallback - - -class AIClient: - """AI 模型客户端""" - - def __init__( - self, - chat_config: ChatModelConfig, - vision_config: VisionModelConfig, - agent_config: AgentModelConfig, - memory_storage: Optional[MemoryStorage] = None, - end_summary_storage: Optional[EndSummaryStorage] = None, - bot_qq: int = 0, - runtime_config: Config | None = None, - cognitive_service: Any = None, - ) -> None: - """初始化 AI 客户端 - - 参数: - chat_config: 对话模型配置 - vision_config: 视觉模型配置 - agent_config: 智能体模型配置 - memory_storage: 长期记忆存储 - end_summary_storage: 短期回忆存储 - bot_qq: 机器人自身的 QQ 号 - """ - self.chat_config = chat_config - self.vision_config = vision_config - self.agent_config = agent_config - self.bot_qq = bot_qq - self.runtime_config = runtime_config - self.memory_storage = memory_storage - self._end_summary_storage = end_summary_storage or EndSummaryStorage() - self._crawl4ai_capabilities = get_crawl4ai_capabilities() - - self._http_client = httpx.AsyncClient(timeout=480.0) - self._token_usage_storage = TokenUsageStorage() - self._requester = ModelRequester(self._http_client, self._token_usage_storage) - self._token_counter = TokenCounter() - self._knowledge_manager: Any = None - self._cognitive_service: Any = cognitive_service - self._meme_service: Any = None - if self.runtime_config is not None: - self.attachment_registry = AttachmentRegistry( - http_client=self._http_client, - remote_download_max_bytes=_attachment_remote_download_max_bytes( - self.runtime_config - ), - max_cache_bytes=_attachment_cache_max_bytes(self.runtime_config), - max_records=self.runtime_config.attachment_cache_max_records, - max_age_seconds=_attachment_cache_max_age_seconds(self.runtime_config), - url_reference_max_records=( - self.runtime_config.attachment_url_reference_max_records - ), - url_max_length=self.runtime_config.attachment_url_max_length, - ) - else: - self.attachment_registry = AttachmentRegistry(http_client=self._http_client) - - self._send_private_message_callback: Optional[SendPrivateMessageCallback] = None - self._send_image_callback: Optional[ - Callable[[int, str, str], Awaitable[None]] - ] = None - - # 当前群聊ID和用户ID(用于send_message工具) - self.current_group_id: Optional[int] = None - self.current_user_id: Optional[int] = None - - base_dir = Path(__file__).resolve().parents[1] - self.tool_registry = ToolRegistry(base_dir / "skills" / "tools") - self.agent_registry = AgentRegistry(base_dir / "skills" / "agents") - - # 初始化 Anthropic Agent Skills 注册表(可选,目录不存在时自动跳过) - anthropic_skills_dir = base_dir / "skills" / "anthropic_skills" - dot_delimiter = self._get_runtime_config().tools_dot_delimiter - self.anthropic_skill_registry = AnthropicSkillRegistry( - anthropic_skills_dir, dot_delimiter=dot_delimiter - ) - - self.tool_manager = ToolManager( - self.tool_registry, - self.agent_registry, - anthropic_skill_registry=self.anthropic_skill_registry, - ) - - self.model_selector = ModelSelector() - - # 绑定上下文资源扫描路径(基于注册表 watch_paths) - scan_paths = [ - p - for p in ( - self.tool_registry._watch_paths + self.agent_registry._watch_paths - ) - if p.exists() - ] - set_context_resource_scan_paths(scan_paths) - logger.debug( - "[初始化] 上下文资源扫描路径已绑定: count=%s", - len(scan_paths), - ) - - # Agent intro 生成器(延迟初始化,需要外部设置 queue_manager) - self._agent_intro_generator: Any | None = None - self._agent_intro_task: asyncio.Task[None] | None = None - self._queue_manager: Any | None = None - self._intro_config: Any | None = None - # 后台 LLM 调用挂起表(走队列的后台请求) - self._pending_llm_calls: dict[ - str, tuple[asyncio.Event, dict[str, Any] | Exception | None] - ] = {} - - # 后台任务引用集合(防止被 GC) - self._background_tasks: set[asyncio.Task[Any]] = set() - - runtime_config = self._get_runtime_config() - self._intro_config = AgentIntroGenConfig( - enabled=runtime_config.agent_intro_autogen_enabled, - queue_interval_seconds=runtime_config.agent_intro_autogen_queue_interval, - max_tokens=runtime_config.agent_intro_autogen_max_tokens, - cache_path=Path(runtime_config.agent_intro_hash_path), - ) - - # 启动 skills 热重载 - hot_reload_enabled = runtime_config.skills_hot_reload - if hot_reload_enabled: - interval = runtime_config.skills_hot_reload_interval - debounce = runtime_config.skills_hot_reload_debounce - self.tool_registry.start_hot_reload(interval=interval, debounce=debounce) - self.agent_registry.start_hot_reload(interval=interval, debounce=debounce) - self.anthropic_skill_registry.start_hot_reload( - interval=interval, debounce=debounce - ) - logger.info( - "[初始化] 技能热重载已启用: interval=%.2fs debounce=%.2fs", - interval, - debounce, - ) - else: - logger.info("[初始化] 技能热重载已禁用") - - # 初始化搜索 wrapper - self._search_wrapper: Optional[Any] = None - if _SEARX_AVAILABLE and _SearxSearchWrapper is not None: - searxng_url = runtime_config.searxng_url - if searxng_url: - try: - self._search_wrapper = _SearxSearchWrapper( - searx_host=searxng_url, k=10 - ) - logger.info( - "[初始化] SearxSearchWrapper 初始化成功: url=%s k=10", - redact_string(searxng_url), - ) - except Exception as exc: - logger.warning("[初始化] SearxSearchWrapper 初始化失败: %s", exc) - else: - logger.info("[初始化] SEARXNG_URL 未配置,搜索功能禁用") - - if self._crawl4ai_capabilities.available: - logger.info("[初始化] crawl4ai 可用,网页获取功能已启用") - else: - detail = self._crawl4ai_capabilities.error - if detail: - logger.warning( - "[初始化] crawl4ai 不可用,网页获取功能将禁用: %s", - detail, - ) - else: - logger.warning("[初始化] crawl4ai 不可用,网页获取功能将禁用") - - self._prompt_builder = PromptBuilder( - bot_qq=self.bot_qq, - memory_storage=self.memory_storage, - end_summary_storage=self._end_summary_storage, - runtime_config_getter=self._get_runtime_config, - anthropic_skill_registry=self.anthropic_skill_registry, - cognitive_service=self._cognitive_service, - ) - self._multimodal = MultimodalAnalyzer(self._requester, self.vision_config) - self._rebuild_summary_service() - - async def init_mcp_async() -> None: - try: - await self.tool_registry.initialize_mcp_toolsets() - except Exception as exc: - logger.warning("[初始化] 异步初始化 MCP 工具集失败: %s", exc) - - self._mcp_init_task = asyncio.create_task(init_mcp_async()) - - async def load_preferences_async() -> None: - try: - await self.model_selector.load_preferences() - except Exception as exc: - logger.warning("[初始化] 加载模型偏好失败: %s", exc) - - self._preferences_load_task = asyncio.create_task(load_preferences_async()) - - logger.info("[初始化] AIClient 初始化完成") - - async def close(self) -> None: - logger.info("[清理] 正在关闭 AIClient...") - - intro_gen = getattr(self, "_agent_intro_generator", None) - if intro_gen is not None: - await intro_gen.stop() - if hasattr(self, "_agent_intro_task") and self._agent_intro_task: - if not self._agent_intro_task.done(): - await self._agent_intro_task - knowledge_manager = getattr(self, "_knowledge_manager", None) - if knowledge_manager is not None and hasattr(knowledge_manager, "stop"): - try: - await knowledge_manager.stop() - except Exception as exc: - logger.warning("[清理] 关闭知识库管理器失败: %s", exc) - self._knowledge_manager = None - cognitive_service = getattr(self, "_cognitive_service", None) - if cognitive_service is not None: - if hasattr(cognitive_service, "stop"): - try: - await cognitive_service.stop() - except Exception as exc: - logger.warning("[清理] 关闭认知记忆服务失败: %s", exc) - self._cognitive_service = None - if hasattr(self, "_prompt_builder") and self._prompt_builder is not None: - self._prompt_builder.set_cognitive_service(None) - - if hasattr(self, "_mcp_init_task") and not self._mcp_init_task.done(): - await self._mcp_init_task - - if hasattr(self, "tool_registry"): - await self.tool_registry.stop_hot_reload() - await self.tool_registry.close_mcp_toolsets() - if hasattr(self, "agent_registry"): - await self.agent_registry.stop_hot_reload() - if hasattr(self, "anthropic_skill_registry"): - await self.anthropic_skill_registry.stop_hot_reload() - - attachment_registry = getattr(self, "attachment_registry", None) - if attachment_registry is not None and hasattr(attachment_registry, "flush"): - try: - await attachment_registry.flush() - except Exception as exc: - logger.warning("[清理] 刷新附件注册表失败: %s", exc) - - if hasattr(self, "_http_client"): - logger.info("[清理] 正在关闭 AIClient HTTP 客户端...") - await self._http_client.aclose() - - logger.info("[清理] AIClient 已关闭") - - def _resolve_queue_lane(self, queue_lane: Any = None) -> str: - queue_lane_text = str(queue_lane or "").strip().lower() - if queue_lane_text in ALL_QUEUE_LANES: - return queue_lane_text - - ctx = RequestContext.current() - if ctx is not None: - ctx_lane = str(ctx.get_resource("queue_lane") or "").strip().lower() - if ctx_lane in ALL_QUEUE_LANES: - return ctx_lane - - runtime_config = self._get_runtime_config() - superadmin_qq = int(getattr(runtime_config, "superadmin_qq", 0) or 0) - if ctx.request_type == "private": - if superadmin_qq > 0 and ( - ctx.user_id == superadmin_qq or ctx.sender_id == superadmin_qq - ): - return QUEUE_LANE_SUPERADMIN - return QUEUE_LANE_PRIVATE - if ctx.request_type == "group": - if superadmin_qq > 0 and ctx.sender_id == superadmin_qq: - return QUEUE_LANE_GROUP_SUPERADMIN - if bool(ctx.get_resource("is_at_bot")): - return QUEUE_LANE_GROUP_MENTION - return QUEUE_LANE_GROUP_NORMAL - - return QUEUE_LANE_BACKGROUND - - def _get_queued_llm_wait_timeout_seconds(self) -> float: - retry_count = resolve_effective_retry_count( - self._get_runtime_config(), - getattr(self, "_queue_manager", None), - ) - return compute_queued_llm_timeout_seconds( - self._get_runtime_config(), - self.chat_config, - retry_count=retry_count, - ) - - async def submit_queued_llm_call( - self, - model_config: Any, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - tool_choice: Any = "auto", - call_type: str = "background", - max_tokens: int | None = None, - transport_state: dict[str, Any] | None = None, - queue_lane: str | None = None, - ) -> dict[str, Any]: - """将 LLM 调用投递到统一队列,走统一发车间隔和重试逻辑。 - 无 queue_manager 时降级为直接调用。""" - effective_max_tokens = ( - max_tokens - if max_tokens is not None - else getattr(model_config, "max_tokens", 4096) - ) - resolved_queue_lane = self._resolve_queue_lane(queue_lane) - if self._queue_manager is None: - return await self.request_model( - model_config=model_config, - messages=messages, - tools=tools, - tool_choice=tool_choice, - call_type=call_type, - max_tokens=effective_max_tokens, - transport_state=transport_state, - ) - request_id = uuid4().hex - event: asyncio.Event = asyncio.Event() - self._pending_llm_calls[request_id] = (event, None) - model_name = getattr(model_config, "model_name", "default") - request: dict[str, Any] = { - "type": "queued_llm_call", - "request_id": request_id, - "model_config": model_config, - "messages": messages, - "tools": tools, - "tool_choice": tool_choice, - "call_type": call_type, - "max_tokens": effective_max_tokens, - "transport_state": transport_state, - } - ctx = RequestContext.current() - if ctx is not None: - if ctx.group_id is not None: - request["group_id"] = ctx.group_id - if ctx.user_id is not None: - request["user_id"] = ctx.user_id - logger.info( - "[queued_llm_enqueue] request_id=%s call_type=%s model=%s lane=%s messages=%s tools=%s", - request_id, - call_type, - model_name, - resolved_queue_lane, - len(messages), - bool(tools), - ) - receipt = await self._queue_manager.add_queued_llm_request( - request, - lane=resolved_queue_lane, - model_name=model_name, - ) - wait_timeout = compute_queued_llm_timeout_seconds( - self._get_runtime_config(), - model_config, - retry_count=resolve_effective_retry_count( - self._get_runtime_config(), self._queue_manager - ), - initial_wait_seconds=float( - getattr(receipt, "estimated_wait_seconds", 0.0) or 0.0 - ), - include_first_dispatch_interval=False, - ) - try: - await asyncio.wait_for(event.wait(), timeout=wait_timeout) - except asyncio.TimeoutError: - logger.exception( - "[queued_llm_wait_timeout] request_id=%s call_type=%s model=%s lane=%s timeout=%.1fs", - request_id, - call_type, - model_name, - resolved_queue_lane, - wait_timeout, - ) - raise - finally: - entry = self._pending_llm_calls.pop(request_id, None) - _, result = entry if entry is not None else (None, None) - if isinstance(result, Exception): - raise result - return result or {} - - async def submit_background_llm_call( - self, - model_config: Any, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - tool_choice: Any = "auto", - call_type: str = "background", - max_tokens: int | None = None, - transport_state: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """后台 LLM 提交兼容包装。""" - return await self.submit_queued_llm_call( - model_config=model_config, - messages=messages, - tools=tools, - tool_choice=tool_choice, - call_type=call_type, - max_tokens=max_tokens, - transport_state=transport_state, - queue_lane=QUEUE_LANE_BACKGROUND, - ) - - def set_llm_call_result( - self, request_id: str, result: dict[str, Any] | Exception - ) -> None: - entry = self._pending_llm_calls.get(request_id) - if entry is None: - return - event, _ = entry - self._pending_llm_calls[request_id] = (event, result) - event.set() - - def set_queue_manager(self, queue_manager: Any) -> None: - """设置队列管理器并启动 Agent intro 生成器。 - - 参数: - queue_manager: 队列管理器实例 - """ - if self._queue_manager is not None: - logger.warning("[AI客户端] queue_manager 已设置,跳过重复设置") - return - - if queue_manager is None: - logger.warning("[AI客户端] 传入的 queue_manager 为 None") - return - - self._queue_manager = queue_manager - - # 启动/刷新 Agent intro 自动生成 - if self._intro_config: - self.apply_intro_config(self._intro_config) - - def apply_intro_config(self, config: AgentIntroGenConfig) -> None: - """应用 Agent intro 生成器配置(支持热更新)。""" - self._intro_config = config - if self._queue_manager is None: - return - task = asyncio.create_task(self._refresh_intro_generator(config)) - task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) - - async def _refresh_intro_generator(self, config: AgentIntroGenConfig) -> None: - if not config.enabled: - if self._agent_intro_generator is not None: - await self._agent_intro_generator.stop() - self._agent_intro_generator = None - self._agent_intro_task = None - logger.info("[Agent介绍] 自动生成已关闭") - return - - if self._queue_manager is None: - return - - if self._agent_intro_generator is None: - self._agent_intro_generator = AgentIntroGenerator( - self.agent_registry.base_dir, - self, - self._queue_manager, - config, - ) - self._agent_intro_task = asyncio.create_task( - self._agent_intro_generator.start() - ) - logger.info( - "[Agent介绍] 自动生成已启动: interval=%.2fs max_tokens=%s cache=%s", - config.queue_interval_seconds, - config.max_tokens, - config.cache_path, - ) - return - - if self._agent_intro_generator.config.cache_path != config.cache_path: - await self._agent_intro_generator.stop() - self._agent_intro_generator = AgentIntroGenerator( - self.agent_registry.base_dir, - self, - self._queue_manager, - config, - ) - self._agent_intro_task = asyncio.create_task( - self._agent_intro_generator.start() - ) - logger.info( - "[Agent介绍] 缓存路径变更,已重启生成器: cache=%s", - config.cache_path, - ) - return - - self._agent_intro_generator.config = config - - def set_knowledge_manager(self, manager: Any) -> None: - self._knowledge_manager = manager - - def set_cognitive_service(self, service: Any) -> None: - self._cognitive_service = service - if hasattr(self, "_prompt_builder") and self._prompt_builder is not None: - self._prompt_builder.set_cognitive_service(service) - logger.info( - "[AI客户端] 认知记忆服务已挂载并同步到 PromptBuilder: enabled=%s", - bool(getattr(service, "enabled", False)) if service is not None else False, - ) - - def set_meme_service(self, service: Any) -> None: - self._meme_service = service - resolver = None - async_resolver = None - if service is not None and hasattr(service, "resolve_global_image_sync"): - resolver = service.resolve_global_image_sync - if service is not None and hasattr(service, "resolve_global_image"): - async_resolver = service.resolve_global_image - self.attachment_registry.set_global_image_resolver(resolver) - self.attachment_registry.set_global_image_resolver_async(async_resolver) - logger.info( - "[AI客户端] 表情包服务已挂载: enabled=%s", - bool(getattr(service, "enabled", False)) if service is not None else False, - ) - - def apply_search_config(self, searxng_url: str) -> None: - """应用搜索服务配置(支持热更新)。""" - if not _SEARX_AVAILABLE or _SearxSearchWrapper is None: - if searxng_url: - logger.warning( - "[配置] 搜索组件不可用,已忽略 SEARXNG_URL=%s", - redact_string(searxng_url), - ) - else: - logger.info("[配置] 搜索组件不可用,搜索已禁用") - self._search_wrapper = None - return - - if not searxng_url: - self._search_wrapper = None - logger.info("[配置] SEARXNG_URL 未配置,搜索功能已禁用") - return - - try: - self._search_wrapper = _SearxSearchWrapper(searx_host=searxng_url, k=10) - logger.info( - "[配置] 搜索服务已更新: url=%s k=10", - redact_string(searxng_url), - ) - except Exception as exc: - logger.warning("[配置] 搜索服务更新失败: %s", exc) - self._search_wrapper = None - logger.info("[配置] 搜索服务已回退为禁用") - - def apply_model_configs( - self, - *, - chat_config: ChatModelConfig, - vision_config: VisionModelConfig, - agent_config: AgentModelConfig, - runtime_config: Config, - ) -> None: - """应用热更新后的模型配置。""" - self.chat_config = chat_config - self.vision_config = vision_config - self.agent_config = agent_config - self.runtime_config = runtime_config - self._multimodal = MultimodalAnalyzer(self._requester, self.vision_config) - self._rebuild_summary_service() - self.apply_attachment_config(runtime_config) - logger.info( - "[配置] AI 模型配置已热更新: chat=%s vision=%s agent=%s", - self.chat_config.model_name, - self.vision_config.model_name, - self.agent_config.model_name, - ) - - def apply_runtime_config(self, runtime_config: Config) -> None: - """应用不需要重建模型客户端的运行时配置。""" - self.runtime_config = runtime_config - self._rebuild_summary_service() - logger.info("[配置] AI 运行时配置已热更新") - - def _rebuild_summary_service(self) -> None: - self._summary_service = SummaryService( - self._requester, - _resolve_summary_model_config(self.runtime_config, self.agent_config), - self._token_counter, - ) - - def _resolve_summary_model_for_requests(self) -> AgentModelConfig: - return _resolve_summary_model_config(self.runtime_config, self.agent_config) - - async def _summarize_message_history_queued( - self, - messages_text: str, - instruction: str = "", - ) -> str: - model_config = self._resolve_summary_model_for_requests() - built_messages = await self._summary_service.build_message_summary_messages( - messages_text, instruction - ) - result = await self.submit_queued_llm_call( - model_config=model_config, - messages=built_messages, - tools=None, - call_type="message_summary", - max_tokens=model_config.max_tokens, - ) - return extract_choices_content(result).strip() - - async def _merge_summaries_queued(self, summaries: list[str]) -> str: - if len(summaries) == 1: - return summaries[0] - - model_config = self._resolve_summary_model_for_requests() - messages = await self._summary_service.build_message_merge_messages(summaries) - result = await self.submit_queued_llm_call( - model_config=model_config, - messages=messages, - tools=None, - call_type="merge_message_summaries", - max_tokens=8192, - ) - return extract_choices_content(result).strip() - - async def summarize_command_session( - self, - history_manager: Any, - *, - group_id: int, - user_id: int, - count: int | None = None, - time_range: str | None = None, - instruction: str = "", - ) -> str: - """Fetch session messages and summarize via summary model without tools.""" - messages_text = await fetch_session_messages( - history_manager, - group_id=group_id, - user_id=user_id, - count=count, - time_range=time_range, - runtime_config=self.runtime_config, - include_header=False, - ) - if not messages_text: - return "当前会话暂无消息记录" - if messages_text.startswith("无法解析时间范围"): - return messages_text - - input_budget = await self._summary_service.resolve_message_input_budget( - instruction - ) - total_tokens = self.count_tokens(messages_text) - if total_tokens <= input_budget: - return await self._summarize_message_history_queued( - messages_text, instruction - ) - - chunks = self.split_messages_by_tokens(messages_text, input_budget) - summaries = [ - await self._summarize_message_history_queued(chunk, instruction) - for chunk in chunks - ] - return await self._merge_summaries_queued(summaries) - - def apply_attachment_config(self, runtime_config: Config) -> None: - self.attachment_registry.set_limits( - remote_download_max_bytes=_attachment_remote_download_max_bytes( - runtime_config - ), - max_cache_bytes=_attachment_cache_max_bytes(runtime_config), - max_records=runtime_config.attachment_cache_max_records, - max_age_seconds=_attachment_cache_max_age_seconds(runtime_config), - url_reference_max_records=( - runtime_config.attachment_url_reference_max_records - ), - url_max_length=runtime_config.attachment_url_max_length, - ) - - def count_tokens(self, text: str) -> int: - return self._token_counter.count(text) - - def _get_runtime_config(self) -> Config: - if self.runtime_config is not None: - return self.runtime_config - from Undefined.config import get_config - - return get_config(strict=False) - - def _find_chat_config_by_name(self, model_name: str) -> ChatModelConfig: - """根据模型名查找配置(主模型或池中模型)""" - if model_name == self.chat_config.model_name: - return self.chat_config - if self.chat_config.pool and self.chat_config.pool.enabled: - for entry in self.chat_config.pool.models: - if entry.model_name == model_name: - return self.model_selector._entry_to_chat_config( - entry, self.chat_config - ) - return self.chat_config - - def _get_prefetch_tool_names(self) -> list[str]: - runtime_config = self._get_runtime_config() - return list(runtime_config.prefetch_tools) - - def _filter_tools_for_runtime_config( - self, tools: list[dict[str, Any]] - ) -> list[dict[str, Any]]: - runtime_config = self._get_runtime_config() - enabled = bool(getattr(runtime_config, "nagaagent_mode_enabled", False)) - if enabled: - return tools - - # 关闭 NagaAgent 模式时:隐藏相关 Agent,避免被模型误调用。 - filtered: list[dict[str, Any]] = [] - for tool in tools: - function = tool.get("function") if isinstance(tool, dict) else None - name = function.get("name") if isinstance(function, dict) else None - if name == "naga_code_analysis_agent": - continue - filtered.append(tool) - return filtered - - def _prefetch_hide_tools(self) -> bool: - runtime_config = self._get_runtime_config() - return runtime_config.prefetch_tools_hide - - def _is_missing_tool_result(self, result: Any) -> bool: - if not isinstance(result, str): - return False - return result.startswith("未找到项目") or result.startswith("未找到 MCP 工具") - - async def _maybe_prefetch_tools( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - call_type: str, - ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: - if not tools: - return messages, tools - - # 预先调用部分工具,为模型补充稳定上下文(同一 call_type 仅执行一次) - prefetch_names = self._get_prefetch_tool_names() - if not prefetch_names: - return messages, tools - - available_names = { - tool.get("function", {}).get("name") - for tool in tools - if tool.get("function") - } - prefetch_targets = [name for name in prefetch_names if name in available_names] - if not prefetch_targets: - return messages, tools - - # 使用 RequestContext 缓存已执行的预先调用,避免重复触发 - ctx = RequestContext.current() - cache: dict[str, list[str]] = {} - done: set[str] = set() - if ctx: - cache = ctx.get_resource("prefetch_tools", {}) or {} - done = set(cache.get(call_type, [])) - - to_run = [name for name in prefetch_targets if name not in done] - if not to_run: - return messages, tools - - results: list[tuple[str, Any]] = [] - for name in to_run: - try: - tool_args: dict[str, Any] = {} - if name == "get_current_time": - tool_args = {"format": "text", "include_lunar": True} - - result = await self.tool_manager.execute_tool( - name, - tool_args, - { - "runtime_config": self._get_runtime_config(), - "easter_egg_silent": True, - }, - ) - except Exception as exc: - logger.warning("[预先调用] %s 执行失败: %s", name, exc) - continue - - if self._is_missing_tool_result(result): - logger.warning("[预先调用] %s 未找到对应工具,跳过", name) - continue - - results.append((name, result)) - done.add(name) - - if not results: - return messages, tools - - if ctx: - cache[call_type] = sorted(done) - ctx.set_resource("prefetch_tools", cache) - - content_lines = ["【预先工具结果】"] - content_lines.extend([f"- {name}: {result}" for name, result in results]) - prefetch_message = {"role": "system", "content": "\n".join(content_lines)} - - insert_idx = 0 - for idx, msg in enumerate(messages): - if msg.get("role") == "system": - insert_idx = idx + 1 - else: - break - new_messages = list(messages) - new_messages.insert(insert_idx, prefetch_message) - - if self._prefetch_hide_tools(): - hidden = set(name for name in done) - tools = [ - tool - for tool in tools - if tool.get("function", {}).get("name") not in hidden - ] - return new_messages, tools - - async def request_model( - self, - model_config: ( - ChatModelConfig | VisionModelConfig | AgentModelConfig | GrokModelConfig - ), - messages: list[dict[str, Any]], - max_tokens: int = 8192, - call_type: str = "chat", - tools: list[dict[str, Any]] | None = None, - tool_choice: str = "auto", - transport_state: dict[str, Any] | None = None, - **kwargs: Any, - ) -> dict[str, Any]: - tools = self.tool_manager.maybe_merge_agent_tools(call_type, tools) - message_count_for_transport = len(messages) - if not ( - isinstance(transport_state, dict) - and transport_state.get("previous_response_id") - ): - messages, tools = await self._maybe_prefetch_tools( - messages, tools, call_type - ) - return await self._requester.request( - model_config=model_config, - messages=messages, - max_tokens=max_tokens, - call_type=call_type, - tools=tools, - tool_choice=tool_choice, - transport_state=transport_state, - message_count_for_transport=message_count_for_transport, - **kwargs, - ) - - def get_active_agent_mcp_registry(self, agent_name: str) -> Any | None: - return self.tool_manager.get_active_agent_mcp_registry(agent_name) - - async def analyze_multimodal( - self, - media_url: str, - media_type: str = "auto", - prompt_extra: str = "", - ) -> dict[str, str]: - return await self._multimodal.analyze(media_url, media_type, prompt_extra) - - async def describe_image( - self, image_url: str, prompt_extra: str = "" - ) -> dict[str, str]: - return await self._multimodal.describe_image(image_url, prompt_extra) - - async def judge_meme_image(self, image_url: str) -> dict[str, Any]: - return await self._multimodal.judge_meme_image(image_url) - - async def describe_meme_image(self, image_url: str) -> dict[str, Any]: - return await self._multimodal.describe_meme_image(image_url) - - def get_media_history(self, media_key: str) -> list[dict[str, str]]: - """获取指定媒体键的多模态分析历史 Q&A 记录。""" - return self._multimodal.get_history(media_key) - - async def save_media_history( - self, media_key: str, question: str, answer: str - ) -> None: - """保存一条多模态分析 Q&A 到历史记录并持久化到磁盘。""" - await self._multimodal.save_history(media_key, question, answer) - - async def summarize_chat(self, messages: str, context: str = "") -> str: - return await self._summary_service.summarize_chat(messages, context) - - async def merge_summaries(self, summaries: list[str]) -> str: - return await self._summary_service.merge_summaries(summaries) - - def split_messages_by_tokens(self, messages: str, max_tokens: int) -> list[str]: - return self._summary_service.split_messages_by_tokens(messages, max_tokens) - - async def generate_title(self, summary: str) -> str: - return await self._summary_service.generate_title(summary) - - def _extract_message_excerpt(self, question: str) -> str: - matched = _CONTENT_TAG_PATTERN.search(question) - if matched: - content = html.unescape(matched.group(1)) - else: - content = question - cleaned = " ".join(content.split()).strip() - if not cleaned: - return "(无文本内容)" - if len(cleaned) > 120: - return cleaned[:117].rstrip() + "..." - return cleaned - - def _is_end_only_tool_calls( - self, - tool_calls: list[dict[str, Any]], - api_to_internal: dict[str, str], - ) -> bool: - if not tool_calls: - return False - for tool_call in tool_calls: - function = tool_call.get("function", {}) - api_name = str(function.get("name", "") or "") - internal_name = api_to_internal.get(api_name, api_name) - if internal_name != "end": - return False - return True - - async def ask( - self, - question: str, - context: str = "", - send_message_callback: SendMessageCallback | None = None, - get_recent_messages_callback: Callable[ - [str, str, int, int], Awaitable[list[dict[str, Any]]] - ] - | None = None, - get_image_url_callback: Callable[[str], Awaitable[str | None]] | None = None, - get_forward_msg_callback: Callable[[str], Awaitable[list[dict[str, Any]]]] - | None = None, - send_like_callback: Callable[[int, int], Awaitable[None]] | None = None, - sender: Any = None, - history_manager: Any = None, - onebot_client: Any = None, - scheduler: Any = None, - extra_context: dict[str, Any] | None = None, - ) -> str: - """发送问题给 AI 并获取回复 (支持工具调用和迭代) - - 参数: - question: 用户输入的问题 - context: 额外的上下文背景 - send_message_callback: 发送消息的回调,支持可选的 reply_to - get_recent_messages_callback: 获取上下文历史消息的回调 - get_image_url_callback: 获取图片 URL 的回调 - get_forward_msg_callback: 获取合并转发内容的回调 - send_like_callback: 点赞回调 - sender: 消息发送助手实例 - history_manager: 历史记录管理器实例 - onebot_client: OneBot 客户端实例 - scheduler: 任务调度器实例 - extra_context: 额外的上下文负载 - - 返回: - AI 生成的最终文本回复 - """ - ctx = RequestContext.current() - pre_context: dict[str, Any] = {} - if ctx: - if ctx.group_id is not None: - pre_context["group_id"] = ctx.group_id - if ctx.user_id is not None: - pre_context["user_id"] = ctx.user_id - if ctx.sender_id is not None: - pre_context["sender_id"] = ctx.sender_id - pre_context["request_type"] = ctx.request_type - pre_context["request_id"] = ctx.request_id - if extra_context: - pre_context.update(extra_context) - - messages = await self._prompt_builder.build_messages( - question, - get_recent_messages_callback=get_recent_messages_callback, - extra_context=extra_context, - ) - - tools = self.tool_manager.get_openai_tools() - tools = self._filter_tools_for_runtime_config(tools) - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[AI消息] 构建完成: messages=%s tools=%s question_len=%s", - len(messages), - len(tools), - len(question), - ) - log_debug_json(logger, "[AI消息内容]", messages) - - tool_context = ctx.get_resources() if ctx else {} - tool_context["conversation_ended"] = False - tool_context.setdefault("agent_histories", {}) - - # 显式注入 RequestContext 的核心字段(与 tooling.py:execute_tool_call 保持一致) - if ctx: - if ctx.group_id is not None: - tool_context.setdefault("group_id", ctx.group_id) - if ctx.user_id is not None: - tool_context.setdefault("user_id", ctx.user_id) - if ctx.sender_id is not None: - tool_context.setdefault("sender_id", ctx.sender_id) - tool_context.setdefault("request_type", ctx.request_type) - tool_context.setdefault("request_id", ctx.request_id) - - if extra_context: - tool_context.update(extra_context) - - # 注入常用资源(用于工具执行) - tool_context.setdefault("ai_client", self) - tool_context.setdefault("runtime_config", self._get_runtime_config()) - tool_context.setdefault("search_wrapper", self._search_wrapper) - tool_context.setdefault( - "crawl4ai_available", self._crawl4ai_capabilities.available - ) - tool_context.setdefault( - "crawl4ai_proxy_config_available", - self._crawl4ai_capabilities.proxy_config_available, - ) - tool_context.setdefault("end_summary_storage", self._end_summary_storage) - tool_context.setdefault("end_summaries", self._prompt_builder.end_summaries) - tool_context.setdefault( - "send_private_message_callback", self._send_private_message_callback - ) - tool_context.setdefault("send_message_callback", send_message_callback) - tool_context.setdefault( - "get_recent_messages_callback", get_recent_messages_callback - ) - - async def fetch_session_messages_callback( - *, - group_id: int, - user_id: int, - count: int | None = None, - time_range: str | None = None, - ) -> str: - return await fetch_session_messages( - history_manager, - group_id=group_id, - user_id=user_id, - count=count, - time_range=time_range, - runtime_config=self._get_runtime_config(), - ) - - tool_context.setdefault( - "fetch_session_messages_callback", fetch_session_messages_callback - ) - tool_context.setdefault("get_image_url_callback", get_image_url_callback) - tool_context.setdefault("get_forward_msg_callback", get_forward_msg_callback) - tool_context.setdefault("send_like_callback", send_like_callback) - tool_context.setdefault("sender", sender) - tool_context.setdefault("history_manager", history_manager) - tool_context.setdefault("onebot_client", onebot_client) - tool_context.setdefault("scheduler", scheduler) - tool_context.setdefault("send_image_callback", self._send_image_callback) - tool_context.setdefault( - "attachment_registry", - getattr(self, "attachment_registry", None), - ) - tool_context.setdefault("memory_storage", self.memory_storage) - tool_context.setdefault("knowledge_manager", self._knowledge_manager) - tool_context.setdefault("cognitive_service", self._cognitive_service) - tool_context.setdefault("meme_service", self._meme_service) - tool_context.setdefault("current_question", question) - message_ids = tool_context.get("message_ids") - if not isinstance(message_ids, list): - message_ids = [] - tool_context["message_ids"] = message_ids - trigger_message_id = tool_context.get("trigger_message_id") - if trigger_message_id is not None: - trigger_message_id_text = str(trigger_message_id).strip() - if trigger_message_id_text and trigger_message_id_text not in message_ids: - message_ids.append(trigger_message_id_text) - - # 动态选择模型(等待偏好加载就绪,避免竞态) - await self.model_selector.wait_ready() - selected_model_name = pre_context.get("selected_model_name") - if selected_model_name: - effective_chat_config = self._find_chat_config_by_name(selected_model_name) - else: - effective_chat_config = self.chat_config - - max_iterations = 1000 - iteration = 0 - conversation_ended = False - cot_compat = getattr(effective_chat_config, "thinking_tool_call_compat", False) - capture_reasoning = cot_compat or bool( - getattr(effective_chat_config, "reasoning_content_replay", False) - ) - cot_compat_logged = False - cot_missing_logged = False - transport_state: dict[str, Any] | None = None - queue_lane = self._resolve_queue_lane(tool_context.get("queue_lane")) - pre_tool_failure_count = 0 - missing_tool_call_count = 0 - last_missing_tool_call_content = "" - runtime_config = self._get_runtime_config() - max_pre_tool_retries = max( - 0, - int(getattr(runtime_config, "ai_request_max_retries", 0) or 0), - ) - max_missing_tool_call_retries = max( - 0, - int(getattr(runtime_config, "missing_tool_call_retries", 3) or 0), - ) - - while iteration < max_iterations: - iteration += 1 - logger.info(f"[AI决策] 开始第 {iteration} 轮迭代...") - message_checkpoint_len = len(messages) - transport_state_checkpoint = transport_state - - try: - result = await self.submit_queued_llm_call( - model_config=effective_chat_config, - messages=messages, - max_tokens=8192, - call_type="chat", - tools=tools, - tool_choice="auto", - transport_state=transport_state, - queue_lane=queue_lane, - ) - except Exception as exc: - logger.exception( - "[queued_llm_error] call_type=chat model=%s lane=%s iteration=%s error=%s", - effective_chat_config.model_name, - queue_lane, - iteration, - exc, - ) - raise - - try: - tool_execution_started = False - tool_name_map = ( - result.get("_tool_name_map") if isinstance(result, dict) else None - ) - api_to_internal: dict[str, str] = {} - if isinstance(tool_name_map, dict): - raw_api_to_internal = tool_name_map.get("api_to_internal") - if isinstance(raw_api_to_internal, dict): - api_to_internal = { - str(k): str(v) for k, v in raw_api_to_internal.items() - } - - next_transport_state = ( - result.get("_transport_state") if isinstance(result, dict) else None - ) - transport_state = ( - next_transport_state - if isinstance(next_transport_state, dict) - else None - ) - - choice = result.get("choices", [{}])[0] - message = choice.get("message", {}) - content: str = message.get("content") or "" - reasoning_content = message.get("reasoning_content") - tool_calls = message.get("tool_calls", []) - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[AI响应] content_len=%s tool_calls=%s", - len(content), - len(tool_calls), - ) - if tool_calls: - log_debug_json(logger, "[AI工具调用]", tool_calls) - - log_thinking = self._get_runtime_config().log_thinking - if ( - capture_reasoning - and tools - and log_thinking - and not cot_compat_logged - ): - cot_compat_logged = True - logger.info( - "[思维链兼容] 多轮工具调用 reasoning_content 本地回填已启用" - ) - if ( - capture_reasoning - and log_thinking - and tools - and getattr(effective_chat_config, "thinking_enabled", False) - and not reasoning_content - and tool_calls - and not cot_missing_logged - ): - cot_missing_logged = True - message_keys = ( - ", ".join(sorted(message.keys())) - if isinstance(message, dict) - else type(message).__name__ - ) - logger.info( - "[思维链兼容] 未在响应中发现 reasoning_content(可能是模型/服务商不返回思维链);message_keys=%s", - message_keys, - ) - - if content.strip() and tool_calls: - logger.debug( - "检测到 content 与工具调用同时存在,忽略 content,仅执行工具调用" - ) - content = "" - - if not tool_calls: - if conversation_ended: - logger.info( - "[AI回复] 会话结束,返回最终内容: length=%s", - len(content), - ) - return content - - if content.strip(): - last_missing_tool_call_content = content.strip() - missing_tool_call_count += 1 - if missing_tool_call_count > max_missing_tool_call_retries: - logger.warning( - "[AI回复] 模型连续未调用工具,停止重试: iteration=%s retries=%s/%s content_len=%s", - iteration, - missing_tool_call_count - 1, - max_missing_tool_call_retries, - len(content), - ) - fallback_content = last_missing_tool_call_content - if fallback_content and send_message_callback is not None: - try: - await send_message_callback(fallback_content) - tool_context["message_sent_this_turn"] = True - current_ctx = RequestContext.current() - if current_ctx is not None: - current_ctx.set_resource( - "message_sent_this_turn", True - ) - return "" - except Exception: - logger.exception("[AI回复] fallback 发送失败") - return fallback_content - - logger.warning( - "[AI回复] 模型返回文本但未调用工具(iteration=%s retry=%s/%s content_len=%s),要求重试", - iteration, - missing_tool_call_count, - max_missing_tool_call_retries, - len(content), - ) - assistant_retry_message: dict[str, Any] = { - "role": "assistant", - "content": content, - } - if capture_reasoning and reasoning_content is not None: - assistant_retry_message["reasoning_content"] = reasoning_content - messages.append(assistant_retry_message) - messages.append( - { - "role": "user", - "content": MISSING_TOOL_CALL_RETRY_HINT, - } - ) - continue - - assistant_message: dict[str, Any] = { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - missing_tool_call_count = 0 - last_missing_tool_call_content = "" - phase = message.get("phase") - if phase is not None: - assistant_message["phase"] = phase - output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) - if isinstance(output_items, list): - assistant_message[RESPONSES_OUTPUT_ITEMS_KEY] = output_items - if capture_reasoning and reasoning_content is not None: - assistant_message["reasoning_content"] = reasoning_content - messages.append(assistant_message) - - tool_tasks = [] - tool_call_ids = [] - tool_api_names: list[str] = [] - tool_internal_names: list[str] = [] - end_tool_call: dict[str, Any] | None = None - end_tool_args: dict[str, Any] = {} - tool_results: list[Any] = [] - - for tool_call in tool_calls: - call_id = "" - if isinstance(tool_call, dict): - call_id = str(tool_call.get("id", "") or "") - function = tool_call.get("function") - else: - function = None - if not isinstance(function, dict): - logger.warning( - "[工具调用] 跳过无效工具调用: missing_function ID=%s", - call_id, - ) - messages.append(_build_invalid_tool_call_response(tool_call)) - continue - api_function_name = str(function.get("name", "") or "").strip() - if not api_function_name: - logger.warning( - "[工具调用] 跳过无效工具调用: empty_name ID=%s", - call_id, - ) - messages.append(_build_invalid_tool_call_response(tool_call)) - continue - raw_args = function.get("arguments") - - internal_function_name = api_to_internal.get( - api_function_name, api_function_name - ) - - if internal_function_name != api_function_name: - logger.info( - "[工具准备] 准备调用: %s (原名: %s) (ID=%s)", - internal_function_name, - api_function_name, - call_id, - ) - else: - logger.info( - "[工具准备] 准备调用: %s (ID=%s)", - api_function_name, - call_id, - ) - logger.debug( - f"[工具参数] {api_function_name} 参数: {redact_string(str(raw_args))}" - ) - - function_args = parse_tool_arguments( - raw_args, - logger=logger, - tool_name=str(api_function_name), - ) - - if not isinstance(function_args, dict): - function_args = {} - - # 检测 end 工具,暂存后统一处理 - if internal_function_name == "end": - if len(tool_calls) > 1: - logger.warning( - "[工具调用] end 与其他工具同时调用," - "将先执行其他工具,end 将返回拒绝结果" - ) - end_tool_call = tool_call - end_tool_args = function_args - continue - - tool_call_ids.append(call_id) - tool_api_names.append(str(api_function_name)) - tool_internal_names.append(str(internal_function_name)) - tool_tasks.append( - self.tool_manager.execute_tool( - str(internal_function_name), function_args, tool_context - ) - ) - - if tool_tasks: - tool_execution_started = True - logger.info( - "[工具执行] 开始并发执行 %s 个工具调用: %s", - len(tool_tasks), - ", ".join(tool_internal_names), - ) - tool_results = await asyncio.gather( - *tool_tasks, return_exceptions=True - ) - - for i, tool_result in enumerate(tool_results): - call_id = tool_call_ids[i] - api_fname = tool_api_names[i] - internal_fname = tool_internal_names[i] - - if isinstance(tool_result, Exception): - logger.error( - "[工具异常] %s (ID=%s) 执行抛出异常: %s", - internal_fname, - call_id, - tool_result, - ) - content_str = f"执行失败: {str(tool_result)}" - else: - content_str = str(tool_result) - logger.debug( - "[工具响应] %s (ID=%s) 返回内容长度=%s", - internal_fname, - call_id, - len(content_str), - ) - if logger.isEnabledFor(logging.DEBUG): - log_debug_json( - logger, - f"[工具响应体] {internal_fname} (ID={call_id})", - content_str, - ) - - messages.append( - { - "role": "tool", - "tool_call_id": call_id, - "name": api_fname, - "content": content_str, - } - ) - - # 如果是 get_forward_msg 工具调用,将其结果写入历史记录 - if internal_fname == "get_forward_msg" and not isinstance( - tool_result, Exception - ): - task = asyncio.create_task( - self._save_forward_to_history( - content_str, pre_context, history_manager - ) - ) - task.add_done_callback( - lambda t: t.exception() if not t.cancelled() else None - ) - - if tool_context.get("conversation_ended"): - conversation_ended = True - logger.info( - "[会话状态] 工具触发会话结束标记: tool=%s", - internal_fname, - ) - - # 处理 end 工具调用 - if end_tool_call: - end_call_id = end_tool_call.get("id", "") - end_api_name = end_tool_call.get("function", {}).get("name", "end") - if tool_tasks: - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": END_CO_CALL_REJECT_CONTENT, - } - ) - logger.info( - "[工具调用] end 与其他工具同时调用," - "其它工具已执行,end 已回填拒绝响应" - ) - else: - # end 单独调用,正常执行(参数已在循环中解析) - tool_execution_started = True - end_result = await self.tool_manager.execute_tool( - "end", end_tool_args, tool_context - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": str(end_result), - } - ) - if tool_context.get("conversation_ended"): - conversation_ended = True - logger.info("[会话状态] end 工具触发会话结束") - - if conversation_ended: - logger.info("[会话状态] 对话已结束(调用 end 工具)") - return "" - pre_tool_failure_count = 0 - - except Exception as exc: - if ( - not tool_execution_started - and pre_tool_failure_count < max_pre_tool_retries - ): - pre_tool_failure_count += 1 - del messages[message_checkpoint_len:] - transport_state = transport_state_checkpoint - logger.warning( - "[chat.pre_tool_retry] model=%s lane=%s retry=%s/%s iteration=%s error=%s", - effective_chat_config.model_name, - queue_lane, - pre_tool_failure_count, - max_pre_tool_retries, - iteration, - exc, - ) - continue - logger.exception( - "[chat.suppressed_error] model=%s lane=%s iteration=%s error=%s", - effective_chat_config.model_name, - queue_lane, - iteration, - exc, - ) - return "" - - logger.warning("[AI决策] 达到最大迭代次数,未能完成处理") - return "达到最大迭代次数,未能完成处理" - - async def _save_forward_to_history( - self, - content: str, - pre_context: dict[str, Any], - history_manager: Any, - ) -> None: - """将合并转发消息写入历史记录""" - if history_manager is None: - return - - try: - group_id = pre_context.get("group_id") - user_id = pre_context.get("user_id") - - if group_id is not None: - await history_manager.add_group_message( - group_id=int(group_id), - sender_id=0, - text_content=content, - sender_card="", - sender_nickname="[合并转发内容]", - group_name="", - role="system", - title="", - message_id=None, - ) - elif user_id is not None: - await history_manager.add_private_message( - user_id=int(user_id), - text_content=content, - display_name="[合并转发内容]", - user_name="", - message_id=None, - ) - else: - logger.debug("[合并转发] 无法写入历史:缺少 group_id 和 user_id") - except Exception as exc: - logger.debug("[合并转发] 写入历史失败: %s", exc) diff --git a/src/Undefined/ai/client/__init__.py b/src/Undefined/ai/client/__init__.py index 40f823ad..0be60fc8 100644 --- a/src/Undefined/ai/client/__init__.py +++ b/src/Undefined/ai/client/__init__.py @@ -1,6 +1,6 @@ """AI 客户端子包。 -对外稳定入口:``AIClient``;旧路径 ``Undefined.ai.client`` 通过 shim 保持兼容。 +对外稳定入口:``AIClient``;导入路径 ``Undefined.ai.client`` 指向本子包。 """ from Undefined.ai.client.ask_loop import ClientAskLoopMixin diff --git a/src/Undefined/ai/llm.py b/src/Undefined/ai/llm.py deleted file mode 100644 index 59411faf..00000000 --- a/src/Undefined/ai/llm.py +++ /dev/null @@ -1,2063 +0,0 @@ -"""LLM 模型请求处理。""" - -from __future__ import annotations - -import asyncio -import hashlib -import json -import logging -import re -import time -from datetime import datetime -from typing import Any -from urllib.parse import parse_qsl, urlsplit, urlunsplit - -import httpx -from openai import ( - APIConnectionError, - APIStatusError, - APITimeoutError, - AsyncOpenAI, -) - -from Undefined.ai.parsing import extract_choices_content -from Undefined.ai.transports import ( - API_MODE_CHAT_COMPLETIONS, - API_MODE_RESPONSES, - build_responses_request_body, - get_api_mode, - get_effort_payload, - get_effort_style, - get_thinking_payload, - normalize_responses_result, -) -from Undefined.ai.retrieval import RetrievalRequester -from Undefined.ai.tokens import TokenCounter -from Undefined.context import RequestContext -from Undefined.config import ( - ChatModelConfig, - VisionModelConfig, - AgentModelConfig, - SecurityModelConfig, - EmbeddingModelConfig, - GrokModelConfig, - RerankModelConfig, - Config, - get_config, -) -from Undefined.token_usage_storage import TokenUsageStorage, TokenUsage -from Undefined.utils.logging import log_debug_json, redact_string -from Undefined.utils.request_params import ( - merge_request_params, - split_reserved_request_params, -) -from Undefined.utils.tool_calls import normalize_tool_arguments_json - -logger = logging.getLogger(__name__) - -ModelConfig = ( - ChatModelConfig - | VisionModelConfig - | AgentModelConfig - | SecurityModelConfig - | EmbeddingModelConfig - | GrokModelConfig - | RerankModelConfig -) - -__all__ = ["ModelRequester", "build_request_body", "ModelConfig"] - -_CHAT_COMPLETIONS_KNOWN_FIELDS: set[str] = { - "model", - "messages", - "audio", - "metadata", - "max_completion_tokens", - "max_tokens", - "modalities", - "parallel_tool_calls", - "prediction", - "prompt_cache_key", - "prompt_cache_retention", - "reasoning_effort", - "safety_identifier", - "service_tier", - "store", - "temperature", - "top_p", - "n", - "stop", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "stream", - "stream_options", - "tools", - "tool_choice", - "logprobs", - "top_logprobs", - "verbosity", - "web_search_options", -} - -_SDK_REQUEST_OPTION_FIELDS: frozenset[str] = frozenset( - {"extra_headers", "extra_query", "extra_body", "timeout"} -) - -_RESPONSES_KNOWN_FIELDS: set[str] = { - "background", - "context_management", - "conversation", - "include", - "model", - "input", - "instructions", - "max_output_tokens", - "max_tool_calls", - "metadata", - "previous_response_id", - "prompt", - "prompt_cache_key", - "prompt_cache_retention", - "reasoning", - "safety_identifier", - "service_tier", - "store", - "temperature", - "top_p", - "tools", - "tool_choice", - "parallel_tool_calls", - "stream", - "stream_options", - "text", - "truncation", - "user", -} - -_CHAT_COMPLETIONS_RESERVED_FIELDS: frozenset[str] = ( - frozenset( - { - "model", - "messages", - "max_tokens", - "tools", - "tool_choice", - "stream", - "stream_options", - "thinking", - "reasoning", - "reasoning_effort", - "output_config", - } - ) - | _SDK_REQUEST_OPTION_FIELDS -) - -_RESPONSES_RESERVED_FIELDS: frozenset[str] = ( - frozenset( - { - "model", - "input", - "instructions", - "max_output_tokens", - "tools", - "tool_choice", - "previous_response_id", - "stream", - "stream_options", - "thinking", - "reasoning", - "reasoning_effort", - "output_config", - } - ) - | _SDK_REQUEST_OPTION_FIELDS -) - -_THINKING_KEYS: tuple[str, ...] = ( - "thinking", - "reasoning", - "reasoning_content", - "chain_of_thought", - "cot", - "thoughts", -) -_CHAT_COMPLETION_STRIP_THINKING_KEYS: frozenset[str] = frozenset( - ("thinking", "reasoning", "chain_of_thought", "cot", "thoughts") -) -_CHAT_COMPLETION_INTERNAL_MESSAGE_KEYS: frozenset[str] = frozenset( - ( - "reasoning_content", - *_CHAT_COMPLETION_STRIP_THINKING_KEYS, - "_responses_output_items", - "phase", - ) -) - -_DEFAULT_TOOLS_DESCRIPTION_MAX_LEN = 1024 -_TOOLS_PARAM_INDEX_RE = re.compile(r"Tools\[(\d+)\]", re.IGNORECASE) -_RESPONSES_MISSING_TOOL_CALL_OUTPUT_RE = re.compile( - r"no tool call found for function call output with call_id", - re.IGNORECASE, -) -_DEFAULT_TOOLS_DESCRIPTION_PREVIEW_LEN = 160 - -_DEFAULT_TOOL_NAME_DOT_DELIMITER = "-_-" -_TOOL_NAME_MAX_LEN = 64 -_TOOL_NAME_ALLOWED_RE = re.compile(r"^[a-zA-Z0-9_-]+$") -_PROMPT_CACHE_KEY_MAX_LEN = 128 - - -def _tool_name_dot_delimiter() -> str: - runtime_config = _get_runtime_config() - value = ( - getattr(runtime_config, "tools_dot_delimiter", None) if runtime_config else None - ) - text = str(value).strip() if value is not None else _DEFAULT_TOOL_NAME_DOT_DELIMITER - if not text: - return _DEFAULT_TOOL_NAME_DOT_DELIMITER - if "." in text: - return _DEFAULT_TOOL_NAME_DOT_DELIMITER - if not _TOOL_NAME_ALLOWED_RE.match(text): - return _DEFAULT_TOOL_NAME_DOT_DELIMITER - # 保持较短长度,避免工具名被服务端截断。 - if len(text) > 16: - return text[:16] - return text - - -def _hash8(text: str) -> str: - return hashlib.sha1(text.encode("utf-8"), usedforsecurity=False).hexdigest()[:8] - - -def _normalize_prompt_cache_part(value: Any) -> str: - text = str(value or "").strip().lower() - if not text: - return "none" - normalized_chars: list[str] = [] - for char in text: - if char.isalnum() or char in {"-", "_", ":"}: - normalized_chars.append(char) - else: - normalized_chars.append("_") - normalized = "".join(normalized_chars).strip("_") - return normalized or "none" - - -def _build_scope_prompt_cache_part() -> str: - ctx = RequestContext.current() - if ctx is None: - return "scope:global" - if ctx.group_id is not None: - return f"group:{int(ctx.group_id)}" - if ctx.user_id is not None: - return f"private:{int(ctx.user_id)}" - if ctx.sender_id is not None: - return f"sender:{int(ctx.sender_id)}" - request_type = _normalize_prompt_cache_part(ctx.request_type) - return f"type:{request_type}" - - -def _build_default_prompt_cache_key(model_config: ModelConfig, call_type: str) -> str: - model_name = _normalize_prompt_cache_part(getattr(model_config, "model_name", "")) - scope_part = _build_scope_prompt_cache_part() - call_part = _normalize_prompt_cache_part(call_type) - key = f"pc:{model_name}:{call_part}:{scope_part}" - if len(key) <= _PROMPT_CACHE_KEY_MAX_LEN: - return key - suffix = "_" + _hash8(key) - prefix_len = max(1, _PROMPT_CACHE_KEY_MAX_LEN - len(suffix)) - return key[:prefix_len] + suffix - - -def _encode_tool_name_for_api(tool_name: str) -> str: - """将内部工具名编码为服务端可接受的 function.name。 - - - 将 '.' 替换为 '-_-'(保留工具集命名语义) - - 其他不允许字符替换为 '_' - - 强制最大长度(<=64),超长时追加稳定哈希 - """ - raw = str(tool_name or "").strip() - if not raw: - return "tool" - - # 保留工具集分隔语义:category.tool -> categorytool - encoded = raw.replace(".", _tool_name_dot_delimiter()) - - # 替换其他不允许字符。 - cleaned_chars: list[str] = [] - for ch in encoded: - if ch.isalnum() or ch in {"_", "-"}: - cleaned_chars.append(ch) - else: - cleaned_chars.append("_") - encoded = "".join(cleaned_chars) - - if not encoded: - encoded = "tool" - - if len(encoded) > _TOOL_NAME_MAX_LEN: - suffix = "_" + _hash8(raw) - prefix_len = max(1, _TOOL_NAME_MAX_LEN - len(suffix)) - encoded = encoded[:prefix_len] + suffix - - # 最后兜底校验(理论上应始终通过) - if not _TOOL_NAME_ALLOWED_RE.match(encoded): - suffix = "_" + _hash8(raw) - encoded = re.sub(r"[^a-zA-Z0-9_-]", "_", encoded) - if len(encoded) > _TOOL_NAME_MAX_LEN: - encoded = encoded[: _TOOL_NAME_MAX_LEN - len(suffix)] + suffix - if not encoded: - encoded = "tool" + suffix - - return encoded - - -def _responses_should_fallback_to_stateless_replay( - exc: APIStatusError, - request_body: dict[str, Any], - *, - stateless_replay: bool, -) -> bool: - if stateless_replay or not request_body.get("previous_response_id"): - return False - input_items = request_body.get("input") - if not isinstance(input_items, list) or not any( - isinstance(item, dict) and item.get("type") == "function_call_output" - for item in input_items - ): - return False - if exc.status_code != 400 or not isinstance(exc.body, dict): - return False - error = exc.body.get("error") - if not isinstance(error, dict): - return False - message = str(error.get("message", "")).strip() - param = str(error.get("param", "")).strip().lower() - return param == "input" and bool( - _RESPONSES_MISSING_TOOL_CALL_OUTPUT_RE.search(message) - ) - - -def _sanitize_openai_tool_names_in_request( - request_body: dict[str, Any], -) -> tuple[dict[str, str], dict[str, str]]: - """将 request_body 的 tools/messages 工具名改写为服务端可接受的名称。 - - Returns: - (api_to_internal, internal_to_api) 映射表。 - - Notes: - - 仅保证 tools schema 中出现的名称可逆映射。 - - 历史消息中的工具调用会尽力重写。 - """ - tools = request_body.get("tools") - if not isinstance(tools, list) or not tools: - return {}, {} - - internal_to_api: dict[str, str] = {} - api_to_internal: dict[str, str] = {} - used_api: set[str] = set() - - new_tools: list[dict[str, Any]] = [] - for tool in tools: - if not isinstance(tool, dict): - new_tools.append(tool) - continue - function = tool.get("function") - if not isinstance(function, dict): - new_tools.append(tool) - continue - internal_name = str(function.get("name", "") or "") - if not internal_name: - new_tools.append(tool) - continue - - # 稳定编码;如发生冲突则追加后缀。 - base_api_name = _encode_tool_name_for_api(internal_name) - api_name = base_api_name - if api_name in used_api and api_to_internal.get(api_name) != internal_name: - suffix = "_" + _hash8(internal_name) - prefix_len = max(1, _TOOL_NAME_MAX_LEN - len(suffix)) - api_name = base_api_name[:prefix_len] + suffix - if api_name in used_api and api_to_internal.get(api_name) != internal_name: - # 极少数冲突兜底:加入索引避免重复。 - suffix = "_" + _hash8(f"{internal_name}:{len(used_api)}") - prefix_len = max(1, _TOOL_NAME_MAX_LEN - len(suffix)) - api_name = base_api_name[:prefix_len] + suffix - - used_api.add(api_name) - internal_to_api[internal_name] = api_name - api_to_internal[api_name] = internal_name - - if api_name != internal_name: - tool = dict(tool) - function = dict(function) - function["name"] = api_name - tool["function"] = function - new_tools.append(tool) - - request_body["tools"] = new_tools - - # 尽力重写历史消息中的工具名。 - messages = request_body.get("messages") - if isinstance(messages, list) and messages: - new_messages: list[dict[str, Any]] = [] - changed = False - for message in messages: - if not isinstance(message, dict): - new_messages.append(message) - continue - - new_message = message - - msg_name = message.get("name") - if isinstance(msg_name, str) and msg_name: - mapped = internal_to_api.get(msg_name) - if mapped and mapped != msg_name: - if new_message is message: - new_message = dict(message) - new_message["name"] = mapped - changed = True - elif (not _TOOL_NAME_ALLOWED_RE.match(msg_name)) or ( - len(msg_name) > _TOOL_NAME_MAX_LEN - ): - # 即便名称不在 schema 映射中,也尽量保证请求合法(如工具被重命名/移除)。 - safe = _encode_tool_name_for_api(msg_name) - if safe != msg_name: - if new_message is message: - new_message = dict(message) - new_message["name"] = safe - changed = True - - tool_calls = message.get("tool_calls") - if isinstance(tool_calls, list) and tool_calls: - new_tool_calls: list[Any] = [] - tool_calls_changed = False - for tool_call in tool_calls: - if not isinstance(tool_call, dict): - new_tool_calls.append(tool_call) - continue - function = tool_call.get("function") - if not isinstance(function, dict): - new_tool_calls.append(tool_call) - continue - fname = function.get("name") - if not isinstance(fname, str) or not fname: - new_tool_calls.append(tool_call) - continue - mapped = internal_to_api.get(fname) - safe_name = mapped or _encode_tool_name_for_api(fname) - if safe_name != fname: - tool_calls_changed = True - new_tool_call = dict(tool_call) - new_function = dict(function) - new_function["name"] = safe_name - new_tool_call["function"] = new_function - new_tool_calls.append(new_tool_call) - else: - new_tool_calls.append(tool_call) - - if tool_calls_changed: - if new_message is message: - new_message = dict(message) - new_message["tool_calls"] = new_tool_calls - changed = True - - new_messages.append(new_message) - - if changed: - request_body["messages"] = new_messages - - return api_to_internal, internal_to_api - - -def _get_runtime_config() -> Config | None: - try: - return get_config(strict=False) - except Exception: - return None - - -def _split_chat_completion_params( - body: dict[str, Any], -) -> tuple[dict[str, Any], dict[str, Any]]: - known: dict[str, Any] = {} - extra: dict[str, Any] = {} - for key, value in body.items(): - if key in _CHAT_COMPLETIONS_KNOWN_FIELDS: - known[key] = value - else: - extra[key] = value - return known, extra - - -def _split_responses_params( - body: dict[str, Any], -) -> tuple[dict[str, Any], dict[str, Any]]: - known: dict[str, Any] = {} - extra: dict[str, Any] = {} - for key, value in body.items(): - if key in _RESPONSES_KNOWN_FIELDS: - known[key] = value - else: - extra[key] = value - return known, extra - - -def _without_stream_request_fields(body: dict[str, Any]) -> dict[str, Any]: - stripped = dict(body) - stripped.pop("stream", None) - stripped.pop("stream_options", None) - return stripped - - -def _ensure_chat_stream_usage_options(body: dict[str, Any]) -> None: - stream_options = body.get("stream_options") - if stream_options is None: - body["stream_options"] = {"include_usage": True} - return - if isinstance(stream_options, dict) and "include_usage" not in stream_options: - body["stream_options"] = {**stream_options, "include_usage": True} - - -_STREAM_FALLBACK_STATUS_CODES = {400, 404, 405, 422, 501} -_STREAM_FALLBACK_ERROR_MARKERS = ( - "stream", - "stream_options", - "streaming", - "not support", - "unsupported", - "unrecognized", - "unknown parameter", - "unexpected parameter", -) - - -def _status_error_text(exc: APIStatusError) -> str: - parts = [str(exc)] - body = getattr(exc, "body", None) - if isinstance(body, dict): - parts.append(json.dumps(body, ensure_ascii=False, default=str)) - elif body is not None: - parts.append(str(body)) - response = getattr(exc, "response", None) - if response is not None: - try: - parts.append(response.text) - except Exception: - pass - return "\n".join(part for part in parts if part).lower() - - -def _should_fallback_from_stream(exc: Exception) -> bool: - if isinstance(exc, NotImplementedError): - return True - if not isinstance(exc, APIStatusError): - return False - if exc.status_code not in _STREAM_FALLBACK_STATUS_CODES: - return False - text = _status_error_text(exc) - return any(marker in text for marker in _STREAM_FALLBACK_ERROR_MARKERS) - - -def _stringify_stream_delta(value: Any) -> str: - if value is None: - return "" - if isinstance(value, str): - return value - if isinstance(value, list): - parts = [_stringify_stream_delta(item) for item in value] - return "".join(part for part in parts if part) - if isinstance(value, dict): - for key in ("text", "content", "delta", "value"): - if value.get(key) is not None: - return _stringify_stream_delta(value.get(key)) - return "" - return str(value) - - -def _extract_stream_response_item(event: dict[str, Any]) -> dict[str, Any] | None: - for key in ("item", "output_item", "data"): - value = event.get(key) - if isinstance(value, dict): - return value - response = event.get("response") - if isinstance(response, dict) and isinstance(response.get("output"), list): - return None - if isinstance(response, dict): - return response - return None - - -def _extract_stream_usage( - event: dict[str, Any], *, api_mode: str -) -> dict[str, Any] | None: - usage = event.get("usage") - if not isinstance(usage, dict): - response = event.get("response") - if isinstance(response, dict) and isinstance(response.get("usage"), dict): - usage = response.get("usage") - if not isinstance(usage, dict): - return None - if api_mode == API_MODE_RESPONSES: - return { - "input_tokens": int(usage.get("input_tokens", 0) or 0), - "output_tokens": int(usage.get("output_tokens", 0) or 0), - "total_tokens": int(usage.get("total_tokens", 0) or 0), - } - return { - "prompt_tokens": int(usage.get("prompt_tokens", 0) or 0), - "completion_tokens": int(usage.get("completion_tokens", 0) or 0), - "total_tokens": int(usage.get("total_tokens", 0) or 0), - } - - -def _ensure_tool_call_slot( - tool_calls: list[dict[str, Any]], index: int -) -> dict[str, Any]: - while len(tool_calls) <= index: - tool_calls.append( - { - "id": "", - "type": "function", - "function": {"name": "", "arguments": ""}, - } - ) - return tool_calls[index] - - -def _merge_tool_call_delta( - target_tool_calls: list[dict[str, Any]], tool_delta: dict[str, Any] -) -> None: - index = tool_delta.get("index") - try: - slot_index = int(index) if index is not None else len(target_tool_calls) - except (TypeError, ValueError): - slot_index = len(target_tool_calls) - tool_call = _ensure_tool_call_slot(target_tool_calls, slot_index) - call_id = str(tool_delta.get("id") or "").strip() - if call_id: - tool_call["id"] = call_id - tool_type = str(tool_delta.get("type") or "").strip() - if tool_type: - tool_call["type"] = tool_type - function_delta = tool_delta.get("function") - if not isinstance(function_delta, dict): - return - function = tool_call.setdefault("function", {"name": "", "arguments": ""}) - if not isinstance(function, dict): - function = {"name": "", "arguments": ""} - tool_call["function"] = function - function_name = str(function_delta.get("name") or "").strip() - if function_name: - function["name"] = function_name - arguments_delta = function_delta.get("arguments") - if arguments_delta is not None: - function["arguments"] = str(function.get("arguments") or "") + str( - arguments_delta - ) - - -def _is_deepseek_provider(model_config: ModelConfig) -> bool: - model_name = str(getattr(model_config, "model_name", "") or "").lower() - if model_name.startswith("deepseek"): - return True - api_url = str(getattr(model_config, "api_url", "") or "").lower() - return "deepseek" in api_url - - -def _normalize_thinking_override( - value: Any, model_config: ModelConfig -) -> dict[str, Any] | None: - if value is None: - return None - - is_deepseek = _is_deepseek_provider(model_config) - - if isinstance(value, dict): - raw_type = value.get("type") - if isinstance(raw_type, str): - type_value = raw_type.strip().lower() - if type_value in {"enabled", "disabled"}: - return {"type": type_value} if is_deepseek else dict(value) - - raw_enabled = value.get("enabled") - if isinstance(raw_enabled, bool): - type_value = "enabled" if raw_enabled else "disabled" - if is_deepseek: - return {"type": type_value} - normalized = dict(value) - normalized.pop("enabled", None) - normalized["type"] = type_value - return normalized - - return None - - if isinstance(value, bool): - return {"type": "enabled" if value else "disabled"} - - if isinstance(value, str): - type_value = value.strip().lower() - if type_value in {"enabled", "disabled"}: - return {"type": type_value} - - return None - - -def _tools_sanitize_enabled() -> bool: - # 历史配置项 tools.sanitize 已迁移为 tools.dot_delimiter。 - # 为兼容严格网关,description 的 schema 清洗默认始终开启。 - return True - - -def _tools_sanitize_verbose() -> bool: - runtime_config = _get_runtime_config() - if runtime_config is not None: - return bool(runtime_config.tools_sanitize_verbose) - return False - - -def _tools_description_max_len() -> int: - runtime_config = _get_runtime_config() - if runtime_config is None: - return _DEFAULT_TOOLS_DESCRIPTION_MAX_LEN - value = runtime_config.tools_description_max_len - return value if value > 0 else _DEFAULT_TOOLS_DESCRIPTION_MAX_LEN - - -def _tools_description_truncate_enabled() -> bool: - runtime_config = _get_runtime_config() - if runtime_config is None: - return False - return bool(runtime_config.tools_description_truncate_enabled) - - -def _clean_control_chars(text: str) -> str: - """将 ASCII 控制字符替换为空格。""" - return "".join(" " if ord(ch) < 32 or ord(ch) == 127 else ch for ch in text) - - -def _desc_preview(text: str) -> str: - runtime_config = _get_runtime_config() - if runtime_config is None: - preview_len = _DEFAULT_TOOLS_DESCRIPTION_PREVIEW_LEN - else: - preview_len = runtime_config.tools_description_preview_len - if preview_len <= 0: - preview_len = _DEFAULT_TOOLS_DESCRIPTION_PREVIEW_LEN - return text[:preview_len] + ("…" if len(text) > preview_len else "") - - -def _normalize_tool_description( - description: Any, - tool_name: str, - max_len: int, - truncate_enabled: bool, -) -> str: - """规范化工具 function.description,适配更严格的 OpenAI 兼容服务。""" - if description is None: - normalized = "" - elif isinstance(description, str): - normalized = description - else: - normalized = str(description) - - normalized = _clean_control_chars(normalized) - normalized = " ".join(normalized.split()) - normalized = normalized.strip() - if not normalized: - normalized = f"Tool function {tool_name}" - if truncate_enabled and len(normalized) > max_len: - normalized = normalized[:max_len].rstrip() - return normalized - - -def _sanitize_openai_tools( - tools: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], int, list[dict[str, Any]]]: - """Sanitize tools schema to avoid 400s on strict providers (e.g., invalid description).""" - if not tools or not _tools_sanitize_enabled(): - return tools, 0, [] - - max_len = _tools_description_max_len() - truncate_enabled = _tools_description_truncate_enabled() - changed = 0 - changes: list[dict[str, Any]] = [] - sanitized: list[dict[str, Any]] = [] - for idx, tool in enumerate(tools): - if not isinstance(tool, dict): - sanitized.append(tool) - continue - function = tool.get("function") - if not isinstance(function, dict): - sanitized.append(tool) - continue - name = function.get("name", "") - old_desc = function.get("description") - old_desc_str = ( - "" - if old_desc is None - else (old_desc if isinstance(old_desc, str) else str(old_desc)) - ) - new_desc = _normalize_tool_description( - old_desc, - str(name), - max_len, - truncate_enabled, - ) - - if old_desc_str != new_desc: - reasons: list[str] = [] - if not isinstance(old_desc, str): - reasons.append("non_string") - if any(ord(ch) < 32 or ord(ch) == 127 for ch in old_desc_str): - reasons.append("control_chars") - if "\n" in old_desc_str or "\r" in old_desc_str or "\t" in old_desc_str: - reasons.append("whitespace") - if not old_desc_str.strip(): - reasons.append("empty") - if ( - truncate_enabled - and len(new_desc) >= max_len - and len(old_desc_str) > len(new_desc) - ): - reasons.append("truncated") - - tool = dict(tool) - function = dict(function) - function["description"] = new_desc - tool["function"] = function - changed += 1 - changes.append( - { - "index": idx, - "name": str(name), - "old_len": len(old_desc_str), - "new_len": len(new_desc), - "old_preview": _desc_preview(_clean_control_chars(old_desc_str)), - "new_preview": _desc_preview(new_desc), - "reasons": reasons, - } - ) - sanitized.append(tool) - return sanitized, changed, changes - - -def _sanitize_openai_messages_tool_arguments( - messages: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], int]: - """Sanitize messages[].tool_calls[].function.arguments to strict JSON strings. - - Some OpenAI-compatible providers reject non-JSON `function.arguments` in the - request body (even though upstream OpenAI treats it as an opaque string). - This primarily affects conversations that include historical tool_calls. - """ - if not messages: - return messages, 0 - - changed = 0 - sanitized_messages: list[dict[str, Any]] = [] - for message in messages: - if not isinstance(message, dict): - sanitized_messages.append(message) - continue - - tool_calls = message.get("tool_calls") - if not isinstance(tool_calls, list) or not tool_calls: - sanitized_messages.append(message) - continue - - tool_calls_changed = False - sanitized_tool_calls: list[Any] = [] - for tool_call in tool_calls: - if not isinstance(tool_call, dict): - sanitized_tool_calls.append(tool_call) - continue - function = tool_call.get("function") - if not isinstance(function, dict): - sanitized_tool_calls.append(tool_call) - continue - - old_args = function.get("arguments") - new_args = normalize_tool_arguments_json(old_args) - if isinstance(old_args, str) and old_args == new_args: - sanitized_tool_calls.append(tool_call) - continue - - tool_calls_changed = True - changed += 1 - new_tool_call = dict(tool_call) - new_function = dict(function) - new_function["arguments"] = new_args - new_tool_call["function"] = new_function - sanitized_tool_calls.append(new_tool_call) - - if tool_calls_changed: - new_message = dict(message) - new_message["tool_calls"] = sanitized_tool_calls - sanitized_messages.append(new_message) - else: - sanitized_messages.append(message) - - return sanitized_messages, changed - - -def _sanitize_chat_completion_messages( - messages: list[dict[str, Any]], - *, - preserve_reasoning_content: bool = False, -) -> tuple[list[dict[str, Any]], int, dict[str, int]]: - """移除 Chat Completions 非标准消息字段。 - - 本地历史里允许保留 reasoning_content 等兼容字段用于日志/回放; - 发往上游时默认剥离。``preserve_reasoning_content=True`` 时保留 - ``reasoning_content`` 供多轮 CoT 续传,仍剥离其它内部字段。 - """ - if not messages: - return messages, 0, {} - - changed = 0 - stripped_fields: dict[str, int] = {} - sanitized_messages: list[dict[str, Any]] = [] - for message in messages: - if not isinstance(message, dict): - sanitized_messages.append(message) - continue - - sanitized_message = message - removed = False - for key in _CHAT_COMPLETION_INTERNAL_MESSAGE_KEYS: - if preserve_reasoning_content and key == "reasoning_content": - continue - if key not in sanitized_message: - continue - if sanitized_message is message: - sanitized_message = dict(message) - sanitized_message.pop(key, None) - stripped_fields[key] = stripped_fields.get(key, 0) + 1 - removed = True - - if removed: - changed += 1 - sanitized_messages.append(sanitized_message) - - return sanitized_messages, changed, stripped_fields - - -def _relocate_system_to_first_user( - messages: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """将 system/developer 消息合并注入首条 user 消息(chat_completions 适配)。""" - if not messages: - return messages - - system_parts: list[str] = [] - remaining: list[dict[str, Any]] = [] - for message in messages: - if not isinstance(message, dict): - remaining.append(message) - continue - role = str(message.get("role") or "").strip().lower() - if role in ("system", "developer"): - content = message.get("content") - if content is not None: - text = content if isinstance(content, str) else str(content) - if text.strip(): - system_parts.append(text.strip()) - continue - remaining.append(message) - - if not system_parts: - return messages - - merged_system = "\n\n".join(system_parts) - first_user_idx: int | None = None - for idx, message in enumerate(remaining): - if ( - isinstance(message, dict) - and str(message.get("role") or "").strip().lower() == "user" - ): - first_user_idx = idx - break - - if first_user_idx is None: - remaining.insert(0, {"role": "user", "content": merged_system}) - return remaining - - first_user = dict(remaining[first_user_idx]) - old_content = first_user.get("content") - old_text = ( - old_content - if isinstance(old_content, str) - else (str(old_content) if old_content is not None else "") - ) - if old_text.strip(): - first_user["content"] = f"{merged_system}\n\n{old_text}" - else: - first_user["content"] = merged_system - updated = list(remaining) - updated[first_user_idx] = first_user - return updated - - -def _prepare_chat_completion_messages( - model_config: ModelConfig, - messages: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """按模型配置整理 Chat Completions 出站消息。""" - preserve_reasoning = bool(getattr(model_config, "reasoning_content_replay", False)) - prepared, _, _ = _sanitize_chat_completion_messages( - messages, - preserve_reasoning_content=preserve_reasoning, - ) - if bool(getattr(model_config, "system_prompt_as_user", False)): - prepared = _relocate_system_to_first_user(prepared) - return prepared - - -def _stringify_thinking_list(value: list[Any]) -> str: - """将列表类型的思维链转换为字符串。 - - Args: - value: 思维链列表 - - Returns: - 格式化后的字符串 - """ - parts = [_stringify_thinking(item) for item in value] - return "\n".join([part for part in parts if part]) - - -def _stringify_thinking_dict(value: dict[str, Any]) -> str: - """将字典类型的思维链转换为字符串。 - - Args: - value: 思维链字典 - - Returns: - 格式化后的字符串 - """ - content = value.get("content") - if isinstance(content, str) and content: - return content - return str(value) - - -def _stringify_thinking(value: Any) -> str: - """将思维链值转换为字符串。 - - Args: - value: 思维链值(可以是 None、字符串、列表或字典) - - Returns: - 格式化后的字符串 - """ - if value is None: - return "" - if isinstance(value, str): - return value - if isinstance(value, list): - return _stringify_thinking_list(value) - if isinstance(value, dict): - return _stringify_thinking_dict(value) - return str(value) - - -def _extract_from_message(message: dict[str, Any]) -> str: - """从 message 对象中提取思维链内容。 - - Args: - message: message 对象 - - Returns: - 思维链内容字符串 - """ - if not isinstance(message, dict): - return "" - for key in _THINKING_KEYS: - if key in message: - return _stringify_thinking(message.get(key)) - return "" - - -def _extract_from_choice(choice: dict[str, Any]) -> str: - """从 choice 对象中提取思维链内容。 - - Args: - choice: choice 对象 - - Returns: - 思维链内容字符串 - """ - if not isinstance(choice, dict): - return "" - - # 优先从 message 中提取 - message = choice.get("message") - if isinstance(message, dict): - thinking = _extract_from_message(message) - if thinking: - return thinking - - # 尝试从 choice 直接提取 - for key in _THINKING_KEYS: - if key in choice: - return _stringify_thinking(choice.get(key)) - - return "" - - -def _extract_from_choices(choices: list[Any]) -> str: - """从 choices 列表中提取思维链内容。 - - Args: - choices: choices 列表 - - Returns: - 思维链内容字符串 - """ - if not isinstance(choices, list) or not choices: - return "" - choice = choices[0] - return _extract_from_choice(choice) - - -def _extract_from_result(result: dict[str, Any]) -> str: - """直接从结果对象中提取思维链内容。 - - Args: - result: API 响应结果 - - Returns: - 思维链内容字符串 - """ - for key in _THINKING_KEYS: - if key in result: - return _stringify_thinking(result.get(key)) - return "" - - -def _extract_thinking_content(result: dict[str, Any]) -> str: - """从 API 响应中提取思维链内容。 - - 提取优先级: - 1. 从 choices[0].message 中提取 - 2. 从 choices[0] 直接提取 - 3. 从响应根对象中提取 - - Args: - result: API 响应结果 - - Returns: - 思维链内容字符串 - """ - # 尝试从 choices 中提取 - choices = result.get("choices") - if isinstance(choices, list): - thinking = _extract_from_choices(choices) - if thinking: - return thinking - - return _extract_from_result(result) - - -def _normalize_openai_base_url( - api_url: str, -) -> tuple[str, dict[str, object] | None, bool]: - """将旧式 /chat/completions URL 归一化为 OpenAI SDK 需要的 base_url。 - - 兼容策略(B):如果发现 api_url 末尾包含 /chat/completions,则自动裁剪为 base_url, - 以便统一走 OpenAI SDK,并给出弃用警告。 - """ - try: - parts = urlsplit(api_url) - except Exception: - return api_url, None, False - - path = parts.path or "" - trimmed_path = path.rstrip("/") - suffix = "/chat/completions" - if not trimmed_path.endswith(suffix): - return api_url, None, False - - new_path = trimmed_path[: -len(suffix)] - default_query: dict[str, object] | None = None - if parts.query: - default_query = { - k: v for k, v in parse_qsl(parts.query, keep_blank_values=True) - } - normalized = urlunsplit(parts._replace(path=new_path, query="", fragment="")) - return normalized, default_query, True - - -def _warn_ignored_request_params( - *, - call_type: str, - model_name: str, - ignored: dict[str, Any], -) -> None: - if not ignored: - return - logger.warning( - "[request_params] ignored_keys=%s type=%s model=%s", - ",".join(sorted(ignored)), - call_type, - model_name, - ) - - -def _build_effective_request_kwargs( - model_config: ModelConfig, - *, - call_type: str, - overrides: dict[str, Any], -) -> dict[str, Any]: - merged = merge_request_params( - getattr(model_config, "request_params", {}), - overrides, - ) - thinking_override = overrides["thinking"] if "thinking" in overrides else None - has_thinking_override = "thinking" in overrides - reserved_fields = ( - _RESPONSES_RESERVED_FIELDS - if get_api_mode(model_config) == API_MODE_RESPONSES - else _CHAT_COMPLETIONS_RESERVED_FIELDS - ) - allowed, ignored = split_reserved_request_params( - merged, - reserved_fields, - ) - if has_thinking_override: - ignored.pop("thinking", None) - _warn_ignored_request_params( - call_type=call_type, - model_name=model_config.model_name, - ignored=ignored, - ) - if has_thinking_override: - allowed["thinking"] = thinking_override - return allowed - - -class ModelRequester: - """统一的模型请求封装。""" - - def __init__( - self, - http_client: httpx.AsyncClient, - token_usage_storage: TokenUsageStorage, - ) -> None: - self._http_client = http_client - self._token_usage_storage = token_usage_storage - self._openai_clients: dict[ - tuple[str, str, tuple[tuple[str, str], ...] | None], AsyncOpenAI - ] = {} - self._token_counters: dict[str, TokenCounter] = {} - self._warned_legacy_api_urls: set[str] = set() - self._background_tasks: set[asyncio.Task[Any]] = set() - self._retrieval_requester = RetrievalRequester( - get_openai_client=self._get_openai_client_for_model, - response_to_dict=self._response_to_dict, - get_token_counter=self._get_token_counter, - record_usage=self._record_usage, - ) - - async def request( - self, - model_config: ModelConfig, - messages: list[dict[str, Any]], - max_tokens: int = 8192, - call_type: str = "chat", - tools: list[dict[str, Any]] | None = None, - tool_choice: str = "auto", - transport_state: dict[str, Any] | None = None, - message_count_for_transport: int | None = None, - **kwargs: Any, - ) -> dict[str, Any]: - """发送请求到模型 API。""" - start_time = time.perf_counter() - cot_compat = getattr(model_config, "thinking_tool_call_compat", False) - reasoning_replay = bool( - getattr(model_config, "reasoning_content_replay", False) - ) - api_mode = get_api_mode(model_config) - transport_message_count = ( - message_count_for_transport - if message_count_for_transport is not None - else len(messages) - ) - messages_for_api, tool_args_fixed = _sanitize_openai_messages_tool_arguments( - messages - ) - if tool_args_fixed and logger.isEnabledFor(logging.INFO): - logger.info( - "[messages.sanitize] tool_args_fixed=%s messages=%s", - tool_args_fixed, - len(messages_for_api), - ) - if api_mode == API_MODE_CHAT_COMPLETIONS: - ( - messages_for_api, - stripped_message_count, - stripped_message_fields, - ) = _sanitize_chat_completion_messages( - messages_for_api, - preserve_reasoning_content=reasoning_replay, - ) - if bool(getattr(model_config, "system_prompt_as_user", False)): - messages_for_api = _relocate_system_to_first_user(messages_for_api) - if stripped_message_count and logger.isEnabledFor(logging.INFO): - details = ",".join( - f"{key}={value}" - for key, value in sorted(stripped_message_fields.items()) - ) - logger.info( - "[chat_completions.standardize] stripped_internal_message_fields=%s messages=%s", - details, - stripped_message_count, - ) - - tools_for_api = tools - api_to_internal: dict[str, str] = {} - internal_to_api: dict[str, str] = {} - if isinstance(tools_for_api, list): - request_for_sanitize = { - "messages": messages_for_api, - "tools": list(tools_for_api), - } - api_to_internal, internal_to_api = _sanitize_openai_tool_names_in_request( - request_for_sanitize - ) - raw_messages = request_for_sanitize.get("messages") - if isinstance(raw_messages, list): - messages_for_api = raw_messages - raw_tools = request_for_sanitize.get("tools") - if isinstance(raw_tools, list): - tools_for_api = raw_tools - - if isinstance(tools_for_api, list): - sanitized_tools, changed_count, changes = _sanitize_openai_tools( - tools_for_api - ) - tools_for_api = sanitized_tools - if changed_count and logger.isEnabledFor(logging.INFO): - logger.info( - "[tools.sanitize] changed=%s total=%s truncate_enabled=%s max_desc_len=%s", - changed_count, - len(sanitized_tools), - _tools_description_truncate_enabled(), - _tools_description_max_len(), - ) - if _tools_sanitize_verbose(): - for change in changes: - logger.info( - "[tools.sanitize.item] index=%s name=%s reasons=%s old_len=%s new_len=%s old=%s new=%s", - change.get("index"), - change.get("name"), - ",".join(change.get("reasons", [])), - change.get("old_len"), - change.get("new_len"), - change.get("old_preview"), - change.get("new_preview"), - ) - - effective_kwargs = _build_effective_request_kwargs( - model_config, - call_type=call_type, - overrides=dict(kwargs), - ) - if bool( - getattr(model_config, "prompt_cache_enabled", True) - ) and not effective_kwargs.get("prompt_cache_key"): - effective_kwargs["prompt_cache_key"] = _build_default_prompt_cache_key( - model_config, - call_type, - ) - responses_stateless_replay = bool( - getattr(model_config, "responses_force_stateless_replay", False) - ) or bool( - isinstance(transport_state, dict) - and transport_state.get("stateless_replay") - ) - effective_transport_state: dict[str, Any] | None - if responses_stateless_replay: - effective_transport_state = dict(transport_state or {}) - effective_transport_state["stateless_replay"] = True - else: - effective_transport_state = transport_state - request_body = build_request_body( - model_config=model_config, - messages=messages_for_api, - max_tokens=max_tokens, - tools=tools_for_api, - tool_choice=tool_choice, - internal_to_api=internal_to_api, - transport_state=effective_transport_state, - **effective_kwargs, - ) - - try: - if cot_compat and logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[思维链兼容] enabled=%s type=%s model=%s api_mode=%s thinking_enabled=%s tools=%s messages=%s", - cot_compat, - call_type, - model_config.model_name, - api_mode, - getattr(model_config, "thinking_enabled", False), - bool(tools), - len(messages), - ) - - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "[API请求] type=%s model=%s api_mode=%s url=%s max_tokens=%s tools=%s tool_choice=%s messages=%s", - call_type, - model_config.model_name, - api_mode, - model_config.api_url, - max_tokens, - bool(tools_for_api), - tool_choice, - len(messages), - ) - log_debug_json(logger, "[API请求体]", request_body) - - try: - raw_result = await self._request_with_openai(model_config, request_body) - except APIStatusError as exc: - if ( - api_mode == API_MODE_RESPONSES - and _responses_should_fallback_to_stateless_replay( - exc, - request_body, - stateless_replay=responses_stateless_replay, - ) - ): - logger.warning( - "[responses.compat] previous_response_id 续轮失败,自动降级为 stateless replay: model=%s call_type=%s previous_response_id=%s", - model_config.model_name, - call_type, - request_body.get("previous_response_id", ""), - ) - effective_transport_state = dict(effective_transport_state or {}) - effective_transport_state["stateless_replay"] = True - responses_stateless_replay = True - request_body = build_request_body( - model_config=model_config, - messages=messages_for_api, - max_tokens=max_tokens, - tools=tools_for_api, - tool_choice=tool_choice, - internal_to_api=internal_to_api, - transport_state=effective_transport_state, - **effective_kwargs, - ) - if logger.isEnabledFor(logging.DEBUG): - log_debug_json( - logger, "[API请求体][stateless replay]", request_body - ) - raw_result = await self._request_with_openai( - model_config, request_body - ) - else: - raise - if api_mode == API_MODE_RESPONSES: - result = normalize_responses_result( - raw_result, - api_to_internal if api_to_internal else None, - ) - response_id = str( - raw_result.get("id") or result.get("id") or "" - ).strip() - if response_id: - choice = result.get("choices", [{}])[0] - message = ( - choice.get("message", {}) if isinstance(choice, dict) else {} - ) - tool_calls = ( - message.get("tool_calls", []) - if isinstance(message, dict) - else [] - ) - result["_transport_state"] = { - "api_mode": api_mode, - "previous_response_id": response_id, - "tool_result_start_index": transport_message_count - + (1 if tool_calls else 0), - } - if responses_stateless_replay: - result["_transport_state"]["stateless_replay"] = True - else: - result = self._normalize_result(raw_result) - if api_to_internal: - result["_tool_name_map"] = { - "api_to_internal": api_to_internal, - "internal_to_api": internal_to_api, - "dot_delimiter": _tool_name_dot_delimiter(), - } - duration = time.perf_counter() - start_time - - usage = result.get("usage", {}) or {} - prompt_tokens = int(usage.get("prompt_tokens", 0) or 0) - completion_tokens = int(usage.get("completion_tokens", 0) or 0) - total_tokens = int(usage.get("total_tokens", 0) or 0) - if total_tokens == 0 and (prompt_tokens or completion_tokens): - total_tokens = prompt_tokens + completion_tokens - if total_tokens == 0: - prompt_tokens, completion_tokens, total_tokens = self._estimate_usage( - model_config.model_name, messages_for_api, result - ) - - logger.info( - f"[API响应] {call_type} 完成: 耗时={duration:.2f}s, " - f"Tokens={total_tokens} (P:{prompt_tokens} + C:{completion_tokens}), " - f"模型={model_config.model_name}" - ) - - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[API响应体]", result) - - self._maybe_log_thinking(result, call_type, model_config.model_name) - - self._record_usage( - model_name=model_config.model_name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - duration_seconds=duration, - call_type=call_type, - ) - - return result - except APIStatusError as exc: - response = exc.response - try: - body = ( - json.dumps(exc.body, ensure_ascii=False, default=str) - if exc.body is not None - else "" - ) - except Exception: - body = str(exc.body) - if ( - exc.status_code == 400 - and isinstance(exc.body, dict) - and isinstance(exc.body.get("error"), dict) - ): - param = exc.body.get("error", {}).get("param") - if isinstance(param, str): - match = _TOOLS_PARAM_INDEX_RE.search(param) - if match and isinstance(request_body.get("tools"), list): - try: - idx = int(match.group(1)) - except ValueError: - idx = -1 - if 0 <= idx < len(request_body["tools"]): - tool = request_body["tools"][idx] - tool_name = ( - tool.get("function", {}).get("name") - if isinstance(tool, dict) - else "" - ) - desc_len: int | None = None - desc_preview = "" - if isinstance(tool, dict): - function = tool.get("function", {}) - if isinstance(function, dict): - desc = function.get("description") - if desc is not None: - desc_str = ( - desc if isinstance(desc, str) else str(desc) - ) - desc_len = len(desc_str) - desc_preview = _desc_preview(desc_str) - logger.error( - "[tools.invalid] index=%s name=%s desc_len=%s desc=%s param=%s", - idx, - tool_name, - desc_len, - desc_preview, - param, - ) - logger.error( - "[API响应错误] status=%s request_id=%s url=%s body=%s", - exc.status_code, - exc.request_id or "", - response.request.url, - redact_string(body), - ) - raise - except (APIConnectionError, APITimeoutError) as exc: - logger.error("[API连接错误] type=%s message=%s", type(exc).__name__, exc) - raise - except Exception as exc: - logger.exception(f"[model.request.error] {call_type} 调用失败: {exc}") - raise - - def _thinking_logging_enabled(self) -> bool: - runtime_config = _get_runtime_config() - if runtime_config is None: - return True - return bool(runtime_config.log_thinking) - - def _maybe_log_thinking( - self, result: dict[str, Any], call_type: str, model_name: str - ) -> None: - if not self._thinking_logging_enabled(): - return - thinking = _extract_thinking_content(result) - if thinking: - logger.info( - "[思维链] type=%s model=%s content=%s", - call_type, - model_name, - redact_string(thinking), - ) - - async def _request_with_openai( - self, model_config: ModelConfig, request_body: dict[str, Any] - ) -> dict[str, Any]: - client = self._get_openai_client_for_model(model_config) - if bool(getattr(model_config, "stream_enabled", False)): - try: - return await self._request_with_openai_streaming( - client, model_config, request_body - ) - except Exception as exc: - if not _should_fallback_from_stream(exc): - raise - logger.warning( - "[API流式回退] model=%s api_mode=%s reason=%s", - getattr(model_config, "model_name", ""), - get_api_mode(model_config), - type(exc).__name__, - ) - request_body = _without_stream_request_fields(request_body) - if get_api_mode(model_config) == API_MODE_RESPONSES: - params, extra_body = _split_responses_params(request_body) - if extra_body: - params["extra_body"] = extra_body - response = await client.responses.create(**params) - return self._response_to_dict(response) - params, extra_body = _split_chat_completion_params(request_body) - if extra_body: - params["extra_body"] = extra_body - response = await client.chat.completions.create(**params) - return self._response_to_dict(response) - - async def _request_with_openai_streaming( - self, - client: AsyncOpenAI, - model_config: ModelConfig, - request_body: dict[str, Any], - ) -> dict[str, Any]: - api_mode = get_api_mode(model_config) - stream_body = dict(request_body) - stream_body["stream"] = True - if api_mode == API_MODE_RESPONSES: - return await self._stream_responses_request(client, stream_body) - _ensure_chat_stream_usage_options(stream_body) - return await self._stream_chat_completions_request( - client, stream_body, model_config - ) - - async def _stream_chat_completions_request( - self, - client: AsyncOpenAI, - request_body: dict[str, Any], - model_config: ModelConfig, - ) -> dict[str, Any]: - params, extra_body = _split_chat_completion_params(request_body) - if extra_body: - params["extra_body"] = extra_body - response = await client.chat.completions.create(**params) - - content_parts: list[str] = [] - reasoning_parts: list[str] = [] - tool_calls: list[dict[str, Any]] = [] - usage: dict[str, Any] | None = None - finish_reason = "stop" - role = "assistant" - reasoning_replay = bool( - getattr(model_config, "reasoning_content_replay", False) - ) - - async for chunk in response: - chunk_dict = self._response_to_dict(chunk) - usage = ( - _extract_stream_usage(chunk_dict, api_mode=API_MODE_CHAT_COMPLETIONS) - or usage - ) - choices = chunk_dict.get("choices") - if not isinstance(choices, list): - continue - for choice in choices: - if not isinstance(choice, dict): - continue - delta = choice.get("delta") - if not isinstance(delta, dict): - continue - role_value = str(delta.get("role") or "").strip() - if role_value: - role = role_value - content_delta = _stringify_stream_delta(delta.get("content")) - if content_delta: - content_parts.append(content_delta) - if reasoning_replay: - reasoning_delta = _stringify_thinking( - delta.get("reasoning_content") - ) - if reasoning_delta: - reasoning_parts.append(reasoning_delta) - raw_tool_calls = delta.get("tool_calls") - if isinstance(raw_tool_calls, list): - for tool_delta in raw_tool_calls: - if isinstance(tool_delta, dict): - _merge_tool_call_delta(tool_calls, tool_delta) - current_finish_reason = str(choice.get("finish_reason") or "").strip() - if current_finish_reason: - finish_reason = current_finish_reason - - message: dict[str, Any] = { - "role": role, - "content": "".join(content_parts), - } - if reasoning_replay: - reasoning_text = "".join(reasoning_parts).strip() - if reasoning_text: - message["reasoning_content"] = reasoning_text - if tool_calls: - message["tool_calls"] = tool_calls - result: dict[str, Any] = { - "choices": [ - { - "index": 0, - "message": message, - "finish_reason": finish_reason, - } - ] - } - if usage is not None: - result["usage"] = usage - return result - - async def _stream_responses_request( - self, client: AsyncOpenAI, request_body: dict[str, Any] - ) -> dict[str, Any]: - params, extra_body = _split_responses_params(request_body) - if extra_body: - params["extra_body"] = extra_body - stream = await client.responses.create(**params) - - output_items: list[dict[str, Any]] = [] - output_text_parts: list[str] = [] - usage: dict[str, Any] | None = None - final_response: dict[str, Any] | None = None - - async for event in stream: - event_dict = self._response_to_dict(event) - usage = ( - _extract_stream_usage(event_dict, api_mode=API_MODE_RESPONSES) or usage - ) - event_type = str(event_dict.get("type") or "").strip().lower() - response = event_dict.get("response") - if isinstance(response, dict): - final_response = response - if event_type == "response.output_text.delta": - delta = _stringify_stream_delta(event_dict.get("delta")) - if delta: - output_text_parts.append(delta) - continue - if event_type == "response.completed": - if isinstance(response, dict): - final_response = response - continue - item = _extract_stream_response_item(event_dict) - if not isinstance(item, dict): - continue - item_type = str(item.get("type") or "").strip().lower() - if item_type == "message": - output_items.append(item) - continue - if item_type == "function_call": - output_items.append(item) - continue - if item_type == "reasoning": - output_items.append(item) - - if final_response is not None: - if usage is not None and not isinstance(final_response.get("usage"), dict): - final_response = dict(final_response) - final_response["usage"] = usage - return final_response - - synthesized: dict[str, Any] = { - "output": output_items, - "output_text": "".join(output_text_parts), - } - if usage is not None: - synthesized["usage"] = usage - return synthesized - - async def embed( - self, - model_config: EmbeddingModelConfig, - texts: list[str], - ) -> list[list[float]]: - """调用统一检索请求层的 embeddings。""" - return await self._retrieval_requester.embed(model_config, texts) - - async def rerank( - self, - model_config: RerankModelConfig, - query: str, - documents: list[str], - top_n: int | None = None, - ) -> list[dict[str, Any]]: - """调用统一检索请求层的 rerank。""" - return await self._retrieval_requester.rerank( - model_config=model_config, - query=query, - documents=documents, - top_n=top_n, - ) - - def _get_openai_client_for_model(self, model_config: ModelConfig) -> AsyncOpenAI: - base_url, default_query, changed = _normalize_openai_base_url( - model_config.api_url - ) - if changed and model_config.api_url not in self._warned_legacy_api_urls: - self._warned_legacy_api_urls.add(model_config.api_url) - logger.warning( - "[配置弃用] 检测到 *_MODEL_API_URL 末尾包含 /chat/completions,这种写法已弃用;" - "已自动裁剪为 base_url=%s(原值=%s)。", - base_url, - model_config.api_url, - ) - return self._get_openai_client( - base_url=base_url, - api_key=model_config.api_key, - default_query=default_query, - ) - - def _record_usage( - self, - *, - model_name: str, - prompt_tokens: int, - completion_tokens: int, - total_tokens: int, - duration_seconds: float, - call_type: str, - ) -> None: - task = asyncio.create_task( - self._token_usage_storage.record( - TokenUsage( - timestamp=datetime.now().isoformat(), - model_name=model_name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - duration_seconds=duration_seconds, - call_type=call_type, - success=True, - ) - ) - ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - - def _get_openai_client( - self, base_url: str, api_key: str, default_query: dict[str, object] | None - ) -> AsyncOpenAI: - query_key = None - if default_query: - query_key = tuple( - sorted((str(k), str(v)) for k, v in default_query.items()) - ) - cache_key = (base_url, api_key, query_key) - client = self._openai_clients.get(cache_key) - if client is not None: - return client - # 复用上层注入的 httpx client(连接池/超时等),避免每个 OpenAI client 自建连接池。 - client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, - timeout=480.0, - default_query=default_query, - http_client=self._http_client, - ) - self._openai_clients[cache_key] = client - return client - - def _response_to_dict(self, response: Any) -> dict[str, Any]: - if isinstance(response, dict): - return response - for attr in ("model_dump", "to_dict", "dict"): - method = getattr(response, attr, None) - if callable(method): - try: - value = method() - if isinstance(value, dict): - return value - except Exception: - continue - to_json = getattr(response, "to_json", None) - if callable(to_json): - try: - raw_json = to_json() - loaded = json.loads(str(raw_json)) - if isinstance(loaded, dict): - return loaded - except Exception: - pass - return {"data": str(response)} - - def _normalize_result(self, result: dict[str, Any]) -> dict[str, Any]: - choices = result.get("choices") - if isinstance(choices, list): - return result - data = result.get("data") - if isinstance(data, dict): - data_choices = data.get("choices") - if isinstance(data_choices, list): - normalized = dict(result) - normalized["choices"] = data_choices - return normalized - normalized = dict(result) - normalized["choices"] = [{}] - return normalized - - def _get_token_counter(self, model_name: str) -> TokenCounter: - counter = self._token_counters.get(model_name) - if counter is None: - counter = TokenCounter(model_name) - self._token_counters[model_name] = counter - return counter - - def _estimate_usage( - self, - model_name: str, - messages: list[dict[str, Any]], - result: dict[str, Any], - ) -> tuple[int, int, int]: - counter = self._get_token_counter(model_name) - try: - prompt_text = "\n".join( - json.dumps(message, ensure_ascii=False, default=str) - for message in messages - ) - except Exception: - prompt_text = str(messages) - prompt_tokens = counter.count(prompt_text) - - completion_text = "" - try: - completion_text = extract_choices_content(result) - except Exception: - completion_text = "" - if not completion_text: - choices = result.get("choices") - if isinstance(choices, list) and choices: - choice = choices[0] - if isinstance(choice, dict): - message = choice.get("message", {}) - tool_calls = ( - message.get("tool_calls") - if isinstance(message, dict) - else choice.get("tool_calls") - ) - if tool_calls: - try: - completion_text = json.dumps( - tool_calls, ensure_ascii=False, default=str - ) - except Exception: - completion_text = str(tool_calls) - completion_tokens = counter.count(completion_text) if completion_text else 0 - total_tokens = prompt_tokens + completion_tokens - logger.debug( - "[API响应] usage 缺失,估算 tokens: prompt=%s completion=%s total=%s", - prompt_tokens, - completion_tokens, - total_tokens, - ) - return prompt_tokens, completion_tokens, total_tokens - - -def build_request_body( - model_config: ModelConfig, - messages: list[dict[str, Any]], - max_tokens: int, - tools: list[dict[str, Any]] | None = None, - tool_choice: str = "auto", - internal_to_api: dict[str, str] | None = None, - transport_state: dict[str, Any] | None = None, - **kwargs: Any, -) -> dict[str, Any]: - """构建 API 请求体。""" - api_mode = get_api_mode(model_config) - extra_kwargs: dict[str, Any] = dict(kwargs) - - if "thinking" in extra_kwargs: - normalized = _normalize_thinking_override( - extra_kwargs.get("thinking"), model_config - ) - if normalized is None: - extra_kwargs.pop("thinking", None) - else: - extra_kwargs["thinking"] = normalized - - if api_mode == API_MODE_RESPONSES: - extra_kwargs.pop("reasoning", None) - extra_kwargs.pop("reasoning_effort", None) - extra_kwargs.pop("output_config", None) - return build_responses_request_body( - model_config, - messages, - max_tokens, - tools=tools, - tool_choice=tool_choice, - extra_kwargs=extra_kwargs, - internal_to_api=internal_to_api or {}, - transport_state=transport_state, - ) - - body: dict[str, Any] = { - "model": model_config.model_name, - "messages": _prepare_chat_completion_messages(model_config, messages), - "max_tokens": max_tokens, - } - - extra_kwargs.pop("reasoning", None) - extra_kwargs.pop("reasoning_effort", None) - extra_kwargs.pop("output_config", None) - - thinking = get_thinking_payload(model_config) - if thinking is not None: - body["thinking"] = thinking - - effort_payload = get_effort_payload(model_config) - if effort_payload is not None: - style = get_effort_style(model_config) - if style == "anthropic": - body["output_config"] = effort_payload - else: - body["reasoning_effort"] = effort_payload["effort"] - - if tools: - body["tools"] = tools - thinking_active = "thinking" in body - if thinking_active and isinstance(tool_choice, dict): - body["tool_choice"] = "auto" - else: - body["tool_choice"] = tool_choice - - body.update(extra_kwargs) - return body diff --git a/src/Undefined/ai/llm/__init__.py b/src/Undefined/ai/llm/__init__.py index a8191f6a..2d0fd29c 100644 --- a/src/Undefined/ai/llm/__init__.py +++ b/src/Undefined/ai/llm/__init__.py @@ -1,7 +1,6 @@ """LLM 模型请求子包。 -对外稳定入口:``ModelRequester``、``build_request_body``、``ModelConfig``; -旧路径 ``Undefined.ai.llm`` 通过包根与 ``llm.py`` shim 保持兼容。 +对外稳定入口:``ModelRequester``、``build_request_body``、``ModelConfig``。 """ from Undefined.ai.llm.requester import ModelRequester, build_request_body diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py deleted file mode 100644 index 8e5bcc90..00000000 --- a/src/Undefined/ai/multimodal.py +++ /dev/null @@ -1,893 +0,0 @@ -"""多模态分析辅助函数。""" - -from __future__ import annotations - -import asyncio -import base64 -import hashlib -import json -import logging -from pathlib import Path -import time -from typing import Any, cast -from urllib.parse import urlsplit - -import aiofiles -import httpx - -from Undefined.ai.parsing import extract_choices_content -from Undefined.utils.coerce import safe_float -from Undefined.ai.llm import ModelRequester -from Undefined.config import VisionModelConfig -from Undefined.ai.transports import API_MODE_CHAT_COMPLETIONS, get_api_mode -from Undefined.utils.tool_calls import extract_required_tool_call_arguments -from Undefined.utils.logging import log_debug_json, redact_string -from Undefined.utils.resources import read_text_resource - -logger = logging.getLogger(__name__) - -# 每个文件名最多保留的历史 Q&A 条数 -_MAX_QA_HISTORY = 5 - -_HISTORY_FILE_PATH = Path("data/media_qa_history.json") - -# 远程媒体缓存目录(用于先下载 URL 再转 data URL) -_MEDIA_URL_CACHE_DIR = Path("data/cache/multimodal_media") - -# 远程媒体缓存清理策略:仅保留最近 6 小时 + 最多 256 个文件。 -# Remote media cache cleanup policy: keep only recent 6h + max 256 files. -_MEDIA_URL_CACHE_TTL_SECONDS = 6 * 60 * 60 -_MEDIA_URL_CACHE_MAX_FILES = 256 - -# 两次自动清理之间的最小间隔(秒),避免每次请求都全量扫描目录。 -_MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS = 60.0 - -# 下载 URL 到本地缓存时的网络超时(秒)。 -_MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS = 120.0 - -# 下载阶段临时文件后缀(追加在缓存文件名后),用于区分真实缓存文件。 -_MEDIA_URL_DOWNLOAD_TMP_SUFFIX = ".downloading" - -_IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg") -_AUDIO_EXTENSIONS = (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".aac", ".wma") -_VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv", ".wmv") - -# MIME 类型前缀到媒体类型的映射 -_MIME_PREFIX_TO_TYPE = { - "image/": "image", - "audio/": "audio", - "video/": "video", -} - - -def _extract_mime_type_from_data_url(media_url: str) -> str | None: - """从 data URL 中提取 MIME 类型。 - - Args: - media_url: 媒体 URL - - Returns: - MIME 类型前缀(如 "image/")或 None - """ - if not media_url.startswith("data:"): - return None - mime_part = media_url.split(";")[0] - if ":" in mime_part: - return mime_part.split(":")[1] - return None - - -def _get_media_type_by_extension(url_lower: str) -> str: - """根据文件扩展名判断媒体类型。 - - Args: - url_lower: 转换为小写的 URL - - Returns: - 媒体类型("image"、"audio" 或 "video") - """ - for ext in _IMAGE_EXTENSIONS: - if ext in url_lower: - return "image" - for ext in _AUDIO_EXTENSIONS: - if ext in url_lower: - return "audio" - for ext in _VIDEO_EXTENSIONS: - if ext in url_lower: - return "video" - return "image" # 默认返回图片类型 - - -def detect_media_type(media_url: str, specified_type: str = "auto") -> str: - """检测媒体文件的类型(图片、音频或视频)。""" - if specified_type and specified_type != "auto": - return specified_type - - media_type = _detect_from_data_url(media_url) - if media_type: - return media_type - - return _detect_by_mimetypes(media_url) - - -def _detect_from_data_url(media_url: str) -> str | None: - """从 data URL 的 MIME 类型中探测媒体类型""" - mime = _extract_mime_type_from_data_url(media_url) - if mime: - for prefix, media_type in _MIME_PREFIX_TO_TYPE.items(): - if mime.startswith(prefix): - return media_type - return None - - -def _detect_by_mimetypes(media_url: str) -> str: - """利用 mimetypes 库或扩展名探测媒体类型""" - import mimetypes - - guessed_mime, _ = mimetypes.guess_type(media_url) - if guessed_mime: - for prefix, media_type in _MIME_PREFIX_TO_TYPE.items(): - if guessed_mime.startswith(prefix): - return media_type - - return _get_media_type_by_extension(media_url.lower()) - - -# 默认 MIME 类型映射 -_DEFAULT_MIME_TYPES = { - "image": "image/jpeg", - "audio": "audio/mpeg", - "video": "video/mp4", -} - - -def get_media_mime_type(media_type: str, file_path: str = "") -> str: - """获取媒体文件的 MIME 类型。 - - Args: - media_type: 媒体类型("image"、"audio" 或 "video") - file_path: 文件路径(可选),用于根据文件扩展名推断 MIME 类型 - - Returns: - MIME 类型字符串 - """ - # 如果提供了文件路径,优先使用 mimetypes 推断 - if file_path: - import mimetypes - - mime_type, _ = mimetypes.guess_type(file_path) - if mime_type: - return mime_type - - # 返回默认 MIME 类型 - return _DEFAULT_MIME_TYPES.get(media_type, "application/octet-stream") - - -_MEDIA_TYPE_TO_FIELD = { - "image": "ocr_text", - "audio": "transcript", - "video": "subtitles", -} - -_MEME_JUDGE_PROMPT_PATH = "res/prompts/judge_meme_image.txt" -_MEME_DESCRIBE_PROMPT_PATH = "res/prompts/describe_meme_image.txt" - -_MEME_JUDGE_TOOL = { - "type": "function", - "function": { - "name": "submit_meme_judgement", - "description": "提交表情包判定结果", - "parameters": { - "type": "object", - "properties": { - "is_meme": { - "type": "boolean", - "description": "该图片是否适合进入表情包库", - }, - "confidence": { - "type": "number", - "description": "0 到 1 的置信度", - }, - "reason": { - "type": "string", - "description": "简短中文判定原因", - }, - }, - "required": ["is_meme", "confidence", "reason"], - }, - }, -} - -_MEME_DESCRIBE_TOOL = { - "type": "function", - "function": { - "name": "submit_meme_description", - "description": "提交表情包描述与标签", - "parameters": { - "type": "object", - "properties": { - "description": { - "type": "string", - "description": "适合检索的简短中文描述", - }, - "tags": { - "type": "array", - "items": {"type": "string"}, - "description": "0 到 6 个短标签", - }, - }, - "required": ["description", "tags"], - }, - }, -} - - -_ERROR_MESSAGES = { - "read": { - "image": "[图片无法读取]", - "audio": "[音频无法读取]", - "video": "[视频无法读取]", - "default": "[媒体文件无法读取]", - }, - "analyze": { - "image": "[图片分析失败]", - "audio": "[音频分析失败]", - "video": "[视频分析失败]", - "default": "[媒体分析失败]", - }, -} - - -def _parse_line_value(line: str, prefix: str) -> str: - """解析行内容,提取指定前缀后的值。 - - Args: - line: 待解析的行 - prefix: 前缀(支持中文冒号和英文冒号) - - Returns: - 提取的值,如果值为 "无" 则返回空字符串 - """ - value = line.split(":", 1)[-1].split(":", 1)[-1].strip() - return "" if value == "无" else value - - -def _parse_analysis_response(content: str) -> dict[str, str]: - """解析 AI 分析响应的内容。 - - Args: - content: AI 返回的文本内容 - - Returns: - 包含描述和类型特定字段的字典 - """ - # 字段前缀映射(支持中文冒号和英文冒号) - field_prefixes = { - "description": ("描述:", "描述:"), - "ocr_text": ("OCR:", "OCR:"), - "transcript": ("转写:", "转写:"), - "subtitles": ("字幕:", "字幕:"), - } - - result = { - "description": "", - "ocr_text": "", - "transcript": "", - "subtitles": "", - } - - for line in content.split("\n"): - line = line.strip() - for field, prefixes in field_prefixes.items(): - if line.startswith(prefixes): - result[field] = _parse_line_value(line, prefixes[0]) - - # 如果没有解析到描述,使用完整内容作为描述 - if not result["description"]: - result["description"] = content - - return result - - -def _extract_json_object(content: str) -> dict[str, Any]: - text = str(content or "").strip() - if not text: - return {} - candidates = [text] - if "```" in text: - parts = text.split("```") - for part in parts: - stripped = part.strip() - if not stripped: - continue - if stripped.lower().startswith("json"): - stripped = stripped[4:].strip() - candidates.append(stripped) - for candidate in candidates: - try: - parsed = json.loads(candidate) - except json.JSONDecodeError: - continue - if isinstance(parsed, dict): - return parsed - start = text.find("{") - end = text.rfind("}") - if start >= 0 and end > start: - try: - parsed = json.loads(text[start : end + 1]) - except json.JSONDecodeError: - return {} - if isinstance(parsed, dict): - return parsed - return {} - - -def _normalize_meme_tags(tags_raw: Any) -> list[str]: - tags: list[str] = [] - if isinstance(tags_raw, list): - seen: set[str] = set() - for item in tags_raw: - text = str(item or "").strip() - lowered = text.lower() - if not text or lowered in seen: - continue - seen.add(lowered) - tags.append(text) - return tags - - -def _parse_meme_analysis_response(content: str) -> dict[str, Any]: - parsed = _extract_json_object(content) - return { - "is_meme": bool(parsed.get("is_meme", False)), - "confidence": safe_float(parsed.get("confidence", 0.0), default=0.0), - "description": str(parsed.get("description") or "").strip(), - "tags": _normalize_meme_tags(parsed.get("tags")), - } - - -class MultimodalAnalyzer: - """多模态媒体分析器。 - - 支持分析图片、音频和视频文件,提取描述内容和类型特定信息(如 OCR 文字、转写文字、字幕等)。 - """ - - def __init__( - self, - requester: ModelRequester, - vision_config: VisionModelConfig, - prompt_path: str = "res/prompts/analyze_multimodal.txt", - ) -> None: - """初始化多模态分析器。 - - Args: - requester: 模型请求器 - vision_config: 视觉模型配置 - prompt_path: 提示词模板文件路径 - """ - self._requester = requester - self._vision_config = vision_config - self._prompt_path = prompt_path - self._cache: dict[str, dict[str, str]] = {} - # 按文件名索引的 Q&A 历史:{filename: [{q: ..., a: ...}, ...]} - self._file_history: dict[str, list[dict[str, str]]] = {} - - # URL 下载锁:按 URL 哈希粒度加锁,避免并发下载同一文件造成竞态。 - # URL download lock: keyed by URL hash to avoid duplicate concurrent downloads. - self._url_cache_locks: dict[str, asyncio.Lock] = {} - self._url_cache_locks_guard = asyncio.Lock() - - # 缓存清理锁 + 上次清理时间,避免并发清理相互干扰。 - # Cache cleanup lock + last cleanup timestamp to avoid concurrent cleanup races. - self._url_cache_cleanup_lock = asyncio.Lock() - self._last_url_cache_cleanup_at = 0.0 - - self._load_history() - - async def _load_media_content(self, media_url: str, media_type: str) -> str: - """加载媒体内容。 - - 如果是本地文件,会将其转换为 base64 编码的 data URL。 - - Args: - media_url: 媒体 URL 或本地文件路径 - media_type: 媒体类型 - - Returns: - 可用于 API 请求的媒体内容字符串 - """ - if media_url.startswith("data:"): - return media_url - - if media_url.startswith("http://") or media_url.startswith("https://"): - return await self._load_remote_media_as_data_url(media_url, media_type) - - # 读取本地文件并转换为 base64 - async with aiofiles.open(media_url, "rb") as f: - media_bytes = bytes(await f.read()) - media_data = base64.b64encode(media_bytes).decode() - mime_type = get_media_mime_type(media_type, media_url) - return f"data:{mime_type};base64,{media_data}" - - async def _load_remote_media_as_data_url( - self, media_url: str, media_type: str - ) -> str: - """将远程 URL 下载到缓存并转换为 data URL。""" - cache_key = self._build_url_cache_key(media_url) - lock = await self._get_url_cache_lock(cache_key) - cache_path = self._build_url_cache_path(cache_key, media_url) - - async with lock: - await self._cleanup_url_cache_if_needed() - if not cache_path.exists(): - await self._download_url_to_cache(media_url, cache_path) - async with aiofiles.open(cache_path, "rb") as f: - media_bytes = bytes(await f.read()) - media_data = base64.b64encode(media_bytes).decode() - - mime_type = get_media_mime_type(media_type, media_url) - return f"data:{mime_type};base64,{media_data}" - - def _build_url_cache_key(self, media_url: str) -> str: - """构建 URL 缓存键(使用 URL 内容哈希)。""" - return hashlib.sha256(media_url.encode("utf-8")).hexdigest() - - def _build_url_cache_path(self, cache_key: str, media_url: str) -> Path: - """基于 URL 生成缓存文件路径。""" - suffix = Path(urlsplit(media_url).path).suffix.lower() - if not suffix or len(suffix) > 10: - suffix = ".bin" - return _MEDIA_URL_CACHE_DIR / f"{cache_key}{suffix}" - - async def _get_url_cache_lock(self, cache_key: str) -> asyncio.Lock: - """获取 URL 对应的下载锁(同 URL 串行化)。""" - async with self._url_cache_locks_guard: - lock = self._url_cache_locks.get(cache_key) - if lock is None: - lock = asyncio.Lock() - self._url_cache_locks[cache_key] = lock - return lock - - async def _download_url_to_cache(self, media_url: str, cache_path: Path) -> None: - """下载远程 URL 到缓存文件(原子写入,避免部分文件)。""" - cache_path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = cache_path.with_name( - f"{cache_path.name}{_MEDIA_URL_DOWNLOAD_TMP_SUFFIX}" - ) - try: - timeout = httpx.Timeout(_MEDIA_URL_DOWNLOAD_TIMEOUT_SECONDS) - async with httpx.AsyncClient( - timeout=timeout, follow_redirects=True - ) as client: - response = await client.get(media_url) - response.raise_for_status() - async with aiofiles.open(tmp_path, "wb") as f: - await f.write(response.content) - tmp_path.replace(cache_path) - except Exception: - try: - tmp_path.unlink(missing_ok=True) - except Exception: - pass - raise - - @staticmethod - def _extract_cache_key_from_tmp(path: Path) -> str: - """从临时文件名提取 cache_key({key}.{ext}. -> key)。 - - Extract cache_key from tmp filename ({key}.{ext}. -> key). - """ - return Path(path.stem).stem - - @staticmethod - def _is_download_tmp_path(path: Path) -> bool: - """判断是否为下载过程临时文件({key}.{ext}.)。 - - Identify download tmp files by requiring a dedicated trailing suffix and - at least one original extension segment before it. - """ - suffixes = path.suffixes - return len(suffixes) >= 2 and suffixes[-1] == _MEDIA_URL_DOWNLOAD_TMP_SUFFIX - - async def _cleanup_url_cache_if_needed(self) -> None: - """按 TTL + 文件数上限清理 URL 媒体缓存。""" - now = time.time() - if ( - now - self._last_url_cache_cleanup_at - < _MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS - ): - return - - async with self._url_cache_cleanup_lock: - # 双重检查,避免并发情况下重复清理。 - # Double-check to avoid repeated cleanup under concurrency. - now = time.time() - if ( - now - self._last_url_cache_cleanup_at - < _MEDIA_URL_CACHE_CLEANUP_INTERVAL_SECONDS - ): - return - self._last_url_cache_cleanup_at = now - - async with self._url_cache_locks_guard: - active_keys = { - key for key, lock in self._url_cache_locks.items() if lock.locked() - } - cache_dir = _MEDIA_URL_CACHE_DIR - if not cache_dir.exists(): - await self._prune_url_cache_locks( - active_keys=active_keys, - present_keys=set(), - ) - return - - files: list[Path] = [p for p in cache_dir.iterdir() if p.is_file()] - expire_before = now - _MEDIA_URL_CACHE_TTL_SECONDS - kept_files: list[Path] = [] - present_keys: set[str] = set() - - # 先按 TTL 清理,跳过正在下载/读取的活跃键。 - # First, TTL cleanup; skip active keys still being downloaded/read. - for path in files: - if self._is_download_tmp_path(path): - tmp_key = self._extract_cache_key_from_tmp(path) - if tmp_key and tmp_key not in active_keys: - path.unlink(missing_ok=True) - continue - present_keys.add(path.stem) - try: - mtime = path.stat().st_mtime - except OSError: - continue - if mtime < expire_before and path.stem not in active_keys: - path.unlink(missing_ok=True) - else: - kept_files.append(path) - - await self._prune_url_cache_locks( - active_keys=active_keys, - present_keys=present_keys, - ) - - # 再按数量上限清理最旧文件,同样跳过活跃键。 - # Then enforce max-file limit by deleting oldest files, skipping active keys. - if len(kept_files) <= _MEDIA_URL_CACHE_MAX_FILES: - return - - kept_with_mtime: list[tuple[float, Path]] = [] - for path in kept_files: - try: - kept_with_mtime.append((path.stat().st_mtime, path)) - except OSError: - continue - kept_with_mtime.sort(key=lambda item: item[0], reverse=True) - for _, path in kept_with_mtime[_MEDIA_URL_CACHE_MAX_FILES:]: - if path.stem in active_keys: - continue - path.unlink(missing_ok=True) - - async def _prune_url_cache_locks( - self, - *, - active_keys: set[str], - present_keys: set[str], - ) -> None: - """回收不再活跃且已无缓存文件的 URL 锁,避免字典无限增长。 - - Prune stale URL locks with no active task/file to avoid unbounded growth. - """ - async with self._url_cache_locks_guard: - stale_keys = [ - key - for key, lock in self._url_cache_locks.items() - if key not in active_keys - and key not in present_keys - and not lock.locked() - ] - for key in stale_keys: - self._url_cache_locks.pop(key, None) - - async def _build_content_items( - self, media_type: str, media_content: str | list[str], prompt: str - ) -> list[dict[str, Any]]: - """构建请求内容项。 - - Args: - media_type: 媒体类型 - media_content: 媒体内容(URL/data URL),或其列表 - prompt: 提示词 - - Returns: - 包含文本和媒体的内容项列表 - """ - content_items: list[dict[str, Any]] = [{"type": "text", "text": prompt}] - - media_item_key = f"{media_type}_url" - contents = media_content if isinstance(media_content, list) else [media_content] - for mc in contents: - content_items.append({"type": media_item_key, media_item_key: {"url": mc}}) - - return content_items - - async def analyze( - self, - media_url: str, - media_type: str = "auto", - prompt_extra: str = "", - ) -> dict[str, str]: - """分析媒体文件。 - - 始终调用视觉模型进行真实分析,不会因历史缓存而跳过。 - - Args: - media_url: 媒体文件 URL 或本地路径 - media_type: 媒体类型,"auto" 表示自动检测 - prompt_extra: 补充提示词 - - Returns: - 包含描述和类型特定信息的字典 - """ - detected_type = detect_media_type(media_url, media_type) - safe_url = redact_string(media_url) - logger.info(f"[媒体分析] 开始分析 {detected_type}: {safe_url[:50]}...") - logger.debug( - "[媒体分析] media_type=%s detected=%s url_len=%s prompt_extra_len=%s", - media_type, - detected_type, - len(media_url), - len(prompt_extra), - ) - - cache_key = f"{detected_type}:{media_url[:100]}:{prompt_extra}" - if cache_key in self._cache: - logger.debug("[媒体分析] 命中缓存: key=%s", cache_key[:120]) - return self._cache[cache_key] - - try: - media_content = await self._load_media_content(media_url, detected_type) - except Exception as exc: - logger.error(f"无法读取媒体文件: {exc}") - return { - "description": _ERROR_MESSAGES["read"].get( - detected_type, _ERROR_MESSAGES["read"]["default"] - ) - } - - try: - prompt = read_text_resource(self._prompt_path) - except Exception: - async with aiofiles.open(self._prompt_path, "r", encoding="utf-8") as f: - prompt = await f.read() - - logger.debug( - "[媒体分析] prompt_len=%s path=%s", - len(prompt), - self._prompt_path, - ) - - if prompt_extra: - prompt += f"\n\n【补充指令】\n{prompt_extra}" - - content_items = await self._build_content_items( - detected_type, media_content, prompt - ) - - try: - result = await self._requester.request( - model_config=self._vision_config, - messages=[{"role": "user", "content": content_items}], - max_tokens=self._vision_config.max_tokens, - call_type=f"vision_{detected_type}", - ) - content = extract_choices_content(result) - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[媒体分析] 原始响应内容", content) - - parsed = _parse_analysis_response(content) - - result_dict: dict[str, str] = {"description": parsed["description"]} - field_name = _MEDIA_TYPE_TO_FIELD.get(detected_type) - if field_name: - result_dict[field_name] = parsed[field_name] - - self._cache[cache_key] = result_dict - logger.info(f"[媒体分析] 完成并缓存: {safe_url[:50]}... ({detected_type})") - return result_dict - - except Exception as exc: - logger.exception(f"媒体分析失败: {exc}") - return { - "description": _ERROR_MESSAGES["analyze"].get( - detected_type, _ERROR_MESSAGES["analyze"]["default"] - ) - } - - # ── 媒体键级别的 Q&A 历史管理 ── - - def _load_history(self) -> None: - """从磁盘加载历史 Q&A 缓存。""" - if not _HISTORY_FILE_PATH.exists(): - return - try: - with open(_HISTORY_FILE_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, dict): - self._file_history = data - logger.info( - "[媒体分析] 从磁盘加载历史缓存: %d 个文件", len(self._file_history) - ) - except Exception as exc: - logger.warning("[媒体分析] 加载历史缓存失败: %s", exc) - - async def _save_history(self) -> None: - """将历史缓存写入磁盘。""" - from Undefined.utils import io - - try: - await io.write_json(_HISTORY_FILE_PATH, self._file_history, use_lock=True) - except Exception as exc: - logger.error("[媒体分析] 历史缓存写入磁盘失败: %s", exc) - - def get_history(self, media_key: str) -> list[dict[str, str]]: - """获取指定媒体键的历史 Q&A 记录。 - - Args: - media_key: 媒体唯一键(可包含作用域和文件身份) - - Returns: - Q&A 列表,每项包含 ``q`` 和 ``a`` 两个键 - """ - pairs = self._file_history.get(media_key) - if not pairs: - return [] - return list(pairs[-_MAX_QA_HISTORY:]) - - async def save_history(self, media_key: str, question: str, answer: str) -> None: - """保存一条 Q&A 到指定媒体键的历史记录(上限 5 条)并持久化。 - - Args: - media_key: 媒体唯一键(可包含作用域和文件身份) - question: 提问内容 - answer: 分析回答 - """ - pairs = self._file_history.setdefault(media_key, []) - pairs.append({"q": question, "a": answer}) - if len(pairs) > _MAX_QA_HISTORY: - self._file_history[media_key] = pairs[-_MAX_QA_HISTORY:] - await self._save_history() - - async def describe_image( - self, image_url: str, prompt_extra: str = "" - ) -> dict[str, str]: - """描述图片内容。 - - Args: - image_url: 图片 URL 或本地路径 - prompt_extra: 补充提示词 - - Returns: - 包含描述和 OCR 文字的字典 - """ - result = await self.analyze(image_url, "image", prompt_extra) - if "ocr_text" not in result: - result["ocr_text"] = "" - return result - - async def _load_prompt_text(self, prompt_path: str) -> str: - try: - return read_text_resource(prompt_path) - except Exception: - async with aiofiles.open(prompt_path, "r", encoding="utf-8") as f: - return await f.read() - - def _build_tool_request_kwargs(self) -> dict[str, Any]: - request_kwargs: dict[str, Any] = {} - if ( - get_api_mode(self._vision_config) == API_MODE_CHAT_COMPLETIONS - and not self._vision_config.thinking_enabled - ): - request_kwargs["thinking"] = {"enabled": False, "budget_tokens": 0} - return request_kwargs - - async def _request_required_tool_args( - self, - *, - prompt_path: str, - image_url: str | list[str], - tool_schema: dict[str, Any], - tool_name: str, - call_type: str, - max_tokens: int, - ) -> dict[str, Any]: - if isinstance(image_url, list): - media_contents: list[str] = [] - for url in image_url: - media_contents.append(await self._load_media_content(url, "image")) - media_content: str | list[str] = media_contents - else: - media_content = await self._load_media_content(image_url, "image") - prompt = await self._load_prompt_text(prompt_path) - content_items = await self._build_content_items("image", media_content, prompt) - response = await self._requester.request( - model_config=self._vision_config, - messages=[{"role": "user", "content": content_items}], - max_tokens=max_tokens, - call_type=call_type, - tools=[tool_schema], - tool_choice=cast( - Any, {"type": "function", "function": {"name": tool_name}} - ), - **self._build_tool_request_kwargs(), - ) - return extract_required_tool_call_arguments( - response, - expected_tool_name=tool_name, - stage=call_type, - logger=logger, - error_context=f"image={redact_string(str(image_url) if isinstance(image_url, list) else image_url)[:120]}", - ) - - async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: - safe_url = redact_string( - str(image_url) if isinstance(image_url, list) else image_url - ) - try: - args = await self._request_required_tool_args( - prompt_path=_MEME_JUDGE_PROMPT_PATH, - image_url=image_url, - tool_schema=_MEME_JUDGE_TOOL, - tool_name="submit_meme_judgement", - call_type="vision_meme_judge", - max_tokens=self._vision_config.max_tokens, - ) - except Exception as exc: - logger.exception("[媒体分析] 表情包判定失败,按非表情包处理: %s", exc) - return { - "is_meme": False, - "confidence": 0.0, - "reason": "", - } - - try: - parsed = { - "is_meme": bool(args.get("is_meme", False)), - "confidence": safe_float(args.get("confidence", 0.0), default=0.0), - "reason": str(args.get("reason") or "").strip(), - } - except Exception: - parsed = {"is_meme": False, "confidence": 0.0, "reason": ""} - logger.info( - "[媒体分析] 表情包判定完成: url=%s is_meme=%s confidence=%.3f reason=%s", - safe_url[:50], - parsed.get("is_meme", False), - safe_float(parsed.get("confidence", 0.0), default=0.0), - str(parsed.get("reason", ""))[:80], - ) - return parsed - - async def describe_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: - safe_url = redact_string( - str(image_url) if isinstance(image_url, list) else image_url - ) - try: - args = await self._request_required_tool_args( - prompt_path=_MEME_DESCRIBE_PROMPT_PATH, - image_url=image_url, - tool_schema=_MEME_DESCRIBE_TOOL, - tool_name="submit_meme_description", - call_type="vision_meme_describe", - max_tokens=self._vision_config.max_tokens, - ) - except Exception as exc: - logger.exception("[媒体分析] 表情包描述失败: %s", exc) - return {"description": "", "tags": []} - - description = str(args.get("description") or "").strip() - tags = _normalize_meme_tags(args.get("tags")) - logger.info( - "[媒体分析] 表情包描述完成: url=%s desc_len=%s tags=%s", - safe_url[:50], - len(description), - tags, - ) - return {"description": description, "tags": tags} diff --git a/src/Undefined/ai/multimodal/__init__.py b/src/Undefined/ai/multimodal/__init__.py index 128ba4ac..e5c08dfa 100644 --- a/src/Undefined/ai/multimodal/__init__.py +++ b/src/Undefined/ai/multimodal/__init__.py @@ -1,7 +1,6 @@ """多模态分析子包。 -对外稳定入口:``MultimodalAnalyzer``、``detect_media_type``、``get_media_mime_type``; -旧路径 ``Undefined.ai.multimodal`` 通过包根与 ``multimodal.py`` shim 保持兼容。 +对外稳定入口:``MultimodalAnalyzer``、``detect_media_type``、``get_media_mime_type``。 """ from Undefined.ai.multimodal import constants as _constants diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py deleted file mode 100644 index 963d8e51..00000000 --- a/src/Undefined/ai/prompts.py +++ /dev/null @@ -1,846 +0,0 @@ -"""Prompt building utilities.""" - -from __future__ import annotations - -import html -import logging -import re -from collections import deque -from datetime import datetime -from typing import Any, Callable, Awaitable, Literal - -import aiofiles - -from Undefined.utils.coerce import safe_int -from Undefined.context import RequestContext -from Undefined.end_summary_storage import ( - EndSummaryStorage, - EndSummaryRecord, - MAX_END_SUMMARIES, -) -from Undefined.memory import MemoryStorage -from Undefined.skills.anthropic_skills import AnthropicSkillRegistry -from Undefined.utils.logging import log_debug_json -from Undefined.utils.resources import read_text_resource -from Undefined.utils.xml import format_message_xml - -logger = logging.getLogger(__name__) - -_CURRENT_MESSAGE_RE = re.compile( - r"[^>]*)>.*?(?P.*?).*?", - re.DOTALL | re.IGNORECASE, -) -_XML_ATTR_RE = re.compile(r'(?P[a-zA-Z_][a-zA-Z0-9_-]*)="(?P[^"]*)"') -_COGNITIVE_QUERY_SHORT_THRESHOLD = 20 -_COGNITIVE_CONTEXT_VALUE_MAX_LEN = 18 - - -class PromptBuilder: - """Construct system/user messages with memory, history, and time.""" - - def __init__( - self, - bot_qq: int, - memory_storage: MemoryStorage | None, - end_summary_storage: EndSummaryStorage, - system_prompt_path: str = "res/prompts/undefined.xml", - runtime_config_getter: Callable[[], Any] | None = None, - anthropic_skill_registry: AnthropicSkillRegistry | None = None, - cognitive_service: Any = None, - ) -> None: - """初始化 Prompt 构建器 - - 参数: - bot_qq: 机器人 QQ 号 - memory_storage: 长期记忆存储 (可选) - end_summary_storage: 短期回忆存储 - system_prompt_path: 系统提示词文件路径 - anthropic_skill_registry: Anthropic Skills 注册中心(可选) - """ - self._bot_qq = bot_qq - self._memory_storage = memory_storage - self._end_summary_storage = end_summary_storage - self._system_prompt_path = system_prompt_path - self._runtime_config_getter = runtime_config_getter - self._anthropic_skill_registry = anthropic_skill_registry - self._cognitive_service = cognitive_service - self._end_summaries: deque[EndSummaryRecord] = deque(maxlen=MAX_END_SUMMARIES) - self._summaries_loaded = False - - def set_cognitive_service(self, service: Any = None) -> None: - """更新认知记忆服务引用(支持运行时注入/替换)。""" - self._cognitive_service = service - logger.info( - "[Prompt] 认知服务引用已更新: enabled=%s", - bool(getattr(service, "enabled", False)) if service is not None else False, - ) - - @property - def end_summaries(self) -> deque[EndSummaryRecord]: - """暴露短期摘要缓存,供工具执行上下文共享。""" - return self._end_summaries - - def _select_system_prompt_path(self) -> str: - """根据运行时配置选择系统提示词路径。 - - - 关闭 nagaagent_mode_enabled: 使用默认 public prompt - - 开启 nagaagent_mode_enabled: 使用 NagaAgent prompt - - 说明:路径在每次构建 messages 时动态选择,以支持配置热更新。 - """ - - if self._runtime_config_getter is None: - return self._system_prompt_path - - runtime_config = None - try: - runtime_config = self._runtime_config_getter() - except Exception: - runtime_config = None - - enabled = bool(getattr(runtime_config, "nagaagent_mode_enabled", False)) - if enabled: - return "res/prompts/undefined_nagaagent.xml" - return "res/prompts/undefined.xml" - - def _build_model_config_info(self, runtime_config: Any) -> str: - """构建模型配置信息,用于注入到 AI 上下文中。 - - 只暴露非隐私字段(model_name 等),不暴露 api_key、api_url 等敏感信息。 - """ - parts: list[str] = ["【当前运行环境配置】"] - - chat_model = getattr(runtime_config, "chat_model", None) - if chat_model: - model_name = getattr(chat_model, "model_name", "未知") - parts.append(f"- 我使用的模型: {model_name}") - - vision_model = getattr(runtime_config, "vision_model", None) - if vision_model: - model_name = getattr(vision_model, "model_name", "") - if model_name: - parts.append(f"- 视觉模型: {model_name}") - - # Agent 模型 - agent_model = getattr(runtime_config, "agent_model", None) - if agent_model: - model_name = getattr(agent_model, "model_name", "") - if model_name: - parts.append(f"- Agent 模型: {model_name}") - - embedding_model = getattr(runtime_config, "embedding_model", None) - if embedding_model: - model_name = getattr(embedding_model, "model_name", "") - if model_name: - parts.append(f"- 嵌入模型: {model_name}") - - security_model = getattr(runtime_config, "security_model", None) - if security_model: - model_name = getattr(security_model, "model_name", "") - if model_name: - parts.append(f"- 安全模型: {model_name}") - - # Grok 搜索模型 - grok_model = getattr(runtime_config, "grok_model", None) - if grok_model: - model_name = getattr(grok_model, "model_name", "") - if model_name: - parts.append(f"- 搜索模型: {model_name}") - - cognitive = getattr(runtime_config, "cognitive", None) - if cognitive: - enabled = getattr(cognitive, "enabled", False) - parts.append(f"- 认知记忆: {'已启用' if enabled else '未启用'}") - - knowledge_enabled = bool(getattr(runtime_config, "knowledge_enabled", False)) - parts.append(f"- 知识库: {'已启用' if knowledge_enabled else '未启用'}") - - grok_search_enabled = bool( - getattr(runtime_config, "grok_search_enabled", False) - ) - parts.append(f"- 联网搜索: {'已启用' if grok_search_enabled else '未启用'}") - - memes = getattr(runtime_config, "memes", None) - if memes is not None: - memes_enabled = bool(getattr(memes, "enabled", False)) - if memes_enabled: - query_mode = str( - getattr(memes, "query_default_mode", "hybrid") or "hybrid" - ).strip() - allow_gif = bool(getattr(memes, "allow_gif", True)) - max_source_bytes = int(getattr(memes, "max_source_image_bytes", 0) or 0) - max_source_kb = max_source_bytes // 1024 if max_source_bytes > 0 else 0 - parts.append( - f"- 表情包库: 已启用(默认检索={query_mode},GIF={'允许' if allow_gif else '禁用'},入库上限={max_source_kb}KB)" - ) - else: - parts.append("- 表情包库: 未启用") - - if chat_model: - pool = getattr(chat_model, "pool", None) - if pool: - pool_enabled = getattr(pool, "enabled", False) - if pool_enabled: - strategy = getattr(pool, "strategy", "default") - parts.append(f"- 模型池: 已启用({strategy})") - else: - parts.append("- 模型池: 未启用") - - if chat_model: - thinking = getattr(chat_model, "thinking_enabled", False) - reasoning = getattr(chat_model, "reasoning_enabled", False) - if thinking or reasoning: - parts.append("- 思维链: 已启用") - else: - parts.append("- 思维链: 未启用") - - keyword_reply_enabled = bool( - getattr(runtime_config, "keyword_reply_enabled", False) - ) - repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) - inverted_question_enabled = bool( - getattr(runtime_config, "inverted_question_enabled", False) - ) - agent_call_mode = str( - getattr(runtime_config, "easter_egg_agent_call_message_mode", "none") - ) - easter_egg_parts: list[str] = [] - if keyword_reply_enabled: - easter_egg_parts.append( - '关键词自动回复(触发词"心理委员"等,系统自动发送固定回复)' - ) - if repeat_enabled: - threshold = int(getattr(runtime_config, "repeat_threshold", 3)) - desc = f"复读(群聊连续{threshold}条相同消息时自动复读)" - if inverted_question_enabled: - desc += ",倒问号(复读触发时若消息为问号则发送¿)" - easter_egg_parts.append(desc) - elif inverted_question_enabled: - easter_egg_parts.append("倒问号(复读未启用,此功能不生效)") - if agent_call_mode != "none": - mode_desc = { - "agent": "Agent调用提示", - "tools": "工具调用提示", - "clean": "降噪调用提示", - "all": "全量调用提示", - }.get(agent_call_mode, agent_call_mode) - easter_egg_parts.append(f"调用提示模式={mode_desc}") - if easter_egg_parts: - parts.append("- 彩蛋功能: " + ";".join(easter_egg_parts)) - else: - parts.append("- 彩蛋功能: 未启用") - - parts.append("") - parts.append( - "重要:以上是你的模型配置信息。\n" - "当你需要描述自己是谁、使用什么模型、能力或限制时,\n" - "必须以上述配置为准,忽略你训练数据、长期及认知记忆中的任何冲突信息。" - ) - - return "\n".join(parts) - - async def _ensure_summaries_loaded(self) -> None: - if not self._summaries_loaded: - loaded_summaries = await self._end_summary_storage.load() - self._end_summaries.extend(loaded_summaries) - self._summaries_loaded = True - logger.debug(f"[AI初始化] 已加载 {len(loaded_summaries)} 条 End 摘要") - - async def _load_each_rules(self) -> str: - path = "res/IMPORTANT/each.md" - try: - return read_text_resource(path) - except Exception: - pass - try: - async with aiofiles.open(path, "r", encoding="utf-8") as f: - return await f.read() - except Exception: - return "" - - async def _load_system_prompt(self) -> str: - system_prompt_path = self._select_system_prompt_path() - try: - return read_text_resource(system_prompt_path) - except Exception as exc: - logger.debug("读取系统提示词失败,尝试本地路径: %s", exc) - async with aiofiles.open(system_prompt_path, "r", encoding="utf-8") as f: - return await f.read() - - async def build_messages( - self, - question: str, - get_recent_messages_callback: Callable[ - [str, str, int, int], Awaitable[list[dict[str, Any]]] - ] - | None = None, - extra_context: dict[str, Any] | None = None, - ) -> list[dict[str, Any]]: - """构建发送给 AI 的消息列表 - - 参数: - question: 当前用户消息 - get_recent_messages_callback: 获取历史消息的回调函数 - extra_context: 额外的上下文信息 (如 group_id, user_id) - - 返回: - 构建好的消息列表 (role/content 结构) - """ - system_prompt = await self._load_system_prompt() - logger.debug( - "[Prompt] system_prompt_len=%s path=%s", - len(system_prompt), - self._select_system_prompt_path(), - ) - - if self._bot_qq != 0: - bot_qq_info = ( - f"\n" - f"\n\n" - ) - system_prompt = bot_qq_info + system_prompt - - messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] - - # 注入当前运行环境配置信息,让 AI 知道自己的模型名称等非隐私信息 - if self._runtime_config_getter is not None: - try: - runtime_config = self._runtime_config_getter() - config_info = self._build_model_config_info(runtime_config) - if config_info: - messages.append( - { - "role": "system", - "content": config_info, - } - ) - logger.debug( - "[Prompt] 已注入运行环境配置信息,长度=%s", - len(config_info), - ) - except Exception as exc: - logger.debug("读取运行环境配置失败: %s", exc) - - # 注入群聊关键词自动回复机制说明,避免模型误判历史中的系统彩蛋消息。 - is_group_context = False - ctx = RequestContext.current() - if ctx and ctx.group_id is not None: - is_group_context = True - elif extra_context and extra_context.get("group_id") is not None: - is_group_context = True - - keyword_reply_enabled = False - repeat_enabled = False - repeat_threshold = 3 - inverted_question_enabled = False - if self._runtime_config_getter is not None: - try: - runtime_config = self._runtime_config_getter() - keyword_reply_enabled = bool( - getattr(runtime_config, "keyword_reply_enabled", False) - ) - repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) - repeat_threshold = int(getattr(runtime_config, "repeat_threshold", 3)) - inverted_question_enabled = bool( - getattr(runtime_config, "inverted_question_enabled", False) - ) - except Exception as exc: - logger.debug("读取彩蛋功能配置失败: %s", exc) - - if is_group_context and keyword_reply_enabled: - messages.append( - { - "role": "system", - "content": ( - "【系统行为说明 — 关键词自动回复】\n" - '当前群聊已开启关键词自动回复彩蛋(例如触发词"心理委员")。' - "该功能由 handlers.py 中的独立代码路径处理," - "在消息到达你之前就已完成发送。\n\n" - '发送后,历史中会出现以"[系统关键词自动回复] "开头的消息。' - "这些消息完全由系统代码生成(固定文案如'受着''那咋了'等)," - "不经过你的工具调用,与你的决策无关。\n\n" - "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" - "除非用户主动询问,否则不要主动解释此机制。" - ), - } - ) - - if is_group_context and repeat_enabled: - repeat_desc = ( - "【系统行为说明】\n" - f"当前群聊已开启复读彩蛋:当群聊中连续出现{repeat_threshold}条内容相同且来自不同人的消息时," - "系统会自动复读一条相同的消息,并在历史中写入" - '以"[系统复读] "开头的消息。' - ) - if inverted_question_enabled: - repeat_desc += ( - "\n此外,若复读触发时消息内容仅由问号组成(如?或???)," - "系统会发送对应数量的倒问号(¿)代替。" - ) - repeat_desc += ( - "\n\n这类消息属于系统预设机制,不代表你在该轮主动决策。" - "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" - "除非用户主动询问,否则不要主动解释此机制。" - ) - messages.append({"role": "system", "content": repeat_desc}) - - # 注入 Anthropic Skills 元数据(Level 1: 始终加载 name + description) - if ( - self._anthropic_skill_registry - and self._anthropic_skill_registry.has_skills() - ): - skills_xml = self._anthropic_skill_registry.build_metadata_xml() - if skills_xml: - messages.append( - { - "role": "system", - "content": ( - "【可用的 Anthropic Skills】\n" - f"{skills_xml}\n\n" - "注意:以上是可用的 Anthropic Agent Skills 列表。" - "当用户的请求与某个 skill 相关时," - "你可以调用对应的 skill tool(tool_name 字段)" - "来获取该领域的详细指令和知识。" - ), - } - ) - logger.debug( - "[Prompt] 已注入 %d 个 Anthropic Skills 元数据", - len(self._anthropic_skill_registry.get_all_skills()), - ) - - each_rules = await self._load_each_rules() - if each_rules: - messages.append( - { - "role": "system", - "content": f"【强制规则 - 必须在进行任何操作前仔细阅读并严格遵守】\n{each_rules}", - } - ) - - deferred_messages: list[dict[str, Any]] = [] - - if self._memory_storage: - memories = self._memory_storage.get_all() - if memories: - memory_lines = [f"- {mem.fact}" for mem in memories] - memory_text = "\n".join(memory_lines) - deferred_messages.append( - { - "role": "system", - "content": ( - "【memory.* 手动长期记忆(可编辑)】\n" - f"{memory_text}\n\n" - "注意:以上是你通过 memory.add 等工具主动维护的长期事实清单。" - "它与认知记忆(cognitive.* / end.observations 产生的事件与侧写)是两套机制。" - "请根据任务选择合适的记忆工具,避免混用。" - ), - } - ) - logger.info(f"[AI会话] 已注入 {len(memories)} 条长期记忆") - if logger.isEnabledFor(logging.DEBUG): - log_debug_json( - logger, "[AI会话] 注入长期记忆", [mem.fact for mem in memories] - ) - - await self._ensure_summaries_loaded() - if self._cognitive_service and getattr( - self._cognitive_service, "enabled", False - ): - recent_action_inject_k = 30 - if self._runtime_config_getter is not None: - try: - runtime_config = self._runtime_config_getter() - cog_cfg = getattr(runtime_config, "cognitive", None) - if cog_cfg is not None and hasattr( - cog_cfg, "recent_end_summaries_inject_k" - ): - recent_action_inject_k = int( - getattr(cog_cfg, "recent_end_summaries_inject_k") - ) - except Exception: - pass - if recent_action_inject_k < 0: - recent_action_inject_k = 0 - - ctx = RequestContext.current() - resolved_group_id = ( - str(ctx.group_id) - if ctx and ctx.group_id is not None - else (str(extra_context.get("group_id", "")) if extra_context else None) - ) - resolved_user_id = ( - str(ctx.user_id) - if ctx and ctx.user_id is not None - else (str(extra_context.get("user_id", "")) if extra_context else None) - ) - resolved_sender_id = ( - str(ctx.sender_id) - if ctx and ctx.sender_id is not None - else ( - str(extra_context.get("sender_id", "")) if extra_context else None - ) - ) - resolved_request_type = ( - str(ctx.request_type).strip() - if ctx and ctx.request_type - else ( - str(extra_context.get("request_type", "")).strip() - if extra_context - else "" - ) - ) - if not resolved_request_type: - if resolved_group_id and str(resolved_group_id).strip(): - resolved_request_type = "group" - elif resolved_sender_id or resolved_user_id: - resolved_request_type = "private" - cognitive_query, query_enhanced = self._build_cognitive_query( - question, extra_context - ) - logger.info( - "[AI会话] 开始自动检索认知记忆: raw_query_len=%s effective_query_len=%s query_enhanced=%s type=%s group=%s user=%s sender=%s", - len(question), - len(cognitive_query), - query_enhanced, - resolved_request_type or "", - resolved_group_id or "", - resolved_user_id or "", - resolved_sender_id or "", - ) - cognitive_context = await self._cognitive_service.build_context( - query=cognitive_query, - group_id=resolved_group_id, - user_id=resolved_user_id, - sender_id=resolved_sender_id, - sender_name=str(extra_context.get("sender_name", "")) - if extra_context - else None, - group_name=str(extra_context.get("group_name", "")) - if extra_context - else None, - request_type=resolved_request_type or None, - ) - if cognitive_context: - deferred_messages.append( - {"role": "system", "content": cognitive_context} - ) - logger.info( - "[AI会话] 已注入认知记忆上下文: context_len=%s", - len(cognitive_context), - ) - else: - logger.info("[AI会话] 自动检索完成:未命中可注入认知记忆") - - # 额外注入最近 end 行动记录,作为短期“工作记忆”,弥补史官异步入库延迟与向量检索的漏召回。 - if recent_action_inject_k > 0 and self._end_summaries: - items = list(self._end_summaries)[-recent_action_inject_k:] - recent_summary_lines: list[str] = [] - for item in items: - location_text = "" - location = item.get("location") - if isinstance(location, dict): - location_type = location.get("type") - location_name = location.get("name") - if ( - location_type in {"private", "group"} - and isinstance(location_name, str) - and location_name.strip() - ): - location_text = ( - f" ({location_type}: {location_name.strip()})" - ) - recent_summary_lines.append( - f"- [{item.get('timestamp', '')}] {item.get('summary', '')}{location_text}" - ) - recent_summary_text = "\n".join(recent_summary_lines).strip() - if recent_summary_text: - deferred_messages.append( - { - "role": "system", - "content": ( - f"【短期行动记录(最近 {len(items)} 条,带时间)】\n" - f"{recent_summary_text}\n\n" - "注意:以上是你最近在 end 时记录的行动摘要,用于保持短期连续性。" - "它可能与认知记忆事件存在重复;优先以更具体、更近期的描述为准。" - ), - } - ) - elif self._end_summaries: - summary_lines: list[str] = [] - for item in self._end_summaries: - location_text = "" - location = item.get("location") - if isinstance(location, dict): - location_type = location.get("type") - location_name = location.get("name") - if ( - location_type in {"private", "group"} - and isinstance(location_name, str) - and location_name.strip() - ): - location_text = f" ({location_type}: {location_name.strip()})" - summary_lines.append( - f"- [{item['timestamp']}] {item['summary']}{location_text}" - ) - summary_text = "\n".join(summary_lines) - deferred_messages.append( - { - "role": "system", - "content": ( - "【这是你之前end时记录的事情】\n" - f"{summary_text}\n\n" - "注意:以上是你之前在end时记录的事情,用于帮助你记住之前做了什么或以后可能要做什么。" - ), - } - ) - logger.info( - f"[AI会话] 已注入 {len(self._end_summaries)} 条短期回忆 (end 摘要)" - ) - if logger.isEnabledFor(logging.DEBUG): - log_debug_json( - logger, "[AI会话] 注入短期回忆", list(self._end_summaries) - ) - - if get_recent_messages_callback: - await self._inject_recent_messages( - deferred_messages, get_recent_messages_callback, extra_context, question - ) - - messages.extend(deferred_messages) - - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - messages.append( - { - "role": "system", - "content": f"【当前时间】\n{current_time}\n\n注意:以上是当前的系统时间,供你参考。", - } - ) - - messages.append({"role": "user", "content": f"【当前消息】\n{question}"}) - logger.debug( - "[Prompt] messages_ready=%s question_len=%s", - len(messages), - len(question), - ) - return messages - - def _resolve_chat_scope( - self, extra_context: dict[str, Any] | None - ) -> tuple[Literal["group", "private"], int] | None: - ctx = RequestContext.current() - - if ctx and ctx.request_type == "group" and ctx.group_id is not None: - group_id = safe_int(ctx.group_id) - if group_id is not None: - return ("group", group_id) - return None - if ctx and ctx.request_type == "private" and ctx.user_id is not None: - user_id = safe_int(ctx.user_id) - if user_id is not None: - return ("private", user_id) - return None - - if extra_context and extra_context.get("group_id") is not None: - group_id = safe_int(extra_context.get("group_id")) - if group_id is not None: - return ("group", group_id) - return None - if extra_context and extra_context.get("user_id") is not None: - user_id = safe_int(extra_context.get("user_id")) - if user_id is not None: - return ("private", user_id) - return None - - return None - - async def _inject_recent_messages( - self, - messages: list[dict[str, Any]], - get_recent_messages_callback: Callable[ - [str, str, int, int], Awaitable[list[dict[str, Any]]] - ], - extra_context: dict[str, Any] | None, - question: str, - ) -> None: - try: - ctx = RequestContext.current() - if ctx: - group_id_from_ctx = ctx.group_id - user_id_from_ctx = ctx.user_id - elif extra_context: - group_id_from_ctx = extra_context.get("group_id") - user_id_from_ctx = extra_context.get("user_id") - else: - group_id_from_ctx = None - user_id_from_ctx = None - - if group_id_from_ctx is not None: - chat_id = str(group_id_from_ctx) - msg_type = "group" - elif user_id_from_ctx is not None: - chat_id = str(user_id_from_ctx) - msg_type = "private" - else: - chat_id = "" - msg_type = "group" - - recent_limit = 20 - if self._runtime_config_getter is not None: - try: - runtime_config = self._runtime_config_getter() - if hasattr(runtime_config, "get_context_recent_messages_limit"): - recent_limit = int( - runtime_config.get_context_recent_messages_limit() - ) - except Exception as exc: - logger.debug("读取上下文历史条数配置失败: %s", exc) - - if recent_limit < 0: - recent_limit = 0 - if recent_limit == 0: - logger.debug("上下文历史消息注入已关闭 (limit=0)") - return - - recent_msgs = await get_recent_messages_callback( - chat_id, - msg_type, - 0, - recent_limit, - ) - recent_msgs = self._drop_current_message_if_duplicated( - recent_msgs, question - ) - context_lines: list[str] = [format_message_xml(msg) for msg in recent_msgs] - - formatted_context = "\n---\n".join(context_lines) - - if formatted_context: - messages.append( - { - "role": "user", - "content": ( - "【历史消息存档】\n" - f"{formatted_context}\n\n" - "注意:以上是之前的聊天记录,用于提供背景信息。每个消息之间使用 --- 分隔。接下来的用户消息才是当前正在发生的对话。" - ), - } - ) - logger.debug(f"自动预获取了 {len(context_lines)} 条历史消息作为上下文") - if logger.isEnabledFor(logging.DEBUG): - log_debug_json( - logger, - "[Prompt] 历史消息上下文", - context_lines, - ) - except Exception as exc: - logger.warning(f"自动获取历史消息失败: {exc}") - - @staticmethod - def _normalize_cognitive_context_value(value: Any) -> str: - text = " ".join(str(value or "").split()).strip() - if len(text) <= _COGNITIVE_CONTEXT_VALUE_MAX_LEN: - return text - return text[: _COGNITIVE_CONTEXT_VALUE_MAX_LEN - 3].rstrip() + "..." - - def _build_cognitive_query( - self, question: str, extra_context: dict[str, Any] | None = None - ) -> tuple[str, bool]: - question_text = str(question or "").strip() - signature = self._extract_current_message_signature(question_text) - current_content = str(signature.get("content", "")).strip() - base_query = current_content or question_text - if not base_query: - return "", False - - # 优先使用当前帧原始消息内容;仅在短消息时追加少量会话语境,降低“这/那个”类指代丢失。 - if ( - not current_content - or len(current_content) > _COGNITIVE_QUERY_SHORT_THRESHOLD - ): - return base_query, False - - context_parts: list[str] = [] - if extra_context: - if bool(extra_context.get("is_private_chat", False)): - context_parts.append("会话:私聊") - elif str(extra_context.get("group_id", "")).strip(): - context_parts.append("会话:群聊") - if bool(extra_context.get("is_at_bot", False)): - context_parts.append("触发:@机器人") - - sender_name = self._normalize_cognitive_context_value( - extra_context.get("sender_name", "") - ) - if sender_name: - context_parts.append(f"发送者:{sender_name}") - - group_name = self._normalize_cognitive_context_value( - extra_context.get("group_name", "") - ) - if group_name: - context_parts.append(f"群:{group_name}") - - if not context_parts: - return base_query, False - return f"{base_query}\n语境: {'; '.join(context_parts)}", True - - def _extract_current_message_signature(self, question: str) -> dict[str, str]: - matched = _CURRENT_MESSAGE_RE.search(str(question or "")) - if not matched: - return {} - - attrs_text = str(matched.group("attrs") or "") - attrs: dict[str, str] = {} - for attr_match in _XML_ATTR_RE.finditer(attrs_text): - key = str(attr_match.group("key") or "").strip() - if not key: - continue - attrs[key] = html.unescape(str(attr_match.group("value") or "")).strip() - - content = html.unescape(str(matched.group("content") or "")).strip() - return { - "sender_id": attrs.get("sender_id", ""), - "timestamp": attrs.get("time", ""), - "content": content, - } - - def _drop_current_message_if_duplicated( - self, recent_msgs: list[dict[str, Any]], question: str - ) -> list[dict[str, Any]]: - if not recent_msgs: - return recent_msgs - - signature = self._extract_current_message_signature(question) - if not signature: - return recent_msgs - - last_msg = recent_msgs[-1] - last_sender_id = str(last_msg.get("user_id", "")).strip() - last_timestamp = str(last_msg.get("timestamp", "")).strip() - last_content = str(last_msg.get("message", "")).strip() - - sig_sender_id = str(signature.get("sender_id", "")).strip() - sig_timestamp = str(signature.get("timestamp", "")).strip() - sig_content = str(signature.get("content", "")).strip() - if not sig_sender_id or not sig_content: - return recent_msgs - - if last_sender_id != sig_sender_id: - return recent_msgs - if last_content != sig_content: - return recent_msgs - - if sig_timestamp and last_timestamp and sig_timestamp != last_timestamp: - # history 写入时间与事件时间可能存在秒级偏差;若分钟都不同则判定不是同一帧。 - if sig_timestamp[:16] != last_timestamp[:16]: - return recent_msgs - - logger.info( - "[Prompt] 历史注入剔除当前帧: sender=%s sig_time=%s history_time=%s content_preview=%s", - sig_sender_id, - sig_timestamp, - last_timestamp, - sig_content[:60], - ) - return recent_msgs[:-1] diff --git a/src/Undefined/ai/prompts/__init__.py b/src/Undefined/ai/prompts/__init__.py index f642ceed..5ca34338 100644 --- a/src/Undefined/ai/prompts/__init__.py +++ b/src/Undefined/ai/prompts/__init__.py @@ -1,6 +1,6 @@ """Prompt 构建子包。 -对外稳定入口:``PromptBuilder``;旧路径 ``Undefined.ai.prompts`` 通过 shim 保持兼容。 +对外稳定入口:``PromptBuilder``;导入路径 ``Undefined.ai.prompts`` 指向本子包。 """ # 子包唯一公开类:PromptBuilder diff --git a/src/Undefined/ai/prompts/builder.py b/src/Undefined/ai/prompts/builder.py index 3906e3a9..57984071 100644 --- a/src/Undefined/ai/prompts/builder.py +++ b/src/Undefined/ai/prompts/builder.py @@ -212,7 +212,7 @@ async def build_messages( "content": ( "【系统行为说明 — 关键词自动回复】\n" '当前群聊已开启关键词自动回复彩蛋(例如触发词"心理委员")。' - "该功能由 handlers.py 中的独立代码路径处理," + "该功能由 handlers/message_flow 中的独立代码路径处理," "在消息到达你之前就已完成发送。\n\n" '发送后,历史中会出现以"[系统关键词自动回复] "开头的消息。' "这些消息完全由系统代码生成(固定文案如'受着''那咋了'等)," diff --git a/src/Undefined/api/routes/naga.py b/src/Undefined/api/routes/naga.py deleted file mode 100644 index b76375de..00000000 --- a/src/Undefined/api/routes/naga.py +++ /dev/null @@ -1,897 +0,0 @@ -"""Naga integration route handlers. - -Extracted from ``RuntimeAPI`` methods into free functions so they can be -registered declaratively in the route table. -""" - -from __future__ import annotations - -import logging -import os -import uuid as _uuid -from copy import deepcopy -from pathlib import Path -from typing import Any - -from aiohttp import web -from aiohttp.web_response import Response - -from Undefined.api._context import RuntimeAPIContext -from Undefined.api._helpers import ( - _json_error, - _naga_message_digest, - _parse_response_payload, - _short_text_preview, -) -from Undefined.api._naga_state import NagaState -from Undefined.render import render_html_to_image, render_markdown_to_html - -logger = logging.getLogger(__name__) - - -# ------------------------------------------------------------------ -# Auth helper -# ------------------------------------------------------------------ - - -def verify_naga_api_key(ctx: RuntimeAPIContext, request: web.Request) -> str | None: - """校验 Naga 共享密钥,返回错误信息或 ``None`` 表示通过。""" - import secrets as _secrets - - cfg = ctx.config_getter() - expected = cfg.naga.api_key - if not expected: - return "naga api_key not configured" - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return "missing or invalid Authorization header" - provided = auth_header[7:] - if not _secrets.compare_digest(provided, expected): - return "invalid api_key" - return None - - -# ------------------------------------------------------------------ -# POST /api/v1/naga/bind/callback -# ------------------------------------------------------------------ - - -async def naga_bind_callback_handler( - ctx: RuntimeAPIContext, request: web.Request -) -> Response: - """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = verify_naga_api_key(ctx, request) - if auth_err is not None: - logger.warning( - "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - status = str(body.get("status", "") or "").strip().lower() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - reason = str(body.get("reason", "") or "").strip() - if not bind_uuid or not naga_id: - return _json_error("bind_uuid and naga_id are required", status=400) - if status not in {"approved", "rejected"}: - return _json_error("status must be 'approved' or 'rejected'", status=400) - logger.info( - "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - status, - _short_text_preview(reason, limit=60), - delivery_signature[:12] + "..." if delivery_signature else "", - ) - - naga_store = ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - sender = ctx.sender - if status == "approved": - if not delivery_signature: - return _json_error( - "delivery_signature is required when approved", status=400 - ) - binding, created, err = await naga_store.activate_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - delivery_signature=delivery_signature, - ) - if err: - logger.warning( - "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - created, - binding.qq_id if binding is not None else "", - ) - if created and binding is not None and sender is not None: - try: - await sender.send_private_message( - binding.qq_id, - f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "approved", - "idempotent": not created, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) - - # --- rejected --- - pending, removed, err = await naga_store.reject_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - reason=reason, - ) - if err: - logger.warning( - "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - removed, - pending.qq_id if pending is not None else "", - ) - if removed and pending is not None and sender is not None: - try: - detail = f"\n原因: {reason}" if reason else "" - await sender.send_private_message( - pending.qq_id, - f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "rejected", - "idempotent": not removed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) - - -# ------------------------------------------------------------------ -# POST /api/v1/naga/messages/send -# ------------------------------------------------------------------ - - -async def naga_messages_send_handler( - ctx: RuntimeAPIContext, - naga_state: NagaState, - request: web.Request, -) -> Response: - """POST /api/v1/naga/messages/send — 验签后发送消息。""" - from Undefined.api.naga_store import mask_token - - trace_id = _uuid.uuid4().hex[:8] - auth_err = verify_naga_api_key(ctx, request) - if auth_err is not None: - logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - request_uuid = str(body.get("uuid", "") or "").strip() - target = body.get("target") - message = body.get("message") - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - if not isinstance(target, dict): - return _json_error("target object is required", status=400) - if not isinstance(message, dict): - return _json_error("message object is required", status=400) - - raw_target_qq = target.get("qq_id") - raw_target_group = target.get("group_id") - if raw_target_qq is None or raw_target_group is None: - return _json_error("target.qq_id and target.group_id are required", status=400) - try: - target_qq = int(raw_target_qq) - target_group = int(raw_target_group) - except Exception: - return _json_error( - "target.qq_id and target.group_id must be integers", status=400 - ) - mode = str(target.get("mode", "") or "").strip().lower() - if mode not in {"private", "group", "both"}: - return _json_error( - "target.mode must be 'private', 'group', or 'both'", status=400 - ) - - fmt = str(message.get("format", "text") or "text").strip().lower() - content = str(message.get("content", "") or "").strip() - if fmt not in {"text", "markdown", "html"}: - return _json_error( - "message.format must be 'text', 'markdown', or 'html'", status=400 - ) - if not content: - return _json_error("message.content is required", status=400) - - message_key = _naga_message_digest( - bind_uuid=bind_uuid, - naga_id=naga_id, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - ) - logger.info( - "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - request_uuid, - mode, - fmt, - target_qq, - target_group, - message_key, - len(content), - _short_text_preview(content), - mask_token(delivery_signature), - ) - if mode == "both": - logger.warning( - "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - inflight_count = await naga_state.track_send_start(message_key) - if inflight_count > 1: - logger.warning( - "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - inflight_count, - ) - try: - if request_uuid: - dedupe_action, dedupe_value = await naga_state.register_request_uuid( - request_uuid, message_key - ) - if dedupe_action == "conflict": - logger.warning( - "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return _json_error("uuid reused with different payload", status=409) - if dedupe_action == "cached": - cached_status, cached_payload = dedupe_value - logger.warning( - "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - if dedupe_action == "await": - wait_future = dedupe_value - logger.warning( - "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - cached_status, cached_payload = await wait_future - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - - response = await naga_messages_send_impl( - ctx, - naga_id=naga_id, - bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - trace_id=trace_id, - message_key=message_key, - ) - if request_uuid: - await naga_state.finish_request_uuid( - request_uuid, - message_key, - status=response.status, - payload=_parse_response_payload(response), - ) - return response - except Exception as exc: - if request_uuid: - await naga_state.fail_request_uuid(request_uuid, message_key, exc) - raise - finally: - remaining = await naga_state.track_send_done(message_key) - logger.info( - "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - remaining, - ) - - -# ------------------------------------------------------------------ -# Core send implementation (no NagaState dependency) -# ------------------------------------------------------------------ - - -async def naga_messages_send_impl( - ctx: RuntimeAPIContext, - *, - naga_id: str, - bind_uuid: str, - delivery_signature: str, - target_qq: int, - target_group: int, - mode: str, - message_format: str, - content: str, - trace_id: str, - message_key: str, -) -> Response: - from Undefined.api.naga_store import mask_token - - naga_store = ctx.naga_store - if naga_store is None: - logger.warning( - "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("Naga integration not available", status=503) - - binding, err_msg = await naga_store.acquire_delivery( - naga_id=naga_id, - bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - ) - if binding is None: - logger.warning( - "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", - trace_id, - naga_id, - bind_uuid, - err_msg.message if err_msg is not None else "unknown_error", - mask_token(delivery_signature), - ) - return _json_error( - err_msg.message if err_msg is not None else "delivery not available", - status=err_msg.http_status if err_msg is not None else 403, - ) - - logger.info( - "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - binding.qq_id, - binding.group_id, - ) - try: - if target_qq != binding.qq_id or target_group != binding.group_id: - logger.warning( - "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", - trace_id, - naga_id, - bind_uuid, - target_qq, - target_group, - binding.qq_id, - binding.group_id, - ) - return _json_error("target does not match bound qq/group", status=403) - - cfg = ctx.config_getter() - if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: - logger.warning( - "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", - trace_id, - naga_id, - bind_uuid, - binding.group_id, - ) - return _json_error("bound group is not in naga.allowed_groups", status=403) - - sender = ctx.sender - if sender is None: - logger.warning( - "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("sender not available", status=503) - - moderation: dict[str, Any] - naga_cfg = getattr(cfg, "naga", None) - moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) - security = getattr(ctx.command_dispatcher, "security", None) - if not moderation_enabled: - moderation = { - "status": "skipped_disabled", - "blocked": False, - "categories": [], - "message": "Naga moderation disabled by config; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - ) - elif security is None or not hasattr(security, "moderate_naga_message"): - moderation = { - "status": "error_allowed", - "blocked": False, - "categories": [], - "message": "Naga moderation service unavailable; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - else: - logger.info( - "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - len(content), - ) - result = await security.moderate_naga_message( - message_format=message_format, - content=content, - ) - moderation = { - "status": result.status, - "blocked": result.blocked, - "categories": result.categories, - "message": result.message, - "model_name": result.model_name, - } - logger.info( - "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - result.blocked, - result.status, - result.model_name, - ",".join(result.categories) or "-", - ) - if moderation["blocked"]: - logger.warning( - "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - moderation["message"], - ) - return web.json_response( - { - "ok": False, - "error": "message blocked by moderation", - "moderation": moderation, - }, - status=403, - ) - - send_content: str | None = content if message_format == "text" else None - image_path: str | None = None - tmp_path: str | None = None - rendered = False - render_fallback = False - if message_format in {"markdown", "html"}: - import tempfile - - try: - html_str = content - if message_format == "markdown": - html_str = await render_markdown_to_html(content) - fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") - os.close(fd) - await render_html_to_image(html_str, tmp_path) - image_path = tmp_path - rendered = True - logger.info( - "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - Path(tmp_path).name if tmp_path is not None else "", - ) - except Exception as exc: - logger.warning( - "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - exc, - ) - send_content = content - render_fallback = True - - sent_private = False - sent_group = False - group_policy_blocked = False - - async def _ensure_delivery_active() -> tuple[Any, Response | None]: - current_binding, live_err = await naga_store.ensure_delivery_active( - naga_id=naga_id, - bind_uuid=bind_uuid, - ) - if current_binding is None: - logger.warning( - "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - live_err.message - if live_err is not None - else "delivery no longer active", - ) - return None, web.json_response( - { - "ok": False, - "error": ( - live_err.message - if live_err is not None - else "delivery no longer active" - ), - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=live_err.http_status if live_err is not None else 409, - ) - return current_binding, None - - try: - cq_image: str | None = None - if image_path is not None: - file_uri = Path(image_path).resolve().as_uri() - cq_image = f"[CQ:image,file={file_uri}]" - - if mode in {"private", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - logger.info( - "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - try: - if send_content is not None: - await sender.send_private_message( - current_binding.qq_id, send_content - ) - elif cq_image is not None: - await sender.send_private_message( - current_binding.qq_id, cq_image - ) - sent_private = True - logger.info( - "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.qq_id, - message_key, - exc, - ) - - if mode in {"group", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - current_cfg = ctx.config_getter() - if current_binding.group_id not in current_cfg.naga.allowed_groups: - group_policy_blocked = True - logger.warning( - "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - else: - logger.info( - "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - try: - if send_content is not None: - await sender.send_group_message( - current_binding.group_id, send_content - ) - elif cq_image is not None: - await sender.send_group_message( - current_binding.group_id, cq_image - ) - sent_group = True - logger.info( - "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.group_id, - message_key, - exc, - ) - finally: - if tmp_path is not None: - try: - os.unlink(tmp_path) - except OSError: - pass - - if mode == "private" and not sent_private: - return web.json_response( - { - "ok": False, - "error": "private delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "group" and not sent_group: - return web.json_response( - { - "ok": False, - "error": "group delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "both" and not (sent_private or sent_group): - if group_policy_blocked: - return web.json_response( - { - "ok": False, - "error": "bound group is not in naga.allowed_groups", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=403, - ) - return web.json_response( - { - "ok": False, - "error": "all deliveries failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - - await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) - partial_success = mode == "both" and (sent_private != sent_group) - logger.info( - "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - sent_private, - sent_group, - partial_success, - rendered, - render_fallback, - ) - return web.json_response( - { - "ok": True, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - "sent_private": sent_private, - "sent_group": sent_group, - "partial_success": partial_success, - "delivery_status": ( - "partial_success" if partial_success else "full_success" - ), - "rendered": rendered, - "render_fallback": render_fallback, - "moderation": moderation, - } - ) - finally: - await naga_store.release_delivery(bind_uuid=bind_uuid) - - -# ------------------------------------------------------------------ -# POST /api/v1/naga/unbind -# ------------------------------------------------------------------ - - -async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: - """POST /api/v1/naga/unbind — 远端主动解绑。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = verify_naga_api_key(ctx, request) - if auth_err is not None: - logger.warning( - "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - logger.info( - "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - delivery_signature[:12] + "...", - ) - - naga_store = ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - binding, changed, err = await naga_store.revoke_binding( - naga_id, - expected_bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - ) - if binding is None: - logger.warning( - "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message if err is not None else "binding not found", - ) - return _json_error( - err.message if err is not None else "binding not found", - status=err.http_status if err is not None else 404, - ) - logger.info( - "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - changed, - binding.qq_id, - binding.group_id, - ) - return web.json_response( - { - "ok": True, - "idempotent": not changed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py deleted file mode 100644 index daa87245..00000000 --- a/src/Undefined/attachments.py +++ /dev/null @@ -1,1680 +0,0 @@ -"""Attachment registry and rich-media helpers.""" - -from __future__ import annotations - -import asyncio -import base64 -import binascii -from dataclasses import asdict, dataclass, replace -from datetime import datetime -import hashlib -import logging -import mimetypes -from pathlib import Path -import re -import time -from typing import Any, Awaitable, Callable, Mapping, Sequence -from urllib.parse import unquote, urlsplit - -import httpx - -from Undefined.utils import io -from Undefined.utils.paths import ( - ATTACHMENT_CACHE_DIR, - ATTACHMENT_REGISTRY_FILE, - WEBUI_FILE_CACHE_DIR, - ensure_dir, -) -from Undefined.utils.xml import escape_xml_attr - -logger = logging.getLogger(__name__) - -_PIC_TAG_PATTERN = re.compile( - r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", - re.IGNORECASE, -) -_ATTACHMENT_TAG_PATTERN = re.compile( - r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", - re.IGNORECASE, -) -_UNIFIED_TAG_PATTERN = re.compile( - r"<(?Ppic|attachment)\s+uid=(?P[\"'])(?P[^\"']+)(?P=quote)\s*/?>", - re.IGNORECASE, -) -_MEDIA_LABELS = { - "image": "图片", - "file": "文件", - "audio": "音频", - "video": "视频", - "record": "语音", - "pic": "图片", -} -_WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") -_DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 -_IMAGE_SUFFIX_TO_MIME = { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".webp": "image/webp", - ".bmp": "image/bmp", - ".svg": "image/svg+xml", -} -_MAGIC_IMAGE_SUFFIXES: tuple[tuple[bytes, str], ...] = ( - (b"\x89PNG\r\n\x1a\n", ".png"), - (b"\xff\xd8\xff", ".jpg"), - (b"GIF87a", ".gif"), - (b"GIF89a", ".gif"), - (b"BM", ".bmp"), -) -_FORWARD_ATTACHMENT_MAX_DEPTH = 3 -_ATTACHMENT_CACHE_MAX_AGE_SECONDS = 7 * 24 * 60 * 60 -_ATTACHMENT_REGISTRY_MAX_RECORDS = 2000 -_ATTACHMENT_CACHE_MAX_BYTES = 0 -_ATTACHMENT_URL_REFERENCE_MAX_RECORDS = 2000 -_ATTACHMENT_URL_MAX_LENGTH = 8192 -_DEFAULT_REMOTE_DOWNLOAD_MAX_BYTES = 25 * 1024 * 1024 - - -@dataclass(frozen=True) -class AttachmentRecord: - uid: str - scope_key: str - kind: str - media_type: str - display_name: str - source_kind: str - source_ref: str - local_path: str | None - mime_type: str - sha256: str - created_at: str - segment_data: dict[str, str] - semantic_kind: str = "" - description: str = "" - - def prompt_ref(self) -> dict[str, str]: - local_available = False - if self.local_path is not None: - try: - local_available = Path(self.local_path).is_file() - except OSError: - local_available = False - ref: dict[str, str] = { - "uid": self.uid, - "kind": self.kind, - "media_type": self.media_type, - "display_name": self.display_name, - } - if self.source_kind.strip(): - ref["source_kind"] = self.source_kind.strip() - if not local_available and self.source_ref.strip(): - ref["source_ref"] = self.source_ref.strip() - if self.semantic_kind.strip(): - ref["semantic_kind"] = self.semantic_kind.strip() - if self.description.strip(): - ref["description"] = self.description.strip() - return ref - - -@dataclass(frozen=True) -class RegisteredMessageAttachments: - attachments: list[dict[str, str]] - normalized_text: str - - -@dataclass(frozen=True) -class RenderedRichMessage: - delivery_text: str - history_text: str - attachments: list[dict[str, str]] - pending_file_sends: tuple[AttachmentRecord, ...] = () - - -class AttachmentRenderError(RuntimeError): - """Raised when an attachment tag cannot be rendered.""" - - -class _RemoteAttachmentTooLarge(Exception): - def __init__(self, mime_type: str = "") -> None: - self.mime_type = mime_type - - -def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") - - -def _escape_cq_component(value: str) -> str: - return ( - value.replace("&", "&") - .replace("[", "[") - .replace("]", "]") - .replace(",", ",") - ) - - -def _coerce_positive_int(value: Any) -> int | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return value if value > 0 else None - if isinstance(value, str): - text = value.strip() - if not text: - return None - try: - parsed = int(text) - except ValueError: - return None - return parsed if parsed > 0 else None - return None - - -def build_attachment_scope( - *, - group_id: Any = None, - user_id: Any = None, - request_type: str | None = None, - webui_session: bool = False, -) -> str | None: - """Build a scope key for attachment visibility.""" - if webui_session: - return "webui" - - group = _coerce_positive_int(group_id) - if group is not None: - return f"group:{group}" - - user = _coerce_positive_int(user_id) - request_type_text = str(request_type or "").strip().lower() - if request_type_text == "private" and user is not None: - return f"private:{user}" - if user is not None: - return f"private:{user}" - return None - - -def scope_from_context(context: Mapping[str, Any] | None) -> str | None: - if not context: - return None - return build_attachment_scope( - group_id=context.get("group_id"), - user_id=context.get("user_id"), - request_type=str(context.get("request_type", "") or ""), - webui_session=bool(context.get("webui_session", False)), - ) - - -def attachment_refs_to_text(attachments: Sequence[Mapping[str, str]]) -> str: - if not attachments: - return "" - parts: list[str] = [] - for item in attachments: - uid = str(item.get("uid", "") or "").strip() - if not uid: - continue - media_type = str(item.get("media_type") or item.get("kind") or "file").strip() - label = _MEDIA_LABELS.get(media_type, "附件") - name = str(item.get("display_name", "") or "").strip() - if name: - parts.append(f"[{label} uid={uid} name={name}]") - else: - parts.append(f"[{label} uid={uid}]") - return " ".join(parts) - - -def attachment_refs_to_xml( - attachments: Sequence[Mapping[str, str]], - *, - indent: str = " ", -) -> str: - if not attachments: - return "" - lines = [f"{indent}"] - for item in attachments: - uid = str(item.get("uid", "") or "").strip() - if not uid: - continue - kind = str(item.get("kind", "") or item.get("media_type", "") or "file").strip() - media_type = str(item.get("media_type", "") or kind or "file").strip() - name = str(item.get("display_name", "") or "").strip() - attrs = [ - f'uid="{escape_xml_attr(uid)}"', - f'type="{escape_xml_attr(kind or media_type)}"', - f'media_type="{escape_xml_attr(media_type)}"', - ] - if name: - attrs.append(f'name="{escape_xml_attr(name)}"') - source_kind = str(item.get("source_kind", "") or "").strip() - if source_kind: - attrs.append(f'source_kind="{escape_xml_attr(source_kind)}"') - source_ref = str(item.get("source_ref", "") or "").strip() - if source_ref: - attrs.append(f'source_ref="{escape_xml_attr(source_ref)}"') - semantic_kind = str(item.get("semantic_kind", "") or "").strip() - if semantic_kind: - attrs.append(f'semantic_kind="{escape_xml_attr(semantic_kind)}"') - description = str(item.get("description", "") or "").strip() - if description: - attrs.append(f'description="{escape_xml_attr(description)}"') - lines.append(f"{indent} ") - lines.append(f"{indent}") - return "\n".join(lines) - - -def append_attachment_text( - base_text: str, attachments: Sequence[Mapping[str, str]] -) -> str: - attachment_text = attachment_refs_to_text(attachments) - if not attachment_text: - return base_text - if not base_text.strip(): - return attachment_text - return f"{base_text}\n附件: {attachment_text}" - - -def _is_http_url(value: str) -> bool: - return value.startswith("http://") or value.startswith("https://") - - -def _is_data_url(value: str) -> bool: - return value.startswith("data:") - - -def _is_localish_path(value: str) -> bool: - return ( - value.startswith("/") - or value.startswith("file://") - or bool(_WINDOWS_ABS_PATH_RE.match(value)) - ) - - -def _decode_data_url(data_url: str) -> tuple[bytes, str]: - header, _, payload = data_url.partition(",") - if ";base64" not in header.lower(): - raise ValueError("unsupported data URL encoding") - mime_type = ( - header.split(":", 1)[1].split(";", 1)[0].strip() or "application/octet-stream" - ) - return base64.b64decode(payload), mime_type - - -def _guess_suffix_from_bytes(content: bytes) -> str: - for magic, suffix in _MAGIC_IMAGE_SUFFIXES: - if content.startswith(magic): - return suffix - if content.startswith(b"RIFF") and content[8:12] == b"WEBP": - return ".webp" - return ".bin" - - -def _guess_suffix(name: str, content: bytes, mime_type: str) -> str: - suffix = Path(name).suffix.lower() - if suffix: - return suffix - guessed_ext = mimetypes.guess_extension(mime_type or "") - if guessed_ext: - return guessed_ext.lower() - return _guess_suffix_from_bytes(content) - - -def _guess_mime_type(name: str, content: bytes) -> str: - guessed, _ = mimetypes.guess_type(name) - if guessed: - return guessed - suffix = _guess_suffix_from_bytes(content) - return _IMAGE_SUFFIX_TO_MIME.get(suffix, "application/octet-stream") - - -def _display_name_from_source(raw_source: str, fallback: str) -> str: - if not raw_source: - return fallback - if raw_source.startswith("file://"): - raw_source = raw_source[7:] - name = Path(unquote(urlsplit(raw_source).path)).name - return name or fallback - - -def _media_kind_from_value(value: str) -> str: - text = str(value or "").strip().lower() - if text in {"image", "file", "audio", "video", "record"}: - return text - return "file" - - -def _remote_reference_source_kind(source_kind: str) -> str: - cleaned = str(source_kind or "").strip() - if not cleaned: - return "remote_url_reference" - if cleaned.endswith("_reference"): - return cleaned - return f"{cleaned}_reference" - - -def _segment_text( - type_: str, data: Mapping[str, Any], ref: Mapping[str, str] | None -) -> str: - if type_ == "text": - return str(data.get("text", "") or "") - if type_ == "at": - qq = str(data.get("qq", "") or "").strip() - name = str(data.get("name") or data.get("nickname") or "").strip() - if qq and name: - return f"[@{qq}({name})]" - if qq: - return f"[@{qq}]" - return "[@]" - if type_ == "face": - return "[表情]" - if type_ == "reply": - reply_id = str(data.get("id") or data.get("message_id") or "").strip() - return f"[引用: {reply_id}]" if reply_id else "[引用]" - if type_ == "forward": - forward_id = str(data.get("id") or data.get("resid") or "").strip() - return f"[合并转发: {forward_id}]" if forward_id else "[合并转发]" - if ref is not None: - label = _MEDIA_LABELS.get( - str(ref.get("media_type") or ref.get("kind") or type_).strip(), "附件" - ) - uid = str(ref.get("uid", "") or "").strip() - name = str(ref.get("display_name", "") or "").strip() - if uid and name: - return f"[{label} uid={uid} name={name}]" - if uid: - return f"[{label} uid={uid}]" - label = _MEDIA_LABELS.get(type_, "附件") - raw = str(data.get("file") or data.get("url") or data.get("id") or "").strip() - return f"[{label}: {raw}]" if raw else f"[{label}]" - - -def _resolve_webui_file_id(file_id: str) -> Path | None: - if not file_id or not file_id.isalnum(): - return None - file_dir = (Path.cwd() / WEBUI_FILE_CACHE_DIR / file_id).resolve() - cache_root = (Path.cwd() / WEBUI_FILE_CACHE_DIR).resolve() - if cache_root not in file_dir.parents and file_dir != cache_root: - return None - if not file_dir.is_dir(): - return None - try: - files = list(file_dir.iterdir()) - except OSError: - return None - for candidate in files: - if candidate.is_file(): - return candidate - return None - - -def _extract_forward_id(data: Mapping[str, Any]) -> str: - forward_id = data.get("id") or data.get("resid") or data.get("message_id") - return str(forward_id).strip() if forward_id is not None else "" - - -def _segment_data_from_onebot_data( - data: Mapping[str, Any], - *, - exclude_keys: set[str] | None = None, -) -> dict[str, str]: - excluded = {key.strip().lower() for key in (exclude_keys or set()) if key.strip()} - normalized: dict[str, str] = {} - for raw_key, raw_value in data.items(): - key = str(raw_key or "").strip() - if not key: - continue - if key.lower() in excluded: - continue - text = str(raw_value or "").strip() - if not text: - continue - normalized[key] = text - return normalized - - -def _normalize_message_segments(message: Any) -> list[Mapping[str, Any]]: - if isinstance(message, list): - normalized: list[Mapping[str, Any]] = [] - for item in message: - if isinstance(item, Mapping): - normalized.append(item) - elif isinstance(item, str): - normalized.append({"type": "text", "data": {"text": item}}) - return normalized - if isinstance(message, Mapping): - return [message] - if isinstance(message, str): - return [{"type": "text", "data": {"text": message}}] - return [] - - -def _normalize_forward_nodes(raw_nodes: Any) -> list[Mapping[str, Any]]: - if isinstance(raw_nodes, list): - return [node for node in raw_nodes if isinstance(node, Mapping)] - if isinstance(raw_nodes, Mapping): - messages = raw_nodes.get("messages") - if isinstance(messages, list): - return [node for node in messages if isinstance(node, Mapping)] - return [] - - -class AttachmentRegistry: - """Persistent attachment registry scoped by conversation.""" - - def __init__( - self, - *, - registry_path: Path = ATTACHMENT_REGISTRY_FILE, - cache_dir: Path = ATTACHMENT_CACHE_DIR, - http_client: httpx.AsyncClient | None = None, - max_records: int = _ATTACHMENT_REGISTRY_MAX_RECORDS, - max_age_seconds: int = _ATTACHMENT_CACHE_MAX_AGE_SECONDS, - max_cache_bytes: int = _ATTACHMENT_CACHE_MAX_BYTES, - url_reference_max_records: int = _ATTACHMENT_URL_REFERENCE_MAX_RECORDS, - url_max_length: int = _ATTACHMENT_URL_MAX_LENGTH, - remote_download_max_bytes: int = _DEFAULT_REMOTE_DOWNLOAD_MAX_BYTES, - ) -> None: - self._registry_path = registry_path - self._cache_dir = cache_dir - self._http_client = http_client - self._max_records = max(0, int(max_records)) - self._max_age_seconds = max(0, int(max_age_seconds)) - self._max_cache_bytes = max(0, int(max_cache_bytes)) - self._url_reference_max_records = max(0, int(url_reference_max_records)) - self._url_max_length = max(0, int(url_max_length)) - self._remote_download_max_bytes = max(0, int(remote_download_max_bytes)) - self._lock = asyncio.Lock() - self._records: dict[str, AttachmentRecord] = {} - self._loaded = False - self._load_task: asyncio.Task[None] | None = None - self._global_image_resolver: Callable[[str], AttachmentRecord | None] | None = ( - None - ) - self._global_image_resolver_async: ( - Callable[[str], Awaitable[AttachmentRecord | None]] | None - ) = None - - def set_remote_download_max_bytes(self, value: int) -> None: - self._remote_download_max_bytes = max(0, int(value)) - - def set_limits( - self, - *, - remote_download_max_bytes: int | None = None, - max_cache_bytes: int | None = None, - max_records: int | None = None, - max_age_seconds: int | None = None, - url_reference_max_records: int | None = None, - url_max_length: int | None = None, - ) -> None: - if remote_download_max_bytes is not None: - self._remote_download_max_bytes = max(0, int(remote_download_max_bytes)) - if max_cache_bytes is not None: - self._max_cache_bytes = max(0, int(max_cache_bytes)) - if max_records is not None: - self._max_records = max(0, int(max_records)) - if max_age_seconds is not None: - self._max_age_seconds = max(0, int(max_age_seconds)) - if url_reference_max_records is not None: - self._url_reference_max_records = max(0, int(url_reference_max_records)) - if url_max_length is not None: - self._url_max_length = max(0, int(url_max_length)) - - def set_global_image_resolver( - self, - resolver: Callable[[str], AttachmentRecord | None] | None, - ) -> None: - self._global_image_resolver = resolver - - def set_global_image_resolver_async( - self, - resolver: Callable[[str], Awaitable[AttachmentRecord | None]] | None, - ) -> None: - self._global_image_resolver_async = resolver - - def _resolve_managed_cache_path(self, raw_path: str | None) -> Path | None: - text = str(raw_path or "").strip() - if not text: - return None - try: - path = Path(text).expanduser().resolve() - cache_root = self._cache_dir.resolve() - except Exception: - return None - if path == cache_root or cache_root not in path.parents: - return None - return path - - def _normalized_url_ref(self, value: str) -> str: - text = str(value or "").strip() - if not _is_http_url(text): - return "" - if self._url_max_length > 0 and len(text) > self._url_max_length: - return "" - return text - - def _record_with_local_path( - self, record: AttachmentRecord, local_path: str | None - ) -> AttachmentRecord: - return replace( - record, - local_path=local_path, - source_kind=_remote_reference_source_kind(record.source_kind) - if local_path is None and _is_http_url(record.source_ref) - else record.source_kind, - ) - - def _remove_cached_content( - self, - record: AttachmentRecord, - cache_path: Path | None, - removable_paths: set[Path], - ) -> AttachmentRecord | None: - source_ref = self._normalized_url_ref(record.source_ref) - if source_ref: - if cache_path is not None: - removable_paths.add(cache_path) - return self._record_with_local_path(record, None) - if cache_path is not None: - removable_paths.add(cache_path) - return None - - def _prune_records(self) -> bool: - dirty = False - now = time.time() - retained: list[tuple[str, AttachmentRecord, Path | None, float, int]] = [] - removable_paths: set[Path] = set() - - for uid, record in self._records.items(): - cache_path = self._resolve_managed_cache_path(record.local_path) - if record.local_path is None: - has_url_ref = bool(self._normalized_url_ref(record.source_ref)) - if _is_http_url(record.source_ref) and not has_url_ref: - dirty = True - continue - try: - mtime = datetime.fromisoformat(record.created_at).timestamp() - except ValueError: - mtime = now - if ( - not has_url_ref - and self._max_age_seconds > 0 - and now - mtime > self._max_age_seconds - ): - dirty = True - continue - retained.append((uid, record, None, mtime, 0)) - continue - if cache_path is None: - replacement = self._remove_cached_content(record, None, removable_paths) - if replacement is not None: - retained.append((uid, replacement, None, now, 0)) - dirty = True - continue - try: - stat_result = cache_path.stat() - mtime = float(stat_result.st_mtime) - size = int(stat_result.st_size) - except OSError: - replacement = self._remove_cached_content( - record, cache_path, removable_paths - ) - if replacement is not None: - retained.append((uid, replacement, None, now, 0)) - dirty = True - continue - if not cache_path.is_file(): - replacement = self._remove_cached_content( - record, cache_path, removable_paths - ) - if replacement is not None: - retained.append((uid, replacement, None, mtime, 0)) - dirty = True - continue - if self._max_age_seconds > 0 and now - mtime > self._max_age_seconds: - replacement = self._remove_cached_content( - record, cache_path, removable_paths - ) - if replacement is not None: - retained.append((uid, replacement, None, mtime, 0)) - dirty = True - continue - retained.append((uid, record, cache_path, mtime, size)) - - if self._max_records > 0 and len(retained) > self._max_records: - retained.sort(key=lambda item: item[3]) - overflow = len(retained) - self._max_records - for _uid, _record, cache_path, _mtime, _size in retained[:overflow]: - if cache_path is not None: - removable_paths.add(cache_path) - retained = retained[overflow:] - dirty = True - - if self._max_cache_bytes > 0: - cache_total = sum( - size - for _uid, _record, path, _mtime, size in retained - if path is not None - ) - if cache_total > self._max_cache_bytes: - reduced: list[ - tuple[str, AttachmentRecord, Path | None, float, int] - ] = [] - for uid, record, cache_path, mtime, size in sorted( - retained, key=lambda item: item[3] - ): - if cache_path is not None and cache_total > self._max_cache_bytes: - replacement = self._remove_cached_content( - record, cache_path, removable_paths - ) - if replacement is not None: - reduced.append((uid, replacement, None, mtime, 0)) - cache_total -= size - dirty = True - else: - reduced.append((uid, record, cache_path, mtime, size)) - retained = reduced - - if self._url_reference_max_records > 0: - url_refs = [ - item - for item in retained - if item[2] is None and _is_http_url(item[1].source_ref) - ] - if len(url_refs) > self._url_reference_max_records: - url_ref_ids = { - uid - for uid, _record, _path, _mtime, _size in sorted( - url_refs, key=lambda item: item[3] - )[: len(url_refs) - self._url_reference_max_records] - } - retained = [item for item in retained if item[0] not in url_ref_ids] - dirty = True - - retained_records = { - uid: record for uid, record, _path, _mtime, _size in retained - } - retained_paths = { - path.resolve() - for _uid, _record, path, _mtime, _size in retained - if path is not None and path.exists() - } - - for path in removable_paths: - try: - resolved = path.resolve() - except Exception: - resolved = path - if resolved in retained_paths: - continue - try: - path.unlink(missing_ok=True) - dirty = True - except OSError: - continue - - if self._cache_dir.exists(): - for item in self._cache_dir.iterdir(): - if not item.is_file(): - continue - try: - resolved = item.resolve() - except Exception: - resolved = item - if resolved in retained_paths: - continue - try: - item.unlink() - dirty = True - except OSError: - continue - - if dirty: - self._records = retained_records - return dirty - - def _load_records_from_payload(self, raw: Any) -> dict[str, AttachmentRecord]: - if not isinstance(raw, dict): - return {} - loaded: dict[str, AttachmentRecord] = {} - for uid, item in raw.items(): - if not isinstance(item, dict): - continue - try: - loaded[str(uid)] = AttachmentRecord( - uid=str(item.get("uid") or uid), - scope_key=str(item.get("scope_key", "") or ""), - kind=_media_kind_from_value(item.get("kind", "file")), - media_type=_media_kind_from_value( - item.get("media_type") or item.get("kind") or "file" - ), - display_name=str(item.get("display_name", "") or ""), - source_kind=str(item.get("source_kind", "") or ""), - source_ref=str(item.get("source_ref", "") or ""), - local_path=str(item.get("local_path", "") or "") or None, - mime_type=str( - item.get("mime_type", "") or "application/octet-stream" - ), - sha256=str(item.get("sha256", "") or ""), - created_at=str(item.get("created_at", "") or ""), - segment_data={ - str(k): str(v) - for k, v in dict(item.get("segment_data") or {}).items() - if str(k).strip() and str(v).strip() - }, - semantic_kind=str(item.get("semantic_kind", "") or ""), - description=str(item.get("description", "") or ""), - ) - except Exception: - continue - return loaded - - async def _load_from_disk_async(self) -> None: - try: - raw = await io.read_json(self._registry_path, use_lock=False) - except Exception as exc: - logger.warning("[AttachmentRegistry] 读取失败: %s", exc) - self._loaded = True - return - self._records = self._load_records_from_payload(raw) - dirty = self._prune_records() - if dirty: - await self._persist() - self._loaded = True - - async def load(self) -> None: - """等待注册表完成初始加载。""" - if self._loaded: - return - if self._load_task is None: - self._load_task = asyncio.create_task(self._load_from_disk_async()) - await self._load_task - - async def _persist(self) -> None: - payload = {uid: asdict(record) for uid, record in self._records.items()} - await io.write_json(self._registry_path, payload, use_lock=True) - - async def flush(self) -> None: - """将当前注册表状态强制落盘。""" - await self.load() - async with self._lock: - await self._persist() - - def get(self, uid: str) -> AttachmentRecord | None: - return self._records.get(str(uid).strip()) - - def resolve(self, uid: str, scope_key: str | None) -> AttachmentRecord | None: - record = self.get(uid) - if record is not None: - if record.scope_key and scope_key and record.scope_key != scope_key: - return None - return record - if self._global_image_resolver is not None: - try: - record = self._global_image_resolver(uid) - except Exception: - logger.exception( - "[AttachmentRegistry] global image resolver failed: uid=%s", uid - ) - record = None - if record is None: - return None - if record.scope_key and scope_key and record.scope_key != scope_key: - return None - return record - - async def resolve_async( - self, uid: str, scope_key: str | None - ) -> AttachmentRecord | None: - record = self.get(uid) - if record is not None: - if record.scope_key and scope_key and record.scope_key != scope_key: - return None - return record - if self._global_image_resolver_async is not None: - try: - record = await self._global_image_resolver_async(uid) - except Exception: - logger.exception( - "[AttachmentRegistry] async global image resolver failed: uid=%s", - uid, - ) - record = None - elif self._global_image_resolver is not None: - try: - record = self._global_image_resolver(uid) - except Exception: - logger.exception( - "[AttachmentRegistry] global image resolver failed: uid=%s", uid - ) - record = None - else: - record = None - if record is None: - return None - if record.scope_key and scope_key and record.scope_key != scope_key: - return None - return record - - def resolve_for_context( - self, - uid: str, - context: Mapping[str, Any] | None, - ) -> AttachmentRecord | None: - return self.resolve(uid, scope_from_context(context)) - - async def get_url_by_uid(self, uid: str) -> str | None: - """通过附件 UID 获取 source_ref(URL)。""" - await self.load() - record = self.get(uid) - if record is None or not record.source_ref.strip(): - return None - return record.source_ref.strip() - - async def get_uid_by_url(self, url: str) -> str | None: - """通过 URL 查找对应的附件 UID。""" - await self.load() - url = url.strip() - if not url: - return None - for record in self._records.values(): - if record.source_ref.strip() == url: - return record.uid - return None - - def _build_uid(self, prefix: str) -> str: - from uuid import uuid4 - - while True: - uid = f"{prefix}_{uuid4().hex[:8]}" - if uid not in self._records: - return uid - - def _find_by_sha256( - self, scope_key: str, sha256: str, kind: str - ) -> AttachmentRecord | None: - """Find an existing record with matching scope, kind, and SHA-256. - - Only returns a record whose *local_path* still exists on disk. - Must be called while ``self._lock`` is held. - """ - for record in self._records.values(): - if ( - record.scope_key == scope_key - and record.sha256 == sha256 - and record.kind == kind - and record.local_path - and Path(record.local_path).is_file() - ): - return record - return None - - async def register_bytes( - self, - scope_key: str, - content: bytes, - *, - kind: str, - display_name: str, - source_kind: str, - source_ref: str = "", - mime_type: str | None = None, - segment_data: Mapping[str, str] | None = None, - ) -> AttachmentRecord: - await self.load() - normalized_kind = _media_kind_from_value(kind) - normalized_media_type = ( - "image" if normalized_kind == "image" else normalized_kind - ) - normalized_mime = mime_type or _guess_mime_type(display_name, content) - suffix = _guess_suffix(display_name, content, normalized_mime) - prefix = "pic" if normalized_media_type == "image" else "file" - - async with self._lock: - digest = await asyncio.to_thread(hashlib.sha256, content) - digest_hex = digest.hexdigest() - - existing = self._find_by_sha256(scope_key, digest_hex, normalized_kind) - if existing is not None: - return existing - - uid = self._build_uid(prefix) - file_name = f"{uid}{suffix}" - cache_path = ensure_dir(self._cache_dir) / file_name - await asyncio.to_thread(cache_path.write_bytes, content) - - record = AttachmentRecord( - uid=uid, - scope_key=scope_key, - kind=normalized_kind, - media_type=normalized_media_type, - display_name=display_name or file_name, - source_kind=source_kind, - source_ref=source_ref, - local_path=str(cache_path), - mime_type=normalized_mime, - sha256=digest_hex, - created_at=_now_iso(), - segment_data={ - str(k): str(v) - for k, v in dict(segment_data or {}).items() - if str(k).strip() and str(v).strip() - }, - ) - self._records[uid] = record - self._prune_records() - await self._persist() - return self._records.get(uid, record) - - async def register_local_file( - self, - scope_key: str, - local_path: str | Path, - *, - kind: str, - display_name: str | None = None, - source_kind: str = "local_file", - source_ref: str = "", - segment_data: Mapping[str, str] | None = None, - ) -> AttachmentRecord: - path = Path(str(local_path)).expanduser() - if not path.is_absolute(): - path = (Path.cwd() / path).resolve() - else: - path = path.resolve() - if not path.is_file(): - raise FileNotFoundError(path) - - def _read() -> bytes: - return path.read_bytes() - - content = await asyncio.to_thread(_read) - return await self.register_bytes( - scope_key, - content, - kind=kind, - display_name=display_name or path.name, - source_kind=source_kind, - source_ref=source_ref or str(path), - mime_type=mimetypes.guess_type(path.name)[0] or None, - segment_data=segment_data, - ) - - async def register_data_url( - self, - scope_key: str, - data_url: str, - *, - kind: str, - display_name: str, - source_kind: str, - source_ref: str = "", - segment_data: Mapping[str, str] | None = None, - ) -> AttachmentRecord: - content, mime_type = _decode_data_url(data_url) - return await self.register_bytes( - scope_key, - content, - kind=kind, - display_name=display_name, - source_kind=source_kind, - source_ref=source_ref, - mime_type=mime_type, - segment_data=segment_data, - ) - - async def register_remote_url( - self, - scope_key: str, - url: str, - *, - kind: str, - display_name: str | None = None, - source_kind: str = "remote_url", - source_ref: str = "", - segment_data: Mapping[str, str] | None = None, - ) -> AttachmentRecord: - name = display_name or _display_name_from_source(url, "attachment.bin") - return await self._register_remote_url_or_reference( - scope_key, - url, - kind=kind, - display_name=name, - source_kind=source_kind, - source_ref=source_ref or url, - segment_data=segment_data, - ) - - async def register_remote_reference( - self, - scope_key: str, - url: str, - *, - kind: str, - display_name: str | None = None, - source_kind: str = "remote_url_reference", - source_ref: str = "", - mime_type: str | None = None, - segment_data: Mapping[str, str] | None = None, - description: str = "", - ) -> AttachmentRecord: - await self.load() - if not self._normalized_url_ref(url): - raise ValueError("远程附件 URL 为空或超过长度上限") - normalized_kind = _media_kind_from_value(kind) - normalized_media_type = ( - "image" if normalized_kind == "image" else normalized_kind - ) - prefix = "pic" if normalized_media_type == "image" else "file" - ref = url - normalized_segment_data = dict(segment_data or {}) - if source_ref and source_ref != url: - normalized_segment_data.setdefault("original_source_ref", source_ref) - name = display_name or _display_name_from_source(url, "attachment.bin") - digest_hex = hashlib.sha256(ref.encode("utf-8")).hexdigest() - - async with self._lock: - for existing in self._records.values(): - if ( - existing.scope_key == scope_key - and existing.kind == normalized_kind - and existing.local_path is None - and existing.source_ref == ref - ): - return existing - - uid = self._build_uid(prefix) - record = AttachmentRecord( - uid=uid, - scope_key=scope_key, - kind=normalized_kind, - media_type=normalized_media_type, - display_name=name, - source_kind=source_kind, - source_ref=ref, - local_path=None, - mime_type=mime_type or mimetypes.guess_type(name)[0] or "", - sha256=digest_hex, - created_at=_now_iso(), - segment_data={ - str(k): str(v) - for k, v in normalized_segment_data.items() - if str(k).strip() and str(v).strip() - }, - description=description, - ) - self._records[uid] = record - self._prune_records() - await self._persist() - return self._records.get(uid, record) - - async def _register_remote_url_or_reference( - self, - scope_key: str, - url: str, - *, - kind: str, - display_name: str, - source_kind: str, - source_ref: str, - segment_data: Mapping[str, str] | None, - ) -> AttachmentRecord: - if not self._normalized_url_ref(url): - raise ValueError("远程附件 URL 为空或超过长度上限") - timeout = httpx.Timeout(_DEFAULT_REMOTE_TIMEOUT_SECONDS) - max_bytes = self._remote_download_max_bytes - reference_segment_data = dict(segment_data or {}) - if source_ref and source_ref != url: - reference_segment_data.setdefault("original_source_ref", source_ref) - if max_bytes <= 0: - return await self.register_remote_reference( - scope_key, - url, - kind=kind, - display_name=display_name, - source_kind=_remote_reference_source_kind(source_kind), - source_ref=url, - segment_data=reference_segment_data, - description="远程附件未下载:remote_download_max_size_mb=0", - ) - - async def _stream(client: httpx.AsyncClient) -> tuple[bytes, str]: - async with client.stream( - "GET", url, timeout=timeout, follow_redirects=True - ) as response: - response.raise_for_status() - mime_type = ( - response.headers.get("content-type", "").split(";", 1)[0].strip() - ) - raw_length = response.headers.get("content-length", "").strip() - if raw_length.isdigit() and int(raw_length) > max_bytes: - raise _RemoteAttachmentTooLarge(mime_type) - - chunks: list[bytes] = [] - total = 0 - async for chunk in response.aiter_bytes(): - total += len(chunk) - if total > max_bytes: - raise _RemoteAttachmentTooLarge(mime_type) - chunks.append(chunk) - return b"".join(chunks), mime_type - - try: - if self._http_client is not None: - content, mime_type = await _stream(self._http_client) - else: - async with httpx.AsyncClient( - timeout=timeout, follow_redirects=True - ) as client: - content, mime_type = await _stream(client) - except _RemoteAttachmentTooLarge as exc: - return await self.register_remote_reference( - scope_key, - url, - kind=kind, - display_name=display_name, - source_kind=_remote_reference_source_kind(source_kind), - source_ref=url, - mime_type=exc.mime_type, - segment_data=reference_segment_data, - description=f"远程附件超过下载上限 {max_bytes} bytes,保留 URL 引用。", - ) - - return await self.register_bytes( - scope_key, - content, - kind=kind, - display_name=display_name, - source_kind=source_kind, - source_ref=url, - mime_type=mime_type or None, - segment_data=reference_segment_data, - ) - - async def ensure_local_file(self, record: AttachmentRecord) -> AttachmentRecord: - await self.load() - if record.local_path and Path(record.local_path).is_file(): - return record - source_ref = self._normalized_url_ref(record.source_ref) - if not source_ref: - return record - existing_uids = set(self._records) - refreshed = await self._register_remote_url_or_reference( - record.scope_key, - source_ref, - kind=record.kind, - display_name=record.display_name, - source_kind=record.source_kind, - source_ref=source_ref, - segment_data=record.segment_data, - ) - if refreshed.local_path is None: - return refreshed - async with self._lock: - current = self._records.get(record.uid) - if current is None: - return refreshed - updated = replace( - current, - local_path=refreshed.local_path, - mime_type=refreshed.mime_type, - sha256=refreshed.sha256, - source_kind=refreshed.source_kind, - segment_data=refreshed.segment_data, - ) - self._records[record.uid] = updated - if refreshed.uid != record.uid and refreshed.uid not in existing_uids: - self._records.pop(refreshed.uid, None) - self._prune_records() - await self._persist() - return self._records.get(record.uid, updated) - - -async def register_message_attachments( - *, - registry: AttachmentRegistry | None, - segments: Sequence[Mapping[str, Any]], - scope_key: str | None, - resolve_image_url: Callable[[str], Awaitable[str | None]] | None = None, - get_forward_messages: Callable[[str], Awaitable[list[dict[str, Any]]]] - | None = None, -) -> RegisteredMessageAttachments: - attachments: list[dict[str, str]] = [] - normalized_parts: list[str] = [] - if registry is None or not scope_key: - for segment in segments: - type_ = str(segment.get("type", "") or "") - raw_data = segment.get("data", {}) - data = raw_data if isinstance(raw_data, Mapping) else {} - normalized_parts.append(_segment_text(type_, data, None)) - return RegisteredMessageAttachments( - attachments=[], - normalized_text="".join(normalized_parts).strip(), - ) - - visited_forward_ids: set[str] = set() - - async def _collect_from_segments( - current_segments: Sequence[Mapping[str, Any]], - *, - depth: int, - prefix: str, - ) -> None: - for index, segment in enumerate(current_segments): - type_ = str(segment.get("type", "") or "").strip().lower() - raw_data = segment.get("data", {}) - data = raw_data if isinstance(raw_data, Mapping) else {} - ref: dict[str, str] | None = None - - try: - if type_ == "image": - raw_source = str(data.get("file") or data.get("url") or "").strip() - display_name = _display_name_from_source( - raw_source, - f"image_{index + 1}.png", - ) - if raw_source.startswith("base64://"): - payload = raw_source[len("base64://") :].strip() - content = base64.b64decode(payload) - record = await registry.register_bytes( - scope_key, - content, - kind="image", - display_name=display_name, - source_kind="base64_image", - source_ref=f"{prefix}segment:{index}", - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - elif _is_data_url(raw_source): - record = await registry.register_data_url( - scope_key, - raw_source, - kind="image", - display_name=display_name, - source_kind="data_url_image", - source_ref=f"{prefix}segment:{index}", - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - else: - resolved_source = raw_source - if raw_source and resolve_image_url is not None: - try: - resolved = await resolve_image_url(raw_source) - except Exception as exc: - logger.debug( - "[AttachmentRegistry] image resolver failed: file=%s err=%s", - raw_source, - exc, - ) - resolved = None - if resolved: - resolved_source = str(resolved) - - if _is_http_url(resolved_source): - record = await registry.register_remote_url( - scope_key, - resolved_source, - kind="image", - display_name=display_name, - source_kind="remote_image", - source_ref=raw_source or resolved_source, - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - elif _is_localish_path(resolved_source): - local_path = ( - resolved_source[7:] - if resolved_source.startswith("file://") - else resolved_source - ) - record = await registry.register_local_file( - scope_key, - local_path, - kind="image", - display_name=display_name, - source_kind="local_image", - source_ref=raw_source or resolved_source, - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - - elif type_ == "file": - file_id = str(data.get("id", "") or "").strip() - raw_source = str(data.get("file") or data.get("url") or "").strip() - local_file_path: Path | None = None - if file_id: - local_file_path = _resolve_webui_file_id(file_id) - elif _is_localish_path(raw_source): - local_file_path = Path( - raw_source[7:] - if raw_source.startswith("file://") - else raw_source - ) - display_name = ( - str(data.get("name", "") or "").strip() - or (local_file_path.name if local_file_path is not None else "") - or _display_name_from_source( - raw_source, f"file_{index + 1}.bin" - ) - ) - if local_file_path is not None and local_file_path.is_file(): - record = await registry.register_local_file( - scope_key, - local_file_path, - kind="file", - display_name=display_name, - source_kind="webui_file" if file_id else "local_file", - source_ref=file_id or raw_source or str(local_file_path), - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - elif _is_http_url(raw_source): - record = await registry.register_remote_url( - scope_key, - raw_source, - kind="file", - display_name=display_name, - source_kind="remote_file", - source_ref=file_id or raw_source, - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - - elif ( - type_ == "forward" - and get_forward_messages is not None - and depth < _FORWARD_ATTACHMENT_MAX_DEPTH - ): - forward_id = _extract_forward_id(data) - if forward_id and forward_id not in visited_forward_ids: - visited_forward_ids.add(forward_id) - try: - nodes = _normalize_forward_nodes( - await get_forward_messages(forward_id) - ) - except Exception as exc: - logger.debug( - "[AttachmentRegistry] forward resolver failed: id=%s err=%s", - forward_id, - exc, - ) - nodes = [] - for node_index, node in enumerate(nodes): - raw_message = ( - node.get("content") - or node.get("message") - or node.get("raw_message") - ) - nested_segments = _normalize_message_segments(raw_message) - if not nested_segments: - continue - await _collect_from_segments( - nested_segments, - depth=depth + 1, - prefix=f"{prefix}forward:{forward_id}:{node_index}:", - ) - except ( - binascii.Error, - ValueError, - FileNotFoundError, - httpx.HTTPError, - ) as exc: - logger.warning( - "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", - type_, - index, - exc, - ) - except Exception as exc: - logger.exception( - "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", - type_, - index, - exc, - ) - - if ref is not None: - attachments.append(ref) - if depth == 0: - normalized_parts.append(_segment_text(type_, data, ref)) - - await _collect_from_segments(segments, depth=0, prefix="") - - return RegisteredMessageAttachments( - attachments=attachments, - normalized_text="".join(normalized_parts).strip(), - ) - - -async def render_message_with_attachments( - message: str, - *, - registry: AttachmentRegistry | None, - scope_key: str | None, - strict: bool, -) -> RenderedRichMessage: - """Render ```` and ```` tags into delivery/history text. - - * ```` — backward-compatible, image-only. - * ```` — unified tag for any media type. - Images (``pic_*``) are inlined as CQ images; files (``file_*``) - are collected into *pending_file_sends* for later dispatch. - """ - has_tags = message and ( - " tag: strictly image-only - if tag_name == "pic" and record.media_type != "image": - replacement = f"[图片 uid={uid} 类型错误]" - if strict: - raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - continue - - # Route by media type - if record.media_type == "image": - ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) - else: - ok = _render_file_tag( - record, - uid, - strict, - delivery_parts, - history_parts, - pending_files, - ) - - if ok: - attachments.append(record.prompt_ref()) - - delivery_parts.append(message[last_index:]) - history_parts.append(message[last_index:]) - return RenderedRichMessage( - delivery_text="".join(delivery_parts), - history_text="".join(history_parts), - attachments=attachments, - pending_file_sends=tuple(pending_files), - ) - - -def _render_image_tag( - record: AttachmentRecord, - uid: str, - strict: bool, - delivery_parts: list[str], - history_parts: list[str], -) -> bool: - """Render an image attachment as an inline CQ:image. Returns True on success.""" - image_source = record.source_ref - if record.local_path: - image_source = Path(record.local_path).resolve().as_uri() - elif not image_source: - replacement = f"[图片 uid={uid} 缺少文件]" - if strict: - raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - return False - - cq_args = [f"file={image_source}"] - for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): - cleaned_key = str(key or "").strip() - cleaned_value = str(value or "").strip() - if ( - not cleaned_key - or not cleaned_value - or cleaned_key in {"file", "original_source_ref"} - ): - continue - cq_args.append( - f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" - ) - delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") - if record.display_name: - history_parts.append(f"[图片 uid={uid} name={record.display_name}]") - else: - history_parts.append(f"[图片 uid={uid}]") - return True - - -def _render_file_tag( - record: AttachmentRecord, - uid: str, - strict: bool, - delivery_parts: list[str], - history_parts: list[str], - pending_files: list[AttachmentRecord], -) -> bool: - """Render a non-image attachment as a pending file send. Returns True on success.""" - if not record.local_path or not Path(record.local_path).is_file(): - if _is_http_url(record.source_ref): - name_part = f" name={record.display_name}" if record.display_name else "" - history_parts.append(f"[文件 uid={uid}{name_part}]") - pending_files.append(record) - return True - replacement = f"[文件 uid={uid} 缺少本地文件]" - if strict: - raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - return False - - # Remove from delivery text (file sent separately) - # Keep a readable placeholder in history - name_part = f" name={record.display_name}" if record.display_name else "" - history_parts.append(f"[文件 uid={uid}{name_part}]") - pending_files.append(record) - return True - - -# Backward-compatible alias -render_message_with_pic_placeholders = render_message_with_attachments - - -async def dispatch_pending_file_sends( - rendered: RenderedRichMessage, - *, - sender: Any, - target_type: str, - target_id: int, - registry: AttachmentRegistry | None = None, -) -> None: - """Send pending file attachments collected by *render_message_with_attachments*. - - This is best-effort: each file send failure is logged but does not interrupt - the remaining sends or the caller. - """ - if not rendered.pending_file_sends or sender is None: - return - for record in rendered.pending_file_sends: - send_record = record - if ( - not send_record.local_path or not Path(send_record.local_path).is_file() - ) and registry is not None: - try: - send_record = await registry.ensure_local_file(send_record) - except Exception: - logger.warning( - "[文件发送] 回源下载失败 uid=%s source=%s", - send_record.uid, - send_record.source_ref, - exc_info=True, - ) - if not send_record.local_path or not Path(send_record.local_path).is_file(): - logger.warning( - "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", - send_record.uid, - send_record.local_path, - ) - continue - try: - if target_type == "group": - await sender.send_group_file( - target_id, - send_record.local_path, - name=send_record.display_name or None, - ) - else: - await sender.send_private_file( - target_id, - send_record.local_path, - name=send_record.display_name or None, - ) - except Exception: - logger.warning( - "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", - send_record.uid, - target_type, - target_id, - exc_info=True, - ) diff --git a/src/Undefined/cognitive/historian.py b/src/Undefined/cognitive/historian.py deleted file mode 100644 index 8692452d..00000000 --- a/src/Undefined/cognitive/historian.py +++ /dev/null @@ -1,1043 +0,0 @@ -"""后台史官 Worker,轮询队列处理任务。""" - -from __future__ import annotations - -import asyncio -import json -import logging -import re -from datetime import datetime, timezone -from typing import Any, Callable - -from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY -from Undefined.utils.tool_calls import extract_required_tool_call_arguments - -logger = logging.getLogger(__name__) - -_MAX_LOG_PREVIEW_LEN = 200 - -_REWRITE_TOOL = { - "type": "function", - "function": { - "name": "submit_rewrite", - "description": "提交绝对化改写后的事件文本", - "parameters": { - "type": "object", - "properties": { - "text": {"type": "string", "description": "改写后的纯文本"}, - }, - "required": ["text"], - }, - }, -} - -_READ_PROFILE_TOOL = { - "type": "function", - "function": { - "name": "read_profile", - "description": "读取指定实体的当前侧写内容", - "parameters": { - "type": "object", - "properties": { - "entity_type": { - "type": "string", - "enum": ["user", "group"], - "description": "实体类型:user 或 group", - }, - "entity_id": { - "type": "string", - "description": "实体 ID(用户 QQ 号或群号)", - }, - }, - "required": ["entity_type", "entity_id"], - }, - }, -} - -_PROFILE_TOOL = { - "type": "function", - "function": { - "name": "update_profile", - "description": "更新用户/群侧写。调用前必须先用 read_profile 查看当前内容", - "parameters": { - "type": "object", - "properties": { - "entity_type": { - "type": "string", - "enum": ["user", "group"], - "description": "实体类型:user 或 group", - }, - "entity_id": { - "type": "string", - "description": "实体 ID(用户 QQ 号或群号)", - }, - "skip": { - "type": "boolean", - "description": "是否跳过更新;当新信息不稳定/不足时为 true", - }, - "skip_reason": { - "type": "string", - "description": "跳过原因", - }, - "name": {"type": "string", "description": "用户/群名称"}, - "tags": { - "type": "array", - "items": {"type": "string"}, - "maxItems": 10, - "description": "身份级标签(角色/核心领域),最多 10 个,不写话题", - }, - "summary": {"type": "string", "description": "侧写正文(Markdown)"}, - }, - "required": ["entity_type", "entity_id", "skip", "name", "tags", "summary"], - }, - }, -} - - -def _preview_text(text: str, max_len: int = _MAX_LOG_PREVIEW_LEN) -> str: - compact = re.sub(r"\s+", " ", str(text or "")).strip() - if len(compact) <= max_len: - return compact - return f"{compact[:max_len]}..." - - -def _extract_frontmatter_name(markdown: str) -> str: - text = str(markdown or "") - if not text.startswith("---"): - return "" - try: - import yaml - - parts = text[3:].split("---", 1) - if len(parts) != 2: - return "" - frontmatter = yaml.safe_load(parts[0]) - if not isinstance(frontmatter, dict): - return "" - value = frontmatter.get("name") - return str(value).strip() if value is not None else "" - except Exception: - return "" - - -def _escape_braces(text: str) -> str: - value = str(text or "") - return value.replace("{", "{{").replace("}", "}}") - - -def _resolve_timestamp_epoch(job: dict[str, Any]) -> int: - raw_epoch = job.get("timestamp_epoch") - if isinstance(raw_epoch, (int, float)): - return int(raw_epoch) - if isinstance(raw_epoch, str): - try: - return int(float(raw_epoch.strip())) - except Exception: - pass - - for key in ("timestamp_utc", "timestamp_local"): - raw_value = job.get(key) - if not isinstance(raw_value, str): - continue - text = raw_value.strip() - if not text: - continue - try: - parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) - return int(parsed.timestamp()) - except Exception: - continue - - return int(datetime.now(timezone.utc).timestamp()) - - -def _coerce_bool(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - normalized = value.strip().lower() - if normalized in {"1", "true", "yes", "y", "on"}: - return True - if normalized in {"0", "false", "no", "n", "off", ""}: - return False - return False - - -class HistorianWorker: - def __init__( - self, - job_queue: Any, - vector_store: Any, - profile_storage: Any, - ai_client: Any, - config_getter: Callable[[], Any], - model_config: Any = None, - ) -> None: - self._job_queue = job_queue - self._vector_store = vector_store - self._profile_storage = profile_storage - self._ai_client = ai_client - self._config_getter = config_getter - self._model_config = model_config - self._stop_event = asyncio.Event() - self._task: asyncio.Task[None] | None = None - self._inflight_tasks: set[asyncio.Task[None]] = set() - - async def _prepare_query_embedding(self, query_text: str) -> list[float] | None: - embed_query = getattr(self._vector_store, "embed_query", None) - if not callable(embed_query): - return None - try: - result = await embed_query(query_text) - except Exception as exc: - logger.warning("[史官] 预生成查询向量失败,回退即时计算: error=%s", exc) - return None - if not isinstance(result, list): - logger.warning("[史官] 预生成查询向量返回值非法,回退即时计算") - return None - normalized: list[float] = [] - for item in result: - try: - normalized.append(float(item)) - except (TypeError, ValueError): - logger.warning("[史官] 预生成查询向量包含非法元素,回退即时计算") - return None - return normalized - - async def start(self) -> None: - logger.info("[史官] Worker 启动中") - self._task = asyncio.create_task(self._poll_loop()) - logger.info("[史官] Worker 已启动") - - async def stop(self) -> None: - logger.info("[史官] Worker 停止中") - self._stop_event.set() - if self._task: - await self._task - logger.info("[史官] Worker 已停止") - - async def _poll_loop(self) -> None: - dispatch_count = 0 - logger.info("[史官] 轮询循环已开始") - while not self._stop_event.is_set(): - result = await self._job_queue.dequeue() - if result: - job_id, job = result - task = asyncio.create_task(self._process_job_with_retry(job_id, job)) - self._inflight_tasks.add(task) - task.add_done_callback(self._inflight_tasks.discard) - dispatch_count += 1 - logger.info( - "[史官] 任务已发车: job_id=%s inflight=%s", - job_id, - len(self._inflight_tasks), - ) - - config = self._config_getter() - if ( - config.failed_cleanup_interval > 0 - and dispatch_count > 0 - and dispatch_count % config.failed_cleanup_interval == 0 - ): - from Undefined.utils.cache import cleanup_cache_dir - - cleanup_cache_dir( - self._job_queue._failed_dir, - max_age_seconds=config.failed_max_age_days * 86400, - max_files=config.failed_max_files, - ) - logger.info( - "[史官] failed 队列清理已执行: interval=%s max_age_days=%s max_files=%s", - config.failed_cleanup_interval, - config.failed_max_age_days, - config.failed_max_files, - ) - - await asyncio.sleep(config.poll_interval_seconds) - - if self._inflight_tasks: - logger.info( - "[史官] 等待在途任务收敛: inflight=%s", len(self._inflight_tasks) - ) - await asyncio.gather(*list(self._inflight_tasks), return_exceptions=True) - logger.info("[史官] 轮询循环已结束") - - async def _process_job_with_retry(self, job_id: str, job: dict[str, Any]) -> None: - try: - await self._process_job(job_id, job) - except Exception as e: - retry_count = job.get("_retry_count", 0) - max_retries = self._config_getter().job_max_retries - if retry_count < max_retries: - logger.warning( - "[史官] 任务 %s 处理失败 (%s/%s),将自动重试: %s", - job_id, - retry_count + 1, - max_retries, - e, - ) - await self._job_queue.requeue(job_id, str(e)) - else: - logger.error( - "[史官] 任务 %s 达到最大重试次数 (%s),移入 failed: %s", - job_id, - max_retries, - e, - ) - await self._job_queue.fail(job_id, str(e)) - - async def _rewrite_and_validate(self, job: dict[str, Any], job_id: str) -> str: - """改写为绝对化事件文本。""" - canonical = await self._rewrite(job, job_id=job_id) - return canonical - - async def _process_job(self, job_id: str, job: dict[str, Any]) -> None: - logger.info( - "[史官] 开始处理任务 %s: user=%s group=%s sender=%s perspective=%s has_observations=%s profile_targets=%s", - job_id, - job.get("user_id", ""), - job.get("group_id", ""), - job.get("sender_id", ""), - job.get("perspective", ""), - job.get("has_observations", job.get("has_new_info", False)), - len(job.get("profile_targets", []) or []), - ) - - raw_observations = ( - job.get("observations") - if "observations" in job - else job.get("new_info", []) - ) - if isinstance(raw_observations, str): - observation_items = ( - [raw_observations.strip()] if raw_observations.strip() else [] - ) - elif isinstance(raw_observations, list): - observation_items = [ - str(s).strip() for s in raw_observations if str(s).strip() - ] - else: - observation_items = [] - - base_metadata: dict[str, Any] = { - "request_id": job.get("request_id", ""), - "end_seq": job.get("end_seq", 0), - "user_id": job.get("user_id", ""), - "group_id": job.get("group_id", ""), - "sender_id": job.get("sender_id", ""), - "request_type": job.get("request_type", ""), - "timestamp_utc": job.get("timestamp_utc", ""), - "timestamp_local": job.get("timestamp_local", ""), - "timestamp_epoch": _resolve_timestamp_epoch(job), - "timezone": job.get("timezone", ""), - "location_abs": job.get("location_abs", ""), - "message_ids": job.get("message_ids", []), - "perspective": str(job.get("perspective", "")).strip(), - "schema_version": job.get("schema_version", "final_v1"), - } - - canonicals: list[str] = [] - - if observation_items: - # 每条 observation 独立改写+入库 - for idx, info_item in enumerate(observation_items): - sub_job = {**job, "observations": info_item} - event_id = f"{job_id}_{idx}" if len(observation_items) > 1 else job_id - canonical = await self._rewrite_and_validate(sub_job, event_id) - meta = { - **base_metadata, - "has_observations": True, - } - await self._vector_store.upsert_event(event_id, canonical, meta) - canonicals.append(canonical) - logger.info( - "[史官] 任务 %s 事件入库完成(%s/%s): len=%s", - event_id, - idx + 1, - len(observation_items), - len(canonical), - ) - - has_obs = ( - job.get("has_observations") - if "has_observations" in job - else job.get("has_new_info", False) - ) - if has_obs and canonicals: - merged_canonical = "\n".join(canonicals) - await self._merge_profiles(job, merged_canonical, job_id) - - await self._job_queue.complete(job_id) - logger.info("[史官] 任务 %s 处理完成", job_id) - - def _extract_required_tool_args( - self, - response: dict[str, Any], - *, - expected_tool_name: str, - stage: str, - job_id: str, - attempt: int | None = None, - target: str | None = None, - ) -> dict[str, Any]: - suffix = f" stage={stage} expected_tool={expected_tool_name}" - if attempt is not None: - suffix += f" attempt={attempt}" - if target: - suffix += f" target={target}" - try: - return extract_required_tool_call_arguments( - response, - expected_tool_name=expected_tool_name, - stage=stage, - logger=logger, - error_context=f"job_id={job_id}{suffix}", - ) - except Exception as exc: - logger.error( - "[史官] 任务 %s 提取工具参数失败:%s err=%s", job_id, suffix, exc - ) - raise - - async def _rewrite( - self, - job: dict[str, Any], - *, - job_id: str = "", - ) -> str: - from Undefined.utils.resources import read_text_resource - - memo = str(job.get("memo") if "memo" in job else job.get("action_summary", "")) - observations = str( - job.get("observations") - if "observations" in job - else job.get("new_info", "") - ) - message_ids_raw = job.get("message_ids", []) - if isinstance(message_ids_raw, list): - message_ids = [ - str(item).strip() for item in message_ids_raw if str(item).strip() - ] - else: - message_ids = [] - profile_targets_raw = job.get("profile_targets", []) - profile_targets_text = "[]" - if isinstance(profile_targets_raw, list) and profile_targets_raw: - compact_targets: list[str] = [] - for target in profile_targets_raw: - if not isinstance(target, dict): - continue - entity_type = str(target.get("entity_type", "")).strip() - entity_id = str(target.get("entity_id", "")).strip() - perspective = str(target.get("perspective", "")).strip() - if not entity_type or not entity_id: - continue - if perspective: - compact_targets.append(f"{entity_type}:{entity_id}({perspective})") - else: - compact_targets.append(f"{entity_type}:{entity_id}") - if compact_targets: - profile_targets_text = ", ".join(compact_targets) - logger.debug( - "[史官] 任务 %s 发起绝对化改写: memo_len=%s observations_len=%s", - job_id or "unknown", - len(memo), - len(observations), - ) - - template = read_text_resource("res/prompts/historian_rewrite.md") - source_message = str(job.get("source_message", "")).strip() - recent_messages_raw = job.get("recent_messages", []) - recent_messages: list[str] = [] - if isinstance(recent_messages_raw, list): - recent_messages = [ - str(item).strip() for item in recent_messages_raw if str(item).strip() - ] - recent_messages_text = "\n---\n".join(recent_messages) - prompt = template.format( - request_id=job.get("request_id", ""), - end_seq=job.get("end_seq", 0), - timestamp_local=job.get("timestamp_local", ""), - timezone=job.get("timezone", "Asia/Shanghai"), - bot_name=job.get("bot_name", "Undefined"), - user_id=job.get("user_id", ""), - group_id=job.get("group_id", ""), - sender_id=job.get("sender_id", ""), - sender_name=job.get("sender_name", ""), - group_name=job.get("group_name", ""), - message_ids=", ".join(message_ids) if message_ids else "[]", - perspective=job.get("perspective", ""), - profile_targets=profile_targets_text, - force="true" if _coerce_bool(job.get("force", False)) else "false", - action_summary=memo, - new_info=observations, - memo=memo, - observations=observations, - source_message=source_message or "(无)", - recent_messages=recent_messages_text or "(无)", - ) - response = await self._ai_client.submit_background_llm_call( - model_config=self._model_config or self._ai_client.agent_config, - messages=[{"role": "user", "content": prompt}], - tools=[_REWRITE_TOOL], - tool_choice={"type": "function", "function": {"name": "submit_rewrite"}}, - call_type="historian_rewrite", - ) - args = self._extract_required_tool_args( - response=response, - expected_tool_name="submit_rewrite", - stage="historian_rewrite", - job_id=job_id or "unknown", - ) - - text = str(args.get("text", "")).strip() - logger.debug( - "[史官] 任务 %s 收到改写结果: len=%s preview=%s", - job_id or "unknown", - len(text), - _preview_text(text), - ) - return text - - def _resolve_profile_targets(self, job: dict[str, Any]) -> list[dict[str, str]]: - targets: list[dict[str, str]] = [] - seen: set[tuple[str, str]] = set() - raw_targets = job.get("profile_targets") - if isinstance(raw_targets, list): - for item in raw_targets: - if not isinstance(item, dict): - continue - entity_type = str(item.get("entity_type", "")).strip() - raw_entity_id = item.get("entity_id") - entity_id = ( - str(raw_entity_id).strip() if raw_entity_id is not None else "" - ) - if entity_type not in {"user", "group"} or not entity_id: - continue - key = (entity_type, entity_id) - if key in seen: - continue - seen.add(key) - targets.append( - { - "entity_type": entity_type, - "entity_id": entity_id, - "perspective": str(item.get("perspective", "")).strip(), - "preferred_name": str(item.get("preferred_name", "")).strip(), - } - ) - if targets: - return targets - - entity_type = "group" if str(job.get("group_id", "")).strip() else "user" - entity_id = str( - job.get("group_id") or job.get("user_id") or job.get("sender_id", "") - ).strip() - if entity_id: - targets.append( - { - "entity_type": entity_type, - "entity_id": entity_id, - "perspective": "legacy", - "preferred_name": "", - } - ) - return targets - - async def _merge_profiles( - self, job: dict[str, Any], canonical: str, event_id: str - ) -> None: - targets = self._resolve_profile_targets(job) - if not targets: - logger.warning("[史官] 任务 %s 侧写合并跳过:缺少目标实体", event_id) - return - logger.info( - "[史官] 任务 %s 开始合并侧写: target_count=%s targets=%s", - event_id, - len(targets), - [ - (t["entity_type"], t["entity_id"], t.get("perspective", "")) - for t in targets - ], - ) - success_count = 0 - for index, target in enumerate(targets, start=1): - try: - merged = await self._merge_profile_target( - job=job, - canonical=canonical, - event_id=event_id, - target=target, - target_index=index, - target_count=len(targets), - ) - if merged: - success_count += 1 - except Exception as exc: - logger.exception( - "[史官] 任务 %s 侧写目标合并失败: target=%s:%s perspective=%s err=%s", - event_id, - target.get("entity_type", ""), - target.get("entity_id", ""), - target.get("perspective", ""), - exc, - ) - logger.info( - "[史官] 任务 %s 侧写合并结束: success=%s total=%s", - event_id, - success_count, - len(targets), - ) - - async def _write_profile( - self, - *, - entity_type: str, - entity_id: str, - effective_name: str, - tags: list[str], - summary: str, - event_id: str, - perspective: str, - ) -> None: - import yaml - - frontmatter: dict[str, Any] = { - "entity_type": entity_type, - "entity_id": entity_id, - "name": effective_name, - "tags": tags, - "updated_at": datetime.now().isoformat(), - "source_event_id": event_id, - } - if entity_type == "user": - frontmatter["nickname"] = effective_name - frontmatter["qq"] = entity_id - else: - frontmatter["group_name"] = effective_name - frontmatter["group_id"] = entity_id - content = f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{summary}" - - await self._profile_storage.write_profile(entity_type, entity_id, content) - logger.info( - "[史官] 任务 %s 侧写文件写入完成: entity_type=%s entity_id=%s tags=%s perspective=%s", - event_id, - entity_type, - entity_id, - tags, - perspective, - ) - - profile_doc_lines: list[str] = [] - if entity_type == "user": - profile_doc_lines.append(f"昵称: {effective_name}") - profile_doc_lines.append(f"QQ号: {entity_id}") - else: - profile_doc_lines.append(f"群名: {effective_name}") - profile_doc_lines.append(f"群号: {entity_id}") - if tags: - profile_doc_lines.append(f"标签: {', '.join(tags)}") - profile_doc_lines.append(summary) - profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) - - profile_metadata: dict[str, Any] = { - "entity_type": entity_type, - "entity_id": entity_id, - "name": effective_name, - } - if entity_type == "user": - profile_metadata["nickname"] = effective_name - profile_metadata["qq"] = entity_id - else: - profile_metadata["group_name"] = effective_name - profile_metadata["group_id"] = entity_id - - await self._vector_store.upsert_profile( - f"{entity_type}:{entity_id}", - profile_doc, - profile_metadata, - ) - logger.info( - "[史官] 任务 %s 侧写向量入库完成: profile_id=%s perspective=%s", - event_id, - f"{entity_type}:{entity_id}", - perspective, - ) - - @staticmethod - def _historical_event_dedupe_key( - event: dict[str, Any], - ) -> tuple[str, str, str, str, str]: - metadata = event.get("metadata") - if not isinstance(metadata, dict): - metadata = {} - return ( - str(event.get("document", "")).strip(), - str(metadata.get("timestamp_local", "")).strip(), - str(metadata.get("sender_id", "")).strip(), - str(metadata.get("user_id", "")).strip(), - str(metadata.get("group_id", "")).strip(), - ) - - async def _query_user_history_events_for_profile_merge( - self, - *, - query_text: str, - entity_id: str, - top_k: int, - query_embedding: list[float] | None = None, - ) -> list[dict[str, Any]]: - """用户历史检索兼容路径:分别按 sender_id/user_id 查询并合并去重。 - - Compatibility path for user history retrieval: - query sender_id/user_id separately, then merge and dedupe. - """ - safe_top_k = max(1, int(top_k)) - query_embedding_value = query_embedding - if query_embedding_value is None: - query_embedding_value = await self._prepare_query_embedding(query_text) - sender_query = self._vector_store.query_events( - query_text, - top_k=safe_top_k, - where={"sender_id": entity_id}, - apply_mmr=True, - query_embedding=query_embedding_value, - ) - user_query = self._vector_store.query_events( - query_text, - top_k=safe_top_k, - where={"user_id": entity_id}, - apply_mmr=True, - query_embedding=query_embedding_value, - ) - sender_events_raw, user_events_raw = await asyncio.gather( - sender_query, user_query - ) - merged_events = list(sender_events_raw) + list(user_events_raw) - - deduped: list[dict[str, Any]] = [] - seen: set[tuple[str, str, str, str, str]] = set() - for event in merged_events: - key = self._historical_event_dedupe_key(event) - if key in seen: - continue - seen.add(key) - deduped.append(event) - if len(deduped) >= safe_top_k: - break - return deduped - - async def _merge_profile_target( - self, - *, - job: dict[str, Any], - canonical: str, - event_id: str, - target: dict[str, str], - target_index: int, - target_count: int, - ) -> bool: - entity_type = str(target.get("entity_type", "")).strip() - entity_id = str(target.get("entity_id", "")).strip() - perspective = str(target.get("perspective", "")).strip() - if entity_type not in {"user", "group"} or not entity_id: - logger.warning( - "[史官] 任务 %s 侧写目标非法,跳过: target=%s", - event_id, - target, - ) - return False - logger.info( - "[史官] 任务 %s 合并侧写目标(%s/%s): entity_type=%s entity_id=%s perspective=%s", - event_id, - target_index, - target_count, - entity_type, - entity_id, - perspective, - ) - - preferred_name = str(target.get("preferred_name", "")).strip() - - observations_raw = job.get("observations", job.get("new_info", [])) - observations_text = ( - "\n".join(observations_raw) - if isinstance(observations_raw, list) - else str(observations_raw) - ) - query_embedding = await self._prepare_query_embedding(observations_text) - if entity_type == "group": - historical_events = await self._vector_store.query_events( - observations_text, - top_k=8, - where={"group_id": entity_id}, - apply_mmr=True, - query_embedding=query_embedding, - ) - else: - historical_events = await self._query_user_history_events_for_profile_merge( - query_text=observations_text, - entity_id=entity_id, - top_k=8, - query_embedding=query_embedding, - ) - historical_lines = ( - "\n".join( - f"- [{e['metadata'].get('timestamp_local', '')}] {e['document']}" - for e in historical_events - ) - or "(暂无历史事件)" - ) - - from Undefined.utils.resources import read_text_resource - - template = read_text_resource("res/prompts/historian_profile_merge.md") - message_ids_raw = job.get("message_ids", []) - if isinstance(message_ids_raw, list): - message_ids = [ - str(item).strip() for item in message_ids_raw if str(item).strip() - ] - else: - message_ids = [] - - prompt = template.format( - historical_events=_escape_braces(historical_lines), - canonical_text=_escape_braces(canonical), - observations=_escape_braces(observations_text), - new_info=_escape_braces(observations_text), - target_entity_type=entity_type, - target_entity_id=entity_id, - target_perspective=perspective, - target_display_name=_escape_braces(preferred_name or entity_id), - request_type=_escape_braces(str(job.get("request_type", ""))), - user_id=_escape_braces(str(job.get("user_id", ""))), - group_id=_escape_braces(str(job.get("group_id", ""))), - sender_id=_escape_braces(str(job.get("sender_id", ""))), - sender_name=_escape_braces(str(job.get("sender_name", ""))), - group_name=_escape_braces(str(job.get("group_name", ""))), - timestamp_local=_escape_braces(str(job.get("timestamp_local", ""))), - timezone=_escape_braces(str(job.get("timezone", ""))), - event_id=_escape_braces(event_id), - request_id=_escape_braces(str(job.get("request_id", ""))), - end_seq=_escape_braces(str(job.get("end_seq", 0))), - message_ids=_escape_braces(", ".join(message_ids) if message_ids else "[]"), - memo=_escape_braces(str(job.get("memo", job.get("action_summary", "")))), - action_summary=_escape_braces( - str(job.get("memo", job.get("action_summary", ""))) - ), - source_message=_escape_braces(str(job.get("source_message", ""))), - recent_messages=_escape_braces( - "\n".join( - f"- {str(item).strip()}" - for item in (job.get("recent_messages", []) or []) - if str(item).strip() - ) - or "(无)" - ), - ) - - messages: list[dict[str, Any]] = [{"role": "user", "content": prompt}] - tools = [_READ_PROFILE_TOOL, _PROFILE_TOOL] - result = False - max_turns = 100 - transport_state: dict[str, Any] | None = None - - for turn in range(max_turns): - response = await self._ai_client.submit_background_llm_call( - model_config=self._model_config or self._ai_client.agent_config, - messages=messages, - tools=tools, - tool_choice="auto", - call_type="historian_profile_merge", - transport_state=transport_state, - ) - - next_transport_state = ( - response.get("_transport_state") if isinstance(response, dict) else None - ) - transport_state = ( - next_transport_state if isinstance(next_transport_state, dict) else None - ) - - choices = response.get("choices") or [] - if not choices: - logger.warning("[史官] 任务 %s turn=%s 响应无 choices", event_id, turn) - break - message = choices[0].get("message") if isinstance(choices[0], dict) else {} - if not isinstance(message, dict): - break - - tool_calls = message.get("tool_calls") or [] - if not tool_calls: - logger.info( - "[史官] 任务 %s turn=%s 无 tool_calls,结束", event_id, turn - ) - break - - assistant_msg: dict[str, Any] = { - "role": "assistant", - "tool_calls": tool_calls, - } - if message.get("content"): - assistant_msg["content"] = message["content"] - output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) - if isinstance(output_items, list): - assistant_msg[RESPONSES_OUTPUT_ITEMS_KEY] = output_items - messages.append(assistant_msg) - - tool_results: list[dict[str, Any]] = [] - done = False - - for tc in tool_calls: - if not isinstance(tc, dict): - continue - func = tc.get("function") or {} - tc_name = str(func.get("name", "")).strip() - tc_id = str(tc.get("id", "")).strip() - try: - tc_args: dict[str, Any] = json.loads( - str(func.get("arguments", "{}")) - ) - except json.JSONDecodeError: - tc_args = {} - - if tc_name == "read_profile": - rp_et = str(tc_args.get("entity_type", "")).strip() - rp_eid = str(tc_args.get("entity_id", "")).strip() - if ( - rp_et not in {"user", "group"} - or not rp_eid - or not rp_eid.isalnum() - ): - tc_content = "错误:entity_type 或 entity_id 无效" - else: - profile_text = await self._profile_storage.read_profile( - rp_et, rp_eid - ) - tc_content = profile_text or "(暂无侧写)" - logger.info( - "[史官] 任务 %s read_profile: %s:%s len=%s", - event_id, - rp_et, - rp_eid, - len(tc_content), - ) - tool_results.append( - {"role": "tool", "tool_call_id": tc_id, "content": tc_content} - ) - - elif tc_name == "update_profile": - up_et = str(tc_args.get("entity_type", entity_type)).strip() - up_eid = str(tc_args.get("entity_id", entity_id)).strip() - if ( - up_et not in {"user", "group"} - or not up_eid - or not up_eid.isalnum() - ): - tool_results.append( - { - "role": "tool", - "tool_call_id": tc_id, - "content": "错误:entity_type 或 entity_id 无效", - } - ) - continue - raw_skip = tc_args.get("skip", False) - skip = ( - raw_skip.lower() not in ("false", "0", "no", "") - if isinstance(raw_skip, str) - else bool(raw_skip) - ) - if skip: - skip_reason = str(tc_args.get("skip_reason", "")).strip() - logger.info( - "[史官] 任务 %s 侧写更新跳过: target=%s:%s perspective=%s reason=%s", - event_id, - up_et, - up_eid, - perspective, - skip_reason or "unspecified", - ) - tool_results.append( - { - "role": "tool", - "tool_call_id": tc_id, - "content": f"已跳过: {skip_reason}", - } - ) - done = True - continue - - summary = str(tc_args.get("summary", "")).strip() - if not summary: - logger.info( - "[史官] 任务 %s 侧写更新跳过: target=%s:%s reason=empty_summary", - event_id, - up_et, - up_eid, - ) - tool_results.append( - { - "role": "tool", - "tool_call_id": tc_id, - "content": "错误:summary 为空", - } - ) - continue - raw_tags = tc_args.get("tags", []) - up_tags: list[str] = [] - if isinstance(raw_tags, list): - up_tags = [str(t).strip() for t in raw_tags if str(t).strip()][ - :10 - ] - - llm_name = str(tc_args.get("name", "")).strip() - is_target = up_et == entity_type and up_eid == entity_id - name_hint = preferred_name if is_target else "" - if not llm_name and not name_hint: - existing = await self._profile_storage.read_profile( - up_et, up_eid - ) - fallback_name = _extract_frontmatter_name(existing or "") - else: - fallback_name = "" - effective_name = ( - name_hint - or llm_name - or fallback_name - or (f"GID:{up_eid}" if up_et == "group" else f"UID:{up_eid}") - ) - - await self._write_profile( - entity_type=up_et, - entity_id=up_eid, - effective_name=effective_name, - tags=up_tags, - summary=summary, - event_id=event_id, - perspective=perspective, - ) - tool_results.append( - {"role": "tool", "tool_call_id": tc_id, "content": "侧写已更新"} - ) - result = True - done = True - - else: - tool_results.append( - { - "role": "tool", - "tool_call_id": tc_id, - "content": f"未知工具: {tc_name}", - } - ) - - messages.extend(tool_results) - if done: - break - - return result diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py deleted file mode 100644 index 831fac53..00000000 --- a/src/Undefined/cognitive/service.py +++ /dev/null @@ -1,898 +0,0 @@ -"""认知记忆服务门面。""" - -from __future__ import annotations - -import asyncio -import logging -import time -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Callable, cast - -from Undefined.context import RequestContext -from Undefined.utils.coerce import safe_float - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from Undefined.knowledge.runtime import RetrievalRuntime - - -def _parse_iso_to_epoch_seconds(value: Any) -> int | None: - if not isinstance(value, str): - return None - text = value.strip() - if not text: - return None - try: - parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) - except Exception: - return None - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) - return int(parsed.timestamp()) - - -def _compose_where(clauses: list[dict[str, Any]]) -> dict[str, Any] | None: - if not clauses: - return None - if len(clauses) == 1: - return clauses[0] - return {"$and": clauses} - - -def _event_base_score(item: dict[str, Any]) -> float: - rerank_score = item.get("rerank_score") - if isinstance(rerank_score, (int, float)): - return max(0.0, float(rerank_score)) - if isinstance(rerank_score, str): - try: - return max(0.0, float(rerank_score.strip())) - except Exception: - pass - similarity = 1.0 - safe_float(item.get("distance"), default=1.0) - if similarity < 0.0: - return 0.0 - if similarity > 1.0: - return 1.0 - return similarity - - -def _event_timestamp_epoch(metadata: Any) -> float: - if not isinstance(metadata, dict): - return float("-inf") - raw_epoch = metadata.get("timestamp_epoch") - if isinstance(raw_epoch, (int, float)): - return float(raw_epoch) - if isinstance(raw_epoch, str): - try: - return float(raw_epoch.strip()) - except Exception: - pass - for key in ("timestamp_utc", "timestamp_local"): - parsed = _parse_iso_to_epoch_seconds(metadata.get(key)) - if parsed is not None: - return float(parsed) - return float("-inf") - - -def _event_dedupe_key(item: dict[str, Any]) -> tuple[str, str, str, str, str, str]: - metadata = item.get("metadata") - if not isinstance(metadata, dict): - metadata = {} - return ( - str(item.get("document", "")).strip(), - str(metadata.get("timestamp_epoch", "")).strip(), - str(metadata.get("timestamp_local", "")).strip(), - str(metadata.get("group_id", "")).strip(), - str(metadata.get("sender_id", "")).strip(), - str(metadata.get("user_id", "")).strip(), - ) - - -def _resolve_auto_request_type( - *, - request_type: str | None, - group_id: str, - user_id: str, - sender_id: str, -) -> str: - normalized = str(request_type or "").strip().lower() - if normalized in {"group", "private"}: - return normalized - if group_id: - return "group" - if sender_id or user_id: - return "private" - return "" - - -def _parse_profile_markdown(markdown: str) -> tuple[dict[str, Any], str] | None: - text = str(markdown or "") - if not text.startswith("---"): - return None - try: - import yaml - - parts = text[3:].split("---", 1) - if len(parts) != 2: - return None - frontmatter = yaml.safe_load(parts[0]) - if not isinstance(frontmatter, dict): - return None - body = parts[1].lstrip("\n") - return frontmatter, body - except Exception: - return None - - -def _serialize_profile_markdown(frontmatter: dict[str, Any], body: str) -> str: - import yaml - - return f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{body}" - - -def _normalize_profile_tags(value: Any) -> list[str]: - if not isinstance(value, list): - return [] - return [str(item).strip() for item in value if str(item).strip()] - - -def _current_profile_name(entity_type: str, frontmatter: dict[str, Any]) -> str: - if entity_type == "user": - return str(frontmatter.get("nickname") or frontmatter.get("name") or "").strip() - return str(frontmatter.get("group_name") or frontmatter.get("name") or "").strip() - - -def _build_profile_vector_payload( - *, - entity_type: str, - entity_id: str, - effective_name: str, - tags: list[str], - summary: str, -) -> tuple[str, dict[str, Any]]: - profile_doc_lines: list[str] = [] - if entity_type == "user": - profile_doc_lines.append(f"昵称: {effective_name}") - profile_doc_lines.append(f"QQ号: {entity_id}") - else: - profile_doc_lines.append(f"群名: {effective_name}") - profile_doc_lines.append(f"群号: {entity_id}") - if tags: - profile_doc_lines.append(f"标签: {', '.join(tags)}") - profile_doc_lines.append(summary) - profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) - - metadata: dict[str, Any] = { - "entity_type": entity_type, - "entity_id": entity_id, - "name": effective_name, - } - if entity_type == "user": - metadata["nickname"] = effective_name - metadata["qq"] = entity_id - else: - metadata["group_name"] = effective_name - metadata["group_id"] = entity_id - return profile_doc, metadata - - -class CognitiveService: - def __init__( - self, - config_getter: Callable[[], Any], - vector_store: Any, - job_queue: Any, - profile_storage: Any, - reranker: Any = None, - retrieval_runtime: RetrievalRuntime | None = None, - ) -> None: - self._config_getter = config_getter - self._vector_store = vector_store - self._job_queue = job_queue - self._profile_storage = profile_storage - self._reranker = reranker - self._retrieval_runtime = retrieval_runtime - - def _base_reranker(self) -> Any: - if self._retrieval_runtime is not None: - return self._retrieval_runtime.ensure_reranker() - return self._reranker - - def _current_reranker(self) -> Any: - config = self._config_getter() - if not bool(getattr(config, "enable_rerank", True)): - return None - return self._base_reranker() - - async def _prepare_query_embedding(self, query: str) -> list[float] | None: - embed_query = getattr(self._vector_store, "embed_query", None) - if not callable(embed_query): - return None - try: - result = await embed_query(query) - except Exception as exc: - logger.warning("[认知服务] 预生成查询向量失败,回退即时计算: error=%s", exc) - return None - if not isinstance(result, list): - logger.warning("[认知服务] 预生成查询向量返回值非法,回退即时计算") - return None - normalized: list[float] = [] - for item in result: - try: - normalized.append(float(item)) - except (TypeError, ValueError): - logger.warning("[认知服务] 预生成查询向量包含非法元素,回退即时计算") - return None - return normalized - - @property - def enabled(self) -> bool: - return bool(self._config_getter().enabled) - - async def sync_profile_display_name( - self, - *, - entity_type: str, - entity_id: str, - preferred_name: str, - ) -> bool: - normalized_entity_type = str(entity_type or "").strip().lower() - normalized_entity_id = str(entity_id or "").strip() - normalized_name = str(preferred_name or "").strip() - if normalized_entity_type not in {"user", "group"}: - return False - if not normalized_entity_id or not normalized_name: - return False - if self._profile_storage is None or self._vector_store is None: - return False - - existing = await self._profile_storage.read_profile( - normalized_entity_type, - normalized_entity_id, - ) - if not existing: - return False - - parsed = _parse_profile_markdown(existing) - if parsed is None: - return False - frontmatter, summary = parsed - current_name = _current_profile_name(normalized_entity_type, frontmatter) - if current_name == normalized_name: - return False - - frontmatter["name"] = normalized_name - frontmatter["updated_at"] = datetime.now().isoformat() - if normalized_entity_type == "user": - frontmatter["nickname"] = normalized_name - frontmatter["qq"] = normalized_entity_id - else: - frontmatter["group_name"] = normalized_name - frontmatter["group_id"] = normalized_entity_id - - updated_markdown = _serialize_profile_markdown(frontmatter, summary) - await self._profile_storage.write_profile( - normalized_entity_type, - normalized_entity_id, - updated_markdown, - ) - - profile_doc, profile_metadata = _build_profile_vector_payload( - entity_type=normalized_entity_type, - entity_id=normalized_entity_id, - effective_name=normalized_name, - tags=_normalize_profile_tags(frontmatter.get("tags")), - summary=summary, - ) - await self._vector_store.upsert_profile( - f"{normalized_entity_type}:{normalized_entity_id}", - profile_doc, - profile_metadata, - ) - logger.info( - "[认知服务] 已刷新侧写展示名: entity_type=%s entity_id=%s old=%s new=%s", - normalized_entity_type, - normalized_entity_id, - current_name, - normalized_name, - ) - return True - - @staticmethod - def _uid_candidates(user_id: str, sender_id: str) -> list[str]: - values: list[str] = [] - for raw in (sender_id, user_id): - text = str(raw or "").strip() - if text and text not in values: - values.append(text) - return values - - @staticmethod - def _merge_weighted_events( - scoped_results: list[tuple[list[dict[str, Any]], float]], - *, - top_k: int, - current_group_id: str = "", - current_group_boost: float = 1.0, - ) -> list[dict[str, Any]]: - safe_top_k = max(1, int(top_k)) - safe_group_boost = max(0.0, float(current_group_boost)) - seen_keys: set[tuple[str, str, str, str, str, str]] = set() - # 排序主键优先使用“作用域内原始排名”(已含 time_decay/mmr/rerank 效果), - scored_items: list[ - tuple[float, float, float, float, float, int, dict[str, Any]] - ] = [] - serial = 0 - for scoped_events, scope_weight in scoped_results: - safe_scope_weight = max(0.0, safe_float(scope_weight, default=1.0)) - scope_size = max(1, len(scoped_events)) - for rank_idx, event in enumerate(scoped_events): - dedupe_key = _event_dedupe_key(event) - if dedupe_key in seen_keys: - continue - seen_keys.add(dedupe_key) - metadata = event.get("metadata") - if not isinstance(metadata, dict): - metadata = {} - scope_boost = safe_scope_weight - if ( - current_group_id - and str(metadata.get("group_id", "")).strip() == current_group_id - ): - scope_boost *= safe_group_boost - # 保留每个 scope 内已重排结果(time_decay/mmr/rerank)的相对顺序。 - rank_score = float(scope_size - rank_idx) / float(scope_size) - weighted_rank_score = rank_score * scope_boost - base_score = _event_base_score(event) - weighted_score = base_score * scope_boost - scored_items.append( - ( - weighted_rank_score, - weighted_score, - rank_score, - base_score, - _event_timestamp_epoch(metadata), - serial, - event, - ) - ) - serial += 1 - scored_items.sort( - key=lambda item: ( - -item[0], - -item[1], - -item[2], - -item[3], - -item[4], - item[5], - ) - ) - return [item[6] for item in scored_items[:safe_top_k]] - - async def _query_events_for_auto_context( - self, - *, - query: str, - request_type: str, - group_id: str, - user_id: str, - sender_id: str, - top_k: int, - config: Any, - ) -> list[dict[str, Any]]: - safe_top_k = max(1, int(top_k)) - scope_candidate_multiplier = int( - getattr(config, "auto_scope_candidate_multiplier", 2) - ) - if scope_candidate_multiplier <= 0: - scope_candidate_multiplier = 2 - scoped_top_k = max(safe_top_k, safe_top_k * scope_candidate_multiplier) - current_group_boost = safe_float( - getattr(config, "auto_current_group_boost", 1.15), default=1.15 - ) - if current_group_boost <= 0: - current_group_boost = 1.15 - current_private_boost = safe_float( - getattr(config, "auto_current_private_boost", 1.25), default=1.25 - ) - if current_private_boost <= 0: - current_private_boost = 1.25 - query_embedding = await self._prepare_query_embedding(query) - common_kwargs: dict[str, Any] = { - "reranker": self._current_reranker(), - "candidate_multiplier": config.rerank_candidate_multiplier, - "time_decay_enabled": bool(getattr(config, "time_decay_enabled", True)), - "time_decay_half_life_days": float( - getattr(config, "time_decay_half_life_days_auto", 14.0) - ), - "time_decay_boost": float(getattr(config, "time_decay_boost", 0.2)), - "time_decay_min_similarity": float( - getattr(config, "time_decay_min_similarity", 0.35) - ), - "apply_mmr": True, - } - if query_embedding is not None: - common_kwargs["query_embedding"] = query_embedding - uid_values = self._uid_candidates(user_id, sender_id) - - if request_type == "group": - group_events: list[dict[str, Any]] = await self._vector_store.query_events( - query, - top_k=scoped_top_k, - where={"request_type": "group"}, - **common_kwargs, - ) - merge_started = time.perf_counter() - merged = self._merge_weighted_events( - [(group_events, 1.0)], - top_k=safe_top_k, - current_group_id=group_id, - current_group_boost=current_group_boost, - ) - merge_duration = time.perf_counter() - merge_started - logger.info( - "[认知服务] 自动检索(群聊): group_candidates=%s merged=%s top_k=%s scope_multiplier=%s current_group_boost=%.2f merge=%.3fs", - len(group_events), - len(merged), - safe_top_k, - scope_candidate_multiplier, - current_group_boost, - merge_duration, - ) - return merged - - if request_type == "private": - group_task = self._vector_store.query_events( - query, - top_k=scoped_top_k, - where={"request_type": "group"}, - **common_kwargs, - ) - if uid_values: - uid_clauses = [{"user_id": value} for value in uid_values] + [ - {"sender_id": value} for value in uid_values - ] - private_where: dict[str, Any] = { - "$and": [ - {"request_type": "private"}, - {"$or": uid_clauses}, - ] - } - private_task = self._vector_store.query_events( - query, - top_k=scoped_top_k, - where=private_where, - **common_kwargs, - ) - group_events_raw, private_events_raw = await asyncio.gather( - group_task, private_task - ) - group_events = cast(list[dict[str, Any]], group_events_raw) - private_events = cast(list[dict[str, Any]], private_events_raw) - else: - group_events = cast(list[dict[str, Any]], await group_task) - private_events = [] - merge_started = time.perf_counter() - merged = self._merge_weighted_events( - [ - (group_events, 1.0), - (private_events, current_private_boost), - ], - top_k=safe_top_k, - ) - merge_duration = time.perf_counter() - merge_started - logger.info( - "[认知服务] 自动检索(私聊): group_candidates=%s private_candidates=%s merged=%s top_k=%s scope_multiplier=%s private_boost=%.2f uid_candidates=%s merge=%.3fs", - len(group_events), - len(private_events), - len(merged), - safe_top_k, - scope_candidate_multiplier, - current_private_boost, - uid_values, - merge_duration, - ) - return merged - - where: dict[str, Any] | None = None - if group_id: - where = {"group_id": group_id} - elif uid_values: - where = { - "$or": [{"user_id": value} for value in uid_values] - + [{"sender_id": value} for value in uid_values] - } - events: list[dict[str, Any]] = await self._vector_store.query_events( - query, - top_k=safe_top_k, - where=where, - **common_kwargs, - ) - logger.info( - "[认知服务] 自动检索(兜底): mode=%s where=%s count=%s top_k=%s", - request_type or "unknown", - where or {}, - len(events), - safe_top_k, - ) - return events - - async def enqueue_job( - self, - memo: str, - observations: list[str], - context: dict[str, Any], - *, - force: bool = False, - ) -> str | None: - memo_text = str(memo or "").strip() - observation_items = ( - [s for s in observations if s.strip()] if observations else [] - ) - if not self.enabled: - logger.info("[认知服务] 已禁用,跳过入队") - return None - if not memo_text and not observation_items: - logger.info("[认知服务] memo/observations 均为空,跳过入队") - return None - ctx = RequestContext.current() - - now = datetime.now().astimezone() - now_utc = datetime.now(timezone.utc) - safe_request_id = ( - str(ctx.request_id) - if ctx and str(ctx.request_id or "").strip() - else str(context.get("request_id", "")).strip() - ) - if not safe_request_id: - safe_request_id = "" - - end_seq_raw = context.get("_end_seq", 0) - try: - end_seq = int(end_seq_raw) - except (TypeError, ValueError): - end_seq = 0 - - has_observations = bool(observation_items) - message_ids = context.get("message_ids") - if not isinstance(message_ids, list): - message_ids = [] - message_ids = [str(item).strip() for item in message_ids if str(item).strip()] - perspective = str(context.get("memory_perspective", "")).strip() - user_id = ( - str(ctx.user_id or "") if ctx else str(context.get("user_id", "") or "") - ) - group_id = ( - str(ctx.group_id or "") if ctx else str(context.get("group_id", "") or "") - ) - sender_id = ( - str(ctx.sender_id or "") - if ctx - else str(context.get("sender_id") or context.get("user_id", "") or "") - ) - request_type = ( - str(ctx.request_type) - if ctx and ctx.request_type - else str(context.get("request_type", "") or "") - ) - sender_name = str(context.get("sender_name") or "").strip() - group_name = str(context.get("group_name") or "").strip() - source_message = str(context.get("historian_source_message") or "").strip() - recent_messages_raw = context.get("historian_recent_messages", []) - recent_messages: list[str] = [] - if isinstance(recent_messages_raw, list): - recent_messages = [ - str(item).strip() for item in recent_messages_raw if str(item).strip() - ] - - profile_targets: list[dict[str, str]] = [] - if has_observations: - group_id = group_id.strip() - sender_id = sender_id.strip() or user_id.strip() - seen: set[tuple[str, str]] = set() - if group_id: - key = ("group", group_id) - if key not in seen: - seen.add(key) - profile_targets.append( - { - "entity_type": "group", - "entity_id": group_id, - "perspective": "group", - "preferred_name": group_name, - } - ) - if sender_id: - key = ("user", sender_id) - if key not in seen: - seen.add(key) - profile_targets.append( - { - "entity_type": "user", - "entity_id": sender_id, - "perspective": "sender", - "preferred_name": sender_name, - } - ) - - bot_name = str(self._config_getter().bot_name or "Undefined").strip() - - job: dict[str, Any] = { - "request_id": safe_request_id, - "end_seq": end_seq, - "user_id": user_id, - "group_id": group_id, - "sender_id": sender_id, - "sender_name": sender_name, - "group_name": group_name, - "bot_name": bot_name, - "request_type": request_type, - "timestamp_utc": now_utc.isoformat(), - "timestamp_local": now.isoformat(), - "timestamp_epoch": int(now_utc.timestamp()), - "timezone": str(now.tzinfo or ""), - "location_abs": str( - context.get("group_name") or context.get("sender_name") or "" - ), - "message_ids": message_ids, - "memo": memo_text, - "observations": observation_items, - "has_observations": has_observations, - "perspective": perspective, - "profile_targets": profile_targets, - "schema_version": "final_v1", - "source_message": source_message, - "recent_messages": recent_messages, - "force": bool(force), - } - logger.info( - "[认知服务] 准备入队: request_id=%s end_seq=%s user=%s group=%s sender=%s perspective=%s has_observations=%s profile_targets=%s memo_len=%s observations_len=%s source_len=%s recent_ref=%s force=%s", - job.get("request_id", ""), - job.get("end_seq", 0), - job.get("user_id", ""), - job.get("group_id", ""), - job.get("sender_id", ""), - perspective or "default", - has_observations, - len(profile_targets), - len(memo_text), - len(observation_items), - len(source_message), - len(recent_messages), - bool(force), - ) - result: str | None = await self._job_queue.enqueue(job) - logger.info("[认知服务] 入队完成: job_id=%s", result or "") - return result - - async def build_context( - self, - query: str, - group_id: str | None = None, - user_id: str | None = None, - sender_id: str | None = None, - sender_name: str | None = None, - group_name: str | None = None, - request_type: str | None = None, - ) -> str: - config = self._config_getter() - safe_group_id = str(group_id or "").strip() - safe_user_id = str(user_id or "").strip() - safe_sender_id = str(sender_id or "").strip() - safe_request_type = _resolve_auto_request_type( - request_type=request_type, - group_id=safe_group_id, - user_id=safe_user_id, - sender_id=safe_sender_id, - ) - parts: list[str] = [] - logger.info( - "[认知服务] 构建上下文: query_len=%s type=%s user=%s sender=%s group=%s top_k=%s", - len(query or ""), - safe_request_type or "", - safe_user_id, - safe_sender_id, - safe_group_id, - getattr(config, "auto_top_k", 5), - ) - - uid = safe_sender_id or safe_user_id - if uid: - profile = await self._profile_storage.read_profile("user", uid) - if profile: - label = f"{sender_name}(UID: {uid})" if sender_name else f"UID: {uid}" - parts.append(f"## 用户侧写 — {label}\n{profile}") - - if safe_group_id: - gprofile = await self._profile_storage.read_profile("group", safe_group_id) - if gprofile: - glabel = ( - f"{group_name}(GID: {safe_group_id})" - if group_name - else f"GID: {safe_group_id}" - ) - parts.append(f"## 群聊侧写 — {glabel}\n{gprofile}") - - default_top_k = 5 - try: - top_k = int(getattr(config, "auto_top_k", default_top_k)) - except Exception: - top_k = default_top_k - if top_k <= 0: - top_k = default_top_k - top_k = min(top_k, 500) - try: - events = await self._query_events_for_auto_context( - query=query, - request_type=safe_request_type, - group_id=safe_group_id, - user_id=safe_user_id, - sender_id=safe_sender_id, - top_k=top_k, - config=config, - ) - except Exception as exc: - logger.warning( - "[认知服务] 自动上下文事件检索失败,降级为空结果: type=%s user=%s sender=%s group=%s err=%s", - safe_request_type, - safe_user_id, - safe_sender_id, - safe_group_id, - exc, - ) - events = [] - if events: - event_lines = "\n".join( - f"- [{e['metadata'].get('timestamp_local', '')}] {e['document']}" - for e in events - ) - parts.append(f"## 相关记忆事件\n{event_lines}") - - if not parts: - logger.info("[认知服务] 构建上下文完成: 无可用记忆") - return "" - - body = "\n\n".join(parts) - result = ( - "\n" - "\n" - f"{body}\n" - "" - ) - logger.info( - "[认知服务] 构建上下文完成: sections=%s result_len=%s", - len(parts), - len(result), - ) - return result - - async def search_events(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: - config = self._config_getter() - group_id = str( - kwargs.get("group_id") or kwargs.get("target_group_id") or "" - ).strip() - user_id = str( - kwargs.get("user_id") or kwargs.get("target_user_id") or "" - ).strip() - sender_id = str(kwargs.get("sender_id") or "").strip() - where_clauses: list[dict[str, Any]] = [] - if group_id: - where_clauses.append({"group_id": group_id}) - if user_id: - where_clauses.append({"user_id": user_id}) - if sender_id: - where_clauses.append({"sender_id": sender_id}) - request_type = str(kwargs.get("request_type") or "").strip() - if request_type: - where_clauses.append({"request_type": request_type}) - - time_from_epoch = _parse_iso_to_epoch_seconds(kwargs.get("time_from")) - time_to_epoch = _parse_iso_to_epoch_seconds(kwargs.get("time_to")) - if ( - time_from_epoch is not None - and time_to_epoch is not None - and time_from_epoch > time_to_epoch - ): - logger.warning( - "[认知服务] search_events 时间范围反转,已自动交换: time_from=%s time_to=%s", - kwargs.get("time_from"), - kwargs.get("time_to"), - ) - time_from_epoch, time_to_epoch = time_to_epoch, time_from_epoch - if time_from_epoch is not None: - where_clauses.append({"timestamp_epoch": {"$gte": time_from_epoch}}) - if time_to_epoch is not None: - where_clauses.append({"timestamp_epoch": {"$lte": time_to_epoch}}) - - where = _compose_where(where_clauses) - default_top_k = getattr(config, "tool_default_top_k", 12) - top_k_raw = kwargs.get("top_k", default_top_k) - try: - top_k = int(top_k_raw) - except Exception: - top_k = default_top_k - if top_k <= 0: - top_k = default_top_k - top_k = min(top_k, 500) - logger.info( - "[认知服务] 搜索事件: query_len=%s top_k=%s where=%s time_from=%s time_to=%s", - len(query or ""), - top_k, - where or {}, - time_from_epoch, - time_to_epoch, - ) - results: list[dict[str, Any]] = await self._vector_store.query_events( - query, - top_k=top_k, - where=where or None, - reranker=self._current_reranker(), - candidate_multiplier=config.rerank_candidate_multiplier, - time_decay_enabled=bool(getattr(config, "time_decay_enabled", True)), - time_decay_half_life_days=float( - getattr(config, "time_decay_half_life_days_tool", 60.0) - ), - time_decay_boost=float(getattr(config, "time_decay_boost", 0.2)), - time_decay_min_similarity=float( - getattr(config, "time_decay_min_similarity", 0.35) - ), - apply_mmr=True, - query_embedding=await self._prepare_query_embedding(query), - ) - logger.info("[认知服务] 搜索事件完成: count=%s", len(results)) - return results - - async def get_profile(self, entity_type: str, entity_id: str) -> str | None: - logger.info( - "[认知服务] 读取侧写: entity_type=%s entity_id=%s", - entity_type, - entity_id, - ) - result: str | None = await self._profile_storage.read_profile( - entity_type, entity_id - ) - logger.info( - "[认知服务] 读取侧写完成: found=%s", - bool(result), - ) - return result - - async def search_profiles(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: - config = self._config_getter() - default_top_k = int(getattr(config, "profile_top_k", 5)) - top_k_raw = kwargs.get("top_k", default_top_k) - try: - top_k = int(top_k_raw) - except Exception: - top_k = default_top_k - if top_k <= 0: - top_k = default_top_k - top_k = min(top_k, 500) - - where: dict[str, Any] | None = None - entity_type_raw = kwargs.get("entity_type") - entity_type = ( - str(entity_type_raw).strip() if entity_type_raw is not None else "" - ) - if entity_type: - where = {"entity_type": entity_type} - - logger.info( - "[认知服务] 搜索侧写: query_len=%s top_k=%s where=%s", - len(query or ""), - top_k, - where or {}, - ) - results: list[dict[str, Any]] = await self._vector_store.query_profiles( - query, - top_k=top_k, - where=where, - reranker=self._current_reranker(), - candidate_multiplier=config.rerank_candidate_multiplier, - query_embedding=await self._prepare_query_embedding(query), - ) - logger.info("[认知服务] 搜索侧写完成: count=%s", len(results)) - return results diff --git a/src/Undefined/config/__init__.py b/src/Undefined/config/__init__.py index b3e3fab2..4fe960fb 100644 --- a/src/Undefined/config/__init__.py +++ b/src/Undefined/config/__init__.py @@ -68,3 +68,4 @@ def set_config(config: Config) -> None: """注入 Config 单例(库嵌入 opt-in;CLI / WebUI 启动链不得调用)。""" global _config _config = config + get_config_manager().replace(config) diff --git a/src/Undefined/config/manager.py b/src/Undefined/config/manager.py index f2c4400b..f497f253 100644 --- a/src/Undefined/config/manager.py +++ b/src/Undefined/config/manager.py @@ -30,6 +30,10 @@ def load(self, strict: bool = True) -> Config: self._config = Config.load(config_path=self.config_path, strict=strict) return self._config + def replace(self, config: Config) -> None: + """替换缓存的配置实例(库嵌入 ``set_config`` 注入时使用)。""" + self._config = config + def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: if self._config is None: self._config = Config.load(config_path=self.config_path, strict=strict) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py deleted file mode 100644 index 8006aa40..00000000 --- a/src/Undefined/handlers.py +++ /dev/null @@ -1,1400 +0,0 @@ -"""消息处理和命令分发""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass -import logging -import os -from pathlib import Path -import random -import time -from typing import Any, Coroutine, Literal - -from Undefined.attachments import ( - append_attachment_text, - build_attachment_scope, - register_message_attachments, -) -from Undefined.ai import AIClient -from Undefined.config import Config -from Undefined.faq import FAQStorage -from Undefined.rate_limit import RateLimiter -from Undefined.services.queue_manager import QueueManager -from Undefined.onebot import ( - OneBotClient, - get_message_content, - get_message_sender_id, -) -from Undefined.utils.common import ( - extract_text, - parse_message_content_for_history, - matches_xinliweiyuan, -) -from Undefined.utils.fake_at import BotNicknameCache, strip_fake_at -from Undefined.utils.history import MessageHistoryManager -from Undefined.utils.scheduler import TaskScheduler -from Undefined.utils.sender import MessageSender -from Undefined.services.security import SecurityService -from Undefined.services.command import CommandDispatcher -from Undefined.services.ai_coordinator import AICoordinator -from Undefined.services.message_batcher import MessageBatcher, make_scope -from Undefined.services.model_pool import ModelPoolService -from Undefined.skills.pipelines import PipelineRegistry -from Undefined.skills.pipelines.context import build_pipeline_context -from Undefined.utils.resources import resolve_resource_path -from Undefined.utils.queue_intervals import build_model_queue_intervals - -from Undefined.scheduled_task_storage import ScheduledTaskStorage -from Undefined.utils.logging import log_debug_json, redact_string -from Undefined.utils.coerce import safe_int - -logger = logging.getLogger(__name__) - -KEYWORD_REPLY_HISTORY_PREFIX = "[系统关键词自动回复] " -REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " - - -def _is_private_model_pool_control_text(text: str) -> bool: - return ModelPoolService.is_private_control_text(text) - - -def _format_poke_history_text(display_name: str, user_id: int) -> str: - """格式化拍一拍历史文本。""" - return f"{display_name}(暱称)[{user_id}(QQ号)] 拍了拍你。" - - -@dataclass(frozen=True) -class PrivatePokeRecord: - poke_text: str - sender_name: str - - -@dataclass(frozen=True) -class GroupPokeRecord: - poke_text: str - sender_name: str - group_name: str - sender_role: str - sender_title: str - sender_level: str - - -class MessageHandler: - """消息处理器""" - - def __init__( - self, - config: Config, - onebot: OneBotClient, - ai: AIClient, - faq_storage: FAQStorage, - task_storage: ScheduledTaskStorage, - ) -> None: - self.config = config - self.onebot = onebot - self.ai = ai - self.faq_storage = faq_storage - # 初始化工具组件 - self.history_manager = MessageHistoryManager(config.history_max_records) - self.sender = MessageSender( - onebot, - self.history_manager, - config.bot_qq, - config, - attachment_registry=getattr(ai, "attachment_registry", None), - ) - - # 初始化服务 - self.security = SecurityService(config, ai._http_client) - self.rate_limiter = RateLimiter(config) - self.queue_manager = QueueManager( - max_retries=config.ai_request_max_retries, - ) - self.queue_manager.update_model_intervals(build_model_queue_intervals(config)) - - # 设置队列管理器到 AIClient(触发 Agent 介绍生成器启动) - ai.set_queue_manager(self.queue_manager) - - self.command_dispatcher = CommandDispatcher( - config, - self.sender, - ai, - faq_storage, - onebot, - self.security, - queue_manager=self.queue_manager, - rate_limiter=self.rate_limiter, - history_manager=self.history_manager, - ) - self.ai_coordinator = AICoordinator( - config, - ai, - self.queue_manager, - self.history_manager, - self.sender, - onebot, - TaskScheduler(ai, self.sender, onebot, self.history_manager, task_storage), - self.security, - command_dispatcher=self.command_dispatcher, - ) - - # 同 sender 短时多消息合并器;coordinator 决定是否旁路 - self.message_batcher = MessageBatcher( - config.message_batcher, - flush_callback=self.ai_coordinator.handle_batched_dispatch, - ) - self.ai_coordinator.set_batcher(self.message_batcher) - - self._background_tasks: set[asyncio.Task[None]] = set() - self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} - self._bot_nickname_cache = BotNicknameCache(onebot, config.bot_qq) - self.pipeline_registry = PipelineRegistry() - self._pipelines_initialized = False - self._pipelines_init_lock = asyncio.Lock() - - # 复读功能状态(按群跟踪最近消息文本与发送者) - self._repeat_counter: dict[int, list[tuple[str, int]]] = {} - self._repeat_locks: dict[int, asyncio.Lock] = {} - # 复读冷却:group_id → {normalized_text → monotonic_timestamp} - self._repeat_cooldown: dict[int, dict[str, float]] = {} - - # 启动队列 - self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) - - async def initialize(self) -> None: - """完成需要事件循环承载的异步初始化。""" - await self.init_pipelines() - - async def init_pipelines(self) -> None: - """异步加载自动处理管线并按配置启动热重载。""" - if getattr(self, "_pipelines_initialized", False): - return - init_lock = getattr(self, "_pipelines_init_lock", None) - if init_lock is None: - init_lock = asyncio.Lock() - self._pipelines_init_lock = init_lock - async with init_lock: - if getattr(self, "_pipelines_initialized", False): - return - await self.pipeline_registry.load_items_async() - self._pipelines_initialized = True - if getattr(self.config, "skills_hot_reload", False): - self.pipeline_registry.start_hot_reload( - interval=self.config.skills_hot_reload_interval, - debounce=self.config.skills_hot_reload_debounce, - ) - - def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: - """获取或创建指定群的复读竞态保护锁。""" - lock = self._repeat_locks.get(group_id) - if lock is None: - lock = asyncio.Lock() - self._repeat_locks[group_id] = lock - return lock - - @staticmethod - def _normalize_repeat_text(text: str) -> str: - """规范化复读文本用于冷却比较(?→?)。""" - return text.replace("?", "?") - - def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: - """检查指定群的文本是否在复读冷却期内。""" - cooldown_minutes = self.config.repeat_cooldown_minutes - if cooldown_minutes <= 0: - return False - group_cd = self._repeat_cooldown.get(group_id) - if not group_cd: - return False - key = self._normalize_repeat_text(text) - last_time = group_cd.get(key) - if last_time is None: - return False - return (time.monotonic() - last_time) < cooldown_minutes * 60 - - def _record_repeat_cooldown(self, group_id: int, text: str) -> None: - """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" - cooldown_seconds = self.config.repeat_cooldown_minutes * 60 - if cooldown_seconds <= 0: - return - key = self._normalize_repeat_text(text) - group_cd = self._repeat_cooldown.setdefault(group_id, {}) - now = time.monotonic() - # 清理已过期条目 - expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] - for k in expired: - # delete - del group_cd[k] - group_cd[key] = now - - async def _annotate_meme_descriptions( - self, - attachments: list[dict[str, str]], - scope_key: str, - ) -> list[dict[str, str]]: - """为图片附件添加表情包描述(如果在表情库中找到)。 - - 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 - 最佳努力:任何失败时返回原始列表。 - """ - if not attachments: - return attachments - - ai_client = getattr(self, "ai", None) - if ai_client is None: - return attachments - - attachment_registry = getattr(ai_client, "attachment_registry", None) - if attachment_registry is None: - return attachments - - meme_service = getattr(ai_client, "_meme_service", None) - if meme_service is None or not getattr(meme_service, "enabled", False): - return attachments - - meme_store = getattr(meme_service, "_store", None) - if meme_store is None: - return attachments - - try: - # 1. 从图片附件收集唯一的 SHA256 哈希值 - uid_to_hash: dict[str, str] = {} - for att in attachments: - uid = att.get("uid", "") - if not uid.startswith("pic_"): - continue - record = attachment_registry.resolve(uid, scope_key) - if record and record.sha256: - uid_to_hash[uid] = record.sha256 - - if not uid_to_hash: - return attachments - - # 2. 批量查询:去重哈希值 - unique_hashes = set(uid_to_hash.values()) - hash_to_desc: dict[str, str] = {} - for h in unique_hashes: - meme = await meme_store.find_by_sha256(h) - if meme and meme.description: - hash_to_desc[h] = meme.description - - if not hash_to_desc: - return attachments - - # 3. 构建带注释的新列表 - result: list[dict[str, str]] = [] - for att in attachments: - uid = att.get("uid", "") - sha = uid_to_hash.get(uid, "") - desc = hash_to_desc.get(sha, "") - if desc: - new_att = dict(att) - new_att["description"] = f"[表情包] {desc}" - result.append(new_att) - else: - result.append(att) - return result - except Exception: - logger.warning("表情包自动匹配失败,跳过", exc_info=True) - return attachments - - async def _collect_message_attachments( - self, - message_content: list[dict[str, Any]], - *, - group_id: int | None = None, - user_id: int | None = None, - request_type: str, - ) -> list[dict[str, str]]: - scope_key = build_attachment_scope( - group_id=group_id, - user_id=user_id, - request_type=request_type, - ) - if not scope_key: - return [] - ai_client = getattr(self, "ai", None) - attachment_registry = ( - getattr(ai_client, "attachment_registry", None) if ai_client else None - ) - if attachment_registry is None: - return [] - onebot = getattr(self, "onebot", None) - resolve_image_url = getattr(onebot, "get_image", None) if onebot else None - result = await register_message_attachments( - registry=attachment_registry, - segments=message_content, - scope_key=scope_key, - resolve_image_url=resolve_image_url, - get_forward_messages=getattr(onebot, "get_forward_msg", None) - if onebot - else None, - ) - attachments = result.attachments - # 为图片附件添加表情包描述 - attachments = await self._annotate_meme_descriptions(attachments, scope_key) - return attachments - - def _schedule_meme_ingest( - self, - *, - attachments: list[dict[str, str]], - chat_type: str, - chat_id: int, - sender_id: int, - message_id: int | None, - scope_key: str | None, - ) -> None: - if not attachments or not scope_key: - return - meme_service = getattr(self.ai, "_meme_service", None) - if meme_service is None or not getattr(meme_service, "enabled", False): - return - self._spawn_background_task( - f"meme_ingest:{chat_type}:{chat_id}:{sender_id}:{message_id or 0}", - meme_service.enqueue_incoming_attachments( - attachments=attachments, - chat_type=chat_type, - chat_id=chat_id, - sender_id=sender_id, - message_id=message_id, - scope_key=scope_key, - ), - ) - - async def _refresh_profile_display_names( - self, - *, - sender_id: int | None = None, - sender_name: str = "", - group_id: int | None = None, - group_name: str = "", - ) -> None: - ai_client = getattr(self, "ai", None) - cognitive_service = getattr(ai_client, "_cognitive_service", None) - if not cognitive_service or not getattr(cognitive_service, "enabled", False): - return - - if sender_id and sender_name.strip(): - await cognitive_service.sync_profile_display_name( - entity_type="user", - entity_id=str(sender_id), - preferred_name=sender_name.strip(), - ) - if group_id and group_name.strip(): - await cognitive_service.sync_profile_display_name( - entity_type="group", - entity_id=str(group_id), - preferred_name=group_name.strip(), - ) - - def _can_refresh_profile_display_names(self) -> bool: - ai_client = getattr(self, "ai", None) - cognitive_service = getattr(ai_client, "_cognitive_service", None) - return bool(cognitive_service and getattr(cognitive_service, "enabled", False)) - - def _schedule_profile_display_name_refresh( - self, - *, - task_name: str, - sender_id: int | None = None, - sender_name: str = "", - group_id: int | None = None, - group_name: str = "", - ) -> None: - if not self._can_refresh_profile_display_names(): - return - - cache = getattr(self, "_profile_name_refresh_cache", None) - if cache is None: - cache = {} - self._profile_name_refresh_cache = cache - - updates: dict[str, Any] = {} - rollback: list[tuple[tuple[str, int], str | None]] = [] - - normalized_sender_name = sender_name.strip() - if sender_id and normalized_sender_name: - sender_key = ("user", int(sender_id)) - previous = cache.get(sender_key) - if previous != normalized_sender_name: - cache[sender_key] = normalized_sender_name - rollback.append((sender_key, previous)) - updates["sender_id"] = sender_id - updates["sender_name"] = normalized_sender_name - - normalized_group_name = group_name.strip() - if group_id and normalized_group_name: - group_key = ("group", int(group_id)) - previous = cache.get(group_key) - if previous != normalized_group_name: - cache[group_key] = normalized_group_name - rollback.append((group_key, previous)) - updates["group_id"] = group_id - updates["group_name"] = normalized_group_name - - if not updates: - return - - async def _run_refresh() -> None: - try: - await self._refresh_profile_display_names(**updates) - except Exception: - for key, previous in rollback: - if previous is None: - cache.pop(key, None) - else: - cache[key] = previous - raise - - self._spawn_background_task(task_name, _run_refresh()) - - async def handle_message(self, event: dict[str, Any]) -> None: - """处理收到的消息事件""" - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[事件数据]", event) - post_type = event.get("post_type", "message") - - # 处理拍一拍事件(效果同被 @) - if post_type == "notice" and event.get("notice_type") == "poke": - target_id = event.get("target_id", 0) - # 只有拍机器人才响应 - if target_id != self.config.bot_qq: - logger.debug( - "[通知] 忽略拍一拍目标非机器人: target=%s", - target_id, - ) - return - - if not self.config.should_process_poke_message(): - logger.debug("[消息策略] 已关闭拍一拍处理,忽略此次 poke 事件") - return - - poke_group_id: int = event.get("group_id", 0) - poke_sender_id: int = event.get("user_id", 0) - - # 访问控制:命中群黑名单或不满足白名单限制时忽略 - if poke_group_id == 0: - if not self.config.is_private_allowed(poke_sender_id): - private_reason = ( - self.config.private_access_denied_reason(poke_sender_id) - or "unknown" - ) - logger.debug( - "[访问控制] 忽略私聊拍一拍: user=%s reason=%s (access enabled=%s)", - poke_sender_id, - private_reason, - self.config.access_control_enabled(), - ) - return - else: - if not self.config.is_group_allowed(poke_group_id): - group_reason = ( - self.config.group_access_denied_reason(poke_group_id) - or "unknown" - ) - logger.debug( - "[访问控制] 忽略群聊拍一拍: group=%s sender=%s reason=%s (access enabled=%s)", - poke_group_id, - poke_sender_id, - group_reason, - self.config.access_control_enabled(), - ) - return - - logger.info( - "[通知] 收到拍一拍: group=%s sender=%s", - poke_group_id, - poke_sender_id, - ) - logger.debug("[通知] 拍一拍事件数据: %s", str(event)[:200]) - - if poke_group_id == 0: - private_poke = await self._record_private_poke_history( - poke_sender_id, event - ) - logger.info("[通知] 私聊拍一拍,触发私聊回复") - await self.ai_coordinator.handle_private_reply( - poke_sender_id, - private_poke.poke_text, - [], - is_poke=True, - sender_name=private_poke.sender_name, - ) - else: - group_poke = await self._record_group_poke_history( - poke_group_id, - poke_sender_id, - event, - ) - logger.info( - "[通知] 群聊拍一拍,触发群聊回复: group=%s", - poke_group_id, - ) - await self.ai_coordinator.handle_auto_reply( - poke_group_id, - poke_sender_id, - group_poke.poke_text, - [], - is_poke=True, - sender_name=group_poke.sender_name, - group_name=group_poke.group_name, - sender_role=group_poke.sender_role, - sender_title=group_poke.sender_title, - sender_level=group_poke.sender_level, - ) - return - - # 处理私聊消息 - if event.get("message_type") == "private": - private_sender_id: int = get_message_sender_id(event) - private_message_content: list[dict[str, Any]] = get_message_content(event) - trigger_message_id = event.get("message_id") - - # 访问控制:命中黑/白名单规则时忽略(不入历史、不触发任何处理) - if not self.config.is_private_allowed(private_sender_id): - private_reason = ( - self.config.private_access_denied_reason(private_sender_id) - or "unknown" - ) - logger.debug( - "[访问控制] 忽略私聊消息: user=%s reason=%s (access enabled=%s)", - private_sender_id, - private_reason, - self.config.access_control_enabled(), - ) - return - - # 获取发送者昵称 - private_sender: dict[str, Any] = event.get("sender", {}) - private_sender_nickname: str = private_sender.get("nickname", "") - - # 获取私聊用户昵称 - user_name = private_sender_nickname - if not user_name: - try: - user_info = await self.onebot.get_stranger_info(private_sender_id) - if user_info: - user_name = user_info.get("nickname", "") - except Exception as exc: - logger.warning("获取用户昵称失败: %s", exc) - - text = extract_text(private_message_content, self.config.bot_qq) - # 并行执行附件收集和历史内容解析 - private_attachments, parsed_content_raw = await asyncio.gather( - self._collect_message_attachments( - private_message_content, - user_id=private_sender_id, - request_type="private", - ), - parse_message_content_for_history( - private_message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ), - ) - safe_text = redact_string(text) - logger.info( - "[私聊消息] 发送者=%s 昵称=%s 内容=%s", - private_sender_id, - user_name or private_sender_nickname, - safe_text[:100], - ) - resolved_private_name = (user_name or private_sender_nickname or "").strip() - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_private:{private_sender_id}", - sender_id=private_sender_id, - sender_name=resolved_private_name, - ) - - # 保存私聊消息到历史记录 - parsed_content = append_attachment_text( - parsed_content_raw, private_attachments - ) - safe_parsed = redact_string(parsed_content) - logger.debug( - "[历史记录] 保存私聊: user=%s content=%s...", - private_sender_id, - safe_parsed[:50], - ) - await self.history_manager.add_private_message( - user_id=private_sender_id, - text_content=parsed_content, - display_name=private_sender_nickname, - user_name=user_name, - message_id=trigger_message_id, - attachments=private_attachments, - ) - - # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 - if private_sender_id == self.config.bot_qq: - return - - self._schedule_meme_ingest( - attachments=private_attachments, - chat_type="private", - chat_id=private_sender_id, - sender_id=private_sender_id, - message_id=safe_int(trigger_message_id), - scope_key=build_attachment_scope( - user_id=private_sender_id, - request_type="private", - ), - ) - - if not self.config.should_process_private_message(): - logger.debug( - "[消息策略] 已关闭私聊处理: user=%s", - private_sender_id, - ) - return - - if ( - getattr(self.config, "model_pool_enabled", False) - and _is_private_model_pool_control_text(text) - ) and await self.ai_coordinator.model_pool.handle_private_message( - private_sender_id, - text, - ): - return - - private_command = self.command_dispatcher.parse_command(text) - if private_command: - await self._flush_command_buffer( - scope=make_scope(user_id=private_sender_id), - sender_id=private_sender_id, - ) - await self.command_dispatcher.dispatch_private( - user_id=private_sender_id, - sender_id=private_sender_id, - command=private_command, - ) - return - - await self._run_pipelines( - target_id=private_sender_id, - target_type="private", - text=text, - message_content=private_message_content, - ) - - await self.ai_coordinator.handle_private_reply( - private_sender_id, - text, - private_message_content, - attachments=private_attachments, - sender_name=user_name, - trigger_message_id=trigger_message_id, - ) - return - - # 只处理群消息 - if event.get("message_type") != "group": - return - - group_id: int = event.get("group_id", 0) - sender_id: int = get_message_sender_id(event) - message_content: list[dict[str, Any]] = get_message_content(event) - trigger_message_id = event.get("message_id") - - # 访问控制:命中黑/白名单规则时忽略(不入历史、不触发任何处理) - if not self.config.is_group_allowed(group_id): - group_reason = self.config.group_access_denied_reason(group_id) or "unknown" - logger.debug( - "[访问控制] 忽略群消息: group=%s sender=%s reason=%s (access enabled=%s)", - group_id, - sender_id, - group_reason, - self.config.access_control_enabled(), - ) - return - - # 获取发送者信息 - group_sender: dict[str, Any] = event.get("sender", {}) - sender_card: str = group_sender.get("card", "") - sender_nickname: str = group_sender.get("nickname", "") - sender_role: str = group_sender.get("role", "member") - sender_title: str = group_sender.get("title", "") - sender_level: str = str(group_sender.get("level", "")).strip() - - # 提取文本内容 - text = extract_text(message_content, self.config.bot_qq) - safe_text = redact_string(text) - logger.info( - f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " - f"role={sender_role} | {safe_text[:100]}" - ) - - # 并行执行 3 个独立的异步操作:附件收集、群信息获取、历史内容解析 - async def _fetch_group_name() -> str: - try: - info = await self.onebot.get_group_info(group_id) - if info: - return str(info.get("group_name", "") or "") - except Exception as e: - logger.warning(f"获取群聊名失败: {e}") - return "" - - group_attachments, group_name, parsed_content_raw = await asyncio.gather( - self._collect_message_attachments( - message_content, - group_id=group_id, - request_type="group", - ), - _fetch_group_name(), - parse_message_content_for_history( - message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ), - ) - - resolved_group_sender_name = (sender_card or sender_nickname or "").strip() - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", - sender_id=sender_id, - sender_name=resolved_group_sender_name, - group_id=group_id, - group_name=str(group_name or "").strip(), - ) - - # 保存消息到历史记录 - parsed_content = append_attachment_text(parsed_content_raw, group_attachments) - safe_parsed = redact_string(parsed_content) - logger.debug( - f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." - ) - await self.history_manager.add_group_message( - group_id=group_id, - sender_id=sender_id, - text_content=parsed_content, - sender_card=sender_card, - sender_nickname=sender_nickname, - group_name=group_name, - role=sender_role, - title=sender_title, - level=sender_level, - message_id=trigger_message_id, - attachments=group_attachments, - ) - - # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 - # 同时把 bot 自身的发言写入复读计数器,使窗口中留有 bot 标记, - # 后续触发检查时会排除含 bot 的窗口,防止"bot 先发 → 用户跟发"或 - # "用户发到一半 bot 插入"等情况误触复读。 - if sender_id == self.config.bot_qq: - if self.config.repeat_enabled and text: - async with self._get_repeat_lock(group_id): - counter = self._repeat_counter.setdefault(group_id, []) - counter.append((text, sender_id)) - n = self.config.repeat_threshold - if len(counter) > n: - self._repeat_counter[group_id] = counter[-n:] - return - - self._schedule_meme_ingest( - attachments=group_attachments, - chat_type="group", - chat_id=group_id, - sender_id=sender_id, - message_id=safe_int(trigger_message_id), - scope_key=build_attachment_scope(group_id=group_id, request_type="group"), - ) - - # 检查是否 @ 了机器人(后续分流共用) - is_at_bot = self.ai_coordinator._is_at_bot(message_content) - - # 假@检测:识别 "@昵称" 纯文本形式 - # normalized_text 用于命令解析和 AI 路由,原始 text 已用于历史/日志 - is_fake_at = False - normalized_text = text - if not is_at_bot and ("@" in text or "@" in text): - nicknames = await self._bot_nickname_cache.get_nicknames(group_id) - if nicknames: - is_fake_at, normalized_text = strip_fake_at(text, nicknames) - if is_fake_at: - is_at_bot = True - logger.info( - "[假@] 识别到假@: group=%s sender=%s", - group_id, - sender_id, - ) - - # 关闭“每条消息处理”后,仅处理 @ 消息(私聊/拍一拍在其他分支中处理) - if not self.config.should_process_group_message(is_at_bot=is_at_bot): - logger.debug( - "[消息策略] 跳过群消息处理: group=%s sender=%s process_every_message=%s at_bot=%s", - group_id, - sender_id, - self.config.process_every_message, - is_at_bot, - ) - return - - # 只有被@时才处理斜杠命令(使用 normalized_text 以支持假@后的命令)。 - # 命令优先于自动处理管线,命中后不触发后续自动提取或 AI 回复。 - if is_at_bot: - command = self.command_dispatcher.parse_command(normalized_text) - if command: - await self._flush_command_buffer( - scope=make_scope(group_id=group_id), - sender_id=sender_id, - ) - await self.command_dispatcher.dispatch(group_id, sender_id, command) - return - - # 关键词自动回复:心理委员 (使用原始消息内容提取文本,保证关键词触发不受影响) - if self.config.keyword_reply_enabled and matches_xinliweiyuan(text): - rand_val = random.random() - if rand_val < 0.01: # 1% 飞起来 - message = f"[@{sender_id}] 再发让你飞起来" - logger.info("关键词回复: 再发让你飞起来") - await self.sender.send_group_message( - group_id, - message, - history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, - ) - return - elif rand_val < 0.11: # 10% 发送图片 - try: - image_path = ( - resolve_resource_path("img/xlwy.jpg").resolve().as_uri() - ) - except Exception: - image_path = Path(os.path.abspath("img/xlwy.jpg")).as_uri() - message = f"[CQ:image,file={image_path}]" - # 50% 概率 @ 发送者 - if random.random() < 0.5: - message = f"[@{sender_id}] {message}" - logger.info("关键词回复: 发送图片 xlwy.jpg") - else: # 90% 原有逻辑 - if random.random() < 0.7: - reply = "受着" - else: - reply = "那咋了" - # 50% 概率 @ 发送者 - if random.random() < 0.5: - message = f"[@{sender_id}] {reply}" - else: - message = reply - logger.info(f"关键词回复: {reply}") - # 使用 sender 发送 - await self.sender.send_group_message( - group_id, - message, - history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, - ) - return - - # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold - if self.config.repeat_enabled and text: - n = self.config.repeat_threshold - async with self._get_repeat_lock(group_id): - counter = self._repeat_counter.setdefault(group_id, []) - counter.append((text, sender_id)) - # 只保留最近 n 条 - if len(counter) > n: - self._repeat_counter[group_id] = counter[-n:] - counter = self._repeat_counter[group_id] - - if len(counter) >= n: - last_n = counter[-n:] - texts = [t for t, _ in last_n] - senders = [s for _, s in last_n] - if ( - len(set(texts)) == 1 - and len(set(senders)) == n - and self.config.bot_qq not in senders - ): - reply_text = texts[0] - # 冷却检查:同一内容在冷却期内不再复读 - if self._is_repeat_on_cooldown(group_id, reply_text): - self._repeat_counter[group_id] = [] - logger.debug( - "[复读] 冷却中跳过: group=%s text=%s", - group_id, - redact_string(reply_text)[:50], - ) - else: - if self.config.inverted_question_enabled: - stripped = reply_text.strip() - if set(stripped) <= {"?", "?"}: - reply_text = "¿" * len(stripped) - # 清空计数器防止重复触发 - self._repeat_counter[group_id] = [] - self._record_repeat_cooldown(group_id, texts[0]) - logger.info( - "[复读] 触发复读: group=%s text=%s", - group_id, - redact_string(reply_text)[:50], - ) - await self.sender.send_group_message( - group_id, - reply_text, - history_prefix=REPEAT_REPLY_HISTORY_PREFIX, - ) - return - - await self._run_pipelines( - target_id=group_id, - target_type="group", - text=text, - message_content=message_content, - ) - - # 提取文本内容 - # (已在上方提取用于日志记录) - - # 自动回复处理(使用 normalized_text 以去除假@前缀) - display_name = sender_card or sender_nickname or str(sender_id) - await self.ai_coordinator.handle_auto_reply( - group_id, - sender_id, - normalized_text, - message_content, - attachments=group_attachments, - sender_name=display_name, - group_name=group_name, - sender_role=sender_role, - sender_title=sender_title, - sender_level=sender_level, - trigger_message_id=trigger_message_id, - is_fake_at=is_fake_at, - ) - - async def _record_private_poke_history( - self, user_id: int, event: dict[str, Any] - ) -> PrivatePokeRecord: - """记录私聊拍一拍到历史。""" - sender = event.get("sender", {}) - sender_nickname = "" - if isinstance(sender, dict): - sender_nickname = str(sender.get("nickname", "")).strip() - - user_name = sender_nickname - if not user_name: - try: - user_info = await self.onebot.get_stranger_info(user_id) - if isinstance(user_info, dict): - user_name = str(user_info.get("nickname", "")).strip() - except Exception as exc: - logger.warning( - "[通知] 获取私聊拍一拍用户昵称失败: user=%s err=%s", - user_id, - exc, - ) - - resolved_sender_name = (sender_nickname or user_name).strip() - display_name = resolved_sender_name or f"QQ{user_id}" - normalized_user_name = user_name or display_name - poke_text = _format_poke_history_text(display_name, user_id) - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_private_poke:{user_id}", - sender_id=user_id, - sender_name=resolved_sender_name, - ) - - try: - await self.history_manager.add_private_message( - user_id=user_id, - text_content=poke_text, - display_name=display_name, - user_name=normalized_user_name, - ) - except Exception as exc: - logger.warning( - "[历史记录] 写入私聊拍一拍失败: user=%s err=%s", - user_id, - exc, - ) - return PrivatePokeRecord(poke_text=poke_text, sender_name=display_name) - - async def _record_group_poke_history( - self, - group_id: int, - sender_id: int, - event: dict[str, Any], - ) -> GroupPokeRecord: - """记录群聊拍一拍到历史。""" - sender = event.get("sender", {}) - sender_card = "" - sender_nickname = "" - sender_role = "member" - sender_title = "" - sender_level = "" - if isinstance(sender, dict): - sender_card = str(sender.get("card", "")).strip() - sender_nickname = str(sender.get("nickname", "")).strip() - sender_role = str(sender.get("role", "member")).strip() or "member" - sender_title = str(sender.get("title", "")).strip() - sender_level = str(sender.get("level", "")).strip() - - if not sender_card and not sender_nickname: - try: - member_info = await self.onebot.get_group_member_info( - group_id, sender_id - ) - if isinstance(member_info, dict): - sender_card = str(member_info.get("card", "")).strip() - sender_nickname = str(member_info.get("nickname", "")).strip() - sender_role = ( - str(member_info.get("role", "member")).strip() or "member" - ) - sender_title = str(member_info.get("title", "")).strip() - sender_level = str(member_info.get("level", "")).strip() - except Exception as exc: - logger.warning( - "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", - group_id, - sender_id, - exc, - ) - - group_name = "" - try: - group_info = await self.onebot.get_group_info(group_id) - if isinstance(group_info, dict): - group_name = str(group_info.get("group_name", "")).strip() - except Exception as exc: - logger.warning( - "[通知] 获取拍一拍群名失败: group=%s err=%s", - group_id, - exc, - ) - - resolved_sender_name = (sender_card or sender_nickname).strip() - resolved_group_name = group_name.strip() - display_name = resolved_sender_name or f"QQ{sender_id}" - poke_text = _format_poke_history_text(display_name, sender_id) - normalized_group_name = resolved_group_name or f"群{group_id}" - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_group_poke:{group_id}:{sender_id}", - sender_id=sender_id, - sender_name=resolved_sender_name, - group_id=group_id, - group_name=resolved_group_name, - ) - - try: - await self.history_manager.add_group_message( - group_id=group_id, - sender_id=sender_id, - text_content=poke_text, - sender_card=sender_card, - sender_nickname=sender_nickname, - group_name=normalized_group_name, - role=sender_role, - title=sender_title, - level=sender_level, - ) - except Exception as exc: - logger.warning( - "[历史记录] 写入群聊拍一拍失败: group=%s sender=%s err=%s", - group_id, - sender_id, - exc, - ) - return GroupPokeRecord( - poke_text=poke_text, - sender_name=display_name, - group_name=normalized_group_name, - sender_role=sender_role, - sender_title=sender_title, - sender_level=sender_level, - ) - - async def _extract_bilibili_ids( - self, text: str, message_content: list[dict[str, Any]] - ) -> list[str]: - """从文本和消息段中提取 B 站视频 BV 号。""" - from Undefined.bilibili.parser import ( - extract_bilibili_ids, - extract_from_json_message, - ) - - bvids = await extract_bilibili_ids(text) - if not bvids: - bvids = await extract_from_json_message(message_content) - return bvids - - async def _flush_command_buffer(self, *, scope: str, sender_id: int) -> None: - batcher_config = getattr(self.config, "message_batcher", None) - if not getattr(batcher_config, "flush_on_command", False): - return - batcher = getattr(self, "message_batcher", None) - if batcher is None: - return - flushed = await batcher.flush_sender(scope, sender_id) - if not flushed: - logger.warning( - "[MessageBatcher] 命令触发 flush 当前 buffer 失败: scope=%s sender=%s", - scope, - sender_id, - ) - - async def _run_pipelines( - self, - *, - target_id: int, - target_type: Literal["group", "private"], - text: str, - message_content: list[dict[str, Any]], - ) -> bool: - """并行检测并处理所有命中的自动处理管线。""" - if not getattr(self, "_pipelines_initialized", False): - await self.init_pipelines() - context = build_pipeline_context( - self, - target_id=target_id, - target_type=target_type, - text=text, - message_content=message_content, - ) - detections = await self.pipeline_registry.run(context) - return bool(detections) - - async def apply_skills_hot_reload_config( - self, - *, - enabled: bool, - interval: float, - debounce: float, - ) -> None: - """跟随全局 skills 热重载配置更新管线。""" - if not enabled: - await self.pipeline_registry.stop_hot_reload() - logger.info("[pipelines] 热重载已随配置禁用") - return - - await self.pipeline_registry.stop_hot_reload() - self.pipeline_registry.start_hot_reload( - interval=interval, - debounce=debounce, - ) - - def _extract_arxiv_ids( - self, text: str, message_content: list[dict[str, Any]] - ) -> list[str]: - """从文本和消息段中提取 arXiv 论文 ID。""" - from Undefined.arxiv.parser import extract_arxiv_ids, extract_from_json_message - - paper_ids: list[str] = [] - seen: set[str] = set() - - for paper_id in extract_arxiv_ids(text): - if paper_id in seen: - continue - seen.add(paper_id) - paper_ids.append(paper_id) - - for paper_id in extract_from_json_message(message_content): - if paper_id in seen: - continue - seen.add(paper_id) - paper_ids.append(paper_id) - - return paper_ids - - def _extract_github_repo_ids( - self, text: str, message_content: list[dict[str, Any]] - ) -> list[str]: - """从文本和消息段中提取 GitHub 仓库 ID。""" - from Undefined.github.parser import ( - extract_from_json_message, - extract_github_repo_ids, - ) - - repo_ids: list[str] = [] - seen: set[str] = set() - - for repo_id in extract_github_repo_ids(text): - key = repo_id.lower() - if key in seen: - continue - seen.add(key) - repo_ids.append(repo_id) - - for repo_id in extract_from_json_message(message_content): - key = repo_id.lower() - if key in seen: - continue - seen.add(key) - repo_ids.append(repo_id) - - return repo_ids - - async def _handle_bilibili_extract( - self, - target_id: int, - bvids: list[str], - target_type: str, - ) -> None: - """处理 bilibili 视频自动提取和发送。""" - from Undefined.bilibili.sender import send_bilibili_video - - for bvid in bvids[:3]: # 最多同时处理 3 个 - try: - await send_bilibili_video( - video_id=bvid, - sender=self.sender, - onebot=self.onebot, - target_type=target_type, # type: ignore[arg-type] - target_id=target_id, - cookie=self.config.bilibili_cookie, - prefer_quality=self.config.bilibili_prefer_quality, - max_duration=self.config.bilibili_max_duration, - max_file_size=self.config.bilibili_max_file_size, - oversize_strategy=self.config.bilibili_oversize_strategy, - danmaku_enabled=self.config.bilibili_danmaku_enabled, - danmaku_batch_size=self.config.bilibili_danmaku_batch_size, - danmaku_max_count=self.config.bilibili_danmaku_max_count, - ) - except Exception as exc: - logger.error( - "[Bilibili] 自动提取失败 %s → %s:%s: %s", - bvid, - target_type, - target_id, - exc, - ) - try: - error_msg = f"视频提取失败: {exc}" - if target_type == "group": - await self.sender.send_group_message(target_id, error_msg) - else: - await self.sender.send_private_message(target_id, error_msg) - except Exception: - pass - - async def _handle_arxiv_extract( - self, - target_id: int, - paper_ids: list[str], - target_type: str, - ) -> None: - """处理 arXiv 论文自动提取和发送。""" - from Undefined.arxiv.sender import send_arxiv_paper - - max_items = max(1, int(self.config.arxiv_auto_extract_max_items)) - - for paper_id in paper_ids[:max_items]: - try: - result = await send_arxiv_paper( - paper_id=paper_id, - sender=self.sender, - target_type=target_type, # type: ignore[arg-type] - target_id=target_id, - max_file_size=self.config.arxiv_max_file_size, - author_preview_limit=self.config.arxiv_author_preview_limit, - summary_preview_chars=self.config.arxiv_summary_preview_chars, - context={ - "request_id": ( - f"arxiv_auto_extract:{target_type}:{target_id}:{paper_id}" - ) - }, - ) - logger.info( - "[arXiv] 自动提取完成 %s → %s:%s: %s", - paper_id, - target_type, - target_id, - result, - ) - except Exception: - logger.exception( - "[arXiv] 自动提取失败 %s → %s:%s", - paper_id, - target_type, - target_id, - ) - - async def _handle_github_extract( - self, - target_id: int, - repo_ids: list[str], - target_type: str, - ) -> None: - """处理 GitHub 仓库自动提取和发送。""" - from Undefined.github.sender import send_github_repo_card - - max_items = max( - 1, int(getattr(self.config, "github_auto_extract_max_items", 3)) - ) - request_timeout = float( - getattr(self.config, "github_request_timeout_seconds", 10.0) - ) - - for repo_id in repo_ids[:max_items]: - try: - result = await send_github_repo_card( - repo_id=repo_id, - sender=self.sender, - target_type=target_type, # type: ignore[arg-type] - target_id=target_id, - request_timeout=request_timeout, - context={ - "request_id": ( - f"github_auto_extract:{target_type}:{target_id}:{repo_id}" - ) - }, - ) - logger.info( - "[GitHub] 自动提取完成 %s → %s:%s: %s", - repo_id, - target_type, - target_id, - result, - ) - except Exception as exc: - logger.info( - "[GitHub] 自动提取跳过 %s → %s:%s: %s", - repo_id, - target_type, - target_id, - exc, - ) - - def _spawn_background_task( - self, - name: str, - coroutine: Coroutine[Any, Any, None], - ) -> None: - task = asyncio.create_task(coroutine, name=name) - self._background_tasks.add(task) - - def _finalize(done_task: asyncio.Task[None]) -> None: - self._background_tasks.discard(done_task) - try: - exc = done_task.exception() - except asyncio.CancelledError: - logger.debug("[后台任务] 已取消: %s", name) - return - if exc is not None: - logger.exception( - "[后台任务] 执行失败: name=%s", - name, - exc_info=(type(exc), exc, exc.__traceback__), - ) - - task.add_done_callback(_finalize) - - async def close(self) -> None: - """关闭消息处理器""" - logger.info("正在关闭消息处理器...") - if self._background_tasks: - logger.info( - "[后台任务] 等待自动提取任务收敛: count=%s", - len(self._background_tasks), - ) - await asyncio.gather( - *list(self._background_tasks), - return_exceptions=True, - ) - await self.pipeline_registry.stop_hot_reload() - await self.message_batcher.flush_all() - await self.ai_coordinator.queue_manager.drain() - await self.ai_coordinator.queue_manager.stop() - await self.history_manager.flush_pending_saves() - logger.info("消息处理器已关闭") diff --git a/src/Undefined/memes/_service.py b/src/Undefined/memes/_service.py deleted file mode 100644 index 33a720b2..00000000 --- a/src/Undefined/memes/_service.py +++ /dev/null @@ -1,153 +0,0 @@ -"""MemeService 门面类。""" - -from __future__ import annotations - -import asyncio -import logging -from pathlib import Path -import threading -from typing import Any - - -from Undefined.attachments import AttachmentRecord -from Undefined.memes._image_utils import ( - _normalize_tags, - _now_iso, -) -from Undefined.memes.models import ( - IngestDigestLockEntry, - build_search_text, -) -from Undefined.memes.store import MemeStore -from Undefined.memes.vector_store import MemeVectorStore -from Undefined.utils.paths import ensure_dir -from Undefined.memes.ingest import MemeIngestMixin -from Undefined.memes.search import MemeSearchMixin - -logger = logging.getLogger(__name__) - - -class MemeService(MemeSearchMixin, MemeIngestMixin): - def __init__( - self, - *, - config_getter: Any, - store: MemeStore, - vector_store: MemeVectorStore, - job_queue: Any | None = None, - ai_client: Any | None = None, - attachment_registry: Any | None = None, - retrieval_runtime: Any | None = None, - ) -> None: - self._config_getter = config_getter - self._store = store - self._vector_store = vector_store - self._job_queue = job_queue - self._ai_client = ai_client - self._attachment_registry = attachment_registry - self._retrieval_runtime = retrieval_runtime - # 同内容 digest 锁:进程内串行入库,防止重复 AI 分析 - self._ingest_digest_locks: dict[str, IngestDigestLockEntry] = {} - self._ingest_digest_locks_guard = asyncio.Lock() - self._global_image_cache: dict[str, AttachmentRecord] = {} - self._global_image_cache_lock = threading.Lock() - - def enabled(self) -> bool: - cfg = self._config_getter() - return bool(getattr(cfg, "enabled", False)) - - def default_query_mode(self) -> str: - mode = ( - str( - getattr(self._config_getter(), "query_default_mode", "hybrid") - or "hybrid" - ) - .strip() - .lower() - ) - return mode if mode in {"keyword", "semantic", "hybrid"} else "hybrid" - - def _cfg(self) -> Any: - return self._config_getter() - - def _blob_dir(self) -> Path: - return ensure_dir(Path(self._cfg().blob_dir)) - - def _preview_dir(self) -> Path: - return ensure_dir(Path(self._cfg().preview_dir)) - - def _queue_enabled(self) -> bool: - return self._job_queue is not None - - def _invalidate_global_image_cache(self, uid: str) -> None: - normalized_uid = str(uid or "").strip() - if not normalized_uid: - return - with self._global_image_cache_lock: - self._global_image_cache.pop(normalized_uid, None) - - async def update_meme( - self, - uid: str, - *, - manual_description: str | None = None, - tags: list[str] | str | None = None, - aliases: list[str] | str | None = None, - enabled: bool | None = None, - pinned: bool | None = None, - ) -> dict[str, Any] | None: - record = await self._store.get(uid) - if record is None: - return None - - next_tags = list(record.tags) if tags is None else _normalize_tags(tags) - next_aliases = ( - list(record.aliases) if aliases is None else _normalize_tags(aliases) - ) - next_manual = ( - record.manual_description - if manual_description is None - else str(manual_description or "").strip() - ) - next_enabled = record.enabled if enabled is None else bool(enabled) - next_pinned = record.pinned if pinned is None else bool(pinned) - next_search_text = build_search_text( - manual_description=next_manual, - auto_description=record.auto_description, - ocr_text="", - tags=next_tags, - aliases=next_aliases, - ) - - updated = await self._store.update_fields( - uid, - { - "manual_description": next_manual, - "tags_json": next_tags, - "aliases_json": next_aliases, - "enabled": next_enabled, - "pinned": next_pinned, - "search_text": next_search_text, - "updated_at": _now_iso(), - }, - ) - if updated is None: - return None - self._invalidate_global_image_cache(uid) - await self._vector_store.upsert(updated) - return self.serialize_record(updated) - - async def delete_meme(self, uid: str) -> bool: - record = await self._store.delete(uid) - if record is None: - return False - self._invalidate_global_image_cache(uid) - await self._vector_store.delete(uid) - await asyncio.to_thread(self._delete_file_if_exists, Path(record.blob_path)) - if record.preview_path and record.preview_path != record.blob_path: - await asyncio.to_thread( - self._delete_file_if_exists, - Path(record.preview_path), - ) - await asyncio.to_thread(self._cleanup_gif_frame_files, uid) - return True diff --git a/src/Undefined/onebot.py b/src/Undefined/onebot.py deleted file mode 100644 index 9a53d885..00000000 --- a/src/Undefined/onebot.py +++ /dev/null @@ -1,924 +0,0 @@ -"""OneBot WebSocket 客户端""" - -import asyncio -import json -import logging -import time -from typing import Any, Callable, Coroutine -from datetime import datetime - -import websockets -from websockets.asyncio.client import ClientConnection - -from Undefined.context import RequestContext -from Undefined.utils.logging import log_debug_json, redact_string, sanitize_data - -logger = logging.getLogger(__name__) - - -def _mark_message_sent_this_turn() -> None: - ctx = RequestContext.current() - if ctx is None: - return - ctx.set_resource("message_sent_this_turn", True) - - -class OneBotClient: - """OneBot v11 WebSocket 客户端""" - - def __init__(self, ws_url: str, token: str = ""): - self.ws_url = ws_url - self.token = token - self.ws: ClientConnection | None = None - self._message_id = 0 - self._pending_responses: dict[str, asyncio.Future[dict[str, Any]]] = {} - self._message_handler: ( - Callable[[dict[str, Any]], Coroutine[Any, Any, None]] | None - ) = None - self._running = False - - def set_message_handler( - self, handler: Callable[[dict[str, Any]], Coroutine[Any, Any, None]] - ) -> None: - """设置消息处理器""" - self._message_handler = handler - - def connection_status(self) -> dict[str, Any]: - """返回连接状态快照。""" - ws = self.ws - ws_exists = ws is not None - # websockets v13+ ClientConnection 没有 .closed 属性, - # 用 close_code 判断:连接关闭后 close_code 为 int,活跃时为 None - ws_closed = (ws.close_code is not None) if ws is not None else True - connected = ws_exists and (not ws_closed) and self._running - return { - "connected": connected, - "running": self._running, - "ws_exists": ws_exists, - "ws_closed": ws_closed, - "ws_url": self.ws_url, - } - - async def connect(self) -> None: - """连接到 OneBot WebSocket""" - # 构建带 token 的 URL - url = self.ws_url - if self.token: - separator = "&" if "?" in url else "?" - url = f"{url}{separator}access_token={self.token}" - - safe_ws_url = redact_string(self.ws_url) - logger.info( - f"[bold cyan][WebSocket][/bold cyan] 正在连接到 [blue]{safe_ws_url}[/blue]..." - ) - - # 同时在请求头中传递 token(兼容不同实现) - extra_headers = {} - if self.token: - extra_headers["Authorization"] = f"Bearer {self.token}" - - try: - self.ws = await websockets.connect( - url, - ping_interval=20, - ping_timeout=480, - max_size=100 * 1024 * 1024, # 100MB,支持大量历史消息 - additional_headers=extra_headers if extra_headers else None, - ) - logger.info("[bold green][WebSocket][/bold green] 连接成功") - except Exception as e: - logger.error(f"[WebSocket] 连接失败: {e}") - raise - - async def disconnect(self) -> None: - """断开连接""" - self._running = False - if self.ws: - logger.info("[WebSocket] 正在主动断开连接...") - await self.ws.close() - self.ws = None - logger.info("[WebSocket] 连接已断开") - - async def _call_api( - self, - action: str, - params: dict[str, Any] | None = None, - *, - suppress_error_retcodes: set[int] | None = None, - ) -> dict[str, Any]: - """调用 OneBot API""" - if not self.ws: - raise RuntimeError("WebSocket 未连接") - - self._message_id += 1 - echo = str(self._message_id) # 使用字符串类型 - - request = { - "action": action, - "params": params or {}, - "echo": echo, - } - - safe_params = sanitize_data(params or {}) - logger.debug( - f"[bold yellow][API请求][/bold yellow] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 参数: {safe_params}" - ) - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[OneBot请求体]", request) - - # 创建 Future 等待响应 - future: asyncio.Future[dict[str, Any]] = asyncio.Future() - self._pending_responses[echo] = future - - start_time = time.perf_counter() - - try: - await self.ws.send(json.dumps(request)) - # 等待响应,超时 8 分钟 - response = await asyncio.wait_for(future, timeout=480.0) - duration = time.perf_counter() - start_time - - # 检查响应状态 - status = response.get("status") - if status == "failed": - retcode = response.get("retcode", -1) - msg = response.get("message", "未知错误") - if suppress_error_retcodes and retcode in suppress_error_retcodes: - logger.warning( - f"[bold yellow][API预期失败][/bold yellow] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 耗时=[magenta]{duration:.2f}s[/magenta] | retcode=[yellow]{retcode}[/yellow] | message={msg}" - ) - else: - logger.error( - f"[bold red][API失败][/bold red] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 耗时=[magenta]{duration:.2f}s[/magenta] | retcode=[red]{retcode}[/red] | message={msg}" - ) - raise RuntimeError(f"API 调用失败: {msg} (retcode={retcode})") - - logger.info( - f"[bold green][API成功][/bold green] [green]{action}[/green] (ID=[magenta]{echo}[/magenta]) | 耗时=[magenta]{duration:.2f}s[/magenta]" - ) - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[OneBot响应体]", response) - return response - except asyncio.TimeoutError: - duration = time.perf_counter() - start_time - logger.error(f"[API超时] {action} (ID={echo}) | 耗时={duration:.2f}s") - raise - finally: - self._pending_responses.pop(echo, None) - - async def send_group_message( - self, - group_id: int, - message: str | list[dict[str, Any]], - *, - mark_sent: bool = True, - ) -> dict[str, Any]: - """发送群消息""" - result = await self._call_api( - "send_group_msg", - { - "group_id": group_id, - "message": message, - }, - ) - if mark_sent: - _mark_message_sent_this_turn() - return result - - async def send_private_message( - self, - user_id: int, - message: str | list[dict[str, Any]], - *, - group_id: int | None = None, - mark_sent: bool = True, - ) -> dict[str, Any]: - """发送私聊消息 - - 参数: - user_id: 用户 QQ 号 - message: 消息内容 - group_id: 共享群号;传入时通过该群的临时会话发送 - mark_sent: 是否标记本轮已发送(用于 end 工具判定) - """ - params: dict[str, Any] = { - "user_id": user_id, - "message": message, - } - if group_id is not None: - params["group_id"] = group_id - - result = await self._call_api( - "send_private_msg", - params, - ) - if mark_sent: - _mark_message_sent_this_turn() - return result - - async def get_group_msg_history( - self, - group_id: int, - message_seq: int | None = None, - count: int = 500, - ) -> list[dict[str, Any]]: - """获取群消息历史 - - 参数: - group_id: 群号 - message_seq: 起始消息序号,None 表示从最新消息开始 - count: 获取的消息数量 - - 返回: - 消息列表 - """ - params: dict[str, Any] = { - "group_id": group_id, - "count": count, - } - if message_seq is not None: - params["message_seq"] = message_seq - - result = await self._call_api("get_group_msg_history", params) - - # 安全获取消息列表 - if result is None: - logger.warning("get_group_msg_history 返回 None") - return [] - - data = result.get("data") - if data is None: - logger.warning(f"get_group_msg_history 响应无 data 字段: {result}") - return [] - - messages: list[dict[str, Any]] = data.get("messages", []) - logger.debug(f"获取到 {len(messages)} 条历史消息") - return messages - - async def get_image(self, file: str) -> str: - """获取图片信息 - - 参数: - file: 图片文件名或 URL - - 返回: - 图片的本地路径或 URL - """ - result = await self._call_api("get_image", {"file": file}) - data: dict[str, str] = result.get("data", {}) - url: str = data.get("url", "") or data.get("file", "") - return url - - async def get_group_info(self, group_id: int) -> dict[str, Any] | None: - """获取群信息 - - 参数: - group_id: 群号 - - 返回: - 群信息字典,包含 group_name 等字段 - """ - try: - result = await self._call_api("get_group_info", {"group_id": group_id}) - data: dict[str, Any] = result.get("data", {}) - return data - except Exception as e: - logger.error(f"获取群信息失败: {e}") - return None - - async def get_stranger_info(self, user_id: int) -> dict[str, Any] | None: - """获取陌生人信息 - - 参数: - user_id: 用户QQ号 - - 返回: - 用户信息字典,包含 nickname 等字段 - """ - try: - result = await self._call_api("get_stranger_info", {"user_id": user_id}) - data: dict[str, Any] = result.get("data", {}) - return data - except Exception as e: - logger.error(f"获取陌生人信息失败: {e}") - return None - - async def get_group_member_info( - self, group_id: int, user_id: int, no_cache: bool = False - ) -> dict[str, Any] | None: - """获取群成员信息 - - 参数: - group_id: 群号 - user_id: 群成员QQ号 - no_cache: 是否不使用缓存(默认 false) - - 返回: - 群成员信息字典,包含群昵称、QQ昵称、加群时间、等级、最后发言时间等字段 - """ - try: - result = await self._call_api( - "get_group_member_info", - {"group_id": group_id, "user_id": user_id, "no_cache": no_cache}, - ) - data: dict[str, Any] = result.get("data", {}) - return data - except Exception as e: - logger.error(f"获取群成员信息失败: {e}") - return None - - async def get_group_member_list(self, group_id: int) -> list[dict[str, Any]]: - """获取群成员列表 - - 参数: - group_id: 群号 - - 返回: - 群成员信息列表 - """ - try: - result = await self._call_api( - "get_group_member_list", {"group_id": group_id} - ) - data: list[dict[str, Any]] = result.get("data", []) - return data - except Exception as e: - logger.error(f"获取群成员列表失败: {e}") - return [] - - async def get_friend_list(self) -> list[dict[str, Any]]: - """获取好友列表 - - 返回: - 好友信息列表,每个好友包含: - - user_id: QQ号 - - nickname: QQ昵称 - - remark: 备注名 - """ - try: - result = await self._call_api("get_friend_list") - data: list[dict[str, Any]] = result.get("data", []) - return data - except Exception as e: - logger.error(f"获取好友列表失败: {e}") - return [] - - async def get_group_list(self) -> list[dict[str, Any]]: - """获取群列表 - - 返回: - 群信息列表,每个群包含: - - group_id: 群号 - - group_name: 群名称 - - member_count: 成员数 - - max_member_count: 最大成员数 - """ - try: - result = await self._call_api("get_group_list") - data: list[dict[str, Any]] = result.get("data", []) - return data - except Exception as e: - logger.error(f"获取群列表失败: {e}") - return [] - - async def get_forward_msg(self, id: str) -> list[dict[str, Any]]: - """获取合并转发消息详情 - - 参数: - id: 合并转发 ID - - 返回: - 消息节点列表 - """ - try: - result = await self._call_api( - "get_forward_msg", - {"message_id": id}, - suppress_error_retcodes={1200}, - ) - data = result.get("data", {}) - # data 可能是字典(包含 messages)或列表(直接是 nodes) - if isinstance(data, dict): - messages: list[dict[str, Any]] = data.get("messages", []) - return messages - elif isinstance(data, list): - nodes: list[dict[str, Any]] = data - return nodes - return [] - except Exception as e: - error_text = str(e) - if "retcode=1200" in error_text: - logger.debug( - "合并转发消息不可获取(可能过期或内层): id=%s err=%s", id, e - ) - return [] - logger.error(f"获取合并转发消息失败: {e}") - return [] - - async def get_msg(self, message_id: int) -> dict[str, Any] | None: - """获取单条消息详情 - - 参数: - message_id: 消息 ID - - 返回: - 消息详情字典 - """ - try: - result = await self._call_api("get_msg", {"message_id": message_id}) - return result.get("data") - except Exception as e: - logger.error(f"获取消息详情失败: {e}") - return None - - async def send_forward_msg( - self, group_id: int, messages: list[dict[str, Any]] - ) -> dict[str, Any]: - """发送合并转发消息到群聊 - - 参数: - group_id: 群号 - messages: 消息节点列表,每个节点格式为: - { - "type": "node", - "data": { - "name": "发送者昵称", - "uin": "发送者QQ号", - "content": "消息内容(字符串或消息段数组)", - "time": "时间戳(可选)" - } - } - - 返回: - API 响应 - """ - return await self._call_api( - "send_forward_msg", {"group_id": group_id, "messages": messages} - ) - - async def send_private_forward_msg( - self, user_id: int, messages: list[dict[str, Any]] - ) -> dict[str, Any]: - """发送合并转发消息到私聊。""" - return await self._call_api( - "send_private_forward_msg", - {"user_id": user_id, "messages": messages}, - ) - - async def send_like(self, user_id: int, times: int = 1) -> dict[str, Any]: - """给用户点赞 - - 参数: - user_id: 对方 QQ 号 - times: 赞的次数(默认1次) - - 返回: - API 响应 - """ - return await self._call_api("send_like", {"user_id": user_id, "times": times}) - - async def fetch_emoji_like(self, message_id: int) -> dict[str, Any] | list[Any]: - """获取消息已设置的表情反应信息(扩展接口)。 - - 参数: - message_id: 消息 ID - - 返回: - data 字段内容(字典或列表),异常时抛出 RuntimeError - """ - result = await self._call_api("fetch_emoji_like", {"message_id": message_id}) - data = result.get("data") - if isinstance(data, (dict, list)): - return data - return {} - - async def set_msg_emoji_like( - self, - message_id: int, - emoji_id: int, - *, - set_like: bool = True, - mark_sent: bool = True, - ) -> dict[str, Any]: - """给指定消息添加/取消表情反应(扩展接口)。 - - 参数: - message_id: 目标消息 ID - emoji_id: 表情 ID - set_like: True=添加反应,False=取消反应 - mark_sent: 是否标记本轮已发送(用于 end 工具判定) - - 返回: - API 响应 - """ - if set_like: - try: - result = await self._call_api( - "set_msg_emoji_like", - {"message_id": message_id, "emoji_id": emoji_id}, - ) - except RuntimeError: - logger.warning( - "[消息表情] set_msg_emoji_like 默认参数失败,尝试 set=true 回退: msg=%s emoji=%s", - message_id, - emoji_id, - ) - result = await self._call_api( - "set_msg_emoji_like", - {"message_id": message_id, "emoji_id": emoji_id, "set": True}, - ) - else: - # 取消反应可能依赖实现方扩展参数,默认采用 set=false。 - result = await self._call_api( - "set_msg_emoji_like", - {"message_id": message_id, "emoji_id": emoji_id, "set": False}, - ) - - if mark_sent: - _mark_message_sent_this_turn() - return result - - async def send_group_poke( - self, - group_id: int, - user_id: int, - *, - mark_sent: bool = True, - ) -> dict[str, Any]: - """在群聊中拍一拍指定成员。 - - 参数: - group_id: 群号 - user_id: 被拍一拍的用户 QQ 号 - mark_sent: 是否标记本轮已发送(用于 end 工具判定) - - 返回: - API 响应 - """ - try: - result = await self._call_api( - "group_poke", {"group_id": group_id, "user_id": user_id} - ) - except RuntimeError: - logger.warning( - "[拍一拍] group_poke 失败,尝试 send_poke 回退: group=%s user=%s", - group_id, - user_id, - ) - result = await self._call_api( - "send_poke", - { - "group_id": group_id, - "user_id": user_id, - "target_id": user_id, - }, - ) - - if mark_sent: - _mark_message_sent_this_turn() - return result - - async def send_private_poke( - self, - user_id: int, - *, - mark_sent: bool = True, - ) -> dict[str, Any]: - """在私聊中拍一拍指定用户。 - - 参数: - user_id: 被拍一拍的用户 QQ 号 - mark_sent: 是否标记本轮已发送(用于 end 工具判定) - - 返回: - API 响应 - """ - try: - result = await self._call_api("friend_poke", {"user_id": user_id}) - except RuntimeError: - logger.warning( - "[拍一拍] friend_poke 失败,尝试 send_poke 回退: user=%s", - user_id, - ) - result = await self._call_api( - "send_poke", - { - "user_id": user_id, - "target_id": user_id, - }, - ) - - if mark_sent: - _mark_message_sent_this_turn() - return result - - async def upload_group_file( - self, - group_id: int, - file_path: str, - name: str | None = None, - ) -> dict[str, Any]: - """上传文件到群聊 - - 参数: - group_id: 群号 - file_path: 本地文件绝对路径 - name: 文件名(可选,默认使用原文件名) - """ - from pathlib import Path as _Path - - file_name = name or _Path(file_path).name - file_uri = _Path(file_path).resolve().as_uri() - try: - return await self._call_api( - "upload_group_file", - { - "group_id": group_id, - "file": file_uri, - "name": file_name, - }, - ) - except RuntimeError: - # 回退:尝试用文件消息段发送 - logger.warning( - "[文件上传] upload_group_file 失败,尝试文件消息段回退: group=%s", - group_id, - ) - return await self.send_group_message( - group_id, - [ - { - "type": "file", - "data": {"file": file_uri, "name": file_name}, - } - ], - ) - - async def upload_private_file( - self, - user_id: int, - file_path: str, - name: str | None = None, - ) -> dict[str, Any]: - """上传文件到私聊 - - 参数: - user_id: 用户 QQ 号 - file_path: 本地文件绝对路径 - name: 文件名(可选,默认使用原文件名) - """ - from pathlib import Path as _Path - - file_name = name or _Path(file_path).name - file_uri = _Path(file_path).resolve().as_uri() - try: - return await self._call_api( - "upload_private_file", - { - "user_id": user_id, - "file": file_uri, - "name": file_name, - }, - ) - except RuntimeError: - logger.warning( - "[文件上传] upload_private_file 失败,尝试文件消息段回退: user=%s", - user_id, - ) - return await self.send_private_message( - user_id, - [ - { - "type": "file", - "data": {"file": file_uri, "name": file_name}, - } - ], - ) - - async def send_group_sign(self, group_id: int) -> dict[str, Any]: - """执行群打卡 - - 参数: - group_id: 群号 - - 返回: - API 响应 - """ - return await self._call_api("send_group_sign", {"group_id": group_id}) - - async def _get_group_notices(self, group_id: int) -> list[dict[str, Any]]: - """获取群公告列表(非标准 API,依赖具体实现) - - 参数: - group_id: 群号 - - 返回: - 公告列表 - """ - try: - result = await self._call_api("_get_group_notice", {"group_id": group_id}) - data = result.get("data") - if isinstance(data, list): - return data - elif isinstance(data, dict): - # 尝试获取常见的列表字段 - notices = data.get("notices") - if notices is None: - notices = data.get("list") - if isinstance(notices, list): - return notices - return [] - except Exception as e: - logger.error(f"获取群公告失败: {e}") - return [] - - async def run(self) -> None: - """运行消息接收循环""" - if not self.ws: - raise RuntimeError("WebSocket 未连接") - - self._running = True - self._tasks: set[asyncio.Task[None]] = set() - logger.info("[WebSocket] 消息接收循环已启动") - - try: - while self._running: - raw_message = "" - try: - message_data = await self.ws.recv() - raw_message = ( - message_data.decode("utf-8") - if isinstance(message_data, bytes) - else message_data - ) - data = json.loads(raw_message) - # 处理消息(不阻塞接收循环) - await self._dispatch_message(data) - except json.JSONDecodeError as e: - logger.error( - f"[WebSocket] 无法解析 JSON 消息: {raw_message!r}, 错误: {e}" - ) - except websockets.ConnectionClosed: - logger.warning("[WebSocket] 连接已关闭,接收循环结束") - break - except Exception as e: - logger.exception(f"[WebSocket] 接收消息时发生异常: {e}") - finally: - self._running = False - # 等待所有后台任务完成 - if self._tasks: - logger.debug( - f"[WebSocket] 正在等待 {len(self._tasks)} 个异步任务完成..." - ) - await asyncio.gather(*self._tasks, return_exceptions=True) - logger.info("[WebSocket] 接收循环已停止") - - async def _dispatch_message(self, data: dict[str, Any]) -> None: - """分发消息(API响应同步处理,事件异步处理)""" - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[WebSocket消息]", data) - # 检查是否是 API 响应(需要立即处理) - echo = data.get("echo") - if echo is not None: - echo_str = str(echo) - if echo_str in self._pending_responses: - logger.debug(f"收到 API 响应: echo={echo_str}") - self._pending_responses[echo_str].set_result(data) - return - else: - logger.debug( - f"收到未知 echo 响应: {echo_str}, 待处理: {list(self._pending_responses.keys())}" - ) - return - - # 事件类型的消息异步处理,不阻塞接收循环 - post_type = data.get("post_type") - if post_type == "message": - msg_type = data.get("message_type", "unknown") - sender = data.get("sender", {}).get("user_id", "unknown") - logger.info( - f"[bold blue][收到消息][/bold blue] type=[yellow]{msg_type}[/yellow], sender=[blue]{sender}[/blue]" - ) - if self._message_handler: - # 创建后台任务处理消息 - task = asyncio.create_task(self._safe_handle_message(data)) - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) - elif post_type == "notice": - notice_type = data.get("notice_type", "") - sub_type = data.get("sub_type", "") - # 处理拍一拍事件 - if notice_type == "notify" and sub_type == "poke": - target_id = data.get("target_id", 0) - sender_id = data.get("user_id", 0) - group_id = data.get("group_id", 0) - logger.info( - f"[bold magenta][收到拍一拍][/bold magenta] sender=[blue]{sender_id}[/blue], target=[blue]{target_id}[/blue], group=[blue]{group_id}[/blue]" - ) - if self._message_handler: - # 将 poke 事件转换为类似消息的格式,方便 handler 处理 - poke_event = { - "post_type": "notice", - "notice_type": "poke", - "group_id": group_id, - "user_id": sender_id, - "sender": {"user_id": sender_id}, - "target_id": target_id, - "message": [], # 空消息 - } - task = asyncio.create_task(self._safe_handle_message(poke_event)) - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) - else: - logger.debug( - f"收到通知事件: notice_type={notice_type}, sub_type={sub_type}" - ) - elif post_type: - logger.debug( - f"收到事件: post_type={post_type}, meta={data.get('meta_event_type', '')}" - ) - - async def _safe_handle_message(self, data: dict[str, Any]) -> None: - """安全地处理消息(捕获异常)""" - try: - if self._message_handler: - await self._message_handler(data) - except Exception as e: - logger.exception(f"处理消息时出错: {e}") - - async def run_with_reconnect(self, reconnect_interval: float = 5.0) -> None: - """带自动重连的运行""" - self._should_stop = False - reconnect_count = 0 - - while not self._should_stop: - try: - if reconnect_count > 0: - logger.info(f"[WebSocket] 正在尝试第 {reconnect_count} 次重连...") - await self.connect() - reconnect_count = 0 # 连接成功重置计数 - await self.run() - except websockets.ConnectionClosed as e: - logger.warning(f"[WebSocket] 连接已断开: {e}") - except Exception as e: - logger.error(f"[WebSocket] 发生错误: {e}") - - if self._should_stop: - break - - reconnect_count += 1 - logger.info(f"{reconnect_interval} 秒后尝试重连...") - await asyncio.sleep(reconnect_interval) - - def stop(self) -> None: - """停止运行""" - self._should_stop = True - self._running = False - - -def parse_message_time(message: dict[str, Any]) -> datetime: - """解析消息时间。 - - 兼容秒级/毫秒级时间戳与字符串输入,异常时回退到当前时间。 - """ - - raw_timestamp = message.get("time") - - if raw_timestamp is None: - return datetime.now() - - try: - timestamp = float(raw_timestamp) - except (TypeError, ValueError): - logger.debug("[OneBot] 无法解析消息时间戳,使用当前时间: %s", raw_timestamp) - return datetime.now() - - # 13 位毫秒时间戳自动降为秒。 - if timestamp > 1_000_000_000_000: - timestamp /= 1000.0 - - if timestamp <= 0: - return datetime.now() - - try: - return datetime.fromtimestamp(timestamp) - except (OSError, OverflowError, ValueError): - logger.debug("[OneBot] 时间戳越界,使用当前时间: %s", raw_timestamp) - return datetime.now() - - -def get_message_sender_id(message: dict[str, Any]) -> int: - """获取消息发送者 QQ 号""" - sender: dict[str, Any] = message.get("sender", {}) - user_id: int = sender.get("user_id", 0) - return user_id - - -def get_message_content(message: dict[str, Any]) -> list[dict[str, Any]]: - """获取消息内容(CQ 码数组格式)""" - msg = message.get("message", []) - if isinstance(msg, str): - # 如果是字符串格式,转换为数组格式 - return [{"type": "text", "data": {"text": msg}}] - content: list[dict[str, Any]] = msg - return content diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index 70f46b66..7bc32036 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -105,7 +105,7 @@ def __init__( self.security = security self.command_dispatcher = command_dispatcher self.model_pool = ModelPoolService(ai, config, sender) - # batcher 由外部(handlers.py)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 + # batcher 由外部(handlers/message_flow)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 self._batcher: MessageBatcher | None = None def set_batcher(self, batcher: MessageBatcher | None) -> None: diff --git a/src/Undefined/services/coordinator/__init__.py b/src/Undefined/services/coordinator/__init__.py index 07c8fd50..ee91e712 100644 --- a/src/Undefined/services/coordinator/__init__.py +++ b/src/Undefined/services/coordinator/__init__.py @@ -53,7 +53,7 @@ def __init__( self.security = security self.command_dispatcher = command_dispatcher self.model_pool = ModelPoolService(ai, config, sender) - # batcher 由外部(handlers.py)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 + # batcher 由外部(handlers/message_flow)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 self._batcher: MessageBatcher | None = None def set_batcher(self, batcher: MessageBatcher | None) -> None: diff --git a/src/Undefined/services/message_batcher.py b/src/Undefined/services/message_batcher.py deleted file mode 100644 index 4ad94e11..00000000 --- a/src/Undefined/services/message_batcher.py +++ /dev/null @@ -1,810 +0,0 @@ -"""同 sender 短时多消息合并器(MessageBatcher)。 - -核心目标:把同一个 sender 在短时间内连续发出的消息合并到同一轮 AI 调用, -让模型一次看到全部 ```` 块自行决定 "独立请求 / 修正 / 打断", -避免 N 条独立 LLM 调用造成的重复回复或行为打架。 - -时序:每个 (scope, sender_id) 桶内有两条独立的"静默计时器": - -- ``T1 = window_seconds`` —— "打字静默阈值"。静默达到 T1 视为用户写完, - 这一批 batch 结束。 -- ``T2 = pre_send_seconds`` —— "投机预发送阈值",要求严格小于 T1。 - 静默到 T2 时**先把当前 batch 提前发给 LLM 抢时间**(speculative pre-fire), - 但 batch 尚未结束;T1 才决定结束。 - -新消息到来: - -- 若桶处于 ``TYPING``(尚未 pre-fire):append 后重置 T1/T2。 -- 若桶处于 ``SPECULATING``(已 pre-fire,请求已入队或 inflight 在跑): - - 检查 inflight 是否已经 "向用户发出过任何消息" - (来自 ``RequestContext.get_resource("message_sent_this_turn")``)。 - - inflight 尚未发消息 → 调 ``inflight_task.cancel()``,桶回到 TYPING; - 新消息照常 append 到原有 items 后面,T1/T2 重置。 - - inflight 已经发过消息且 ``allow_cancel_after_send=False``(默认安全)→ - 保留旧 batch 让其自然走完,新消息开新 batch(即清空当前桶后立即重新作为首条入桶)。 - - inflight 已经发过消息但开关 = True → 仍 cancel(可能造成重复发送,仅极端场景)。 - -兼容回退:当 ``pre_send_seconds <= 0`` 或 ``>= window_seconds`` 时投机模式关闭, -退化为旧版 "T1 静默到期才发车" 的行为。 -""" - -from __future__ import annotations - -import asyncio -import enum -import logging -import time -from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable - -from Undefined.config.models import MessageBatcherConfig -from Undefined.utils.coerce import was_message_sent - -logger = logging.getLogger(__name__) - - -@dataclass -class BatchDispatchToken: - """一次 batch 发车的身份令牌,用于取消已入队但尚未执行的投机请求。""" - - scope: str - sender_id: int - batch_id: int - speculative: bool = False - cancelled: bool = False - - def cancel(self) -> None: - self.cancelled = True - - -@dataclass -class BufferedMessage: - """缓冲中的单条消息上下文。""" - - scope: str - sender_id: int - text: str - message_content: list[dict[str, Any]] - attachments: list[dict[str, str]] - sender_name: str - arrival_time: float - is_private: bool - trigger_message_id: int | None = None - is_poke: bool = False - is_at_bot: bool = False - is_fake_at: bool = False - # 群聊扩展字段 - group_id: int | None = None - group_name: str = "" - sender_role: str = "member" - sender_title: str = "" - sender_level: str = "" - batch_token: BatchDispatchToken | None = None - - -FlushCallback = Callable[[list[BufferedMessage]], Awaitable[None]] -"""``flush_callback(items)``:batcher 决定 fire 时调用,调用方负责拼装 prompt 并入队执行。 - -调用约定: -- batcher 的 ``flush_callback`` **不应** 立即 await LLM 的完成, - 而是把请求扔进 QueueManager 后立即返回,真正的 LLM 任务由 coordinator 在 ``execute_reply`` - 开头调用 :meth:`MessageBatcher.register_inflight` 上报。 -- 若需要 batcher 关停时也等待 in-flight 收尾,由 :meth:`MessageBatcher.flush_all` 处理。 -""" - - -class BatchPhase(enum.Enum): - """桶状态机。""" - - TYPING = "typing" # 等待 T1/T2 静默 - SPECULATING = "speculating" # T2 已触发,请求已入队或 inflight 在跑;T1 仍未到 - FINALIZING = "finalizing" # T1 已到,等 inflight(若有)自然结束 - - -@dataclass -class _InflightInfo: - """inflight LLM 任务关联信息,由 coordinator 通过 ``register_inflight`` 上报。""" - - task: asyncio.Task[Any] - # ``RequestContext`` 引用,用于判断 ``message_sent_this_turn`` 资源 - request_context: Any = None - - -@dataclass -class _BatchState: - """单个 (scope, sender_id) 桶的状态。""" - - phase: BatchPhase = BatchPhase.TYPING - items: list[BufferedMessage] = field(default_factory=list) - first_arrival_monotonic: float = 0.0 - # T1 = window_seconds 静默 timer(决定 batch 结束) - t1_handle: asyncio.TimerHandle | None = None - # T2 = pre_send_seconds 静默 timer(决定 pre-fire);投机关闭时为 None - t2_handle: asyncio.TimerHandle | None = None - # SPECULATING 阶段记录 inflight LLM 任务(由 coordinator 通过 register_inflight 注入) - inflight: _InflightInfo | None = None - # T2 fire 时由 batcher 创建的 flush task;inflight 还未上报前用于兜底取消 - speculative_flush_task: asyncio.Task[Any] | None = None - # 当前 batch 的身份令牌;T2 入队后若又来新消息,可将旧 token 标记取消, - # coordinator 在真正执行前会跳过它。 - dispatch_token: BatchDispatchToken | None = None - - -def make_scope(*, group_id: int | None = None, user_id: int | None = None) -> str: - """构造合并 key 的 scope 字符串。""" - if group_id and group_id > 0: - return f"group:{group_id}" - if user_id is not None: - return f"private:{user_id}" - return "unknown" - - -class MessageBatcher: - """同 sender 短时合并器(含 T2 投机预发送)。""" - - def __init__( - self, - config: MessageBatcherConfig, - flush_callback: FlushCallback, - ) -> None: - self._config = config - self._flush_callback = flush_callback - self._buckets: dict[tuple[str, int], _BatchState] = {} - self._flush_failure_counts: dict[tuple[str, int], int] = {} - self._lock = asyncio.Lock() - # 持有 timer 触发后创建的 flush task 强引用,避免被 GC(asyncio 文档要求) - self._pending_tasks: set[asyncio.Task[Any]] = set() - self._next_batch_id = 0 - self._shutdown = False - - # ------------------------------------------------------------------ public - - def update_config(self, config: MessageBatcherConfig) -> None: - """配置热更新。""" - self._config = config - logger.info( - "[MessageBatcher] 配置已更新: enabled=%s window=%.2fs pre_send=%.2fs " - "strategy=%s max_window=%.2fs max_messages=%s group=%s private=%s " - "allow_cancel_after_send=%s", - config.enabled, - config.window_seconds, - config.pre_send_seconds, - config.strategy, - config.max_window_seconds, - config.max_messages_per_batch, - config.group_enabled, - config.private_enabled, - config.allow_cancel_after_send, - ) - - @property - def config(self) -> MessageBatcherConfig: - return self._config - - def is_enabled_for(self, *, is_group: bool) -> bool: - cfg = self._config - if not cfg.enabled or cfg.window_seconds <= 0: - return False - return cfg.group_enabled if is_group else cfg.private_enabled - - def has_buffer(self, scope: str, sender_id: int) -> bool: - return (scope, sender_id) in self._buckets - - async def flush_sender(self, scope: str, sender_id: int) -> bool: - return await self._handle_t1((scope, sender_id), raise_on_failure=False) - - @property - def speculative_enabled(self) -> bool: - cfg = self._config - return 0 < cfg.pre_send_seconds < cfg.window_seconds - - async def submit(self, item: BufferedMessage) -> None: - """提交一条消息进入合并桶。 - - 新消息到来时的处理依赖当前桶 ``phase``,详见模块 docstring。 - """ - cfg = self._config - key = (item.scope, item.sender_id) - # 异步路径里只在锁内修改桶;invoke callback 在锁外执行 - immediate_fire_items: list[BufferedMessage] | None = None - - async with self._lock: - if self._shutdown: - logger.info( - "[MessageBatcher] 已进入关停模式,新消息立即发车: scope=%s sender=%s", - item.scope, - item.sender_id, - ) - immediate_fire_items = [item] - else: - now_mono = time.monotonic() - state = self._buckets.get(key) - - # === 阶段 1: 决定本条消息怎么进桶 === - if state is None: - # 全新桶 - state = _BatchState( - phase=BatchPhase.TYPING, - first_arrival_monotonic=now_mono, - dispatch_token=self._new_token(item.scope, item.sender_id), - ) - self._buckets[key] = state - state.items.append(item) - elif state.phase is BatchPhase.SPECULATING: - # 已 pre-fire,决定是否 cancel inflight - inflight = state.inflight - already_sent = ( - was_message_sent(inflight.request_context) - if inflight is not None - else False - ) - allow_cancel = (not already_sent) or cfg.allow_cancel_after_send - - if inflight is not None and allow_cancel: - logger.info( - "[MessageBatcher] 投机调用被新消息抢占取消: scope=%s sender=%s " - "already_sent=%s allow_cancel_after_send=%s", - item.scope, - item.sender_id, - already_sent, - cfg.allow_cancel_after_send, - ) - if state.dispatch_token is not None: - state.dispatch_token.cancel() - inflight.task.cancel() - state.inflight = None - state.phase = BatchPhase.TYPING - # 新消息追加到现有 items 后面 - state.items.append(item) - self._retokenize_locked(state, item.scope, item.sender_id) - elif inflight is None: - # inflight 尚未注册(coordinator 还没进入 execute_reply): - # 1) 若 flush task 仍在跑,先 cancel; - # 2) 若它已经把请求入队,则取消旧 token,execute_reply 入口会跳过旧请求。 - logger.info( - "[MessageBatcher] inflight 未注册,取消投机 token/flush task: " - "scope=%s sender=%s", - item.scope, - item.sender_id, - ) - if state.dispatch_token is not None: - state.dispatch_token.cancel() - if state.speculative_flush_task is not None: - state.speculative_flush_task.cancel() - state.speculative_flush_task = None - state.phase = BatchPhase.TYPING - state.items.append(item) - self._retokenize_locked(state, item.scope, item.sender_id) - else: - # 已发过消息且不允许取消:丢弃当前桶,新消息开新桶 - logger.info( - "[MessageBatcher] 投机调用已发出消息且不允许取消,新消息开新 batch: " - "scope=%s sender=%s", - item.scope, - item.sender_id, - ) - self._cancel_t1(state) - self._cancel_t2(state) - state.phase = BatchPhase.FINALIZING - # 旧桶让 inflight 自然结束;从 _buckets pop 以释放 key 给新 batch - self._buckets.pop(key, None) - # 新桶 - state = _BatchState( - phase=BatchPhase.TYPING, - first_arrival_monotonic=now_mono, - dispatch_token=self._new_token(item.scope, item.sender_id), - ) - self._buckets[key] = state - state.items.append(item) - elif state.phase is BatchPhase.FINALIZING: - # 极少见:T1 已到、inflight 未上报但 task 已不可控;当作新桶处理 - logger.warning( - "[MessageBatcher] 桶处于 FINALIZING 期间收到新消息,开新 batch: " - "scope=%s sender=%s", - item.scope, - item.sender_id, - ) - self._buckets.pop(key, None) - state = _BatchState( - phase=BatchPhase.TYPING, - first_arrival_monotonic=now_mono, - dispatch_token=self._new_token(item.scope, item.sender_id), - ) - self._buckets[key] = state - state.items.append(item) - else: # TYPING:直接 append - state.items.append(item) - - self._bind_items_to_token_locked(state) - - # === 阶段 2: 重置 T1/T2 timer === - self._cancel_t1(state) - self._cancel_t2(state) - - elapsed = now_mono - state.first_arrival_monotonic - unlimited_window = cfg.max_window_seconds <= 0 - remaining_max = ( - float("inf") - if unlimited_window - else cfg.max_window_seconds - elapsed - ) - - # 硬顶:max_messages_per_batch 立即发车(结束 batch) - if ( - cfg.max_messages_per_batch > 0 - and len(state.items) >= cfg.max_messages_per_batch - ): - logger.info( - "[MessageBatcher] 达到 max_messages_per_batch=%s 立即发车: " - "scope=%s sender=%s", - cfg.max_messages_per_batch, - item.scope, - item.sender_id, - ) - immediate_fire_items = self._pop_locked(key) - elif not unlimited_window and remaining_max <= 0: - logger.info( - "[MessageBatcher] 已超 max_window_seconds 硬顶 立即发车: " - "scope=%s sender=%s elapsed=%.2fs", - item.scope, - item.sender_id, - elapsed, - ) - immediate_fire_items = self._pop_locked(key) - else: - # T1 delay - if cfg.strategy == "fixed": - target = state.first_arrival_monotonic + cfg.window_seconds - t1_delay = max(0.0, target - now_mono) - else: # extend - t1_delay = cfg.window_seconds - if not unlimited_window: - t1_delay = min(t1_delay, remaining_max) - - loop = asyncio.get_running_loop() - state.t1_handle = loop.call_later( - max(0.0, t1_delay), self._on_t1_timer, key - ) - - # T2 delay(仅当投机启用,且本桶尚未 pre-fire 时设置) - if ( - self.speculative_enabled - and state.phase is BatchPhase.TYPING - and cfg.pre_send_seconds < t1_delay - ): - t2_delay = cfg.pre_send_seconds - state.t2_handle = loop.call_later( - max(0.0, t2_delay), self._on_t2_timer, key - ) - logger.debug( - "[MessageBatcher] 缓冲: scope=%s sender=%s count=%s " - "t1=%.2fs t2=%.2fs strategy=%s", - item.scope, - item.sender_id, - len(state.items), - t1_delay, - t2_delay, - cfg.strategy, - ) - else: - logger.debug( - "[MessageBatcher] 缓冲: scope=%s sender=%s count=%s " - "t1=%.2fs strategy=%s phase=%s", - item.scope, - item.sender_id, - len(state.items), - t1_delay, - cfg.strategy, - state.phase.value, - ) - - # 锁外执行 callback - if immediate_fire_items is not None: - success = await self._invoke_callback(immediate_fire_items) - if success: - self._flush_failure_counts.pop(key, None) - else: - await self._restore_items_after_failed_flush( - key, immediate_fire_items, schedule_retry=True - ) - - # ----------------------------------------------------------- inflight API - - def register_inflight( - self, - scope: str, - sender_id: int, - task: asyncio.Task[Any], - request_context: Any = None, - ) -> None: - """coordinator 在 ``execute_reply`` 开头上报 inflight LLM 任务。 - - 如果桶不存在或 phase 不是 SPECULATING,则忽略(说明这次 fire 不是投机的)。 - """ - key = (scope, sender_id) - state = self._buckets.get(key) - if state is None: - return - if state.phase is not BatchPhase.SPECULATING: - return - state.inflight = _InflightInfo(task=task, request_context=request_context) - logger.debug( - "[MessageBatcher] 注册 inflight 任务: scope=%s sender=%s", - scope, - sender_id, - ) - - def unregister_inflight( - self, scope: str, sender_id: int, task: asyncio.Task[Any] - ) -> None: - """coordinator 在 ``execute_reply`` 结束(含异常/取消)时上报。""" - key = (scope, sender_id) - state = self._buckets.get(key) - if state is None: - return - if state.inflight is not None and state.inflight.task is not task: - logger.debug( - "[MessageBatcher] 忽略过期 inflight 注销: scope=%s sender=%s phase=%s", - scope, - sender_id, - state.phase.value, - ) - return - state.inflight = None - # 若 phase 是 SPECULATING 且 T1 已经 fire 过(FINALIZING 才 unregister), - # 此时 inflight 自然结束 → 桶已经在 _on_t1_timer 中弹出,无需再做事 - # 若仍在 SPECULATING(T1 未到):inflight 已结束但仍可能有新消息进来; - # 保持 SPECULATING,新消息会按 SPECULATING 分支处理(已发消息开新 batch / 未发追加) - logger.debug( - "[MessageBatcher] 注销 inflight 任务: scope=%s sender=%s phase=%s", - scope, - sender_id, - state.phase.value, - ) - - # ---------------------------------------------------------------- timers - - def _cancel_t1(self, state: _BatchState) -> None: - if state.t1_handle is not None: - state.t1_handle.cancel() - state.t1_handle = None - - def _cancel_t2(self, state: _BatchState) -> None: - if state.t2_handle is not None: - state.t2_handle.cancel() - state.t2_handle = None - - def _new_token(self, scope: str, sender_id: int) -> BatchDispatchToken: - self._next_batch_id += 1 - return BatchDispatchToken( - scope=scope, - sender_id=sender_id, - batch_id=self._next_batch_id, - ) - - def _retokenize_locked( - self, state: _BatchState, scope: str, sender_id: int - ) -> None: - state.dispatch_token = self._new_token(scope, sender_id) - self._bind_items_to_token_locked(state) - - @staticmethod - def _bind_items_to_token_locked(state: _BatchState) -> None: - if state.dispatch_token is None: - return - for buffered in state.items: - buffered.batch_token = state.dispatch_token - - def _pop_locked(self, key: tuple[str, int]) -> list[BufferedMessage] | None: - state = self._buckets.pop(key, None) - if state is None or not state.items: - return None - self._cancel_t1(state) - self._cancel_t2(state) - return list(state.items) - - def _on_t1_timer(self, key: tuple[str, int]) -> None: - """T1 静默到期:batch 结束。""" - task = asyncio.create_task(self._handle_t1(key)) - self._pending_tasks.add(task) - task.add_done_callback(self._pending_tasks.discard) - - def _on_t2_timer(self, key: tuple[str, int]) -> None: - """T2 静默到期:投机预发送(pre-fire),但 batch 不结束。""" - task = asyncio.create_task(self._handle_t2(key)) - self._pending_tasks.add(task) - task.add_done_callback(self._pending_tasks.discard) - - async def _handle_t1( - self, key: tuple[str, int], *, raise_on_failure: bool = False - ) -> bool: - items_to_fire: list[BufferedMessage] | None = None - wait_inflight: asyncio.Task[Any] | None = None - wait_prefire: asyncio.Task[Any] | None = None - finalizing_state: _BatchState | None = None - async with self._lock: - state = self._buckets.get(key) - if state is None: - return True - self._cancel_t2(state) - if state.phase is BatchPhase.SPECULATING: - # T1 到了,投机请求已经发出/入队;这里只结束 batch,不能再次发车。 - state.phase = BatchPhase.FINALIZING - finalizing_state = state - if state.inflight is not None: - wait_inflight = state.inflight.task - elif ( - state.speculative_flush_task is not None - and not state.speculative_flush_task.done() - ): - wait_prefire = state.speculative_flush_task - else: - self._buckets.pop(key, None) - logger.debug( - "[MessageBatcher] T1 结束已投机 batch,不重复发车: " - "scope=%s sender=%s", - key[0], - key[1], - ) - else: - # 普通模式或 SPECULATING 但 inflight 已结束:直接 fire - items_to_fire = self._pop_locked(key) - if items_to_fire is not None: - state.phase = BatchPhase.FINALIZING - - wait_task: asyncio.Task[Any] | None = wait_inflight or wait_prefire - if wait_task is not None: - try: - await wait_task - except asyncio.CancelledError: - # inflight/prefire 已被 cancel(极少同时发生),让 cancel 路径自然走 - logger.info( - "[MessageBatcher] T1 等待投机任务时被取消: scope=%s sender=%s", - key[0], - key[1], - ) - except Exception: - logger.exception( - "[MessageBatcher] T1 等待投机任务失败: scope=%s sender=%s", - key[0], - key[1], - ) - finally: - # 仅当桶仍是 finalizing_state(同一对象)时才 pop; - # 否则 submit 已经在 SPECULATING/FINALIZING 分支把旧桶 pop 并建立新桶, - # 不能误删新桶。 - async with self._lock: - current = self._buckets.get(key) - if current is finalizing_state: - self._buckets.pop(key, None) - return True - - if items_to_fire is not None: - success = await self._invoke_callback(items_to_fire, speculative=False) - if success: - self._flush_failure_counts.pop(key, None) - else: - await self._restore_items_after_failed_flush( - key, items_to_fire, schedule_retry=not self._shutdown - ) - if raise_on_failure: - raise RuntimeError("message batcher flush callback failed") - return success - return True - - async def _handle_t2(self, key: tuple[str, int]) -> None: - speculative_items: list[BufferedMessage] | None = None - async with self._lock: - state = self._buckets.get(key) - if state is None: - return - if state.phase is not BatchPhase.TYPING: - return - if not state.items: - return - # 切到 SPECULATING,但**不**清空 items(保留以便后续 T1 也能用 / 抢占回收) - state.phase = BatchPhase.SPECULATING - self._cancel_t2(state) - if state.dispatch_token is None: - state.dispatch_token = self._new_token(key[0], key[1]) - self._bind_items_to_token_locked(state) - state.dispatch_token.speculative = True - # 记录"承担投机职责"的当前 task;此处指向 _handle_t2 协程本身 - # (pre-fire 协程),不是 LLM inflight task。 - # 后续 submit() 抢占判定通过 `state.speculative_flush_task is asyncio.current_task()` - # 区分新旧 pre-fire 协程,避免误清理新 batch。 - state.speculative_flush_task = asyncio.current_task() - speculative_items = list(state.items) - logger.info( - "[MessageBatcher] 投机预发送: scope=%s sender=%s count=%s", - key[0], - key[1], - len(speculative_items), - ) - - if speculative_items is not None: - success = False - try: - success = await self._invoke_callback( - speculative_items, speculative=True - ) - finally: - # 清掉自身引用,避免 state 残留指向已结束 task;若投机 callback - # 异常/取消且桶仍是本次 SPECULATING,则回滚为 TYPING,等待 T1 正常重试。 - async with self._lock: - state2 = self._buckets.get(key) - if ( - state2 is not None - and state2.speculative_flush_task is asyncio.current_task() - ): - state2.speculative_flush_task = None - if state2.phase is BatchPhase.SPECULATING and not success: - if state2.dispatch_token is not None: - state2.dispatch_token.cancel() - state2.phase = BatchPhase.TYPING - self._retokenize_locked(state2, key[0], key[1]) - logger.warning( - "[MessageBatcher] 投机预发送失败,回滚等待 T1 重试: " - "scope=%s sender=%s", - key[0], - key[1], - ) - - async def _invoke_callback( - self, - items: list[BufferedMessage], - *, - speculative: bool = False, - ) -> bool: - if not items: - return True - first = items[0] - logger.info( - "[MessageBatcher] 发车: scope=%s sender=%s count=%s speculative=%s", - first.scope, - first.sender_id, - len(items), - speculative, - ) - try: - await self._flush_callback(items) - return True - except asyncio.CancelledError: - # 投机被新消息取消是预期行为 - logger.info( - "[MessageBatcher] flush_callback 被取消(投机抢占): " - "scope=%s sender=%s speculative=%s", - first.scope, - first.sender_id, - speculative, - ) - return False - except Exception: - logger.exception( - "[MessageBatcher] flush_callback 异常: scope=%s sender=%s count=%s", - first.scope, - first.sender_id, - len(items), - ) - return False - - async def _restore_items_after_failed_flush( - self, - key: tuple[str, int], - items: list[BufferedMessage], - *, - schedule_retry: bool, - ) -> None: - """flush callback 失败后回滚到 TYPING 阶段。 - - 重试策略(fail-fast): - - 每次失败累加 ``self._flush_failure_counts[key]``; - - 仅在 ``failure_count <= 1``(即首次失败)时安排一次延后 T1 重试; - - 第二次起仅恢复 batch、等待用户新消息或 ``flush_all`` 触发, - 避免 LLM 端持续故障时形成"无限重试风暴"; - - 桶在成功一次后 ``failure_count`` 会被 pop 清零。 - - ``flush_all`` 路径会 raise,从而暴露持续失败。 - """ - if not items: - return - async with self._lock: - state = self._buckets.get(key) - if state is None: - state = _BatchState( - phase=BatchPhase.TYPING, - first_arrival_monotonic=time.monotonic(), - dispatch_token=self._new_token(key[0], key[1]), - ) - self._buckets[key] = state - state.items = list(items) - else: - self._cancel_t1(state) - self._cancel_t2(state) - state.phase = BatchPhase.TYPING - state.items = list(items) + state.items - state.first_arrival_monotonic = time.monotonic() - state.inflight = None - if state.dispatch_token is not None: - state.dispatch_token.cancel() - self._retokenize_locked(state, key[0], key[1]) - logger.warning( - "[MessageBatcher] flush 失败,已恢复 batch: scope=%s sender=%s count=%s", - key[0], - key[1], - len(state.items), - ) - failure_count = self._flush_failure_counts.get(key, 0) + 1 - self._flush_failure_counts[key] = failure_count - if schedule_retry and not self._shutdown and failure_count <= 1: - loop = asyncio.get_running_loop() - delay = max(0.0, self._config.window_seconds) - state.t1_handle = loop.call_later(delay, self._on_t1_timer, key) - - # ------------------------------------------------------------ shutdown - - async def flush_all(self) -> None: - """立即 flush 所有 buckets(用于关停)。 - - 关停时直接对所有桶执行 T1 等价路径并等 inflight 收尾。 - """ - while True: - async with self._lock: - self._shutdown = True - keys = list(self._buckets.keys()) - if not keys: - break - logger.info("[MessageBatcher] flush_all: pending_buckets=%s", len(keys)) - for key in keys: - await self._handle_t1(key, raise_on_failure=True) - # 等 timer 已触发但回调仍在跑的 task - pending = [t for t in self._pending_tasks if not t.done()] - if pending: - logger.info( - "[MessageBatcher] flush_all: 等待 %s 个 in-flight flush task", - len(pending), - ) - await asyncio.gather(*pending, return_exceptions=True) - - # ------------------------------------------------------------- snapshot - - def snapshot(self) -> dict[str, Any]: - """返回当前 buckets 状态的非阻塞快照(供 Runtime API / WebUI 展示)。""" - cfg = self._config - now_mono = time.monotonic() - buckets: list[dict[str, Any]] = [] - for (scope, sender_id), state in list(self._buckets.items()): - buckets.append( - { - "scope": scope, - "sender_id": sender_id, - "count": len(state.items), - "elapsed_seconds": round( - max(0.0, now_mono - state.first_arrival_monotonic), 2 - ), - "phase": state.phase.value, - "has_inflight": state.inflight is not None, - "has_speculative_dispatch": ( - state.dispatch_token is not None - and state.dispatch_token.speculative - and not state.dispatch_token.cancelled - ), - } - ) - return { - "config": { - "enabled": cfg.enabled, - "window_seconds": cfg.window_seconds, - "pre_send_seconds": cfg.pre_send_seconds, - "speculative_enabled": self.speculative_enabled, - "strategy": cfg.strategy, - "max_window_seconds": cfg.max_window_seconds, - "max_messages_per_batch": cfg.max_messages_per_batch, - "group_enabled": cfg.group_enabled, - "private_enabled": cfg.private_enabled, - "flush_on_command": cfg.flush_on_command, - "allow_cancel_after_send": cfg.allow_cancel_after_send, - "shutdown": self._shutdown, - }, - "pending_buckets": len(buckets), - "buckets": buckets, - } diff --git a/src/Undefined/skills/agents/runner.py b/src/Undefined/skills/agents/runner.py deleted file mode 100644 index b68fd22e..00000000 --- a/src/Undefined/skills/agents/runner.py +++ /dev/null @@ -1,384 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from pathlib import Path -from typing import Any - -import aiofiles - -from Undefined.config.models import AgentModelConfig -from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY -from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry -from Undefined.skills.anthropic_skills import AnthropicSkillRegistry -from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT -from Undefined.utils.tool_calls import parse_tool_arguments - - -async def load_prompt_text(agent_dir: Path, default_prompt: str) -> str: - """从 agent 目录加载 prompt.md,缺失时返回默认提示词。""" - - prompt_path = agent_dir / "prompt.md" - if prompt_path.exists(): - async with aiofiles.open(prompt_path, "r", encoding="utf-8") as file: - return await file.read() - return default_prompt - - -def _filter_tools_for_runtime_config( - agent_name: str, - tools: list[dict[str, Any]], - runtime_config: Any | None, -) -> list[dict[str, Any]]: - if agent_name != "web_agent" or runtime_config is None: - return tools - - if bool(getattr(runtime_config, "grok_search_enabled", False)): - return tools - - filtered: list[dict[str, Any]] = [] - for tool in tools: - function = tool.get("function") if isinstance(tool, dict) else None - name = function.get("name") if isinstance(function, dict) else None - if name == "grok_search": - continue - filtered.append(tool) - return filtered - - -async def run_agent_with_tools( - *, - agent_name: str, - user_content: str, - context_messages: list[dict[str, str]] | None = None, - empty_user_content_message: str, - default_prompt: str, - context: dict[str, Any], - agent_dir: Path, - logger: logging.Logger, - max_iterations: int = 20, - tool_error_prefix: str = "错误", -) -> str: - """执行通用 Agent 循环。 - - 该方法统一处理: - - prompt 加载 - - LLM 迭代决策 - - tool call 并发执行 - - tool 结果回填 messages - """ - - if not user_content.strip(): - return empty_user_content_message - - tool_registry = AgentToolRegistry( - agent_dir / "tools", - current_agent_name=agent_name, - is_main_agent=False, - ) - tools = tool_registry.get_tools_schema() - runtime_config = context.get("runtime_config") - tools = _filter_tools_for_runtime_config(agent_name, tools, runtime_config) - - # 发现并加载 agent 私有 Anthropic Skills(可选) - agent_skills_dir = agent_dir / "anthropic_skills" - agent_skill_registry: AnthropicSkillRegistry | None = None - if agent_skills_dir.exists() and agent_skills_dir.is_dir(): - agent_skill_registry = AnthropicSkillRegistry(agent_skills_dir) - if agent_skill_registry.has_skills(): - # 将 anthropic skill tools 加入 agent 的可用工具列表 - tools = tools + agent_skill_registry.get_tools_schema() - logger.info( - "[Agent:%s] 加载了 %d 个私有 Anthropic Skills", - agent_name, - len(agent_skill_registry.get_all_skills()), - ) - - ai_client = context.get("ai_client") - if not ai_client: - return "AI client 未在上下文中提供" - - model_config_override = context.get("model_config_override") - if isinstance(model_config_override, AgentModelConfig): - agent_config = model_config_override - else: - agent_config = ai_client.agent_config - # 动态选择 agent 模型 - group_id = context.get("group_id", 0) or 0 - user_id = context.get("user_id", 0) or 0 - global_enabled = runtime_config.model_pool_enabled if runtime_config else False - agent_config = ai_client.model_selector.select_agent_config( - agent_config, - group_id=group_id, - user_id=user_id, - global_enabled=global_enabled, - ) - system_prompt = await load_prompt_text(agent_dir, default_prompt) - - # 注入 agent 私有 Anthropic Skills 元数据到 system prompt - if agent_skill_registry and agent_skill_registry.has_skills(): - skills_xml = agent_skill_registry.build_metadata_xml() - if skills_xml: - system_prompt = ( - f"{system_prompt}\n\n" - f"【可用的 Anthropic Skills】\n" - f"{skills_xml}\n\n" - f"注意:以上是你可用的 Anthropic Agent Skills。" - f"当任务与某个 skill 相关时," - f"可以调用对应的 skill tool(tool_name 字段)" - f"来获取该领域的详细指令和知识。" - ) - - agent_history = context.get("agent_history", []) - - messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] - if agent_history: - messages.extend(agent_history) - if context_messages: - messages.extend(context_messages) - messages.append({"role": "user", "content": user_content}) - transport_state: dict[str, Any] | None = None - queue_lane = context.get("queue_lane") - max_pre_tool_retries = max( - 0, int(getattr(runtime_config, "ai_request_max_retries", 0) or 0) - ) - pre_tool_failure_count = 0 - - for iteration in range(1, max_iterations + 1): - logger.debug("[Agent:%s] iteration=%s", agent_name, iteration) - message_checkpoint_len = len(messages) - transport_state_checkpoint = transport_state - try: - result = await ai_client.submit_queued_llm_call( - model_config=agent_config, - messages=messages, - max_tokens=agent_config.max_tokens, - call_type=f"agent:{agent_name}", - tools=tools if tools else None, - tool_choice="auto", - transport_state=transport_state, - queue_lane=queue_lane, - ) - except Exception as exc: - logger.exception( - "[Agent:%s] queued LLM 调用失败: lane=%s iteration=%s error=%s", - agent_name, - queue_lane, - iteration, - exc, - ) - raise RuntimeError("智能体模型请求失败") from exc - - try: - tool_execution_started = False - tool_name_map = ( - result.get("_tool_name_map") if isinstance(result, dict) else None - ) - api_to_internal: dict[str, str] = {} - if isinstance(tool_name_map, dict): - raw_api_to_internal = tool_name_map.get("api_to_internal") - if isinstance(raw_api_to_internal, dict): - api_to_internal = { - str(key): str(value) - for key, value in raw_api_to_internal.items() - } - - next_transport_state = ( - result.get("_transport_state") if isinstance(result, dict) else None - ) - transport_state = ( - next_transport_state if isinstance(next_transport_state, dict) else None - ) - - choice: dict[str, Any] = result.get("choices", [{}])[0] - message: dict[str, Any] = choice.get("message", {}) - content: str = message.get("content") or "" - reasoning_content: str | None = message.get("reasoning_content") - tool_calls: list[dict[str, Any]] = message.get("tool_calls", []) - - if content.strip() and tool_calls: - content = "" - - if not tool_calls: - return content - - assistant_message: dict[str, Any] = { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) - if isinstance(output_items, list): - assistant_message[RESPONSES_OUTPUT_ITEMS_KEY] = output_items - capture_reasoning = bool( - getattr(agent_config, "thinking_tool_call_compat", False) - ) or bool(getattr(agent_config, "reasoning_content_replay", False)) - if capture_reasoning and reasoning_content is not None: - assistant_message["reasoning_content"] = reasoning_content - messages.append(assistant_message) - - tool_tasks: list[asyncio.Future[Any]] = [] - tool_call_ids: list[str] = [] - tool_api_names: list[str] = [] - end_tool_call: dict[str, Any] | None = None - end_tool_args: dict[str, Any] = {} - results: list[Any] = [] - - for tool_call in tool_calls: - call_id = str(tool_call.get("id", "")) - function: dict[str, Any] = tool_call.get("function", {}) - api_function_name = str(function.get("name", "")) - raw_args = function.get("arguments") - - internal_function_name = api_to_internal.get( - api_function_name, api_function_name - ) - logger.info( - "[Agent:%s] preparing tool=%s", - agent_name, - internal_function_name, - ) - - function_args = parse_tool_arguments( - raw_args, - logger=logger, - tool_name=api_function_name, - ) - - if not isinstance(function_args, dict): - function_args = {} - - # 检测 end 工具,暂存后统一处理 - if internal_function_name == "end": - if len(tool_calls) > 1: - logger.warning( - "[Agent:%s] end 与其他工具同时调用," - "将先执行其他工具,end 将返回拒绝结果", - agent_name, - ) - end_tool_call = tool_call - end_tool_args = function_args - continue - - tool_call_ids.append(call_id) - tool_api_names.append(api_function_name) - - # Anthropic Skill tool 路由 - # 工具名格式: skills,如 skills-_-pdf-processing - skill_delimiter = ( - agent_skill_registry.dot_delimiter - if agent_skill_registry - else "-_-" - ) - is_agent_skill = internal_function_name.startswith( - f"skills{skill_delimiter}" - ) - if is_agent_skill and agent_skill_registry: - tool_tasks.append( - asyncio.ensure_future( - agent_skill_registry.execute_skill_tool( - internal_function_name, - function_args, - context, - ) - ) - ) - else: - tool_tasks.append( - asyncio.ensure_future( - tool_registry.execute_tool( - internal_function_name, - function_args, - context, - ) - ) - ) - - if tool_tasks: - tool_execution_started = True - logger.info( - "[Agent:%s] executing tools in parallel: count=%s", - agent_name, - len(tool_tasks), - ) - results = await asyncio.gather(*tool_tasks, return_exceptions=True) - - for index, tool_result in enumerate(results): - call_id = tool_call_ids[index] - api_tool_name = tool_api_names[index] - if isinstance(tool_result, Exception): - content_str = f"{tool_error_prefix}: {tool_result}" - else: - content_str = str(tool_result) - - messages.append( - { - "role": "tool", - "tool_call_id": call_id, - "name": api_tool_name, - "content": content_str, - } - ) - - # 处理 end 工具调用 - if end_tool_call: - end_call_id = str(end_tool_call.get("id", "")) - end_api_name = end_tool_call.get("function", {}).get("name", "end") - if tool_tasks: - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": END_CO_CALL_REJECT_CONTENT, - } - ) - logger.info( - "[Agent:%s] end 与其他工具同时调用," - "其它工具已执行,end 已回填拒绝响应", - agent_name, - ) - else: - # end 单独调用,正常执行(参数已在循环中解析) - tool_execution_started = True - end_result = await tool_registry.execute_tool( - "end", end_tool_args, context - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": str(end_result), - } - ) - pre_tool_failure_count = 0 - - except Exception as exc: - if ( - not tool_execution_started - and pre_tool_failure_count < max_pre_tool_retries - ): - pre_tool_failure_count += 1 - del messages[message_checkpoint_len:] - transport_state = transport_state_checkpoint - logger.warning( - "[Agent:%s] pre-tool 本地失败,重试当前 LLM 轮次: lane=%s retry=%s/%s iteration=%s error=%s", - agent_name, - queue_lane, - pre_tool_failure_count, - max_pre_tool_retries, - iteration, - exc, - ) - continue - logger.exception( - "[Agent:%s] 执行失败,已静默抑制: lane=%s iteration=%s error=%s", - agent_name, - queue_lane, - iteration, - exc, - ) - return "" - - return "达到最大迭代次数" diff --git a/tests/test_ai_client_setup_paths.py b/tests/test_ai_client_setup_paths.py index 2f1381e8..baac7c81 100644 --- a/tests/test_ai_client_setup_paths.py +++ b/tests/test_ai_client_setup_paths.py @@ -10,12 +10,26 @@ from Undefined.skills.tools import ToolRegistry from Undefined.utils.paths import PACKAGE_ROOT -# Snapshot counts from skills/*/config.json inventory (excluding MCP). -EXPECTED_BASIC_TOOL_COUNT = 15 -EXPECTED_TOOLSET_COUNT = 53 -EXPECTED_AGENT_COUNT = 8 -EXPECTED_COMMAND_COUNT = 12 -EXPECTED_PIPELINE_COUNT = 3 + +def _skill_dirs(base: Path) -> set[str]: + if not base.is_dir(): + return set() + return { + item.name + for item in base.iterdir() + if item.is_dir() and (item / "config.json").exists() + } + + +def _toolset_tool_names(base: Path) -> set[str]: + if not base.is_dir(): + return set() + names: set[str] = set() + for config_path in base.rglob("config.json"): + rel = config_path.parent.relative_to(base) + if len(rel.parts) == 2: + names.add(".".join(rel.parts)) + return names def test_package_root_matches_undefined_package_directory() -> None: @@ -38,23 +52,22 @@ def test_setup_wrong_path_does_not_exist() -> None: def test_tool_registry_loads_all_skill_directories() -> None: - registry = ToolRegistry(PACKAGE_ROOT / "skills" / "tools") + tools_dir = PACKAGE_ROOT / "skills" / "tools" + toolsets_dir = PACKAGE_ROOT / "skills" / "toolsets" + registry = ToolRegistry(tools_dir) basic = [name for name in registry._items if "." not in name] toolsets = [ name for name in registry._items if "." in name and not name.startswith("mcp.") ] - assert len(basic) == EXPECTED_BASIC_TOOL_COUNT - assert len(toolsets) == EXPECTED_TOOLSET_COUNT - assert len(registry._items) == EXPECTED_BASIC_TOOL_COUNT + EXPECTED_TOOLSET_COUNT + basic_dirs = _skill_dirs(tools_dir) + toolset_names = _toolset_tool_names(toolsets_dir) - tool_dirs = [ - item.name - for item in (PACKAGE_ROOT / "skills" / "tools").iterdir() - if item.is_dir() and (item / "config.json").exists() - ] - assert len(tool_dirs) == EXPECTED_BASIC_TOOL_COUNT - assert set(basic) == set(tool_dirs) + assert len(basic) == len(basic_dirs) + assert len(toolsets) == len(toolset_names) + assert len(registry._items) == len(basic) + len(toolsets) + assert set(basic) == basic_dirs + assert set(toolsets) == toolset_names def test_all_registered_tools_import_handlers() -> None: @@ -73,14 +86,23 @@ def test_all_registered_tools_import_handlers() -> None: def test_agent_registry_loads_expected_agents() -> None: - registry = AgentRegistry(PACKAGE_ROOT / "skills" / "agents") - assert len(registry._items) == EXPECTED_AGENT_COUNT + agents_dir = PACKAGE_ROOT / "skills" / "agents" + registry = AgentRegistry(agents_dir) + assert len(registry._items) == len(_skill_dirs(agents_dir)) + assert set(registry._items) == _skill_dirs(agents_dir) def test_command_registry_loads_expected_commands() -> None: - registry = CommandRegistry(PACKAGE_ROOT / "skills" / "commands") + commands_dir = PACKAGE_ROOT / "skills" / "commands" + registry = CommandRegistry(commands_dir) registry.load_commands() - assert len(registry._commands) == EXPECTED_COMMAND_COUNT + command_dirs = { + item.name + for item in commands_dir.iterdir() + if item.is_dir() and (item / "handler.py").exists() + } + assert len(registry._commands) == len(command_dirs) + assert set(registry._commands) == command_dirs def test_pipeline_registry_loads_expected_pipelines() -> None: @@ -90,5 +112,4 @@ async def _load() -> PipelineRegistry: return registry registry = asyncio.run(_load()) - assert len(registry._items) == EXPECTED_PIPELINE_COUNT assert set(registry._items) == {"arxiv", "bilibili", "github"} diff --git a/tests/test_config_from_mapping.py b/tests/test_config_from_mapping.py index 70e0dbd0..2e28f806 100644 --- a/tests/test_config_from_mapping.py +++ b/tests/test_config_from_mapping.py @@ -44,9 +44,11 @@ def test_set_config_injects_singleton(monkeypatch: pytest.MonkeyPatch) -> None: import Undefined.config as config_pkg monkeypatch.setattr(config_pkg, "_config", None) + monkeypatch.setattr(config_pkg, "_config_manager", None) cfg = Config.from_mapping(_MINIMAL_MAPPING, strict=False) set_config(cfg) assert config_pkg.get_config(strict=False) is cfg + assert config_pkg.get_config_manager().load(strict=False) is cfg def test_from_mapping_matches_load(tmp_path: Path) -> None: diff --git a/tests/test_handlers_meme_annotation.py b/tests/test_handlers_meme_annotation.py index a8ff4d62..1b35b322 100644 --- a/tests/test_handlers_meme_annotation.py +++ b/tests/test_handlers_meme_annotation.py @@ -1,4 +1,4 @@ -"""测试 handlers.py 中的表情包自动匹配功能""" +"""测试 handlers/message_flow 中的表情包自动匹配功能""" from __future__ import annotations diff --git a/tests/test_package_layout.py b/tests/test_package_layout.py new file mode 100644 index 00000000..66833a0f --- /dev/null +++ b/tests/test_package_layout.py @@ -0,0 +1,32 @@ +"""打包布局与模块结构回归测试。""" + +from __future__ import annotations + +from pathlib import Path + +import Undefined + + +def test_py_typed_marker_exists() -> None: + pkg_root = Path(Undefined.__file__).resolve().parent + marker = pkg_root / "py.typed" + assert marker.is_file(), "src/Undefined/py.typed must exist for PEP 561" + + +def test_py_typed_declared_in_pyproject() -> None: + repo_root = Path(__file__).resolve().parents[1] + pyproject = (repo_root / "pyproject.toml").read_text(encoding="utf-8") + assert 'src/Undefined/py.typed" = "Undefined/py.typed"' in pyproject + + +def test_no_shadowed_monolith_modules() -> None: + """禁止 foo.py 与 foo/ 包目录并存(会导致一份实现成为不可达死代码)。""" + pkg_root = Path(Undefined.__file__).resolve().parent + violations: list[str] = [] + for path in pkg_root.rglob("*.py"): + if path.name == "__init__.py": + continue + package_dir = path.with_suffix("") + if package_dir.is_dir() and (package_dir / "__init__.py").is_file(): + violations.append(str(path.relative_to(pkg_root))) + assert violations == [], f"shadowed monolith modules: {violations}" diff --git a/tests/test_public_api_imports.py b/tests/test_public_api_imports.py index 520a3cbf..c1b97d99 100644 --- a/tests/test_public_api_imports.py +++ b/tests/test_public_api_imports.py @@ -147,3 +147,4 @@ def test_set_config_not_used_by_default_get_config( ) set_config(injected) assert config_module.get_config(strict=False) is injected + assert config_module.get_config_manager().load(strict=False) is injected From e73ad74a2632785512f300e7b76010c86f32e461 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 24 May 2026 09:34:22 +0800 Subject: [PATCH 11/16] chore(coderabbit): enable auto review on all base branches Configure base_branches to .* so CodeRabbit reviews PRs targeting any branch. --- .coderabbit.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .coderabbit.yaml diff --git a/.coderabbit.yaml b/.coderabbit.yaml new file mode 100644 index 00000000..b220aba8 --- /dev/null +++ b/.coderabbit.yaml @@ -0,0 +1,8 @@ +# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json + +reviews: + auto_review: + enabled: true + auto_incremental_review: true + base_branches: + - ".*" From 531b6a6644e0bb388d6004d379f4dbb9e224d193 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 24 May 2026 09:51:08 +0800 Subject: [PATCH 12/16] fix: address CodeRabbit review findings across runtime modules Harden queued LLM retries, pending-call cleanup, tool filtering, config bounds, attachment rendering, and privacy-safe prompt cache keys. Co-authored-by: Cursor --- src/Undefined/ai/client/ask_loop.py | 13 +---- src/Undefined/ai/client/queue.py | 56 +++++++++---------- src/Undefined/ai/client/setup.py | 38 ++++++++++++- src/Undefined/ai/llm/requester.py | 6 +- src/Undefined/ai/llm/streaming.py | 2 - src/Undefined/ai/multimodal/detection.py | 9 ++- src/Undefined/ai/multimodal/parsing.py | 7 ++- src/Undefined/ai/prompts/system_context.py | 2 +- src/Undefined/api/routes/naga/bind.py | 2 + src/Undefined/api/routes/naga/unbind.py | 2 + src/Undefined/attachments/render.py | 11 +++- .../config/load_sections/finalize.py | 18 +++--- .../config/load_sections/history_skills.py | 2 + .../config/load_sections/integrations.py | 8 +++ .../config/load_sections/knowledge.py | 2 + .../config/load_sections/logging_tools.py | 4 ++ tests/test_llm_request_params.py | 5 +- tests/test_llm_retry_suppression.py | 14 +++-- 18 files changed, 131 insertions(+), 70 deletions(-) diff --git a/src/Undefined/ai/client/ask_loop.py b/src/Undefined/ai/client/ask_loop.py index 75a074d9..bd75f01b 100644 --- a/src/Undefined/ai/client/ask_loop.py +++ b/src/Undefined/ai/client/ask_loop.py @@ -221,6 +221,7 @@ async def fetch_session_messages_callback( message_checkpoint_len = len(messages) transport_state_checkpoint = transport_state + tool_execution_started = False try: result = await self.submit_queued_llm_call( model_config=effective_chat_config, @@ -232,18 +233,7 @@ async def fetch_session_messages_callback( transport_state=transport_state, queue_lane=queue_lane, ) - except Exception as exc: - logger.exception( - "[queued_llm_error] call_type=chat model=%s lane=%s iteration=%s error=%s", - effective_chat_config.model_name, - queue_lane, - iteration, - exc, - ) - raise - try: - tool_execution_started = False tool_name_map = ( result.get("_tool_name_map") if isinstance(result, dict) else None ) @@ -609,7 +599,6 @@ async def fetch_session_messages_callback( exc, ) continue - # 工具已执行或重试用尽:吞掉异常,避免向用户暴露内部错误 logger.exception( "[chat.suppressed_error] model=%s lane=%s iteration=%s error=%s", effective_chat_config.model_name, diff --git a/src/Undefined/ai/client/queue.py b/src/Undefined/ai/client/queue.py index 39bb1df7..9a683a9d 100644 --- a/src/Undefined/ai/client/queue.py +++ b/src/Undefined/ai/client/queue.py @@ -134,36 +134,36 @@ async def submit_queued_llm_call( len(messages), bool(tools), ) - receipt = await self._queue_manager.add_queued_llm_request( - request, - lane=resolved_queue_lane, - model_name=model_name, - ) - wait_timeout = compute_queued_llm_timeout_seconds( - self._get_runtime_config(), - model_config, - retry_count=resolve_effective_retry_count( - self._get_runtime_config(), self._queue_manager - ), - initial_wait_seconds=float( - getattr(receipt, "estimated_wait_seconds", 0.0) or 0.0 - ), - # 首次 dispatch 间隔已含在 estimated_wait 中,避免重复计入 - include_first_dispatch_interval=False, - ) try: - await asyncio.wait_for(event.wait(), timeout=wait_timeout) - except asyncio.TimeoutError: - logger.exception( - "[queued_llm_wait_timeout] request_id=%s call_type=%s model=%s lane=%s timeout=%.1fs", - request_id, - call_type, - model_name, - resolved_queue_lane, - wait_timeout, + receipt = await self._queue_manager.add_queued_llm_request( + request, + lane=resolved_queue_lane, + model_name=model_name, + ) + wait_timeout = compute_queued_llm_timeout_seconds( + self._get_runtime_config(), + model_config, + retry_count=resolve_effective_retry_count( + self._get_runtime_config(), self._queue_manager + ), + initial_wait_seconds=float( + getattr(receipt, "estimated_wait_seconds", 0.0) or 0.0 + ), + # 首次 dispatch 间隔已含在 estimated_wait 中,避免重复计入 + include_first_dispatch_interval=False, ) - raise - # finally:无论成败都执行清理 + try: + await asyncio.wait_for(event.wait(), timeout=wait_timeout) + except asyncio.TimeoutError: + logger.exception( + "[queued_llm_wait_timeout] request_id=%s call_type=%s model=%s lane=%s timeout=%.1fs", + request_id, + call_type, + model_name, + resolved_queue_lane, + wait_timeout, + ) + raise finally: entry = self._pending_llm_calls.pop(request_id, None) _, result = entry if entry is not None else (None, None) diff --git a/src/Undefined/ai/client/setup.py b/src/Undefined/ai/client/setup.py index 9a1df7a1..f89a2abd 100644 --- a/src/Undefined/ai/client/setup.py +++ b/src/Undefined/ai/client/setup.py @@ -249,6 +249,7 @@ def __init__( # Agent intro 生成器(延迟初始化,需要外部设置 queue_manager) self._agent_intro_generator: Any | None = None self._agent_intro_task: asyncio.Task[None] | None = None + self._intro_refresh_task: asyncio.Task[None] | None = None self._queue_manager: Any | None = None self._intro_config: Any | None = None # 后台 LLM 调用挂起表(走队列的后台请求) @@ -350,6 +351,14 @@ async def close(self) -> None: intro_gen = getattr(self, "_agent_intro_generator", None) if intro_gen is not None: await intro_gen.stop() + intro_refresh_task = getattr(self, "_intro_refresh_task", None) + if intro_refresh_task is not None and not intro_refresh_task.done(): + intro_refresh_task.cancel() + try: + await intro_refresh_task + except asyncio.CancelledError: + pass + self._intro_refresh_task = None if hasattr(self, "_agent_intro_task") and self._agent_intro_task: if not self._agent_intro_task.done(): await self._agent_intro_task @@ -420,8 +429,31 @@ def apply_intro_config(self, config: AgentIntroGenConfig) -> None: self._intro_config = config if self._queue_manager is None: return - task = asyncio.create_task(self._refresh_intro_generator(config)) - task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) + existing = self._intro_refresh_task + if existing is not None and not existing.done(): + existing.cancel() + + async def _run_refresh() -> None: + try: + await self._refresh_intro_generator(config) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("[Agent介绍] 刷新 intro 生成器失败") + + task = asyncio.create_task(_run_refresh()) + + def _finalize(done_task: asyncio.Task[None]) -> None: + if getattr(self, "_intro_refresh_task", None) is done_task: + self._intro_refresh_task = None + if done_task.cancelled(): + return + exc = done_task.exception() + if exc is not None: + logger.error("[Agent介绍] intro 刷新任务异常结束", exc_info=exc) + + task.add_done_callback(_finalize) + self._intro_refresh_task = task async def _refresh_intro_generator(self, config: AgentIntroGenConfig) -> None: if not config.enabled: @@ -742,6 +774,8 @@ async def request_model( **kwargs: Any, ) -> dict[str, Any]: tools = self.tool_manager.maybe_merge_agent_tools(call_type, tools) + if tools is not None: + tools = self._filter_tools_for_runtime_config(tools) message_count_for_transport = len(messages) # Responses 续轮(previous_response_id)时跳过 prefetch,避免重复注入系统消息 if not ( diff --git a/src/Undefined/ai/llm/requester.py b/src/Undefined/ai/llm/requester.py index 2e7e667f..10329d0e 100644 --- a/src/Undefined/ai/llm/requester.py +++ b/src/Undefined/ai/llm/requester.py @@ -161,11 +161,11 @@ def _build_scope_prompt_cache_part() -> str: if ctx is None: return "scope:global" if ctx.group_id is not None: - return f"group:{int(ctx.group_id)}" + return f"group:{_hash8(str(int(ctx.group_id)))}" if ctx.user_id is not None: - return f"private:{int(ctx.user_id)}" + return f"private:{_hash8(str(int(ctx.user_id)))}" if ctx.sender_id is not None: - return f"sender:{int(ctx.sender_id)}" + return f"sender:{_hash8(str(int(ctx.sender_id)))}" request_type = _normalize_prompt_cache_part(ctx.request_type) return f"type:{request_type}" diff --git a/src/Undefined/ai/llm/streaming.py b/src/Undefined/ai/llm/streaming.py index 50632476..cc12113e 100644 --- a/src/Undefined/ai/llm/streaming.py +++ b/src/Undefined/ai/llm/streaming.py @@ -358,8 +358,6 @@ def aggregate_responses_stream(events: list[dict[str, Any]]) -> dict[str, Any]: usage = extract_stream_usage(event_dict, api_mode=API_MODE_RESPONSES) or usage event_type = str(event_dict.get("type") or "").strip().lower() response = event_dict.get("response") - if isinstance(response, dict): - final_response = response if event_type == "response.output_text.delta": delta = stringify_stream_delta(event_dict.get("delta")) if delta: diff --git a/src/Undefined/ai/multimodal/detection.py b/src/Undefined/ai/multimodal/detection.py index 666bb9b3..9508d5ce 100644 --- a/src/Undefined/ai/multimodal/detection.py +++ b/src/Undefined/ai/multimodal/detection.py @@ -30,14 +30,17 @@ def _extract_mime_type_from_data_url(media_url: str) -> str | None: def _get_media_type_by_extension(url_lower: str) -> str: """根据文件扩展名判断媒体类型。""" + from urllib.parse import urlsplit + + path = urlsplit(url_lower).path for ext in IMAGE_EXTENSIONS: - if ext in url_lower: + if path.endswith(ext): return "image" for ext in AUDIO_EXTENSIONS: - if ext in url_lower: + if path.endswith(ext): return "audio" for ext in VIDEO_EXTENSIONS: - if ext in url_lower: + if path.endswith(ext): return "video" return "image" diff --git a/src/Undefined/ai/multimodal/parsing.py b/src/Undefined/ai/multimodal/parsing.py index b5338078..f8912456 100644 --- a/src/Undefined/ai/multimodal/parsing.py +++ b/src/Undefined/ai/multimodal/parsing.py @@ -10,7 +10,7 @@ def _parse_line_value(line: str, prefix: str) -> str: """解析行内容,提取指定前缀后的值。""" - value = line.split(":", 1)[-1].split(":", 1)[-1].strip() + value = line[len(prefix) :].strip() if line.startswith(prefix) else line.strip() return "" if value == "无" else value @@ -34,7 +34,10 @@ def _parse_analysis_response(content: str) -> dict[str, str]: line = line.strip() for field, prefixes in field_prefixes.items(): if line.startswith(prefixes): - result[field] = _parse_line_value(line, prefixes[0]) + matched_prefix = ( + prefixes[0] if line.startswith(prefixes[0]) else prefixes[1] + ) + result[field] = _parse_line_value(line, matched_prefix) if not result["description"]: result["description"] = content diff --git a/src/Undefined/ai/prompts/system_context.py b/src/Undefined/ai/prompts/system_context.py index e1ac02e9..7dbf2840 100644 --- a/src/Undefined/ai/prompts/system_context.py +++ b/src/Undefined/ai/prompts/system_context.py @@ -24,7 +24,7 @@ def select_system_prompt_path( # NagaAgent 模式切换专用系统提示词模板 if enabled: return "res/prompts/undefined_nagaagent.xml" - return "res/prompts/undefined.xml" + return default_path def build_model_config_info(runtime_config: Any) -> str: diff --git a/src/Undefined/api/routes/naga/bind.py b/src/Undefined/api/routes/naga/bind.py index fe2e2ca9..f2f3b9a2 100644 --- a/src/Undefined/api/routes/naga/bind.py +++ b/src/Undefined/api/routes/naga/bind.py @@ -41,6 +41,8 @@ async def naga_bind_callback_handler( body = await request.json() except Exception: return _json_error("Invalid JSON", status=400) + if not isinstance(body, dict): + return _json_error("JSON body must be an object", status=400) bind_uuid = str(body.get("bind_uuid", "") or "").strip() naga_id = str(body.get("naga_id", "") or "").strip() diff --git a/src/Undefined/api/routes/naga/unbind.py b/src/Undefined/api/routes/naga/unbind.py index f29e3a91..d6b18d13 100644 --- a/src/Undefined/api/routes/naga/unbind.py +++ b/src/Undefined/api/routes/naga/unbind.py @@ -38,6 +38,8 @@ async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> R body = await request.json() except Exception: return _json_error("Invalid JSON", status=400) + if not isinstance(body, dict): + return _json_error("JSON body must be an object", status=400) bind_uuid = str(body.get("bind_uuid", "") or "").strip() naga_id = str(body.get("naga_id", "") or "").strip() diff --git a/src/Undefined/attachments/render.py b/src/Undefined/attachments/render.py index a6a3be44..6f3433c6 100644 --- a/src/Undefined/attachments/render.py +++ b/src/Undefined/attachments/render.py @@ -149,7 +149,7 @@ def _render_image_tag( ) -> bool: """Render an image attachment as an inline CQ:image. Returns True on success.""" image_source = record.source_ref - if record.local_path: + if record.local_path and Path(record.local_path).is_file(): image_source = Path(record.local_path).resolve().as_uri() elif not image_source: replacement = f"[图片 uid={uid} 缺少文件]" @@ -264,12 +264,19 @@ async def dispatch_pending_file_sends( send_record.local_path, name=send_record.display_name or None, ) - else: + elif target_type == "private": await sender.send_private_file( target_id, send_record.local_path, name=send_record.display_name or None, ) + else: + logger.warning( + "[文件发送] 跳过:不支持的 target_type=%s uid=%s", + target_type, + send_record.uid, + ) + continue except Exception: logger.warning( "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", diff --git a/src/Undefined/config/load_sections/finalize.py b/src/Undefined/config/load_sections/finalize.py index 5eda3191..d6da9a07 100644 --- a/src/Undefined/config/load_sections/finalize.py +++ b/src/Undefined/config/load_sections/finalize.py @@ -23,12 +23,14 @@ def load_finalize(ctx: dict[str, Any], *, strict: bool = True) -> None: embedding_model=ctx["embedding_model"], ) - _log_debug_info( - ctx["chat_model"], - ctx["vision_model"], - ctx["security_model"], - ctx["naga_model"], - ctx["agent_model"], - ctx["summary_model"], - ctx["grok_model"], + debug_keys = ( + "chat_model", + "vision_model", + "security_model", + "naga_model", + "agent_model", + "summary_model", + "grok_model", ) + if all(key in ctx for key in debug_keys): + _log_debug_info(*(ctx[key] for key in debug_keys)) diff --git a/src/Undefined/config/load_sections/history_skills.py b/src/Undefined/config/load_sections/history_skills.py index 4c7d9788..615580e0 100644 --- a/src/Undefined/config/load_sections/history_skills.py +++ b/src/Undefined/config/load_sections/history_skills.py @@ -195,12 +195,14 @@ def load_history_skills( ), 2.0, ) + skills_hot_reload_interval = _normalize_queue_interval(skills_hot_reload_interval) skills_hot_reload_debounce = _coerce_float( _get_value( data, ("skills", "hot_reload_debounce"), "SKILLS_HOT_RELOAD_DEBOUNCE" ), 0.5, ) + skills_hot_reload_debounce = _normalize_queue_interval(skills_hot_reload_debounce) agent_intro_autogen_enabled = _coerce_bool( _get_value( diff --git a/src/Undefined/config/load_sections/integrations.py b/src/Undefined/config/load_sections/integrations.py index bd98758c..ebbafdf7 100644 --- a/src/Undefined/config/load_sections/integrations.py +++ b/src/Undefined/config/load_sections/integrations.py @@ -153,10 +153,14 @@ def load_integrations( _get_value(data, ("code_delivery", "default_command_timeout_seconds"), None), 600, ) + if code_delivery_command_timeout < 1: + code_delivery_command_timeout = 600 code_delivery_max_command_output = _coerce_int( _get_value(data, ("code_delivery", "max_command_output_chars"), None), 20000, ) + if code_delivery_max_command_output < 1: + code_delivery_max_command_output = 20000 code_delivery_default_archive_format = _coerce_str( _get_value(data, ("code_delivery", "default_archive_format"), None), "zip", @@ -166,6 +170,8 @@ def load_integrations( code_delivery_max_archive_size_mb = _coerce_int( _get_value(data, ("code_delivery", "max_archive_size_mb"), None), 200 ) + if code_delivery_max_archive_size_mb < 1: + code_delivery_max_archive_size_mb = 200 code_delivery_cleanup_on_finish = _coerce_bool( _get_value(data, ("code_delivery", "cleanup_on_finish"), None), True ) @@ -176,6 +182,8 @@ def load_integrations( _get_value(data, ("code_delivery", "llm_max_retries_per_request"), None), 5, ) + if code_delivery_llm_max_retries < 0: + code_delivery_llm_max_retries = 5 code_delivery_notify_on_llm_failure = _coerce_bool( _get_value(data, ("code_delivery", "notify_on_llm_failure"), None), True, diff --git a/src/Undefined/config/load_sections/knowledge.py b/src/Undefined/config/load_sections/knowledge.py index 9e5406bf..266976f2 100644 --- a/src/Undefined/config/load_sections/knowledge.py +++ b/src/Undefined/config/load_sections/knowledge.py @@ -62,6 +62,8 @@ def load_knowledge( ) if knowledge_chunk_overlap < 0: knowledge_chunk_overlap = 0 + if knowledge_chunk_overlap >= knowledge_chunk_size: + knowledge_chunk_overlap = max(0, knowledge_chunk_size - 1) knowledge_default_top_k = _coerce_int( _get_value(data, ("knowledge", "default_top_k"), None), 5 ) diff --git a/src/Undefined/config/load_sections/logging_tools.py b/src/Undefined/config/load_sections/logging_tools.py index 65a3347a..f5d6c06e 100644 --- a/src/Undefined/config/load_sections/logging_tools.py +++ b/src/Undefined/config/load_sections/logging_tools.py @@ -37,9 +37,13 @@ def load_logging_tools( log_max_size_mb = _coerce_int( _get_value(data, ("logging", "max_size_mb"), "LOG_MAX_SIZE_MB"), 10 ) + if log_max_size_mb <= 0: + log_max_size_mb = 10 log_backup_count = _coerce_int( _get_value(data, ("logging", "backup_count"), "LOG_BACKUP_COUNT"), 5 ) + if log_backup_count < 0: + log_backup_count = 0 log_tty_enabled = _coerce_bool( _get_value(data, ("logging", "tty_enabled"), "LOG_TTY_ENABLED"), False, diff --git a/tests/test_llm_request_params.py b/tests/test_llm_request_params.py index e34227b6..a917e6fa 100644 --- a/tests/test_llm_request_params.py +++ b/tests/test_llm_request_params.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib from types import SimpleNamespace from typing import Any, cast from unittest.mock import AsyncMock @@ -339,10 +340,11 @@ async def test_chat_request_auto_sets_prompt_cache_key_from_request_context() -> call_type="chat", ) + group_scope = hashlib.sha1(b"12345", usedforsecurity=False).hexdigest()[:8] assert fake_client.chat.completions.last_kwargs is not None assert ( fake_client.chat.completions.last_kwargs["prompt_cache_key"] - == "pc:gpt-test:chat:group:12345" + == f"pc:gpt-test:chat:group:{group_scope}" ) await requester._http_client.aclose() @@ -1130,6 +1132,7 @@ async def test_ai_client_request_model_prefetch_keeps_transport_count_from_calle "tool_manager", SimpleNamespace(maybe_merge_agent_tools=lambda _call_type, tools: tools), ) + setattr(client, "_filter_tools_for_runtime_config", lambda tools: tools) monkeypatch.setattr(client, "_maybe_prefetch_tools", prefetch_mock) cfg = ChatModelConfig( diff --git a/tests/test_llm_retry_suppression.py b/tests/test_llm_retry_suppression.py index d47905f9..c58a1746 100644 --- a/tests/test_llm_retry_suppression.py +++ b/tests/test_llm_retry_suppression.py @@ -21,8 +21,11 @@ @pytest.mark.asyncio -async def test_ai_ask_reraises_queued_llm_error() -> None: +async def test_ai_ask_suppresses_queued_llm_error_when_retries_exhausted() -> None: client: Any = object.__new__(AIClient) + client.runtime_config = cast( + Any, SimpleNamespace(log_thinking=False, ai_request_max_retries=0) + ) client._prompt_builder = cast( Any, SimpleNamespace( @@ -34,9 +37,7 @@ async def test_ai_ask_reraises_queued_llm_error() -> None: ) client.tool_manager = cast(Any, SimpleNamespace(get_openai_tools=lambda: [])) client._filter_tools_for_runtime_config = lambda tools: tools - client._get_runtime_config = cast( - Any, lambda: cast(Any, SimpleNamespace(log_thinking=False)) - ) + client._get_runtime_config = cast(Any, lambda: client.runtime_config) client.model_selector = cast(Any, SimpleNamespace(wait_ready=AsyncMock())) client.chat_config = ChatModelConfig( api_url="https://api.openai.com/v1", @@ -60,8 +61,9 @@ async def test_ai_ask_reraises_queued_llm_error() -> None: proxy_config_available=False, ) - with pytest.raises(RuntimeError, match="boom"): - await AIClient.ask(client, "hello") + result = await AIClient.ask(client, "hello") + + assert result == "" @pytest.mark.asyncio From 8f0c511ece512ba465a57d030624c04da5ec3d8b Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 24 May 2026 10:31:47 +0800 Subject: [PATCH 13/16] chore(version): bump version to 3.5.0 --- apps/undefined-console/package-lock.json | 4 ++-- apps/undefined-console/package.json | 2 +- apps/undefined-console/src-tauri/Cargo.lock | 2 +- apps/undefined-console/src-tauri/Cargo.toml | 2 +- apps/undefined-console/src-tauri/tauri.conf.json | 2 +- pyproject.toml | 2 +- src/Undefined/__init__.py | 2 +- uv.lock | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/apps/undefined-console/package-lock.json b/apps/undefined-console/package-lock.json index a6a0a5f5..df066c1d 100644 --- a/apps/undefined-console/package-lock.json +++ b/apps/undefined-console/package-lock.json @@ -1,12 +1,12 @@ { "name": "undefined-console", - "version": "3.4.2", + "version": "3.5.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "undefined-console", - "version": "3.4.2", + "version": "3.5.0", "dependencies": { "@tauri-apps/api": "^2.3.0", "@tauri-apps/plugin-http": "^2.3.0" diff --git a/apps/undefined-console/package.json b/apps/undefined-console/package.json index d5e75fd1..a24b0b6b 100644 --- a/apps/undefined-console/package.json +++ b/apps/undefined-console/package.json @@ -1,7 +1,7 @@ { "name": "undefined-console", "private": true, - "version": "3.4.2", + "version": "3.5.0", "type": "module", "scripts": { "tauri": "tauri", diff --git a/apps/undefined-console/src-tauri/Cargo.lock b/apps/undefined-console/src-tauri/Cargo.lock index dcb41530..ee1e96a1 100644 --- a/apps/undefined-console/src-tauri/Cargo.lock +++ b/apps/undefined-console/src-tauri/Cargo.lock @@ -4063,7 +4063,7 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "undefined_console" -version = "3.4.2" +version = "3.5.0" dependencies = [ "serde", "serde_json", diff --git a/apps/undefined-console/src-tauri/Cargo.toml b/apps/undefined-console/src-tauri/Cargo.toml index 3b020568..7e0bf510 100644 --- a/apps/undefined-console/src-tauri/Cargo.toml +++ b/apps/undefined-console/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "undefined_console" -version = "3.4.2" +version = "3.5.0" description = "Undefined cross-platform management console" authors = ["Undefined contributors"] license = "MIT" diff --git a/apps/undefined-console/src-tauri/tauri.conf.json b/apps/undefined-console/src-tauri/tauri.conf.json index f4270754..d26c93ba 100644 --- a/apps/undefined-console/src-tauri/tauri.conf.json +++ b/apps/undefined-console/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "Undefined Console", - "version": "3.4.2", + "version": "3.5.0", "identifier": "com.undefined.console", "build": { "beforeDevCommand": "npm run dev", diff --git a/pyproject.toml b/pyproject.toml index 3b6d44d5..d8569499 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Undefined-bot" -version = "3.4.2" +version = "3.5.0" description = "QQ bot platform with cognitive memory architecture and multi-agent Skills, via OneBot V11." readme = "README.md" authors = [ diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index b3151b01..8f621efb 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -5,7 +5,7 @@ import importlib from typing import Any -__version__ = "3.4.2" +__version__ = "3.5.0" __all__ = [ "__version__", diff --git a/uv.lock b/uv.lock index 3459b092..dd09c6b9 100644 --- a/uv.lock +++ b/uv.lock @@ -4626,7 +4626,7 @@ wheels = [ [[package]] name = "undefined-bot" -version = "3.4.2" +version = "3.5.0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From b46009fc9b9521bd7e764e1f9af4c674aadcf22d Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 24 May 2026 10:36:04 +0800 Subject: [PATCH 14/16] docs(changelog): add v3.5.0 release summary Co-authored-by: Cursor --- CHANGELOG.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 90b8e3d4..f0cc5ac9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,17 @@ +## v3.5.0 模块拆分、库嵌入与工具链收敛 + +本版本是一次以可维护性与可嵌入性为主线的架构整理:将 AI 客户端、消息处理、附件、认知、表情包、协调器、OneBot 与 Agent 运行器等原先体量过大的单文件拆成职责清晰的子包,同时保留向后兼容的 import 路径与 CLI 行为不变。并行修复了 `end` 与同轮业务工具并行调用时可能重复发送或误结束会话的问题,并补齐 Python 库嵌入所需的 `set_config`、`Config.from_mapping`、根包 lazy re-export 与 `py.typed` 类型标记。配置加载改为分段解析 + 域级 parser,文档与测试同步覆盖库嵌入、包布局与公共 API 契约,让 Undefined 既能继续作为 QQ Bot 运行,也能更可靠地被脚本、测试与其它服务按需复用。 + +- 重构核心运行时模块结构。`ai/client`、`ai/llm`、`ai/prompts`、`ai/multimodal` 分别承载客户端组合、模型请求、Prompt 构建与多模态解析;`handlers/`、`onebot/`、`attachments/`、`cognitive/service` + `cognitive/historian`、`memes/`(ingest / search / 图像工具)、`services/coordinator`、`services/message_batcher`、`skills/agents/runner` 与 `api/routes/naga/` 等子包按域拆分;原 monolith 文件以兼容 shim 或 `__init__` 重导出保留,删除被 shadow 的不可达死代码,降低单文件复杂度与后续改动风险。 +- 完善 Python 库嵌入能力。根包新增 lazy re-export(`Config`、`get_config`、`set_config`、`AIClient`、Skills 注册表、认知/知识库/表情包/附件/Runtime API 等稳定符号);`Config.from_mapping` / `ConfigBuilder` 支持无 `config.toml` 的内存构建,`env_registry` 统一管理环境变量兜底;`set_config()` 为 opt-in 注入全局单例,CLI 启动链不调用。wheel 打包 `py.typed`(PEP 561),新增 `docs/python-api.md` 公共 API 参考,README 补充嵌入示例与文档索引。 +- 收敛 `end` 工具与同轮并行调用的运行时语义。当模型在同一轮将 `end` 与 `send_message` 或其它业务工具一并调用时:其它工具照常并行执行并返回结果,`end` 不执行并回填明确拒绝响应,避免重复发送与「未读 tool 结果就结束」;提示词、`each.md` 与决策回归用例同步写入 P0 级「end 禁止并行」规则与运行时效果说明。 +- 改进缺失 tool call 时的重试策略。模型返回纯文本但未调用任何工具时,保留 assistant 原文于 messages,注入通用纠正提示而非硬编码 `send_message`/`end`,减少误导性后续 tool 调用;拆分后修复 Skills 路径解析,`PACKAGE_ROOT` 统一指向包根,避免内置工具零加载回归。 +- 加固队列化 LLM 与运行时边界。收紧 queued LLM 重试与 pending-call 清理、配置分段加载的边界校验、附件渲染容错,以及 Prompt 缓存键的隐私安全处理(系统上下文只暴露非敏感模型名等字段);表情包入库锁与图像工具去重,Naga API 路由守卫与 B 站 WBI 导航解析小幅简化。 +- 更新架构图、开发指南与配置文档。`docs/development.md` 反映拆分后的目录树;`docs/configuration.md` 新增库嵌入专节(`from_mapping` / `set_config` / 环境变量注册表);`ARCHITECTURE.md` 与相关运维文档同步引用路径。 +- 补强测试与工程契约。新增包布局、公共 API import、CLI 启动兼容、`Config.from_mapping` / 纯环境变量构建、`end` 同轮拒绝与 defer、`AIClient` setup 路径等回归测试;更新 LLM 重试抑制与请求参数相关用例,总测试覆盖库嵌入与拆分后的关键路径。 + +--- + ## v3.4.2 总结更准、上下文可配、高并发更稳 本版本主要解决三类实际问题:群聊消息变长、合并发送变多之后,`/summary` 容易慢、容易编、分块预算也不准;主对话注入历史的 200 条硬顶与模型真实窗口脱节;高并发下用户连问「在吗」「好了吗」时,机器人仍可能把旧任务当新活重跑。围绕这些痛点,版本把「用户主动要总结」和「AI 自己调总结能力」拆成两条更合适的链路,用可配置的上下文窗口统一约束注入与分块,并同步收紧提示词、史官侧写与定时任务持久化,让总结更可信、配置更贴近上游模型、并发场景下更少重复劳动。 From c85ad667b903642fc25a5f97e4d69642a640e2f1 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 24 May 2026 10:49:17 +0800 Subject: [PATCH 15/16] fix(package): add TYPE_CHECKING stubs for lazy root exports Co-authored-by: Cursor --- src/Undefined/__init__.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 8f621efb..5cbc88de 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -3,7 +3,22 @@ from __future__ import annotations import importlib -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .ai import AIClient + from .api._context import RuntimeAPIContext + from .api.app import RuntimeAPIServer + from .attachments import AttachmentRegistry + from .cognitive.service import CognitiveService + from .config import Config, get_config, set_config + from .knowledge.manager import KnowledgeManager + from .memes.service import MemeService + from .skills.agents import AgentRegistry + from .skills.anthropic_skills import AnthropicSkillRegistry + from .skills.pipelines.registry import PipelineRegistry + from .skills.registry import BaseRegistry + from .skills.tools import ToolRegistry __version__ = "3.5.0" From d3cc973f46e95e293ab5a90be7e09a1011cc64e8 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 24 May 2026 10:55:17 +0800 Subject: [PATCH 16/16] refactor(package): derive __all__ from lazy export registry Co-authored-by: Cursor --- src/Undefined/__init__.py | 51 +++++++++++++++------------------------ 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 5cbc88de..86cf9678 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -6,41 +6,26 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from .ai import AIClient - from .api._context import RuntimeAPIContext - from .api.app import RuntimeAPIServer - from .attachments import AttachmentRegistry - from .cognitive.service import CognitiveService - from .config import Config, get_config, set_config - from .knowledge.manager import KnowledgeManager - from .memes.service import MemeService - from .skills.agents import AgentRegistry - from .skills.anthropic_skills import AnthropicSkillRegistry - from .skills.pipelines.registry import PipelineRegistry - from .skills.registry import BaseRegistry - from .skills.tools import ToolRegistry + from .ai import AIClient as AIClient + from .api._context import RuntimeAPIContext as RuntimeAPIContext + from .api.app import RuntimeAPIServer as RuntimeAPIServer + from .attachments import AttachmentRegistry as AttachmentRegistry + from .cognitive.service import CognitiveService as CognitiveService + from .config import Config as Config + from .config import get_config as get_config + from .config import set_config as set_config + from .knowledge.manager import KnowledgeManager as KnowledgeManager + from .memes.service import MemeService as MemeService + from .skills.agents import AgentRegistry as AgentRegistry + from .skills.anthropic_skills import ( + AnthropicSkillRegistry as AnthropicSkillRegistry, + ) + from .skills.pipelines.registry import PipelineRegistry as PipelineRegistry + from .skills.registry import BaseRegistry as BaseRegistry + from .skills.tools import ToolRegistry as ToolRegistry __version__ = "3.5.0" -__all__ = [ - "__version__", - "Config", - "get_config", - "set_config", - "AIClient", - "ToolRegistry", - "AgentRegistry", - "PipelineRegistry", - "BaseRegistry", - "AnthropicSkillRegistry", - "CognitiveService", - "KnowledgeManager", - "MemeService", - "AttachmentRegistry", - "RuntimeAPIServer", - "RuntimeAPIContext", -] - # symbol -> (module_path, attribute_name);首次访问时才 importlib 加载 _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "Config": ("Undefined.config", "Config"), @@ -63,6 +48,8 @@ "RuntimeAPIContext": ("Undefined.api._context", "RuntimeAPIContext"), } +__all__ = ["__version__", *_LAZY_IMPORTS] + def __getattr__(name: str) -> Any: if name not in _LAZY_IMPORTS: