diff --git a/.coderabbit.yaml b/.coderabbit.yaml
new file mode 100644
index 00000000..b220aba8
--- /dev/null
+++ b/.coderabbit.yaml
@@ -0,0 +1,8 @@
+# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
+
+reviews:
+ auto_review:
+ enabled: true
+ auto_incremental_review: true
+ base_branches:
+ - ".*"
diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md
index 2a48c448..a7661901 100644
--- a/ARCHITECTURE.md
+++ b/ARCHITECTURE.md
@@ -18,14 +18,14 @@ graph TB
ConfigLoader["ConfigManager
配置管理器
[config/manager.py + loader.py]"]
ConfigHotReload["ConfigHotReload
热更新应用器
[config/hot_reload.py]"]
ConfigModels["配置模型
[config/models.py]
ChatModelConfig
VisionModelConfig
SecurityModelConfig
AgentModelConfig"]
- OneBotClient["OneBotClient
WebSocket 客户端
[onebot.py]"]
+ OneBotClient["OneBotClient
WebSocket 客户端
[onebot/ + onebot.py shim]"]
Context["RequestContext
请求上下文
[context.py]"]
WebUI["webui.py
配置控制台
[src/Undefined/webui.py]"]
end
%% ==================== 消息处理层 ====================
subgraph MessageLayer["消息处理层 (src/Undefined/)"]
- MessageHandler["MessageHandler
消息处理器
[handlers.py]"]
+ MessageHandler["MessageHandler
消息处理器
[handlers/]"]
subgraph BilibiliModule["Bilibili 模块 (bilibili/)"]
BilibiliParser["parser.py
标识符解析
• BV/AV号 • URL
• b23.tv短链 • 小程序JSON"]
@@ -46,10 +46,10 @@ graph TB
CommandDispatcher["CommandDispatcher
命令分发器
• /help /stats /admin
• /bugfix /faq
[services/command.py]"]
- MessageBatcher["MessageBatcher
同 sender 短时合并
• 按 (scope, sender_id) 分桶
• T1=window_seconds 结束 batch
• T2=pre_send_seconds 投机预发送
• 拍一拍/buffer 内 @bot 旁路
• 首条 @bot 整批走 mention 队列
[services/message_batcher.py]"]
+ MessageBatcher["MessageBatcher
同 sender 短时合并
• 按 (scope, sender_id) 分桶
• T1=window_seconds 结束 batch
• T2=pre_send_seconds 投机预发送
• 拍一拍/buffer 内 @bot 旁路
• 首条 @bot 整批走 mention 队列
[services/message_batcher/ + shim]"]
subgraph QueueSystem["车站-列车 队列系统 (services/)"]
- AICoordinator["AICoordinator
AI 协调器
• Prompt 构建
• 队列管理
• 回复执行
[ai_coordinator.py]"]
+ AICoordinator["AICoordinator
AI 协调器
• Prompt 构建
• 队列管理
• 回复执行
[services/coordinator/ + ai_coordinator.py shim]"]
QueueManager["QueueManager
队列管理器
[queue_manager.py]"]
subgraph ModelQueues["ModelQueue 队列组 (按模型隔离)"]
@@ -65,13 +65,13 @@ graph TB
%% ==================== AI 核心能力层 ====================
subgraph AILayer["AI 核心能力层 (src/Undefined/ai/)"]
- AIClient["AIClient
AI 客户端主入口
[client.py]
• 技能热重载 • MCP 初始化
• Agent intro 生成"]
+ AIClient["AIClient
AI 客户端主入口
[ai/client/ + client.py shim]
• 技能热重载 • MCP 初始化
• Agent intro 生成"]
subgraph AIComponents["AI 组件"]
- PromptBuilder["PromptBuilder
提示词构建器
[prompts.py]"]
- ModelRequester["ModelRequester
模型请求器
[llm.py]
• OpenAI SDK • 工具清理
• Thinking 提取"]
+ PromptBuilder["PromptBuilder
提示词构建器
[ai/prompts/ + prompts.py shim]"]
+ ModelRequester["ModelRequester
模型请求器
[ai/llm/ + llm.py shim]
• OpenAI SDK • 工具清理
• Thinking 提取"]
ToolManager["ToolManager
工具管理器
[tooling.py]
• 工具执行 • Agent 工具合并
• MCP 工具注入"]
- MultimodalAnalyzer["MultimodalAnalyzer
多模态分析器
[multimodal.py]
• 图片/音频/视频"]
+ MultimodalAnalyzer["MultimodalAnalyzer
多模态分析器
[ai/multimodal/ + multimodal.py shim]
• 图片/音频/视频"]
SummaryService["SummaryService
总结服务
[summaries.py]
• 聊天记录总结
• 标题生成"]
TokenCounter["TokenCounter
Token 统计
[tokens.py]"]
Parsing["Parsing
响应解析
[parsing.py]"]
@@ -154,9 +154,9 @@ graph TB
HistoryManager["MessageHistoryManager
消息历史管理
[utils/history.py]
• 懒加载
• 10000条限制"]
MemoryStorage["MemoryStorage
置顶备忘录
[memory.py]
• 500条上限
• 自动去重"]
EndSummaryStorage["EndSummaryStorage
短期总结存储
[end_summary_storage.py]"]
- CognitiveService["CognitiveService
认知记忆服务
[cognitive/service.py]
• 事件检索 • 侧写读取
• 入队 memory job"]
+ CognitiveService["CognitiveService
认知记忆服务
[cognitive/service/]
• 事件检索 • 侧写读取
• 入队 memory job"]
CognitiveJobQueue["JobQueue
认知任务队列
[cognitive/job_queue.py]
• pending/processing/failed"]
- CognitiveHistorian["HistorianWorker
后台史官
[cognitive/historian.py]
• 绝对化改写 • 闸门重试
• 侧写合并(含历史事件注入)"]
+ CognitiveHistorian["HistorianWorker
后台史官
[cognitive/historian/]
• 绝对化改写 • 闸门重试
• 侧写合并(含历史事件注入)"]
CognitiveVectorStore["CognitiveVectorStore
向量存储
[cognitive/vector_store.py]
• events/profiles
• 时间衰减加权排序
• MMR 去重"]
CognitiveProfileStorage["ProfileStorage
侧写存储
[cognitive/profile_storage.py]
• users/groups Markdown
• 历史快照"]
MemeSystem["MemeSystem
表情包存储
[memes/]
• worker.py (两阶段识别)
• sqlite+chromadb
• blob 持久化"]
@@ -849,10 +849,10 @@ description: 从 PDF 文件中提取文本和表格,填写表单。当用户
### 8层架构分层
1. **外部实体层**:用户、管理员、OneBot 协议端 (NapCat/Lagrange.Core)、大模型 API 服务商
-2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块)
-3. **消息处理层**:MessageHandler (handlers.py)、SecurityService (security.py)、CommandDispatcher (services/command.py)、MessageBatcher (services/message_batcher.py)、AICoordinator (ai_coordinator.py)、QueueManager (queue_manager.py)、自动处理管线 (skills/pipelines/)、Bilibili/arXiv/GitHub 解析与发送模块
+2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py + parsers/ + load_sections/)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot/ + onebot.py shim)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块,含 naga/ 子包)
+3. **消息处理层**:MessageHandler (`handlers/`)、SecurityService (security.py)、CommandDispatcher (services/command.py + commands/ mixins)、MessageBatcher (services/message_batcher/)、AICoordinator (services/coordinator/ + ai_coordinator.py 门面)、QueueManager (queue_manager.py)、自动处理管线 (skills/pipelines/)、Bilibili/arXiv/GitHub 解析与发送模块
自动提取由 `PipelineRegistry` 并行检测、并行处理全部命中的管线;发送结果写入历史后继续进入 AI 自动回复。
-4. **AI 核心能力层**:AIClient (client.py)、PromptBuilder (prompts.py)、ModelRequester (llm.py)、ToolManager (tooling.py)、MultimodalAnalyzer (multimodal.py)、SummaryService (summaries.py)、TokenCounter (tokens.py)
+4. **AI 核心能力层**:AIClient (ai/client/ + client.py shim)、PromptBuilder (ai/prompts/ + prompts.py shim)、ModelRequester (ai/llm/ + llm.py shim)、ToolManager (tooling.py)、MultimodalAnalyzer (ai/multimodal/ + multimodal.py shim)、SummaryService (summaries.py)、TokenCounter (tokens.py)
5. **存储与上下文层**:MessageHistoryManager (utils/history.py, 10000条限制)、MemoryStorage (memory.py, 置顶备忘录, 500条上限)、EndSummaryStorage、CognitiveService + JobQueue + HistorianWorker + VectorStore + ProfileStorage、MemeService + MemeWorker + MemeStore + MemeVectorStore (表情包库)、FAQStorage、ScheduledTaskStorage、TokenUsageStorage (自动归档)
6. **技能系统层**:ToolRegistry (registry.py)、AgentRegistry、6个 Agents、11类 Toolsets
7. **异步 IO 层**:统一 IO 工具 (utils/io.py),包含 write_json、read_json、append_line、跨平台文件锁 (flock/msvcrt)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 90b8e3d4..f0cc5ac9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,17 @@
+## v3.5.0 模块拆分、库嵌入与工具链收敛
+
+本版本是一次以可维护性与可嵌入性为主线的架构整理:将 AI 客户端、消息处理、附件、认知、表情包、协调器、OneBot 与 Agent 运行器等原先体量过大的单文件拆成职责清晰的子包,同时保留向后兼容的 import 路径与 CLI 行为不变。并行修复了 `end` 与同轮业务工具并行调用时可能重复发送或误结束会话的问题,并补齐 Python 库嵌入所需的 `set_config`、`Config.from_mapping`、根包 lazy re-export 与 `py.typed` 类型标记。配置加载改为分段解析 + 域级 parser,文档与测试同步覆盖库嵌入、包布局与公共 API 契约,让 Undefined 既能继续作为 QQ Bot 运行,也能更可靠地被脚本、测试与其它服务按需复用。
+
+- 重构核心运行时模块结构。`ai/client`、`ai/llm`、`ai/prompts`、`ai/multimodal` 分别承载客户端组合、模型请求、Prompt 构建与多模态解析;`handlers/`、`onebot/`、`attachments/`、`cognitive/service` + `cognitive/historian`、`memes/`(ingest / search / 图像工具)、`services/coordinator`、`services/message_batcher`、`skills/agents/runner` 与 `api/routes/naga/` 等子包按域拆分;原 monolith 文件以兼容 shim 或 `__init__` 重导出保留,删除被 shadow 的不可达死代码,降低单文件复杂度与后续改动风险。
+- 完善 Python 库嵌入能力。根包新增 lazy re-export(`Config`、`get_config`、`set_config`、`AIClient`、Skills 注册表、认知/知识库/表情包/附件/Runtime API 等稳定符号);`Config.from_mapping` / `ConfigBuilder` 支持无 `config.toml` 的内存构建,`env_registry` 统一管理环境变量兜底;`set_config()` 为 opt-in 注入全局单例,CLI 启动链不调用。wheel 打包 `py.typed`(PEP 561),新增 `docs/python-api.md` 公共 API 参考,README 补充嵌入示例与文档索引。
+- 收敛 `end` 工具与同轮并行调用的运行时语义。当模型在同一轮将 `end` 与 `send_message` 或其它业务工具一并调用时:其它工具照常并行执行并返回结果,`end` 不执行并回填明确拒绝响应,避免重复发送与「未读 tool 结果就结束」;提示词、`each.md` 与决策回归用例同步写入 P0 级「end 禁止并行」规则与运行时效果说明。
+- 改进缺失 tool call 时的重试策略。模型返回纯文本但未调用任何工具时,保留 assistant 原文于 messages,注入通用纠正提示而非硬编码 `send_message`/`end`,减少误导性后续 tool 调用;拆分后修复 Skills 路径解析,`PACKAGE_ROOT` 统一指向包根,避免内置工具零加载回归。
+- 加固队列化 LLM 与运行时边界。收紧 queued LLM 重试与 pending-call 清理、配置分段加载的边界校验、附件渲染容错,以及 Prompt 缓存键的隐私安全处理(系统上下文只暴露非敏感模型名等字段);表情包入库锁与图像工具去重,Naga API 路由守卫与 B 站 WBI 导航解析小幅简化。
+- 更新架构图、开发指南与配置文档。`docs/development.md` 反映拆分后的目录树;`docs/configuration.md` 新增库嵌入专节(`from_mapping` / `set_config` / 环境变量注册表);`ARCHITECTURE.md` 与相关运维文档同步引用路径。
+- 补强测试与工程契约。新增包布局、公共 API import、CLI 启动兼容、`Config.from_mapping` / 纯环境变量构建、`end` 同轮拒绝与 defer、`AIClient` setup 路径等回归测试;更新 LLM 重试抑制与请求参数相关用例,总测试覆盖库嵌入与拆分后的关键路径。
+
+---
+
## v3.4.2 总结更准、上下文可配、高并发更稳
本版本主要解决三类实际问题:群聊消息变长、合并发送变多之后,`/summary` 容易慢、容易编、分块预算也不准;主对话注入历史的 200 条硬顶与模型真实窗口脱节;高并发下用户连问「在吗」「好了吗」时,机器人仍可能把旧任务当新活重跑。围绕这些痛点,版本把「用户主动要总结」和「AI 自己调总结能力」拆成两条更合适的链路,用可配置的上下文窗口统一约束注入与分块,并同步收紧提示词、史官侧写与定时任务持久化,让总结更可信、配置更贴近上游模型、并发场景下更少重复劳动。
diff --git a/README.md b/README.md
index b626b06b..2043a552 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,7 @@
## ⚡ 核心特性
- **Skills 架构**:全新设计的技能系统,将基础工具(Tools)与智能代理(Agents)分层管理,支持自动发现与注册。
+- **可嵌入 Python 库**:`pip install Undefined-bot` 后可 import 配置、`AIClient`、Skills 与认知记忆等组件,无需启动 Bot CLI。详见 [Python 库 API 参考](docs/python-api.md)。
- **Skills 热重载**:自动扫描 `skills/` 目录,检测到变更后即时重载工具与 Agent,无需重启服务。
- **三层分层记忆架构**:创新的分层记忆系统,模拟人类记忆机制——
- **短期记忆**(`end.memo`):每轮对话结束自动记录便签备忘,最近 N 条始终注入,保持短期连续性,零配置开箱即用
@@ -77,9 +78,10 @@
Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将各个模块的深入解析与高阶玩法整理成了专题游览图。这里是开启探索的钥匙:
- ⚙️ **[安装与部署指南](docs/deployment.md)**:不管你是需要 `pip` 无脑一键安装,还是源码二次开发,这里的排坑指南应有尽有。
+- 📦 **[Python 库 API 参考](docs/python-api.md)**:根包 lazy re-export、`Config.from_mapping` / `set_config`、公共 API 符号表与嵌入示例。
- 🖥️ **[WebUI 使用指南](docs/webui-guide.md)**:管理控制台功能一览——配置编辑、日志查看、认知记忆管理、表情包库、AI 对话与系统监控。
- 🧭 **[Management API 与远程管理](docs/management-api.md)**:WebUI / App 共用的管理接口、认证、配置/日志/Bot 控制与引导探针说明。
-- 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置。
+- 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置;库嵌入见 [§2 库嵌入配置](docs/configuration.md#2-库嵌入配置)。
- 😶 **[表情包系统 (Memes)](docs/memes.md)**:查看表情包两阶段判定管线、统一图片 `uid` 发送机制、检索模式及库存管理说明。
- 💡 **[交互与使用手册](docs/usage.md)**:包含实用的对话示例、多模态解析用法,以及群管家必备的管理员`/指令`。
- 📝 **[版本变更记录](CHANGELOG.md)**:查看按版本整理的更新摘要,也可在运行时使用 `/changelog` 查询。
@@ -91,7 +93,7 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将
- 🌐 **[Runtime API 与 OpenAPI](docs/openapi.md)**:主进程 Runtime API、鉴权、探针、记忆/侧写查询和运行态集成说明。
- 🏗️ **[构建指南](docs/build.md)**:Python 包、WebUI、跨平台 App、Android 与 Release 工作流的构建说明。
- 🔧 **[运维脚本](scripts/README.md)**:嵌入模型更换后的向量库重嵌入等维护工具。
-- 👨💻 **[开发者与拓展中心](docs/development.md)**:代码结构剖析和开发新 Agent 的流程参考及自检命令。
+- 👨💻 **[开发者与拓展中心](docs/development.md)**:代码结构剖析、模块拆分后的目录树、开发新 Agent 的流程参考及自检命令。
- **[核心技能系统 (Skills) 解析](src/Undefined/skills/README.md)**:全景式掌握什么是 Skills 架构、怎样定制原子工具与子智能体。
- **[callable.json 共享授权说明](docs/callable.md)**:细粒度管控 Agent 之间的相互调用与工具越权防范。
@@ -125,6 +127,58 @@ uv run Undefined-webui
---
+## 作为 Python 库使用
+
+除 Bot CLI 外,Undefined 也可嵌入脚本、测试或其它服务,直接复用 **Skills 注册表**、**认知记忆**、**知识库**、**AIClient** 等运行时(与 CLI 启动链隔离)。
+
+```bash
+pip install Undefined-bot # 源码开发:uv sync
+```
+
+```python
+import asyncio
+
+# 根包 lazy re-export:与 CLI 共用同一套运行时组件
+from Undefined import AgentRegistry, Config, ToolRegistry, set_config
+
+# 内存构建配置,测试/嵌入场景无需 config.toml
+cfg = Config.from_mapping(
+ {
+ "onebot": {"ws_url": "ws://127.0.0.1:3001"},
+ "models": {
+ "chat": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"},
+ "vision": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"},
+ "agent": {"api_url": "https://api.example/v1", "api_key": "sk-xxx", "model_name": "gpt-4o-mini"},
+ },
+ },
+ strict=False,
+)
+set_config(cfg) # opt-in 注入全局单例;CLI 启动链不会调用
+
+# 自动扫描 skills/:tools + toolsets(end / group.* / cognitive.* …)
+tools = ToolRegistry()
+# 自动扫描 skills/agents/:web_agent、code_delivery_agent …
+agents = AgentRegistry()
+
+async def main() -> None:
+ # 直接调用原子工具,无需启动 OneBot
+ lunar_time = await tools.execute(
+ "get_current_time",
+ {"format": "text", "include_lunar": True},
+ context={},
+ )
+ print(lunar_time)
+ print(len(tools.get_schema()), "tools,", len(agents.get_schema()), "agents")
+
+asyncio.run(main())
+```
+
+- [Python 库 API 参考](docs/python-api.md) — 根包符号表、shim 路径、`AIClient` / `CognitiveService` 等嵌入示例
+- [配置详解 — 库嵌入配置](docs/configuration.md#2-库嵌入配置) — `from_mapping` / `Config.builder`
+- [开发者与拓展中心](docs/development.md) — 模块结构与自检命令
+
+---
+
## 风险提示与免责声明
1. **账号风控与封禁风险(含 QQ 账号)**
@@ -140,7 +194,9 @@ uv run Undefined-webui
本项目遵循 [MIT License](LICENSE) 开源协议。
-感谢 **NagaAgent** 子模块作者及社区支持:[NagaAgent - A simple yet powerful agent framework.](https://github.com/Xxiii8322766509/NagaAgent)。
+感谢 **NagaAgent** 子模块作者及社区提供的支持与鼓励:[NagaAgent - A simple yet powerful agent framework.](https://github.com/Xxiii8322766509/NagaAgent)!
+
+感谢在开发过程中为我提供各种灵感的群友们!
[\"'])(?P[^\"']+)(?P=quote)\s*/?>", - re.IGNORECASE, -) -_MEDIA_LABELS = { - "image": "图片", - "file": "文件", - "audio": "音频", - "video": "视频", - "record": "语音", - "pic": "图片", -} -_WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") _DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 _IMAGE_SUFFIX_TO_MIME = { ".png": "image/png", @@ -67,7 +52,6 @@ (b"GIF89a", ".gif"), (b"BM", ".bmp"), ) -_FORWARD_ATTACHMENT_MAX_DEPTH = 3 _ATTACHMENT_CACHE_MAX_AGE_SECONDS = 7 * 24 * 60 * 60 _ATTACHMENT_REGISTRY_MAX_RECORDS = 2000 _ATTACHMENT_CACHE_MAX_BYTES = 0 @@ -76,219 +60,10 @@ _DEFAULT_REMOTE_DOWNLOAD_MAX_BYTES = 25 * 1024 * 1024 -@dataclass(frozen=True) -class AttachmentRecord: - uid: str - scope_key: str - kind: str - media_type: str - display_name: str - source_kind: str - source_ref: str - local_path: str | None - mime_type: str - sha256: str - created_at: str - segment_data: dict[str, str] - semantic_kind: str = "" - description: str = "" - - def prompt_ref(self) -> dict[str, str]: - local_available = False - if self.local_path is not None: - try: - local_available = Path(self.local_path).is_file() - except OSError: - local_available = False - ref: dict[str, str] = { - "uid": self.uid, - "kind": self.kind, - "media_type": self.media_type, - "display_name": self.display_name, - } - if self.source_kind.strip(): - ref["source_kind"] = self.source_kind.strip() - if not local_available and self.source_ref.strip(): - ref["source_ref"] = self.source_ref.strip() - if self.semantic_kind.strip(): - ref["semantic_kind"] = self.semantic_kind.strip() - if self.description.strip(): - ref["description"] = self.description.strip() - return ref - - -@dataclass(frozen=True) -class RegisteredMessageAttachments: - attachments: list[dict[str, str]] - normalized_text: str - - -@dataclass(frozen=True) -class RenderedRichMessage: - delivery_text: str - history_text: str - attachments: list[dict[str, str]] - pending_file_sends: tuple[AttachmentRecord, ...] = () - - -class AttachmentRenderError(RuntimeError): - """Raised when an attachment tag cannot be rendered.""" - - -class _RemoteAttachmentTooLarge(Exception): - def __init__(self, mime_type: str = "") -> None: - self.mime_type = mime_type - - def _now_iso() -> str: return datetime.now().isoformat(timespec="seconds") -def _escape_cq_component(value: str) -> str: - return ( - value.replace("&", "&") - .replace("[", "[") - .replace("]", "]") - .replace(",", ",") - ) - - -def _coerce_positive_int(value: Any) -> int | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return value if value > 0 else None - if isinstance(value, str): - text = value.strip() - if not text: - return None - try: - parsed = int(text) - except ValueError: - return None - return parsed if parsed > 0 else None - return None - - -def build_attachment_scope( - *, - group_id: Any = None, - user_id: Any = None, - request_type: str | None = None, - webui_session: bool = False, -) -> str | None: - """Build a scope key for attachment visibility.""" - if webui_session: - return "webui" - - group = _coerce_positive_int(group_id) - if group is not None: - return f"group:{group}" - - user = _coerce_positive_int(user_id) - request_type_text = str(request_type or "").strip().lower() - if request_type_text == "private" and user is not None: - return f"private:{user}" - if user is not None: - return f"private:{user}" - return None - - -def scope_from_context(context: Mapping[str, Any] | None) -> str | None: - if not context: - return None - return build_attachment_scope( - group_id=context.get("group_id"), - user_id=context.get("user_id"), - request_type=str(context.get("request_type", "") or ""), - webui_session=bool(context.get("webui_session", False)), - ) - - -def attachment_refs_to_text(attachments: Sequence[Mapping[str, str]]) -> str: - if not attachments: - return "" - parts: list[str] = [] - for item in attachments: - uid = str(item.get("uid", "") or "").strip() - if not uid: - continue - media_type = str(item.get("media_type") or item.get("kind") or "file").strip() - label = _MEDIA_LABELS.get(media_type, "附件") - name = str(item.get("display_name", "") or "").strip() - if name: - parts.append(f"[{label} uid={uid} name={name}]") - else: - parts.append(f"[{label} uid={uid}]") - return " ".join(parts) - - -def attachment_refs_to_xml( - attachments: Sequence[Mapping[str, str]], - *, - indent: str = " ", -) -> str: - if not attachments: - return "" - lines = [f"{indent} "] - for item in attachments: - uid = str(item.get("uid", "") or "").strip() - if not uid: - continue - kind = str(item.get("kind", "") or item.get("media_type", "") or "file").strip() - media_type = str(item.get("media_type", "") or kind or "file").strip() - name = str(item.get("display_name", "") or "").strip() - attrs = [ - f'uid="{escape_xml_attr(uid)}"', - f'type="{escape_xml_attr(kind or media_type)}"', - f'media_type="{escape_xml_attr(media_type)}"', - ] - if name: - attrs.append(f'name="{escape_xml_attr(name)}"') - source_kind = str(item.get("source_kind", "") or "").strip() - if source_kind: - attrs.append(f'source_kind="{escape_xml_attr(source_kind)}"') - source_ref = str(item.get("source_ref", "") or "").strip() - if source_ref: - attrs.append(f'source_ref="{escape_xml_attr(source_ref)}"') - semantic_kind = str(item.get("semantic_kind", "") or "").strip() - if semantic_kind: - attrs.append(f'semantic_kind="{escape_xml_attr(semantic_kind)}"') - description = str(item.get("description", "") or "").strip() - if description: - attrs.append(f'description="{escape_xml_attr(description)}"') - lines.append(f"{indent} ") - return "\n".join(lines) - - -def append_attachment_text( - base_text: str, attachments: Sequence[Mapping[str, str]] -) -> str: - attachment_text = attachment_refs_to_text(attachments) - if not attachment_text: - return base_text - if not base_text.strip(): - return attachment_text - return f"{base_text}\n附件: {attachment_text}" - - -def _is_http_url(value: str) -> bool: - return value.startswith("http://") or value.startswith("https://") - - -def _is_data_url(value: str) -> bool: - return value.startswith("data:") - - -def _is_localish_path(value: str) -> bool: - return ( - value.startswith("/") - or value.startswith("file://") - or bool(_WINDOWS_ABS_PATH_RE.match(value)) - ) - - def _decode_data_url(data_url: str) -> tuple[bytes, str]: header, _, payload = data_url.partition(",") if ";base64" not in header.lower(): @@ -326,22 +101,6 @@ def _guess_mime_type(name: str, content: bytes) -> str: return _IMAGE_SUFFIX_TO_MIME.get(suffix, "application/octet-stream") -def _display_name_from_source(raw_source: str, fallback: str) -> str: - if not raw_source: - return fallback - if raw_source.startswith("file://"): - raw_source = raw_source[7:] - name = Path(unquote(urlsplit(raw_source).path)).name - return name or fallback - - -def _media_kind_from_value(value: str) -> str: - text = str(value or "").strip().lower() - if text in {"image", "file", "audio", "video", "record"}: - return text - return "file" - - def _remote_reference_source_kind(source_kind: str) -> str: cleaned = str(source_kind or "").strip() if not cleaned: @@ -351,114 +110,11 @@ def _remote_reference_source_kind(source_kind: str) -> str: return f"{cleaned}_reference" -def _segment_text( - type_: str, data: Mapping[str, Any], ref: Mapping[str, str] | None -) -> str: - if type_ == "text": - return str(data.get("text", "") or "") - if type_ == "at": - qq = str(data.get("qq", "") or "").strip() - name = str(data.get("name") or data.get("nickname") or "").strip() - if qq and name: - return f"[@{qq}({name})]" - if qq: - return f"[@{qq}]" - return "[@]" - if type_ == "face": - return "[表情]" - if type_ == "reply": - reply_id = str(data.get("id") or data.get("message_id") or "").strip() - return f"[引用: {reply_id}]" if reply_id else "[引用]" - if type_ == "forward": - forward_id = str(data.get("id") or data.get("resid") or "").strip() - return f"[合并转发: {forward_id}]" if forward_id else "[合并转发]" - if ref is not None: - label = _MEDIA_LABELS.get( - str(ref.get("media_type") or ref.get("kind") or type_).strip(), "附件" - ) - uid = str(ref.get("uid", "") or "").strip() - name = str(ref.get("display_name", "") or "").strip() - if uid and name: - return f"[{label} uid={uid} name={name}]" - if uid: - return f"[{label} uid={uid}]" - label = _MEDIA_LABELS.get(type_, "附件") - raw = str(data.get("file") or data.get("url") or data.get("id") or "").strip() - return f"[{label}: {raw}]" if raw else f"[{label}]" - - -def _resolve_webui_file_id(file_id: str) -> Path | None: - if not file_id or not file_id.isalnum(): - return None - file_dir = (Path.cwd() / WEBUI_FILE_CACHE_DIR / file_id).resolve() - cache_root = (Path.cwd() / WEBUI_FILE_CACHE_DIR).resolve() - if cache_root not in file_dir.parents and file_dir != cache_root: - return None - if not file_dir.is_dir(): - return None - try: - files = list(file_dir.iterdir()) - except OSError: - return None - for candidate in files: - if candidate.is_file(): - return candidate - return None - - -def _extract_forward_id(data: Mapping[str, Any]) -> str: - forward_id = data.get("id") or data.get("resid") or data.get("message_id") - return str(forward_id).strip() if forward_id is not None else "" - - -def _segment_data_from_onebot_data( - data: Mapping[str, Any], - *, - exclude_keys: set[str] | None = None, -) -> dict[str, str]: - excluded = {key.strip().lower() for key in (exclude_keys or set()) if key.strip()} - normalized: dict[str, str] = {} - for raw_key, raw_value in data.items(): - key = str(raw_key or "").strip() - if not key: - continue - if key.lower() in excluded: - continue - text = str(raw_value or "").strip() - if not text: - continue - normalized[key] = text - return normalized - - -def _normalize_message_segments(message: Any) -> list[Mapping[str, Any]]: - if isinstance(message, list): - normalized: list[Mapping[str, Any]] = [] - for item in message: - if isinstance(item, Mapping): - normalized.append(item) - elif isinstance(item, str): - normalized.append({"type": "text", "data": {"text": item}}) - return normalized - if isinstance(message, Mapping): - return [message] - if isinstance(message, str): - return [{"type": "text", "data": {"text": message}}] - return [] - - -def _normalize_forward_nodes(raw_nodes: Any) -> list[Mapping[str, Any]]: - if isinstance(raw_nodes, list): - return [node for node in raw_nodes if isinstance(node, Mapping)] - if isinstance(raw_nodes, Mapping): - messages = raw_nodes.get("messages") - if isinstance(messages, list): - return [node for node in messages if isinstance(node, Mapping)] - return [] - - class AttachmentRegistry: - """Persistent attachment registry scoped by conversation.""" + """按会话作用域持久化的附件注册表。 + + 写入 JSON 注册表与本地缓存目录,支持远程 URL 引用与按需回源下载。 + """ def __init__( self, @@ -494,6 +150,7 @@ def __init__( ) = None def set_remote_download_max_bytes(self, value: int) -> None: + """设置单次远程下载字节上限。""" self._remote_download_max_bytes = max(0, int(value)) def set_limits( @@ -506,6 +163,7 @@ def set_limits( url_reference_max_records: int | None = None, url_max_length: int | None = None, ) -> None: + """批量更新注册表容量与 TTL 限制。""" if remote_download_max_bytes is not None: self._remote_download_max_bytes = max(0, int(remote_download_max_bytes)) if max_cache_bytes is not None: @@ -523,12 +181,14 @@ def set_global_image_resolver( self, resolver: Callable[[str], AttachmentRecord | None] | None, ) -> None: + """注册同步全局图片 UID 回退解析器。""" self._global_image_resolver = resolver def set_global_image_resolver_async( self, resolver: Callable[[str], Awaitable[AttachmentRecord | None]] | None, ) -> None: + """注册异步全局图片 UID 回退解析器。""" self._global_image_resolver_async = resolver def _resolve_managed_cache_path(self, raw_path: str | None) -> Path | None: @@ -546,7 +206,7 @@ def _resolve_managed_cache_path(self, raw_path: str | None) -> Path | None: def _normalized_url_ref(self, value: str) -> str: text = str(value or "").strip() - if not _is_http_url(text): + if not is_http_url(text): return "" if self._url_max_length > 0 and len(text) > self._url_max_length: return "" @@ -559,7 +219,7 @@ def _record_with_local_path( record, local_path=local_path, source_kind=_remote_reference_source_kind(record.source_kind) - if local_path is None and _is_http_url(record.source_ref) + if local_path is None and is_http_url(record.source_ref) else record.source_kind, ) @@ -588,7 +248,7 @@ def _prune_records(self) -> bool: cache_path = self._resolve_managed_cache_path(record.local_path) if record.local_path is None: has_url_ref = bool(self._normalized_url_ref(record.source_ref)) - if _is_http_url(record.source_ref) and not has_url_ref: + if is_http_url(record.source_ref) and not has_url_ref: dirty = True continue try: @@ -641,6 +301,7 @@ def _prune_records(self) -> bool: retained.append((uid, record, cache_path, mtime, size)) if self._max_records > 0 and len(retained) > self._max_records: + # 超出记录上限时按 mtime 淘汰最旧条目 retained.sort(key=lambda item: item[3]) overflow = len(retained) - self._max_records for _uid, _record, cache_path, _mtime, _size in retained[:overflow]: @@ -678,9 +339,10 @@ def _prune_records(self) -> bool: url_refs = [ item for item in retained - if item[2] is None and _is_http_url(item[1].source_ref) + if item[2] is None and is_http_url(item[1].source_ref) ] if len(url_refs) > self._url_reference_max_records: + # 仅 URL 引用(未下载)单独计数上限 url_ref_ids = { uid for uid, _record, _path, _mtime, _size in sorted( @@ -743,8 +405,8 @@ def _load_records_from_payload(self, raw: Any) -> dict[str, AttachmentRecord]: loaded[str(uid)] = AttachmentRecord( uid=str(item.get("uid") or uid), scope_key=str(item.get("scope_key", "") or ""), - kind=_media_kind_from_value(item.get("kind", "file")), - media_type=_media_kind_from_value( + kind=media_kind_from_value(item.get("kind", "file")), + media_type=media_kind_from_value( item.get("media_type") or item.get("kind") or "file" ), display_name=str(item.get("display_name", "") or ""), @@ -800,11 +462,14 @@ async def flush(self) -> None: await self._persist() def get(self, uid: str) -> AttachmentRecord | None: + """按 UID 读取内存中的附件记录(不触发磁盘加载)。""" return self._records.get(str(uid).strip()) def resolve(self, uid: str, scope_key: str | None) -> AttachmentRecord | None: + """同步解析 UID,含 scope 校验与全局图片回退。""" record = self.get(uid) if record is not None: + # scope 不匹配时拒绝跨会话引用 if record.scope_key and scope_key and record.scope_key != scope_key: return None return record @@ -825,6 +490,7 @@ def resolve(self, uid: str, scope_key: str | None) -> AttachmentRecord | None: async def resolve_async( self, uid: str, scope_key: str | None ) -> AttachmentRecord | None: + """异步解析 UID,优先异步全局回退解析器。""" record = self.get(uid) if record is not None: if record.scope_key and scope_key and record.scope_key != scope_key: @@ -860,6 +526,7 @@ def resolve_for_context( uid: str, context: Mapping[str, Any] | None, ) -> AttachmentRecord | None: + """从请求上下文推断 scope 后解析 UID。""" return self.resolve(uid, scope_from_context(context)) async def get_url_by_uid(self, uid: str) -> str | None: @@ -882,8 +549,6 @@ async def get_uid_by_url(self, url: str) -> str | None: return None def _build_uid(self, prefix: str) -> str: - from uuid import uuid4 - while True: uid = f"{prefix}_{uuid4().hex[:8]}" if uid not in self._records: @@ -920,8 +585,9 @@ async def register_bytes( mime_type: str | None = None, segment_data: Mapping[str, str] | None = None, ) -> AttachmentRecord: + """将字节内容写入缓存并注册新附件(含 SHA-256 去重)。""" await self.load() - normalized_kind = _media_kind_from_value(kind) + normalized_kind = media_kind_from_value(kind) normalized_media_type = ( "image" if normalized_kind == "image" else normalized_kind ) @@ -935,6 +601,7 @@ async def register_bytes( existing = self._find_by_sha256(scope_key, digest_hex, normalized_kind) if existing is not None: + # 同 scope+SHA256 去重,复用已有 UID return existing uid = self._build_uid(prefix) @@ -976,6 +643,7 @@ async def register_local_file( source_ref: str = "", segment_data: Mapping[str, str] | None = None, ) -> AttachmentRecord: + """读取本地文件并注册为附件。""" path = Path(str(local_path)).expanduser() if not path.is_absolute(): path = (Path.cwd() / path).resolve() @@ -1010,6 +678,7 @@ async def register_data_url( source_ref: str = "", segment_data: Mapping[str, str] | None = None, ) -> AttachmentRecord: + """解码 ``data:`` URL 并注册附件。""" content, mime_type = _decode_data_url(data_url) return await self.register_bytes( scope_key, @@ -1033,7 +702,8 @@ async def register_remote_url( source_ref: str = "", segment_data: Mapping[str, str] | None = None, ) -> AttachmentRecord: - name = display_name or _display_name_from_source(url, "attachment.bin") + """下载远程 URL 或在上限时降级为 URL 引用。""" + name = display_name or display_name_from_source(url, "attachment.bin") return await self._register_remote_url_or_reference( scope_key, url, @@ -1057,10 +727,11 @@ async def register_remote_reference( segment_data: Mapping[str, str] | None = None, description: str = "", ) -> AttachmentRecord: + """仅登记远程 URL 引用,不下载内容。""" await self.load() if not self._normalized_url_ref(url): raise ValueError("远程附件 URL 为空或超过长度上限") - normalized_kind = _media_kind_from_value(kind) + normalized_kind = media_kind_from_value(kind) normalized_media_type = ( "image" if normalized_kind == "image" else normalized_kind ) @@ -1069,7 +740,7 @@ async def register_remote_reference( normalized_segment_data = dict(segment_data or {}) if source_ref and source_ref != url: normalized_segment_data.setdefault("original_source_ref", source_ref) - name = display_name or _display_name_from_source(url, "attachment.bin") + name = display_name or display_name_from_source(url, "attachment.bin") digest_hex = hashlib.sha256(ref.encode("utf-8")).hexdigest() async with self._lock: @@ -1126,6 +797,7 @@ async def _register_remote_url_or_reference( if source_ref and source_ref != url: reference_segment_data.setdefault("original_source_ref", source_ref) if max_bytes <= 0: + # 配置为 0 时一律只登记 URL 引用,不下载 return await self.register_remote_reference( scope_key, url, @@ -1153,6 +825,7 @@ async def _stream(client: httpx.AsyncClient) -> tuple[bytes, str]: total = 0 async for chunk in response.aiter_bytes(): total += len(chunk) + # 流式累计超限则降级为 URL 引用 if total > max_bytes: raise _RemoteAttachmentTooLarge(mime_type) chunks.append(chunk) @@ -1191,6 +864,7 @@ async def _stream(client: httpx.AsyncClient) -> tuple[bytes, str]: ) async def ensure_local_file(self, record: AttachmentRecord) -> AttachmentRecord: + """若记录仅有 URL 引用则尝试回源下载到本地缓存。""" await self.load() if record.local_path and Path(record.local_path).is_file(): return record @@ -1227,454 +901,3 @@ async def ensure_local_file(self, record: AttachmentRecord) -> AttachmentRecord: self._prune_records() await self._persist() return self._records.get(record.uid, updated) - - -async def register_message_attachments( - *, - registry: AttachmentRegistry | None, - segments: Sequence[Mapping[str, Any]], - scope_key: str | None, - resolve_image_url: Callable[[str], Awaitable[str | None]] | None = None, - get_forward_messages: Callable[[str], Awaitable[list[dict[str, Any]]]] - | None = None, -) -> RegisteredMessageAttachments: - attachments: list[dict[str, str]] = [] - normalized_parts: list[str] = [] - if registry is None or not scope_key: - for segment in segments: - type_ = str(segment.get("type", "") or "") - raw_data = segment.get("data", {}) - data = raw_data if isinstance(raw_data, Mapping) else {} - normalized_parts.append(_segment_text(type_, data, None)) - return RegisteredMessageAttachments( - attachments=[], - normalized_text="".join(normalized_parts).strip(), - ) - - visited_forward_ids: set[str] = set() - - async def _collect_from_segments( - current_segments: Sequence[Mapping[str, Any]], - *, - depth: int, - prefix: str, - ) -> None: - for index, segment in enumerate(current_segments): - type_ = str(segment.get("type", "") or "").strip().lower() - raw_data = segment.get("data", {}) - data = raw_data if isinstance(raw_data, Mapping) else {} - ref: dict[str, str] | None = None - - try: - if type_ == "image": - raw_source = str(data.get("file") or data.get("url") or "").strip() - display_name = _display_name_from_source( - raw_source, - f"image_{index + 1}.png", - ) - if raw_source.startswith("base64://"): - payload = raw_source[len("base64://") :].strip() - content = base64.b64decode(payload) - record = await registry.register_bytes( - scope_key, - content, - kind="image", - display_name=display_name, - source_kind="base64_image", - source_ref=f"{prefix}segment:{index}", - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - elif _is_data_url(raw_source): - record = await registry.register_data_url( - scope_key, - raw_source, - kind="image", - display_name=display_name, - source_kind="data_url_image", - source_ref=f"{prefix}segment:{index}", - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - else: - resolved_source = raw_source - if raw_source and resolve_image_url is not None: - try: - resolved = await resolve_image_url(raw_source) - except Exception as exc: - logger.debug( - "[AttachmentRegistry] image resolver failed: file=%s err=%s", - raw_source, - exc, - ) - resolved = None - if resolved: - resolved_source = str(resolved) - - if _is_http_url(resolved_source): - record = await registry.register_remote_url( - scope_key, - resolved_source, - kind="image", - display_name=display_name, - source_kind="remote_image", - source_ref=raw_source or resolved_source, - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - elif _is_localish_path(resolved_source): - local_path = ( - resolved_source[7:] - if resolved_source.startswith("file://") - else resolved_source - ) - record = await registry.register_local_file( - scope_key, - local_path, - kind="image", - display_name=display_name, - source_kind="local_image", - source_ref=raw_source or resolved_source, - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - - elif type_ == "file": - file_id = str(data.get("id", "") or "").strip() - raw_source = str(data.get("file") or data.get("url") or "").strip() - local_file_path: Path | None = None - if file_id: - local_file_path = _resolve_webui_file_id(file_id) - elif _is_localish_path(raw_source): - local_file_path = Path( - raw_source[7:] - if raw_source.startswith("file://") - else raw_source - ) - display_name = ( - str(data.get("name", "") or "").strip() - or (local_file_path.name if local_file_path is not None else "") - or _display_name_from_source( - raw_source, f"file_{index + 1}.bin" - ) - ) - if local_file_path is not None and local_file_path.is_file(): - record = await registry.register_local_file( - scope_key, - local_file_path, - kind="file", - display_name=display_name, - source_kind="webui_file" if file_id else "local_file", - source_ref=file_id or raw_source or str(local_file_path), - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - elif _is_http_url(raw_source): - record = await registry.register_remote_url( - scope_key, - raw_source, - kind="file", - display_name=display_name, - source_kind="remote_file", - source_ref=file_id or raw_source, - segment_data=_segment_data_from_onebot_data( - data, - exclude_keys={"file", "url"}, - ), - ) - ref = record.prompt_ref() - - elif ( - type_ == "forward" - and get_forward_messages is not None - and depth < _FORWARD_ATTACHMENT_MAX_DEPTH - ): - forward_id = _extract_forward_id(data) - if forward_id and forward_id not in visited_forward_ids: - visited_forward_ids.add(forward_id) - try: - nodes = _normalize_forward_nodes( - await get_forward_messages(forward_id) - ) - except Exception as exc: - logger.debug( - "[AttachmentRegistry] forward resolver failed: id=%s err=%s", - forward_id, - exc, - ) - nodes = [] - for node_index, node in enumerate(nodes): - raw_message = ( - node.get("content") - or node.get("message") - or node.get("raw_message") - ) - nested_segments = _normalize_message_segments(raw_message) - if not nested_segments: - continue - await _collect_from_segments( - nested_segments, - depth=depth + 1, - prefix=f"{prefix}forward:{forward_id}:{node_index}:", - ) - except ( - binascii.Error, - ValueError, - FileNotFoundError, - httpx.HTTPError, - ) as exc: - logger.warning( - "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", - type_, - index, - exc, - ) - except Exception as exc: - logger.exception( - "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", - type_, - index, - exc, - ) - - if ref is not None: - attachments.append(ref) - if depth == 0: - normalized_parts.append(_segment_text(type_, data, ref)) - - await _collect_from_segments(segments, depth=0, prefix="") - - return RegisteredMessageAttachments( - attachments=attachments, - normalized_text="".join(normalized_parts).strip(), - ) - - -async def render_message_with_attachments( - message: str, - *, - registry: AttachmentRegistry | None, - scope_key: str | None, - strict: bool, -) -> RenderedRichMessage: - """Render ``") - lines.append(f"{indent} `` and `` `` tags into delivery/history text. - - * `` `` — backward-compatible, image-only. - * `` `` — unified tag for any media type. - Images (``pic_*``) are inlined as CQ images; files (``file_*``) - are collected into *pending_file_sends* for later dispatch. - """ - has_tags = message and ( - " tag: strictly image-only - if tag_name == "pic" and record.media_type != "image": - replacement = f"[图片 uid={uid} 类型错误]" - if strict: - raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - continue - - # Route by media type - if record.media_type == "image": - ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) - else: - ok = _render_file_tag( - record, - uid, - strict, - delivery_parts, - history_parts, - pending_files, - ) - - if ok: - attachments.append(record.prompt_ref()) - - delivery_parts.append(message[last_index:]) - history_parts.append(message[last_index:]) - return RenderedRichMessage( - delivery_text="".join(delivery_parts), - history_text="".join(history_parts), - attachments=attachments, - pending_file_sends=tuple(pending_files), - ) - - -def _render_image_tag( - record: AttachmentRecord, - uid: str, - strict: bool, - delivery_parts: list[str], - history_parts: list[str], -) -> bool: - """Render an image attachment as an inline CQ:image. Returns True on success.""" - image_source = record.source_ref - if record.local_path: - image_source = Path(record.local_path).resolve().as_uri() - elif not image_source: - replacement = f"[图片 uid={uid} 缺少文件]" - if strict: - raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - return False - - cq_args = [f"file={image_source}"] - for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): - cleaned_key = str(key or "").strip() - cleaned_value = str(value or "").strip() - if ( - not cleaned_key - or not cleaned_value - or cleaned_key in {"file", "original_source_ref"} - ): - continue - cq_args.append( - f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" - ) - delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") - if record.display_name: - history_parts.append(f"[图片 uid={uid} name={record.display_name}]") - else: - history_parts.append(f"[图片 uid={uid}]") - return True - - -def _render_file_tag( - record: AttachmentRecord, - uid: str, - strict: bool, - delivery_parts: list[str], - history_parts: list[str], - pending_files: list[AttachmentRecord], -) -> bool: - """Render a non-image attachment as a pending file send. Returns True on success.""" - if not record.local_path or not Path(record.local_path).is_file(): - if _is_http_url(record.source_ref): - name_part = f" name={record.display_name}" if record.display_name else "" - history_parts.append(f"[文件 uid={uid}{name_part}]") - pending_files.append(record) - return True - replacement = f"[文件 uid={uid} 缺少本地文件]" - if strict: - raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - return False - - # Remove from delivery text (file sent separately) - # Keep a readable placeholder in history - name_part = f" name={record.display_name}" if record.display_name else "" - history_parts.append(f"[文件 uid={uid}{name_part}]") - pending_files.append(record) - return True - - -# Backward-compatible alias -render_message_with_pic_placeholders = render_message_with_attachments - - -async def dispatch_pending_file_sends( - rendered: RenderedRichMessage, - *, - sender: Any, - target_type: str, - target_id: int, - registry: AttachmentRegistry | None = None, -) -> None: - """Send pending file attachments collected by *render_message_with_attachments*. - - This is best-effort: each file send failure is logged but does not interrupt - the remaining sends or the caller. - """ - if not rendered.pending_file_sends or sender is None: - return - for record in rendered.pending_file_sends: - send_record = record - if ( - not send_record.local_path or not Path(send_record.local_path).is_file() - ) and registry is not None: - try: - send_record = await registry.ensure_local_file(send_record) - except Exception: - logger.warning( - "[文件发送] 回源下载失败 uid=%s source=%s", - send_record.uid, - send_record.source_ref, - exc_info=True, - ) - if not send_record.local_path or not Path(send_record.local_path).is_file(): - logger.warning( - "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", - send_record.uid, - send_record.local_path, - ) - continue - try: - if target_type == "group": - await sender.send_group_file( - target_id, - send_record.local_path, - name=send_record.display_name or None, - ) - else: - await sender.send_private_file( - target_id, - send_record.local_path, - name=send_record.display_name or None, - ) - except Exception: - logger.warning( - "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", - send_record.uid, - target_type, - target_id, - exc_info=True, - ) diff --git a/src/Undefined/attachments/render.py b/src/Undefined/attachments/render.py new file mode 100644 index 00000000..6f3433c6 --- /dev/null +++ b/src/Undefined/attachments/render.py @@ -0,0 +1,287 @@ +"""富媒体标签渲染与待发送文件派发。 + +将 `` `` / `` `` 占位符转为 CQ 段或历史可读文本; +不修改注册表结构。 +""" + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from Undefined.attachments.models import ( + AttachmentRecord, + AttachmentRenderError, + RenderedRichMessage, +) +from Undefined.attachments.segments import _MEDIA_LABELS, is_http_url + +if TYPE_CHECKING: + from Undefined.attachments.registry import AttachmentRegistry + +logger = logging.getLogger(__name__) + +_PIC_TAG_PATTERN = re.compile( + r" [\"'])(?P [^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_ATTACHMENT_TAG_PATTERN = re.compile( + r" [\"'])(?P [^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_UNIFIED_TAG_PATTERN = re.compile( + r"<(?P pic|attachment)\s+uid=(?P [\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) + + +def _escape_cq_component(value: str) -> str: + return ( + value.replace("&", "&") + .replace("[", "[") + .replace("]", "]") + .replace(",", ",") + ) + + +async def render_message_with_attachments( + message: str, + *, + registry: AttachmentRegistry | None, + scope_key: str | None, + strict: bool, +) -> RenderedRichMessage: + """Render `` `` and `` `` tags into delivery/history text. + + * `` `` — backward-compatible, image-only. + * `` `` — unified tag for any media type. + Images (``pic_*``) are inlined as CQ images; files (``file_*``) + are collected into *pending_file_sends* for later dispatch. + + Args: + message: 含占位标签的原始消息文本。 + registry: 附件注册表。 + scope_key: 当前会话作用域键。 + strict: 为 True 时 UID 不可用或类型不匹配则抛出 ``AttachmentRenderError``。 + + Returns: + 投递文本、历史文本、附件引用及待发送文件列表。 + + Raises: + AttachmentRenderError: ``strict=True`` 且标签无法解析时。 + """ + has_tags = message and ( + " tag: strictly image-only + if tag_name == "pic" and record.media_type != "image": + replacement = f"[图片 uid={uid} 类型错误]" + if strict: + raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + continue + + # 仅允许图片; 按 media_type 分流 + if record.media_type == "image": + ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) + else: + ok = _render_file_tag( + record, + uid, + strict, + delivery_parts, + history_parts, + pending_files, + ) + + if ok: + attachments.append(record.prompt_ref()) + + delivery_parts.append(message[last_index:]) + history_parts.append(message[last_index:]) + return RenderedRichMessage( + delivery_text="".join(delivery_parts), + history_text="".join(history_parts), + attachments=attachments, + pending_file_sends=tuple(pending_files), + ) + + +def _render_image_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], +) -> bool: + """Render an image attachment as an inline CQ:image. Returns True on success.""" + image_source = record.source_ref + if record.local_path and Path(record.local_path).is_file(): + image_source = Path(record.local_path).resolve().as_uri() + elif not image_source: + replacement = f"[图片 uid={uid} 缺少文件]" + if strict: + raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return False + + cq_args = [f"file={image_source}"] + for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): + cleaned_key = str(key or "").strip() + cleaned_value = str(value or "").strip() + if ( + not cleaned_key + or not cleaned_value + or cleaned_key in {"file", "original_source_ref"} + ): + continue + cq_args.append( + f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" + ) + delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") + if record.display_name: + history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + else: + history_parts.append(f"[图片 uid={uid}]") + return True + + +def _render_file_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], + pending_files: list[AttachmentRecord], +) -> bool: + """Render a non-image attachment as a pending file send. Returns True on success.""" + if not record.local_path or not Path(record.local_path).is_file(): + if is_http_url(record.source_ref): + # 仅有远程 URL 时先入 pending,发送前尝试回源下载 + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + return True + replacement = f"[文件 uid={uid} 缺少本地文件]" + if strict: + raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return False + + # 文件不在 CQ 文本中内联,单独走 send_group/private_file + # Keep a readable placeholder in history + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + return True + + +render_message_with_pic_placeholders = render_message_with_attachments + + +async def dispatch_pending_file_sends( + rendered: RenderedRichMessage, + *, + sender: Any, + target_type: str, + target_id: int, + registry: AttachmentRegistry | None = None, +) -> None: + """Send pending file attachments collected by *render_message_with_attachments*. + + This is best-effort: each file send failure is logged but does not interrupt + the remaining sends or the caller. + + Args: + rendered: ``render_message_with_attachments`` 的返回值。 + sender: 实现群/私聊文件发送的 OneBot 客户端。 + target_type: ``group`` 或 ``private``。 + target_id: 目标群号或 QQ 号。 + registry: 可选,用于发送前回源下载仅有 URL 的附件。 + """ + if not rendered.pending_file_sends or sender is None: + return + for record in rendered.pending_file_sends: + send_record = record + if ( + not send_record.local_path or not Path(send_record.local_path).is_file() + ) and registry is not None: + try: + send_record = await registry.ensure_local_file(send_record) + except Exception: + logger.warning( + "[文件发送] 回源下载失败 uid=%s source=%s", + send_record.uid, + send_record.source_ref, + exc_info=True, + ) + if not send_record.local_path or not Path(send_record.local_path).is_file(): + logger.warning( + "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", + send_record.uid, + send_record.local_path, + ) + continue + try: + if target_type == "group": + await sender.send_group_file( + target_id, + send_record.local_path, + name=send_record.display_name or None, + ) + elif target_type == "private": + await sender.send_private_file( + target_id, + send_record.local_path, + name=send_record.display_name or None, + ) + else: + logger.warning( + "[文件发送] 跳过:不支持的 target_type=%s uid=%s", + target_type, + send_record.uid, + ) + continue + except Exception: + logger.warning( + "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", + send_record.uid, + target_type, + target_id, + exc_info=True, + ) diff --git a/src/Undefined/attachments/segments.py b/src/Undefined/attachments/segments.py new file mode 100644 index 00000000..d78b458f --- /dev/null +++ b/src/Undefined/attachments/segments.py @@ -0,0 +1,566 @@ +"""OneBot 消息段解析与会话作用域辅助。 + +负责 scope 键构建、附件引用文本/XML 序列化,以及从消息段批量注册附件; +不处理磁盘持久化或 CQ 标签渲染。 +""" + +from __future__ import annotations + +import base64 +import binascii +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Mapping, Sequence +from urllib.parse import unquote, urlsplit + +import httpx + +from Undefined.attachments.models import RegisteredMessageAttachments +from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR +from Undefined.utils.xml import escape_xml_attr + +if TYPE_CHECKING: + from Undefined.attachments.registry import AttachmentRegistry + +logger = logging.getLogger(__name__) + +_MEDIA_LABELS = { + "image": "图片", + "file": "文件", + "audio": "音频", + "video": "视频", + "record": "语音", + "pic": "图片", +} +_WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") +_FORWARD_ATTACHMENT_MAX_DEPTH = 3 + + +def _coerce_positive_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value if value > 0 else None + if isinstance(value, str): + text = value.strip() + if not text: + return None + try: + parsed = int(text) + except ValueError: + return None + return parsed if parsed > 0 else None + return None + + +def build_attachment_scope( + *, + group_id: Any = None, + user_id: Any = None, + request_type: str | None = None, + webui_session: bool = False, +) -> str | None: + """构建附件可见性作用域键。 + + Args: + group_id: 群号;优先于私聊作用域。 + user_id: 用户 QQ 号。 + request_type: 请求类型(``private`` 等)。 + webui_session: WebUI 会话时使用固定 ``webui`` 作用域。 + + Returns: + 形如 ``group:123`` / ``private:456`` / ``webui`` 的键,无法推断时返回 ``None``。 + """ + if webui_session: + return "webui" + + group = _coerce_positive_int(group_id) + if group is not None: + # 群聊作用域优先于私聊,避免同会话附件串读 + return f"group:{group}" + + user = _coerce_positive_int(user_id) + request_type_text = str(request_type or "").strip().lower() + if request_type_text == "private" and user is not None: + return f"private:{user}" + if user is not None: + return f"private:{user}" + return None + + +def scope_from_context(context: Mapping[str, Any] | None) -> str | None: + """从请求上下文字典提取附件作用域键。""" + if not context: + return None + return build_attachment_scope( + group_id=context.get("group_id"), + user_id=context.get("user_id"), + request_type=str(context.get("request_type", "") or ""), + webui_session=bool(context.get("webui_session", False)), + ) + + +def attachment_refs_to_text(attachments: Sequence[Mapping[str, str]]) -> str: + """将附件引用列表转为可读占位文本。""" + if not attachments: + return "" + parts: list[str] = [] + for item in attachments: + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + media_type = str(item.get("media_type") or item.get("kind") or "file").strip() + label = _MEDIA_LABELS.get(media_type, "附件") + name = str(item.get("display_name", "") or "").strip() + if name: + parts.append(f"[{label} uid={uid} name={name}]") + else: + parts.append(f"[{label} uid={uid}]") + return " ".join(parts) + + +def attachment_refs_to_xml( + attachments: Sequence[Mapping[str, str]], + *, + indent: str = " ", +) -> str: + """将附件引用列表序列化为 XML `` `` 片段。""" + if not attachments: + return "" + lines = [f"{indent} "] + for item in attachments: + uid = str(item.get("uid", "") or "").strip() + if not uid: + continue + kind = str(item.get("kind", "") or item.get("media_type", "") or "file").strip() + media_type = str(item.get("media_type", "") or kind or "file").strip() + name = str(item.get("display_name", "") or "").strip() + attrs = [ + f'uid="{escape_xml_attr(uid)}"', + f'type="{escape_xml_attr(kind or media_type)}"', + f'media_type="{escape_xml_attr(media_type)}"', + ] + if name: + attrs.append(f'name="{escape_xml_attr(name)}"') + source_kind = str(item.get("source_kind", "") or "").strip() + if source_kind: + attrs.append(f'source_kind="{escape_xml_attr(source_kind)}"') + source_ref = str(item.get("source_ref", "") or "").strip() + if source_ref: + attrs.append(f'source_ref="{escape_xml_attr(source_ref)}"') + semantic_kind = str(item.get("semantic_kind", "") or "").strip() + if semantic_kind: + attrs.append(f'semantic_kind="{escape_xml_attr(semantic_kind)}"') + description = str(item.get("description", "") or "").strip() + if description: + attrs.append(f'description="{escape_xml_attr(description)}"') + lines.append(f"{indent} ") + return "\n".join(lines) + + +def append_attachment_text( + base_text: str, attachments: Sequence[Mapping[str, str]] +) -> str: + """在基础文本后追加附件占位摘要行。""" + attachment_text = attachment_refs_to_text(attachments) + if not attachment_text: + return base_text + if not base_text.strip(): + return attachment_text + return f"{base_text}\n附件: {attachment_text}" + + +def is_http_url(value: str) -> bool: + """判断字符串是否为 HTTP(S) URL。""" + return value.startswith("http://") or value.startswith("https://") + + +def is_data_url(value: str) -> bool: + """判断字符串是否为 ``data:`` URL。""" + return value.startswith("data:") + + +def is_localish_path(value: str) -> bool: + """判断字符串是否像本地绝对路径或 ``file://`` URI。""" + return ( + value.startswith("/") + or value.startswith("file://") + or bool(_WINDOWS_ABS_PATH_RE.match(value)) + ) + + +def display_name_from_source(raw_source: str, fallback: str) -> str: + """从 URL 或路径推断展示文件名。""" + if not raw_source: + return fallback + if raw_source.startswith("file://"): + raw_source = raw_source[7:] + name = Path(unquote(urlsplit(raw_source).path)).name + return name or fallback + + +def media_kind_from_value(value: str) -> str: + """将任意媒体类型字符串规范为 registry 支持的 kind。""" + text = str(value or "").strip().lower() + if text in {"image", "file", "audio", "video", "record"}: + return text + return "file" + + +def segment_text( + type_: str, data: Mapping[str, Any], ref: Mapping[str, str] | None +) -> str: + """将单条 OneBot 消息段转为可读占位文本。""" + if type_ == "text": + return str(data.get("text", "") or "") + if type_ == "at": + qq = str(data.get("qq", "") or "").strip() + name = str(data.get("name") or data.get("nickname") or "").strip() + if qq and name: + return f"[@{qq}({name})]" + if qq: + return f"[@{qq}]" + return "[@]" + if type_ == "face": + return "[表情]" + if type_ == "reply": + reply_id = str(data.get("id") or data.get("message_id") or "").strip() + return f"[引用: {reply_id}]" if reply_id else "[引用]" + if type_ == "forward": + forward_id = str(data.get("id") or data.get("resid") or "").strip() + return f"[合并转发: {forward_id}]" if forward_id else "[合并转发]" + if ref is not None: + label = _MEDIA_LABELS.get( + str(ref.get("media_type") or ref.get("kind") or type_).strip(), "附件" + ) + uid = str(ref.get("uid", "") or "").strip() + name = str(ref.get("display_name", "") or "").strip() + if uid and name: + return f"[{label} uid={uid} name={name}]" + if uid: + return f"[{label} uid={uid}]" + label = _MEDIA_LABELS.get(type_, "附件") + raw = str(data.get("file") or data.get("url") or data.get("id") or "").strip() + return f"[{label}: {raw}]" if raw else f"[{label}]" + + +def _resolve_webui_file_id(file_id: str) -> Path | None: + if not file_id or not file_id.isalnum(): + return None + file_dir = (Path.cwd() / WEBUI_FILE_CACHE_DIR / file_id).resolve() + cache_root = (Path.cwd() / WEBUI_FILE_CACHE_DIR).resolve() + if cache_root not in file_dir.parents and file_dir != cache_root: + return None + if not file_dir.is_dir(): + return None + try: + files = list(file_dir.iterdir()) + except OSError: + return None + for candidate in files: + if candidate.is_file(): + return candidate + return None + + +def _extract_forward_id(data: Mapping[str, Any]) -> str: + forward_id = data.get("id") or data.get("resid") or data.get("message_id") + return str(forward_id).strip() if forward_id is not None else "" + + +def segment_data_from_onebot_data( + data: Mapping[str, Any], + *, + exclude_keys: set[str] | None = None, +) -> dict[str, str]: + """提取 OneBot 段 ``data`` 中需保留的字符串键值。""" + excluded = {key.strip().lower() for key in (exclude_keys or set()) if key.strip()} + normalized: dict[str, str] = {} + for raw_key, raw_value in data.items(): + key = str(raw_key or "").strip() + if not key: + continue + if key.lower() in excluded: + continue + text = str(raw_value or "").strip() + if not text: + continue + normalized[key] = text + return normalized + + +def normalize_message_segments(message: Any) -> list[Mapping[str, Any]]: + """将多种消息表示统一为 OneBot 段列表。""" + if isinstance(message, list): + normalized: list[Mapping[str, Any]] = [] + for item in message: + if isinstance(item, Mapping): + normalized.append(item) + elif isinstance(item, str): + normalized.append({"type": "text", "data": {"text": item}}) + return normalized + if isinstance(message, Mapping): + return [message] + if isinstance(message, str): + return [{"type": "text", "data": {"text": message}}] + return [] + + +def _normalize_forward_nodes(raw_nodes: Any) -> list[Mapping[str, Any]]: + if isinstance(raw_nodes, list): + return [node for node in raw_nodes if isinstance(node, Mapping)] + if isinstance(raw_nodes, Mapping): + messages = raw_nodes.get("messages") + if isinstance(messages, list): + return [node for node in messages if isinstance(node, Mapping)] + return [] + + +async def register_message_attachments( + *, + registry: AttachmentRegistry | None, + segments: Sequence[Mapping[str, Any]], + scope_key: str | None, + resolve_image_url: Callable[[str], Awaitable[str | None]] | None = None, + get_forward_messages: Callable[[str], Awaitable[list[dict[str, Any]]]] + | None = None, +) -> RegisteredMessageAttachments: + """扫描消息段并将图片/文件注册到 ``AttachmentRegistry``。 + + Args: + registry: 附件注册表;为 ``None`` 时仅归一化文本。 + segments: OneBot 消息段序列。 + scope_key: 会话作用域键。 + resolve_image_url: 可选,将 ``file`` 字段解析为可下载 URL。 + get_forward_messages: 可选,拉取合并转发子消息。 + + Returns: + 已注册附件引用与归一化纯文本。 + """ + attachments: list[dict[str, str]] = [] + normalized_parts: list[str] = [] + if registry is None or not scope_key: + for segment in segments: + type_ = str(segment.get("type", "") or "") + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + normalized_parts.append(segment_text(type_, data, None)) + return RegisteredMessageAttachments( + attachments=[], + normalized_text="".join(normalized_parts).strip(), + ) + + visited_forward_ids: set[str] = set() + + async def _collect_from_segments( + current_segments: Sequence[Mapping[str, Any]], + *, + depth: int, + prefix: str, + ) -> None: + for index, segment in enumerate(current_segments): + type_ = str(segment.get("type", "") or "").strip().lower() + raw_data = segment.get("data", {}) + data = raw_data if isinstance(raw_data, Mapping) else {} + ref: dict[str, str] | None = None + + try: + if type_ == "image": + raw_source = str(data.get("file") or data.get("url") or "").strip() + display_name = display_name_from_source( + raw_source, + f"image_{index + 1}.png", + ) + if raw_source.startswith("base64://"): + payload = raw_source[len("base64://") :].strip() + content = base64.b64decode(payload) + record = await registry.register_bytes( + scope_key, + content, + kind="image", + display_name=display_name, + source_kind="base64_image", + source_ref=f"{prefix}segment:{index}", + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + elif is_data_url(raw_source): + record = await registry.register_data_url( + scope_key, + raw_source, + kind="image", + display_name=display_name, + source_kind="data_url_image", + source_ref=f"{prefix}segment:{index}", + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + else: + resolved_source = raw_source + if raw_source and resolve_image_url is not None: + try: + # NapCat file id 需经 get_image 解析为可下载 URL + resolved = await resolve_image_url(raw_source) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] image resolver failed: file=%s err=%s", + raw_source, + exc, + ) + resolved = None + if resolved: + resolved_source = str(resolved) + + if is_http_url(resolved_source): + record = await registry.register_remote_url( + scope_key, + resolved_source, + kind="image", + display_name=display_name, + source_kind="remote_image", + source_ref=raw_source or resolved_source, + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + elif is_localish_path(resolved_source): + local_path = ( + resolved_source[7:] + if resolved_source.startswith("file://") + else resolved_source + ) + record = await registry.register_local_file( + scope_key, + local_path, + kind="image", + display_name=display_name, + source_kind="local_image", + source_ref=raw_source or resolved_source, + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + + elif type_ == "file": + file_id = str(data.get("id", "") or "").strip() + raw_source = str(data.get("file") or data.get("url") or "").strip() + local_file_path: Path | None = None + if file_id: + local_file_path = _resolve_webui_file_id(file_id) + elif is_localish_path(raw_source): + local_file_path = Path( + raw_source[7:] + if raw_source.startswith("file://") + else raw_source + ) + display_name = ( + str(data.get("name", "") or "").strip() + or (local_file_path.name if local_file_path is not None else "") + or display_name_from_source(raw_source, f"file_{index + 1}.bin") + ) + if local_file_path is not None and local_file_path.is_file(): + record = await registry.register_local_file( + scope_key, + local_file_path, + kind="file", + display_name=display_name, + source_kind="webui_file" if file_id else "local_file", + source_ref=file_id or raw_source or str(local_file_path), + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + elif is_http_url(raw_source): + record = await registry.register_remote_url( + scope_key, + raw_source, + kind="file", + display_name=display_name, + source_kind="remote_file", + source_ref=file_id or raw_source, + segment_data=segment_data_from_onebot_data( + data, + exclude_keys={"file", "url"}, + ), + ) + ref = record.prompt_ref() + + elif ( + type_ == "forward" + and get_forward_messages is not None + and depth < _FORWARD_ATTACHMENT_MAX_DEPTH + ): + # 合并转发递归展开,深度上限防止无限嵌套 + forward_id = _extract_forward_id(data) + if forward_id and forward_id not in visited_forward_ids: + visited_forward_ids.add(forward_id) + try: + nodes = _normalize_forward_nodes( + await get_forward_messages(forward_id) + ) + except Exception as exc: + logger.debug( + "[AttachmentRegistry] forward resolver failed: id=%s err=%s", + forward_id, + exc, + ) + nodes = [] + for node_index, node in enumerate(nodes): + raw_message = ( + node.get("content") + or node.get("message") + or node.get("raw_message") + ) + nested_segments = normalize_message_segments(raw_message) + if not nested_segments: + continue + await _collect_from_segments( + nested_segments, + depth=depth + 1, + prefix=f"{prefix}forward:{forward_id}:{node_index}:", + ) + except ( + binascii.Error, + ValueError, + FileNotFoundError, + httpx.HTTPError, + ) as exc: + logger.warning( + "[AttachmentRegistry] segment registration skipped: type=%s index=%s err=%s", + type_, + index, + exc, + ) + except Exception as exc: + logger.exception( + "[AttachmentRegistry] unexpected segment registration failure: type=%s index=%s err=%s", + type_, + index, + exc, + ) + + if ref is not None: + attachments.append(ref) + if depth == 0: + normalized_parts.append(segment_text(type_, data, ref)) + + await _collect_from_segments(segments, depth=0, prefix="") + + return RegisteredMessageAttachments( + attachments=attachments, + normalized_text="".join(normalized_parts).strip(), + ) diff --git a/src/Undefined/bilibili/wbi.py b/src/Undefined/bilibili/wbi.py index f898af05..f3fdb97b 100644 --- a/src/Undefined/bilibili/wbi.py +++ b/src/Undefined/bilibili/wbi.py @@ -152,11 +152,7 @@ def _compute_mixin_key(img_key: str, sub_key: str) -> str: return mixed[:32] -async def _refresh_mixin_key(client: httpx.AsyncClient) -> str: - resp = await client.get(_BILIBILI_API_NAV) - resp.raise_for_status() - payload = resp.json() - +def _mixin_key_from_nav_payload(payload: Any) -> str: if not isinstance(payload, dict): raise ValueError("nav 接口返回格式异常") @@ -186,6 +182,12 @@ async def _refresh_mixin_key(client: httpx.AsyncClient) -> str: return _compute_mixin_key(img_key, sub_key) +async def _refresh_mixin_key(client: httpx.AsyncClient) -> str: + resp = await client.get(_BILIBILI_API_NAV) + resp.raise_for_status() + return _mixin_key_from_nav_payload(resp.json()) + + async def get_mixin_key( client: httpx.AsyncClient, *, @@ -258,35 +260,7 @@ async def build_signed_params( def _refresh_mixin_key_sync(client: httpx.Client) -> str: resp = client.get(_BILIBILI_API_NAV) resp.raise_for_status() - payload = resp.json() - - if not isinstance(payload, dict): - raise ValueError("nav 接口返回格式异常") - - code = int(payload.get("code", -1)) - if code not in (0, -101): - message = payload.get("message", "未知错误") - raise ValueError(f"获取 wbi key 失败: {message} (code={code})") - - data = payload.get("data") - if not isinstance(data, dict): - raise ValueError("nav 接口 data 字段异常") - - wbi_img = data.get("wbi_img") - if not isinstance(wbi_img, dict): - raise ValueError("nav 接口 wbi_img 字段缺失") - - img_url = str(wbi_img.get("img_url", "")).strip() - sub_url = str(wbi_img.get("sub_url", "")).strip() - if not img_url or not sub_url: - raise ValueError("nav 接口未返回有效的 img_url/sub_url") - - img_key = _extract_key_from_url(img_url) - sub_key = _extract_key_from_url(sub_url) - if not img_key or not sub_key: - raise ValueError("无法提取有效的 img_key/sub_key") - - return _compute_mixin_key(img_key, sub_key) + return _mixin_key_from_nav_payload(resp.json()) def get_mixin_key_sync( diff --git a/src/Undefined/cognitive/historian/__init__.py b/src/Undefined/cognitive/historian/__init__.py new file mode 100644 index 00000000..0e1fce98 --- /dev/null +++ b/src/Undefined/cognitive/historian/__init__.py @@ -0,0 +1,5 @@ +"""认知史官 Worker 包。""" + +from Undefined.cognitive.historian.worker import HistorianWorker + +__all__ = ["HistorianWorker"] diff --git a/src/Undefined/cognitive/historian/helpers.py b/src/Undefined/cognitive/historian/helpers.py new file mode 100644 index 00000000..e1bda7da --- /dev/null +++ b/src/Undefined/cognitive/historian/helpers.py @@ -0,0 +1,85 @@ +"""Historian 辅助函数。""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime, timezone +from typing import Any + +logger = logging.getLogger(__name__) + +_MAX_LOG_PREVIEW_LEN = 200 + + +def _preview_text(text: str, max_len: int = _MAX_LOG_PREVIEW_LEN) -> str: + compact = re.sub(r"\s+", " ", str(text or "")).strip() + if len(compact) <= max_len: + return compact + return f"{compact[:max_len]}..." + + +def _extract_frontmatter_name(markdown: str) -> str: + text = str(markdown or "") + if not text.startswith("---"): + return "" + try: + import yaml + + parts = text[3:].split("---", 1) + if len(parts) != 2: + return "" + frontmatter = yaml.safe_load(parts[0]) + if not isinstance(frontmatter, dict): + return "" + value = frontmatter.get("name") + return str(value).strip() if value is not None else "" + except Exception: + return "" + + +def _escape_braces(text: str) -> str: + value = str(text or "") + return value.replace("{", "{{").replace("}", "}}") + + +def _resolve_timestamp_epoch(job: dict[str, Any]) -> int: + raw_epoch = job.get("timestamp_epoch") + if isinstance(raw_epoch, (int, float)): + return int(raw_epoch) + if isinstance(raw_epoch, str): + try: + return int(float(raw_epoch.strip())) + except Exception: + pass + + for key in ("timestamp_utc", "timestamp_local"): + raw_value = job.get(key) + if not isinstance(raw_value, str): + continue + text = raw_value.strip() + if not text: + continue + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return int(parsed.timestamp()) + except Exception: + continue + + return int(datetime.now(timezone.utc).timestamp()) + + +def _coerce_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "y", "on"}: + return True + if normalized in {"0", "false", "no", "n", "off", ""}: + return False + return False diff --git a/src/Undefined/cognitive/historian/tools.py b/src/Undefined/cognitive/historian/tools.py new file mode 100644 index 00000000..83cc3018 --- /dev/null +++ b/src/Undefined/cognitive/historian/tools.py @@ -0,0 +1,78 @@ +"""Historian LLM 工具定义。""" + +from __future__ import annotations + +_REWRITE_TOOL = { + "type": "function", + "function": { + "name": "submit_rewrite", + "description": "提交绝对化改写后的事件文本", + "parameters": { + "type": "object", + "properties": { + "text": {"type": "string", "description": "改写后的纯文本"}, + }, + "required": ["text"], + }, + }, +} +_READ_PROFILE_TOOL = { + "type": "function", + "function": { + "name": "read_profile", + "description": "读取指定实体的当前侧写内容", + "parameters": { + "type": "object", + "properties": { + "entity_type": { + "type": "string", + "enum": ["user", "group"], + "description": "实体类型:user 或 group", + }, + "entity_id": { + "type": "string", + "description": "实体 ID(用户 QQ 号或群号)", + }, + }, + "required": ["entity_type", "entity_id"], + }, + }, +} +_PROFILE_TOOL = { + "type": "function", + "function": { + "name": "update_profile", + "description": "更新用户/群侧写。调用前必须先用 read_profile 查看当前内容", + "parameters": { + "type": "object", + "properties": { + "entity_type": { + "type": "string", + "enum": ["user", "group"], + "description": "实体类型:user 或 group", + }, + "entity_id": { + "type": "string", + "description": "实体 ID(用户 QQ 号或群号)", + }, + "skip": { + "type": "boolean", + "description": "是否跳过更新;当新信息不稳定/不足时为 true", + }, + "skip_reason": { + "type": "string", + "description": "跳过原因", + }, + "name": {"type": "string", "description": "用户/群名称"}, + "tags": { + "type": "array", + "items": {"type": "string"}, + "maxItems": 10, + "description": "身份级标签(角色/核心领域),最多 10 个,不写话题", + }, + "summary": {"type": "string", "description": "侧写正文(Markdown)"}, + }, + "required": ["entity_type", "entity_id", "skip", "name", "tags", "summary"], + }, + }, +} diff --git a/src/Undefined/cognitive/historian.py b/src/Undefined/cognitive/historian/worker.py similarity index 86% rename from src/Undefined/cognitive/historian.py rename to src/Undefined/cognitive/historian/worker.py index fda333c3..30ecf43c 100644 --- a/src/Undefined/cognitive/historian.py +++ b/src/Undefined/cognitive/historian/worker.py @@ -1,170 +1,30 @@ -"""后台史官 Worker,轮询队列处理任务。""" +"""HistorianWorker 实现。""" from __future__ import annotations import asyncio import json import logging -import re -from datetime import datetime, timezone +from datetime import datetime from typing import Any, Callable from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY from Undefined.utils.tool_calls import extract_required_tool_call_arguments -logger = logging.getLogger(__name__) - -_MAX_LOG_PREVIEW_LEN = 200 - -_REWRITE_TOOL = { - "type": "function", - "function": { - "name": "submit_rewrite", - "description": "提交绝对化改写后的事件文本", - "parameters": { - "type": "object", - "properties": { - "text": {"type": "string", "description": "改写后的纯文本"}, - }, - "required": ["text"], - }, - }, -} - -_READ_PROFILE_TOOL = { - "type": "function", - "function": { - "name": "read_profile", - "description": "读取指定实体的当前侧写内容", - "parameters": { - "type": "object", - "properties": { - "entity_type": { - "type": "string", - "enum": ["user", "group"], - "description": "实体类型:user 或 group", - }, - "entity_id": { - "type": "string", - "description": "实体 ID(用户 QQ 号或群号)", - }, - }, - "required": ["entity_type", "entity_id"], - }, - }, -} - -_PROFILE_TOOL = { - "type": "function", - "function": { - "name": "update_profile", - "description": "更新用户/群侧写。调用前必须先用 read_profile 查看当前内容", - "parameters": { - "type": "object", - "properties": { - "entity_type": { - "type": "string", - "enum": ["user", "group"], - "description": "实体类型:user 或 group", - }, - "entity_id": { - "type": "string", - "description": "实体 ID(用户 QQ 号或群号)", - }, - "skip": { - "type": "boolean", - "description": "是否跳过更新;当新信息不稳定/不足时为 true", - }, - "skip_reason": { - "type": "string", - "description": "跳过原因", - }, - "name": {"type": "string", "description": "用户/群名称"}, - "tags": { - "type": "array", - "items": {"type": "string"}, - "maxItems": 10, - "description": "身份级标签(角色/核心领域),最多 10 个,不写话题", - }, - "summary": {"type": "string", "description": "侧写正文(Markdown)"}, - }, - "required": ["entity_type", "entity_id", "skip", "name", "tags", "summary"], - }, - }, -} - - -def _preview_text(text: str, max_len: int = _MAX_LOG_PREVIEW_LEN) -> str: - compact = re.sub(r"\s+", " ", str(text or "")).strip() - if len(compact) <= max_len: - return compact - return f"{compact[:max_len]}..." - - -def _extract_frontmatter_name(markdown: str) -> str: - text = str(markdown or "") - if not text.startswith("---"): - return "" - try: - import yaml - - parts = text[3:].split("---", 1) - if len(parts) != 2: - return "" - frontmatter = yaml.safe_load(parts[0]) - if not isinstance(frontmatter, dict): - return "" - value = frontmatter.get("name") - return str(value).strip() if value is not None else "" - except Exception: - return "" - - -def _escape_braces(text: str) -> str: - value = str(text or "") - return value.replace("{", "{{").replace("}", "}}") +from Undefined.cognitive.historian.helpers import ( + _coerce_bool, + _escape_braces, + _extract_frontmatter_name, + _preview_text, + _resolve_timestamp_epoch, +) +from Undefined.cognitive.historian.tools import ( + _PROFILE_TOOL, + _READ_PROFILE_TOOL, + _REWRITE_TOOL, +) - -def _resolve_timestamp_epoch(job: dict[str, Any]) -> int: - raw_epoch = job.get("timestamp_epoch") - if isinstance(raw_epoch, (int, float)): - return int(raw_epoch) - if isinstance(raw_epoch, str): - try: - return int(float(raw_epoch.strip())) - except Exception: - pass - - for key in ("timestamp_utc", "timestamp_local"): - raw_value = job.get(key) - if not isinstance(raw_value, str): - continue - text = raw_value.strip() - if not text: - continue - try: - parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) - return int(parsed.timestamp()) - except Exception: - continue - - return int(datetime.now(timezone.utc).timestamp()) - - -def _coerce_bool(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - normalized = value.strip().lower() - if normalized in {"1", "true", "yes", "y", "on"}: - return True - if normalized in {"0", "false", "no", "n", "off", ""}: - return False - return False +logger = logging.getLogger(__name__) class HistorianWorker: @@ -307,7 +167,6 @@ async def _process_job(self, job_id: str, job: dict[str, Any]) -> None: len(job.get("profile_targets", []) or []), ) - # 兼容旧版:优先 observations,fallback new_info raw_observations = ( job.get("observations") if "observations" in job @@ -363,7 +222,6 @@ async def _process_job(self, job_id: str, job: dict[str, Any]) -> None: len(canonical), ) - # 侧写合并:传入所有 canonical 文本 has_obs = ( job.get("has_observations") if "has_observations" in job @@ -413,9 +271,7 @@ async def _rewrite( ) -> str: from Undefined.utils.resources import read_text_resource - # 向后兼容:优先 memo,fallback action_summary memo = str(job.get("memo") if "memo" in job else job.get("action_summary", "")) - # 向后兼容:优先 observations,fallback new_info observations = str( job.get("observations") if "observations" in job @@ -537,7 +393,6 @@ def _resolve_profile_targets(self, job: dict[str, Any]) -> list[dict[str, str]]: if targets: return targets - # 向后兼容旧任务:沿用单目标策略。 entity_type = "group" if str(job.get("group_id", "")).strip() else "user" entity_id = str( job.get("group_id") or job.get("user_id") or job.get("sender_id", "") @@ -768,7 +623,6 @@ async def _merge_profile_target( preferred_name = str(target.get("preferred_name", "")).strip() - # 检索该实体的历史事件作为 merge 参考 observations_raw = job.get("observations", job.get("new_info", [])) observations_text = ( "\n".join(observations_raw) @@ -884,7 +738,6 @@ async def _merge_profile_target( ) break - # 追加 assistant 轮次 assistant_msg: dict[str, Any] = { "role": "assistant", "tool_calls": tool_calls, diff --git a/src/Undefined/cognitive/job_queue.py b/src/Undefined/cognitive/job_queue.py index 7c5eba61..e1b036d2 100644 --- a/src/Undefined/cognitive/job_queue.py +++ b/src/Undefined/cognitive/job_queue.py @@ -23,7 +23,6 @@ def __init__(self, base_path: str | Path) -> None: self._failed_dir = base / "failed" for d in (self._pending_dir, self._processing_dir, self._failed_dir): d.mkdir(parents=True, exist_ok=True) - # 启动时清理所有遗留的 lock 文件 stale_lock_count = 0 for d in (self._pending_dir, self._processing_dir, self._failed_dir): for lock_file in d.glob("*.lock"): @@ -65,7 +64,6 @@ def _pick() -> tuple[str, dict[str, Any]] | None: with open(dst, "r", encoding="utf-8") as fh: data = json.load(fh) - # 清理遗留的 lock 文件 lock_file = f.with_name(f"{f.name}.lock") lock_file.unlink(missing_ok=True) return f.stem, data @@ -123,7 +121,6 @@ async def requeue(self, job_id: str, error: str) -> None: data = await read_json(src) or {} data["_retry_count"] = data.get("_retry_count", 0) + 1 data["_last_error"] = error - # 先原子更新 processing 内容,再原子移动到 pending await write_json(src, data) await asyncio.to_thread(lambda: os.replace(src, dst)) logger.info( diff --git a/src/Undefined/cognitive/profile_storage.py b/src/Undefined/cognitive/profile_storage.py index 90585fae..d1f61215 100644 --- a/src/Undefined/cognitive/profile_storage.py +++ b/src/Undefined/cognitive/profile_storage.py @@ -74,7 +74,6 @@ def _write() -> None: p.parent.mkdir(parents=True, exist_ok=True) hist_dir.mkdir(parents=True, exist_ok=True) - # 备份现有版本 if p.exists(): ts = datetime.now().strftime("%Y%m%d%H%M%S%f") (hist_dir / f"{ts}.md").write_text( @@ -87,7 +86,6 @@ def _write() -> None: ts, ) - # 原子写入 fd, tmp = tempfile.mkstemp( prefix=f".{p.name}.", suffix=".tmp", dir=str(p.parent) ) @@ -102,7 +100,6 @@ def _write() -> None: pass raise - # 清理旧快照 snapshots = sorted(hist_dir.glob("*.md")) for old in snapshots[: max(0, len(snapshots) - self._revision_keep)]: try: @@ -140,7 +137,6 @@ def _list() -> list[str]: def _sanitize_profile(content: str, entity_type: str, entity_id: str) -> str: import yaml - # 剥离 ```markdown / ``` 包裹 stripped = content.strip() if stripped.startswith("```"): lines = stripped.splitlines() @@ -151,7 +147,6 @@ def _sanitize_profile(content: str, entity_type: str, entity_id: str) -> str: if end: stripped = "\n".join(lines[1:end]) - # 解析 frontmatter if stripped.startswith("---"): parts = stripped[3:].split("---", 1) if len(parts) == 2: diff --git a/src/Undefined/cognitive/service/__init__.py b/src/Undefined/cognitive/service/__init__.py new file mode 100644 index 00000000..45878197 --- /dev/null +++ b/src/Undefined/cognitive/service/__init__.py @@ -0,0 +1,5 @@ +"""认知记忆服务包。""" + +from Undefined.cognitive.service.service import CognitiveService + +__all__ = ["CognitiveService"] diff --git a/src/Undefined/cognitive/service/helpers.py b/src/Undefined/cognitive/service/helpers.py new file mode 100644 index 00000000..74752c48 --- /dev/null +++ b/src/Undefined/cognitive/service/helpers.py @@ -0,0 +1,169 @@ +"""认知服务辅助函数。""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from Undefined.utils.coerce import safe_float + + +def _parse_iso_to_epoch_seconds(value: Any) -> int | None: + if not isinstance(value, str): + return None + text = value.strip() + if not text: + return None + try: + parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) + except Exception: + return None + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return int(parsed.timestamp()) + + +def _compose_where(clauses: list[dict[str, Any]]) -> dict[str, Any] | None: + if not clauses: + return None + if len(clauses) == 1: + return clauses[0] + return {"$and": clauses} + + +def _event_base_score(item: dict[str, Any]) -> float: + # 优先 rerank 分,否则用 1-distance 作为相似度 + rerank_score = item.get("rerank_score") + if isinstance(rerank_score, (int, float)): + return max(0.0, float(rerank_score)) + if isinstance(rerank_score, str): + try: + return max(0.0, float(rerank_score.strip())) + except Exception: + pass + similarity = 1.0 - safe_float(item.get("distance"), default=1.0) + if similarity < 0.0: + return 0.0 + if similarity > 1.0: + return 1.0 + return similarity + + +def _event_timestamp_epoch(metadata: Any) -> float: + if not isinstance(metadata, dict): + return float("-inf") + raw_epoch = metadata.get("timestamp_epoch") + if isinstance(raw_epoch, (int, float)): + return float(raw_epoch) + if isinstance(raw_epoch, str): + try: + return float(raw_epoch.strip()) + except Exception: + pass + for key in ("timestamp_utc", "timestamp_local"): + parsed = _parse_iso_to_epoch_seconds(metadata.get(key)) + if parsed is not None: + return float(parsed) + return float("-inf") + + +def _event_dedupe_key(item: dict[str, Any]) -> tuple[str, str, str, str, str, str]: + metadata = item.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + return ( + str(item.get("document", "")).strip(), + str(metadata.get("timestamp_epoch", "")).strip(), + str(metadata.get("timestamp_local", "")).strip(), + str(metadata.get("group_id", "")).strip(), + str(metadata.get("sender_id", "")).strip(), + str(metadata.get("user_id", "")).strip(), + ) + + +def _resolve_auto_request_type( + *, + request_type: str | None, + group_id: str, + user_id: str, + sender_id: str, +) -> str: + normalized = str(request_type or "").strip().lower() + if normalized in {"group", "private"}: + return normalized + if group_id: + return "group" + if sender_id or user_id: + return "private" + return "" + + +def _parse_profile_markdown(markdown: str) -> tuple[dict[str, Any], str] | None: + text = str(markdown or "") + if not text.startswith("---"): + return None + try: + import yaml + + parts = text[3:].split("---", 1) + if len(parts) != 2: + return None + frontmatter = yaml.safe_load(parts[0]) + if not isinstance(frontmatter, dict): + return None + body = parts[1].lstrip("\n") + return frontmatter, body + except Exception: + return None + + +def _serialize_profile_markdown(frontmatter: dict[str, Any], body: str) -> str: + import yaml + + return f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{body}" + + +def _normalize_profile_tags(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [str(item).strip() for item in value if str(item).strip()] + + +def _current_profile_name(entity_type: str, frontmatter: dict[str, Any]) -> str: + if entity_type == "user": + return str(frontmatter.get("nickname") or frontmatter.get("name") or "").strip() + return str(frontmatter.get("group_name") or frontmatter.get("name") or "").strip() + + +def _build_profile_vector_payload( + *, + entity_type: str, + entity_id: str, + effective_name: str, + tags: list[str], + summary: str, +) -> tuple[str, dict[str, Any]]: + profile_doc_lines: list[str] = [] + if entity_type == "user": + profile_doc_lines.append(f"昵称: {effective_name}") + profile_doc_lines.append(f"QQ号: {entity_id}") + else: + profile_doc_lines.append(f"群名: {effective_name}") + profile_doc_lines.append(f"群号: {entity_id}") + if tags: + profile_doc_lines.append(f"标签: {', '.join(tags)}") + profile_doc_lines.append(summary) + profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) + + metadata: dict[str, Any] = { + "entity_type": entity_type, + "entity_id": entity_id, + "name": effective_name, + } + if entity_type == "user": + metadata["nickname"] = effective_name + metadata["qq"] = entity_id + else: + metadata["group_name"] = effective_name + metadata["group_id"] = entity_id + return profile_doc, metadata diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service/service.py similarity index 84% rename from src/Undefined/cognitive/service.py rename to src/Undefined/cognitive/service/service.py index 13cf3f1c..c7a69e9d 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service/service.py @@ -1,4 +1,4 @@ -"""认知记忆服务门面。""" +"""认知记忆服务实现。""" from __future__ import annotations @@ -10,171 +10,24 @@ from Undefined.context import RequestContext from Undefined.utils.coerce import safe_float - -logger = logging.getLogger(__name__) +from Undefined.cognitive.service.helpers import ( + _build_profile_vector_payload, + _compose_where, + _current_profile_name, + _event_base_score, + _event_dedupe_key, + _event_timestamp_epoch, + _normalize_profile_tags, + _parse_iso_to_epoch_seconds, + _parse_profile_markdown, + _resolve_auto_request_type, + _serialize_profile_markdown, +) if TYPE_CHECKING: from Undefined.knowledge.runtime import RetrievalRuntime - -def _parse_iso_to_epoch_seconds(value: Any) -> int | None: - if not isinstance(value, str): - return None - text = value.strip() - if not text: - return None - try: - parsed = datetime.fromisoformat(text.replace("Z", "+00:00")) - except Exception: - return None - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) - return int(parsed.timestamp()) - - -def _compose_where(clauses: list[dict[str, Any]]) -> dict[str, Any] | None: - if not clauses: - return None - if len(clauses) == 1: - return clauses[0] - return {"$and": clauses} - - -def _event_base_score(item: dict[str, Any]) -> float: - rerank_score = item.get("rerank_score") - if isinstance(rerank_score, (int, float)): - return max(0.0, float(rerank_score)) - if isinstance(rerank_score, str): - try: - return max(0.0, float(rerank_score.strip())) - except Exception: - pass - similarity = 1.0 - safe_float(item.get("distance"), default=1.0) - if similarity < 0.0: - return 0.0 - if similarity > 1.0: - return 1.0 - return similarity - - -def _event_timestamp_epoch(metadata: Any) -> float: - if not isinstance(metadata, dict): - return float("-inf") - raw_epoch = metadata.get("timestamp_epoch") - if isinstance(raw_epoch, (int, float)): - return float(raw_epoch) - if isinstance(raw_epoch, str): - try: - return float(raw_epoch.strip()) - except Exception: - pass - for key in ("timestamp_utc", "timestamp_local"): - parsed = _parse_iso_to_epoch_seconds(metadata.get(key)) - if parsed is not None: - return float(parsed) - return float("-inf") - - -def _event_dedupe_key(item: dict[str, Any]) -> tuple[str, str, str, str, str, str]: - metadata = item.get("metadata") - if not isinstance(metadata, dict): - metadata = {} - return ( - str(item.get("document", "")).strip(), - str(metadata.get("timestamp_epoch", "")).strip(), - str(metadata.get("timestamp_local", "")).strip(), - str(metadata.get("group_id", "")).strip(), - str(metadata.get("sender_id", "")).strip(), - str(metadata.get("user_id", "")).strip(), - ) - - -def _resolve_auto_request_type( - *, - request_type: str | None, - group_id: str, - user_id: str, - sender_id: str, -) -> str: - normalized = str(request_type or "").strip().lower() - if normalized in {"group", "private"}: - return normalized - if group_id: - return "group" - if sender_id or user_id: - return "private" - return "" - - -def _parse_profile_markdown(markdown: str) -> tuple[dict[str, Any], str] | None: - text = str(markdown or "") - if not text.startswith("---"): - return None - try: - import yaml - - parts = text[3:].split("---", 1) - if len(parts) != 2: - return None - frontmatter = yaml.safe_load(parts[0]) - if not isinstance(frontmatter, dict): - return None - body = parts[1].lstrip("\n") - return frontmatter, body - except Exception: - return None - - -def _serialize_profile_markdown(frontmatter: dict[str, Any], body: str) -> str: - import yaml - - return f"---\n{yaml.dump(frontmatter, allow_unicode=True)}---\n{body}" - - -def _normalize_profile_tags(value: Any) -> list[str]: - if not isinstance(value, list): - return [] - return [str(item).strip() for item in value if str(item).strip()] - - -def _current_profile_name(entity_type: str, frontmatter: dict[str, Any]) -> str: - if entity_type == "user": - return str(frontmatter.get("nickname") or frontmatter.get("name") or "").strip() - return str(frontmatter.get("group_name") or frontmatter.get("name") or "").strip() - - -def _build_profile_vector_payload( - *, - entity_type: str, - entity_id: str, - effective_name: str, - tags: list[str], - summary: str, -) -> tuple[str, dict[str, Any]]: - profile_doc_lines: list[str] = [] - if entity_type == "user": - profile_doc_lines.append(f"昵称: {effective_name}") - profile_doc_lines.append(f"QQ号: {entity_id}") - else: - profile_doc_lines.append(f"群名: {effective_name}") - profile_doc_lines.append(f"群号: {entity_id}") - if tags: - profile_doc_lines.append(f"标签: {', '.join(tags)}") - profile_doc_lines.append(summary) - profile_doc = "\n".join(line for line in profile_doc_lines if line.strip()) - - metadata: dict[str, Any] = { - "entity_type": entity_type, - "entity_id": entity_id, - "name": effective_name, - } - if entity_type == "user": - metadata["nickname"] = effective_name - metadata["qq"] = entity_id - else: - metadata["group_name"] = effective_name - metadata["group_id"] = entity_id - return profile_doc, metadata +logger = logging.getLogger(__name__) class CognitiveService: @@ -320,7 +173,6 @@ def _merge_weighted_events( safe_group_boost = max(0.0, float(current_group_boost)) seen_keys: set[tuple[str, str, str, str, str, str]] = set() # 排序主键优先使用“作用域内原始排名”(已含 time_decay/mmr/rerank 效果), - # 再使用相似度分值兜底,避免跨 scope 合并时打乱衰减后的顺序。 scored_items: list[ tuple[float, float, float, float, float, int, dict[str, Any]] ] = [] @@ -547,7 +399,6 @@ async def enqueue_job( else str(context.get("request_id", "")).strip() ) if not safe_request_id: - # 最终兜底由 JobQueue 生成 request_id。 safe_request_id = "" end_seq_raw = context.get("_end_seq", 0) @@ -699,7 +550,6 @@ async def build_context( getattr(config, "auto_top_k", 5), ) - # 用户侧写(优先 sender_id,与 enqueue_job 写入侧一致) uid = safe_sender_id or safe_user_id if uid: profile = await self._profile_storage.read_profile("user", uid) @@ -707,7 +557,6 @@ async def build_context( label = f"{sender_name}(UID: {uid})" if sender_name else f"UID: {uid}" parts.append(f"## 用户侧写 — {label}\n{profile}") - # 群聊侧写 if safe_group_id: gprofile = await self._profile_storage.read_profile("group", safe_group_id) if gprofile: diff --git a/src/Undefined/cognitive/vector_store.py b/src/Undefined/cognitive/vector_store.py index 48b80eb1..076b8a9f 100644 --- a/src/Undefined/cognitive/vector_store.py +++ b/src/Undefined/cognitive/vector_store.py @@ -129,7 +129,6 @@ def _mmr_select( if n <= top_k: return np.arange(n) - # 预计算 query-doc 相关性(cosine similarity) query_norm = np.sqrt(np.sum(query_embedding * query_embedding)) relevance = np.empty(n, dtype=np.float64) norms = np.empty(n, dtype=np.float64) @@ -161,7 +160,6 @@ def _mmr_select( return selected[:step] selected[step] = best_idx chosen[best_idx] = True - # 更新 max_sim_to_selected:新选中项与所有未选中项的相似度 if norms[best_idx] > 0.0: for j in range(n): if not chosen[j] and norms[j] > 0.0: @@ -446,7 +444,6 @@ async def _query( query_embedding=query_embedding, ) embed_duration = time.perf_counter() - embed_started - # 重排要求候选数 > 最终返回数,否则重排无意义 use_reranker = bool(reranker) and safe_multiplier >= 2 if reranker and safe_multiplier < 2: logger.warning( diff --git a/src/Undefined/config/__init__.py b/src/Undefined/config/__init__.py index 7242f3cf..4fe960fb 100644 --- a/src/Undefined/config/__init__.py +++ b/src/Undefined/config/__init__.py @@ -2,7 +2,7 @@ from typing import Optional -from .loader import Config, WebUISettings, load_webui_settings +from .config_class import Config, ConfigBuilder from .manager import ConfigManager from .models import ( APIConfig, @@ -19,9 +19,11 @@ SecurityModelConfig, VisionModelConfig, ) +from .webui_settings import WebUISettings, load_webui_settings __all__ = [ "Config", + "ConfigBuilder", "ChatModelConfig", "VisionModelConfig", "SecurityModelConfig", @@ -37,11 +39,11 @@ "RenderCacheConfig", "get_config", "get_config_manager", + "set_config", "load_webui_settings", "WebUISettings", ] -# 全局配置实例 _config: Optional[Config] = None _config_manager: Optional[ConfigManager] = None @@ -60,3 +62,10 @@ def get_config(strict: bool = True) -> Config: if _config is None: _config = get_config_manager().load(strict=strict) return _config + + +def set_config(config: Config) -> None: + """注入 Config 单例(库嵌入 opt-in;CLI / WebUI 启动链不得调用)。""" + global _config + _config = config + get_config_manager().replace(config) diff --git a/src/Undefined/config/build_config.py b/src/Undefined/config/build_config.py new file mode 100644 index 00000000..72021e21 --- /dev/null +++ b/src/Undefined/config/build_config.py @@ -0,0 +1,57 @@ +"""Build Config from parsed TOML mapping.""" + +from __future__ import annotations + + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from .config_class import Config + +from .load_sections import ( + load_access, + load_core, + load_domains, + load_finalize, + load_history_skills, + load_integrations, + load_knowledge, + load_logging_tools, + load_models, + load_network, +) + + +# 从中间态构建最终对象 +def build_config( + data: dict[str, Any], + *, + strict: bool = True, + config_path: Optional[Path] = None, +) -> "Config": + """从已解析的 TOML mapping 构建 Config。""" + from .config_class import Config + + # 按依赖顺序分阶段加载:core/knowledge/models 在前,access 等依赖 admin 合并 + ctx: dict[str, Any] = {} + ctx.update(load_core(data)) + ctx.update(load_knowledge(data)) + ctx.update(load_models(data)) + from .model_parsers import _merge_admins + + # 合并 config.toml 与本地 admins.json,超管始终纳入 admin 列表 + superadmin_qq, admin_qqs = _merge_admins( + superadmin_qq=ctx["superadmin_qq"], admin_qqs=ctx["admin_qqs"] + ) + ctx["superadmin_qq"] = superadmin_qq + ctx["admin_qqs"] = admin_qqs + ctx.update(load_access(data)) + ctx.update(load_logging_tools(data)) + ctx.update(load_history_skills(data)) + ctx.update(load_network(data)) + ctx.update(load_integrations(data)) + # domains 含 WebUI/API/认知/合并器等子域,放最后以便前面模型段已就绪 + ctx.update(load_domains(data, config_path=config_path)) + load_finalize(ctx, strict=strict) # strict 时校验必填项并打 debug 摘要 + return Config(**ctx) diff --git a/src/Undefined/config/config_class.py b/src/Undefined/config/config_class.py new file mode 100644 index 00000000..b29a786e --- /dev/null +++ b/src/Undefined/config/config_class.py @@ -0,0 +1,583 @@ +"""Config dataclass and instance methods.""" + +from __future__ import annotations + + +from dataclasses import dataclass, field as dataclass_field, fields +from pathlib import Path +from typing import Any, Optional + +from .admin import load_local_admins, save_local_admins +from .domain_parsers import _update_dataclass +from .models import ( + AgentModelConfig, + APIConfig, + ChatModelConfig, + CognitiveConfig, + EmbeddingModelConfig, + GrokModelConfig, + ImageGenConfig, + ImageGenModelConfig, + MemeConfig, + MessageBatcherConfig, + NagaConfig, + RenderCacheConfig, + RerankModelConfig, + SecurityModelConfig, + VisionModelConfig, +) +from .toml_io import _load_env, load_toml_data + + +@dataclass +class Config: + """应用配置""" + + bot_qq: int + superadmin_qq: int + admin_qqs: list[int] + # 访问控制模式:off / blacklist / allowlist + access_mode: str + # 访问控制(会话白名单 + 黑名单) + allowed_group_ids: list[int] + blocked_group_ids: list[int] + allowed_private_ids: list[int] + blocked_private_ids: list[int] + # 是否允许超级管理员在私聊中绕过 allowed_private_ids(仅私聊收发) + superadmin_bypass_allowlist: bool + # 是否允许超级管理员在私聊中绕过 blocked_private_ids(仅私聊收发) + superadmin_bypass_private_blacklist: bool + forward_proxy_qq: int | None + process_every_message: bool + process_private_message: bool + process_poke_message: bool + keyword_reply_enabled: bool + repeat_enabled: bool + repeat_threshold: int + repeat_cooldown_minutes: int + inverted_question_enabled: bool + context_recent_messages_limit: int + ai_request_max_retries: int + missing_tool_call_retries: int + nagaagent_mode_enabled: bool + onebot_ws_url: str + onebot_token: str + chat_model: ChatModelConfig + vision_model: VisionModelConfig + security_model_enabled: bool + security_model: SecurityModelConfig + naga_model: SecurityModelConfig + agent_model: AgentModelConfig + historian_model: AgentModelConfig + summary_model: AgentModelConfig + summary_model_configured: bool + grok_model: GrokModelConfig + model_pool_enabled: bool + log_level: str + log_file_path: str + log_max_size: int + log_backup_count: int + log_tty_enabled: bool + log_thinking: bool + tools_dot_delimiter: str + tools_description_truncate_enabled: bool + tools_description_max_len: int + tools_sanitize_verbose: bool + tools_description_preview_len: int + easter_egg_agent_call_message_mode: str + token_usage_max_size_mb: int + token_usage_max_archives: int + token_usage_max_total_mb: int + token_usage_archive_prune_mode: str + history_max_records: int + history_filtered_result_limit: int + history_search_scan_limit: int + history_summary_fetch_limit: int + history_summary_time_fetch_limit: int + history_onebot_fetch_limit: int + history_group_analysis_limit: int + attachment_remote_download_max_size_mb: int + attachment_cache_max_total_size_mb: int + attachment_cache_max_records: int + attachment_cache_max_age_days: int + attachment_url_reference_max_records: int + attachment_url_max_length: int + skills_hot_reload: bool + skills_hot_reload_interval: float + skills_hot_reload_debounce: float + agent_intro_autogen_enabled: bool + agent_intro_autogen_queue_interval: float + agent_intro_autogen_max_tokens: int + agent_intro_hash_path: str + searxng_url: str + grok_search_enabled: bool + use_proxy: bool + http_proxy: str + https_proxy: str + network_request_timeout: float + network_request_retries: int + render_browser_max_concurrency: int + api_xxapi_base_url: str + api_xingzhige_base_url: str + api_jkyai_base_url: str + api_seniverse_base_url: str + weather_api_key: str + xxapi_api_token: str + mcp_config_path: str + prefetch_tools: list[str] + prefetch_tools_hide: bool + webui_url: str + webui_port: int + webui_password: str + api: APIConfig + # Code Delivery Agent + code_delivery_enabled: bool + code_delivery_task_root: str + code_delivery_docker_image: str + code_delivery_container_name_prefix: str + code_delivery_container_name_suffix: str + code_delivery_command_timeout: int + code_delivery_max_command_output: int + code_delivery_default_archive_format: str + code_delivery_max_archive_size_mb: int + code_delivery_cleanup_on_finish: bool + code_delivery_cleanup_on_start: bool + code_delivery_llm_max_retries: int + code_delivery_notify_on_llm_failure: bool + code_delivery_container_memory_limit: str + code_delivery_container_cpu_limit: str + code_delivery_command_blacklist: list[str] + # messages 工具集 + messages_send_text_file_max_size_kb: int + messages_send_url_file_max_size_mb: int + # 嵌入模型 + embedding_model: EmbeddingModelConfig + rerank_model: RerankModelConfig + # 知识库 + knowledge_enabled: bool + knowledge_base_dir: str + knowledge_auto_scan: bool + knowledge_auto_embed: bool + knowledge_scan_interval: float + knowledge_embed_batch_size: int + knowledge_chunk_size: int + knowledge_chunk_overlap: int + knowledge_default_top_k: int + knowledge_enable_rerank: bool + knowledge_rerank_top_k: int + # Bilibili 视频提取 + bilibili_auto_extract_enabled: bool + bilibili_cookie: str + bilibili_prefer_quality: int + bilibili_max_duration: int + bilibili_max_file_size: int + bilibili_oversize_strategy: str + bilibili_danmaku_enabled: bool + bilibili_danmaku_batch_size: int + bilibili_danmaku_max_count: int + bilibili_auto_extract_group_ids: list[int] + bilibili_auto_extract_private_ids: list[int] + # arXiv 论文提取 + arxiv_auto_extract_enabled: bool + arxiv_max_file_size: int + arxiv_auto_extract_group_ids: list[int] + arxiv_auto_extract_private_ids: list[int] + arxiv_auto_extract_max_items: int + arxiv_author_preview_limit: int + arxiv_summary_preview_chars: int + # GitHub 仓库自动提取 + github_auto_extract_enabled: bool + github_request_timeout_seconds: float + github_auto_extract_group_ids: list[int] + github_auto_extract_private_ids: list[int] + github_auto_extract_max_items: int + # 认知记忆 + cognitive: CognitiveConfig + # 表情包库 + memes: MemeConfig + # 同 sender 短时多消息合并器 + message_batcher: MessageBatcherConfig + # HTML 渲染结果缓存 + render_cache: RenderCacheConfig + # Naga 集成 + naga: NagaConfig + # 生图工具配置 + image_gen: ImageGenConfig + models_image_gen: ImageGenModelConfig + models_image_edit: ImageGenModelConfig + _allowed_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _blocked_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _allowed_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _blocked_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _bilibili_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _bilibili_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _arxiv_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _arxiv_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _github_group_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + _github_private_ids_set: set[int] = dataclass_field( + default_factory=set, + init=False, + repr=False, + ) + + def __post_init__(self) -> None: + # 访问控制属于高频热路径,启动后缓存为 set 降低重复构建开销。 + normalized_mode = str(self.access_mode).strip().lower() + if normalized_mode not in {"off", "blacklist", "allowlist", "legacy"}: + normalized_mode = "off" + self.access_mode = normalized_mode + self._allowed_group_ids_set = {int(item) for item in self.allowed_group_ids} + self._blocked_group_ids_set = {int(item) for item in self.blocked_group_ids} + self._allowed_private_ids_set = {int(item) for item in self.allowed_private_ids} + self._blocked_private_ids_set = {int(item) for item in self.blocked_private_ids} + self._bilibili_group_ids_set = { + int(item) for item in self.bilibili_auto_extract_group_ids + } + self._bilibili_private_ids_set = { + int(item) for item in self.bilibili_auto_extract_private_ids + } + self._arxiv_group_ids_set = { + int(item) for item in self.arxiv_auto_extract_group_ids + } + self._arxiv_private_ids_set = { + int(item) for item in self.arxiv_auto_extract_private_ids + } + self._github_group_ids_set = { + int(item) for item in self.github_auto_extract_group_ids + } + self._github_private_ids_set = { + int(item) for item in self.github_auto_extract_private_ids + } + + @classmethod + def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Config": + """从 config.toml 和本地配置加载配置""" + from .build_config import build_config + + _load_env() # 先加载 .env,供 _get_value 环境变量回退 + data = load_toml_data(config_path, strict=strict) + return build_config(data, strict=strict, config_path=config_path) + + @classmethod + def from_mapping(cls, data: dict[str, Any], *, strict: bool = True) -> "Config": + """从内存 mapping 构建配置(无 TOML 文件)。""" + from .build_config import build_config + + return build_config(data, strict=strict, config_path=None) + + @classmethod + def builder(cls) -> "ConfigBuilder": + """返回可链式覆盖字段的配置构建器。""" + return ConfigBuilder() + + @property + def bilibili_sessdata(self) -> str: + """兼容旧字段名,等价于 bilibili_cookie。""" + return self.bilibili_cookie + + def allowlist_mode_enabled(self) -> bool: + """是否启用白名单限制模式。""" + + return self.access_mode in {"allowlist", "legacy"} and ( + bool(self.allowed_group_ids) or bool(self.allowed_private_ids) + ) + + def group_allowlist_enabled(self) -> bool: + """群聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" + + return bool(self.allowed_group_ids) + + def private_allowlist_enabled(self) -> bool: + """私聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" + + return bool(self.allowed_private_ids) + + def blacklist_mode_enabled(self) -> bool: + """是否启用黑名单限制模式。""" + + return self.access_mode in {"blacklist", "legacy"} and ( + bool(self.blocked_group_ids) or bool(self.blocked_private_ids) + ) + + def access_control_enabled(self) -> bool: + """是否启用访问控制。""" + + return self.allowlist_mode_enabled() or self.blacklist_mode_enabled() + + def group_access_denied_reason(self, group_id: int) -> str | None: + """群聊访问被拒绝原因。 + + 返回: + - "blacklist": 命中 access.blocked_group_ids + - "allowlist": allowlist 模式下不在 access.allowed_group_ids + - None: 允许访问 + """ + + normalized_group_id = int(group_id) + if self.access_mode == "off": + return None + if self.access_mode == "blacklist": + if normalized_group_id in self._blocked_group_ids_set: + return "blacklist" + return None + if self.access_mode == "legacy": + if normalized_group_id in self._blocked_group_ids_set: + return "blacklist" + if not self.allowlist_mode_enabled(): + return None + if normalized_group_id not in self._allowed_group_ids_set: + return "allowlist" + return None + if not self.group_allowlist_enabled(): + return None + if normalized_group_id not in self._allowed_group_ids_set: + return "allowlist" + return None + + def is_group_allowed(self, group_id: int) -> bool: + """群聊是否允许收发消息。""" + + return self.group_access_denied_reason(group_id) is None + + def private_access_denied_reason(self, user_id: int) -> str | None: + """私聊访问被拒绝原因。""" + + normalized_user_id = int(user_id) + if self.access_mode == "off": + return None + if self.access_mode == "blacklist": + if normalized_user_id not in self._blocked_private_ids_set: + return None + if ( + self.superadmin_bypass_private_blacklist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + return "blacklist" + if self.access_mode == "legacy": + if normalized_user_id in self._blocked_private_ids_set: + if ( + self.superadmin_bypass_private_blacklist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + return "blacklist" + if not self.allowlist_mode_enabled(): + return None + if ( + self.superadmin_bypass_allowlist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + if normalized_user_id not in self._allowed_private_ids_set: + return "allowlist" + return None + if not self.private_allowlist_enabled(): + return None + if ( + self.superadmin_bypass_allowlist + and normalized_user_id == int(self.superadmin_qq) + and self.superadmin_qq > 0 + ): + return None + if normalized_user_id not in self._allowed_private_ids_set: + return "allowlist" + return None + + def is_private_allowed(self, user_id: int) -> bool: + """私聊是否允许收发消息。""" + + return self.private_access_denied_reason(user_id) is None + + def is_bilibili_auto_extract_allowed_group(self, group_id: int) -> bool: + """群聊是否允许 bilibili 自动提取。""" + if self._bilibili_group_ids_set: + return int(group_id) in self._bilibili_group_ids_set + # 功能白名单为空时跟随全局 access 控制 + return self.is_group_allowed(group_id) + + def is_bilibili_auto_extract_allowed_private(self, user_id: int) -> bool: + """私聊是否允许 bilibili 自动提取。""" + if self._bilibili_private_ids_set: + return int(user_id) in self._bilibili_private_ids_set + # 功能白名单为空时跟随全局 access 控制 + return self.is_private_allowed(user_id) + + def is_arxiv_auto_extract_allowed_group(self, group_id: int) -> bool: + """群聊是否允许 arXiv 自动提取。""" + if self._arxiv_group_ids_set: + return int(group_id) in self._arxiv_group_ids_set + return self.is_group_allowed(group_id) + + def is_arxiv_auto_extract_allowed_private(self, user_id: int) -> bool: + """私聊是否允许 arXiv 自动提取。""" + if self._arxiv_private_ids_set: + return int(user_id) in self._arxiv_private_ids_set + return self.is_private_allowed(user_id) + + def is_github_auto_extract_allowed_group(self, group_id: int) -> bool: + """群聊是否允许 GitHub 仓库自动提取。""" + if self._github_group_ids_set: + return int(group_id) in self._github_group_ids_set + return self.is_group_allowed(group_id) + + def is_github_auto_extract_allowed_private(self, user_id: int) -> bool: + """私聊是否允许 GitHub 仓库自动提取。""" + if self._github_private_ids_set: + return int(user_id) in self._github_private_ids_set + return self.is_private_allowed(user_id) + + def should_process_group_message(self, is_at_bot: bool) -> bool: + """是否处理该条群消息。""" + + if self.process_every_message: + return True + return bool(is_at_bot) + + def should_process_private_message(self) -> bool: + """是否处理私聊消息回复。""" + + return bool(self.process_private_message) + + def should_process_poke_message(self) -> bool: + """是否处理拍一拍触发。""" + + return bool(self.process_poke_message) + + def get_context_recent_messages_limit(self) -> int: + """获取上下文最近历史消息条数上限。""" + + limit = int(self.context_recent_messages_limit) + if limit < 0: + return 0 + return limit + + def security_check_enabled(self) -> bool: + """是否启用安全模型检查。""" + # 热更新运行时参数 + + return bool(self.security_model_enabled) + + # 热更新运行时参数 + def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: + # 逐字段 diff;嵌套模型配置用 _update_dataclass 展开为 chat_model.api_url 等键 + changes: dict[str, tuple[Any, Any]] = {} + for field in fields(self): + name = field.name + old_value = getattr(self, name) + new_value = getattr(new_config, name) + if isinstance( + old_value, + ( + ChatModelConfig, + VisionModelConfig, + SecurityModelConfig, + AgentModelConfig, + GrokModelConfig, + ), + ): + changes.update(_update_dataclass(old_value, new_value, prefix=name)) + continue + if old_value != new_value: + setattr(self, name, new_value) + changes[name] = (old_value, new_value) + return changes + + def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: + # 对外入队 API + new_config = Config.load(strict=strict) + return self.update_from(new_config) + + # 对外入队 API + def add_admin(self, qq: int) -> bool: + if qq in self.admin_qqs: + return False + self.admin_qqs.append(qq) + local_admins = load_local_admins() + if qq not in local_admins: + local_admins.append(qq) + save_local_admins(local_admins) + return True + + def remove_admin(self, qq: int) -> bool: + if qq == self.superadmin_qq or qq not in self.admin_qqs: + return False + self.admin_qqs.remove(qq) + local_admins = load_local_admins() + if qq in local_admins: + local_admins.remove(qq) + save_local_admins(local_admins) + return True + + def is_superadmin(self, qq: int) -> bool: + return qq == self.superadmin_qq + + def is_admin(self, qq: int) -> bool: + return qq in self.admin_qqs + + +class ConfigBuilder: + """链式构建 Config;未设置的字段使用 build_config 默认值。""" + + def __init__(self) -> None: + self._overrides: dict[str, Any] = {} + + def with_mapping(self, data: dict[str, Any]) -> "ConfigBuilder": + self._overrides["_base_mapping"] = data + return self + + def override(self, **kwargs: Any) -> "ConfigBuilder": + self._overrides.update(kwargs) + return self + + # 从中间态构建最终对象 + def build(self, *, strict: bool = True) -> Config: + from .build_config import build_config + + base = self._overrides.pop("_base_mapping", {}) + if not isinstance(base, dict): + base = {} + data = dict(base) + # TOML-style nested overrides for simple top-level keys only + for key, value in self._overrides.items(): + data[key] = value + return build_config(data, strict=strict, config_path=None) diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py index 986c5cb1..c0de5fa0 100644 --- a/src/Undefined/config/domain_parsers.py +++ b/src/Undefined/config/domain_parsers.py @@ -191,7 +191,6 @@ def _parse_message_batcher_config(data: dict[str, Any]) -> MessageBatcherConfig: window_seconds = _coerce_float(section.get("window_seconds"), 5.0) if window_seconds < 0: window_seconds = 0.0 - # max_window_seconds <= 0 视为不限制(仅靠 window_seconds + max_messages_per_batch 触发) max_window_seconds = _coerce_float(section.get("max_window_seconds"), 30.0) if max_window_seconds < 0: max_window_seconds = 0.0 diff --git a/src/Undefined/config/env_registry.py b/src/Undefined/config/env_registry.py new file mode 100644 index 00000000..99d10ea5 --- /dev/null +++ b/src/Undefined/config/env_registry.py @@ -0,0 +1,192 @@ +"""TOML path to environment variable mapping for configuration.""" + +from __future__ import annotations + + +from typing import Final + +# TOML 路径 → 环境变量名;供 _get_value 在 TOML 缺省时回退,以及文档/工具生成 +ENV_REGISTRY: Final[dict[tuple[str, ...], str]] = { + ("access", "allowed_group_ids"): "ALLOWED_GROUP_IDS", + ("access", "allowed_private_ids"): "ALLOWED_PRIVATE_IDS", + ("access", "blocked_group_ids"): "BLOCKED_GROUP_IDS", + ("access", "blocked_private_ids"): "BLOCKED_PRIVATE_IDS", + ("access", "mode"): "ACCESS_MODE", + ("api_endpoints", "jkyai_base_url"): "JKYAI_BASE_URL", + ("api_endpoints", "xxapi_base_url"): "XXAPI_BASE_URL", + ("core", "admin_qq"): "ADMIN_QQ", + ("core", "bot_qq"): "BOT_QQ", + ("core", "forward_proxy_qq"): "FORWARD_PROXY_QQ", + ("core", "superadmin_qq"): "SUPERADMIN_QQ", + ("features", "pool_enabled"): "MODEL_POOL_ENABLED", + ("history", "max_records"): "HISTORY_MAX_RECORDS", + ("image_gen", "provider"): "IMAGE_GEN_PROVIDER", + ("logging", "backup_count"): "LOG_BACKUP_COUNT", + ("logging", "file_path"): "LOG_FILE_PATH", + ("logging", "level"): "LOG_LEVEL", + ("logging", "log_thinking"): "LOG_THINKING", + ("logging", "max_size_mb"): "LOG_MAX_SIZE_MB", + ("logging", "tty_enabled"): "LOG_TTY_ENABLED", + ("mcp", "config_path"): "MCP_CONFIG_PATH", + ("models", "agent", "api_key"): "AGENT_MODEL_API_KEY", + ("models", "agent", "api_mode"): "AGENT_MODEL_API_MODE", + ("models", "agent", "api_url"): "AGENT_MODEL_API_URL", + ("models", "agent", "context_window_tokens"): "AGENT_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "agent", "model_name"): "AGENT_MODEL_NAME", + ( + "models", + "agent", + "reasoning_content_replay", + ): "AGENT_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "agent", + "responses_force_stateless_replay", + ): "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "agent", + "responses_tool_choice_compat", + ): "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "agent", "system_prompt_as_user"): "AGENT_MODEL_SYSTEM_PROMPT_AS_USER", + ("models", "chat", "api_key"): "CHAT_MODEL_API_KEY", + ("models", "chat", "api_mode"): "CHAT_MODEL_API_MODE", + ("models", "chat", "api_url"): "CHAT_MODEL_API_URL", + ("models", "chat", "context_window_tokens"): "CHAT_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "chat", "max_tokens"): "CHAT_MODEL_MAX_TOKENS", + ("models", "chat", "model_name"): "CHAT_MODEL_NAME", + ( + "models", + "chat", + "reasoning_content_replay", + ): "CHAT_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "chat", + "responses_force_stateless_replay", + ): "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "chat", + "responses_tool_choice_compat", + ): "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "chat", "system_prompt_as_user"): "CHAT_MODEL_SYSTEM_PROMPT_AS_USER", + ( + "models", + "embedding", + "context_window_tokens", + ): "EMBEDDING_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "grok", "api_key"): "GROK_MODEL_API_KEY", + ("models", "grok", "api_url"): "GROK_MODEL_API_URL", + ("models", "grok", "context_window_tokens"): "GROK_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "grok", "max_tokens"): "GROK_MODEL_MAX_TOKENS", + ("models", "grok", "model_name"): "GROK_MODEL_NAME", + ("models", "naga", "api_key"): "NAGA_MODEL_API_KEY", + ("models", "naga", "api_mode"): "NAGA_MODEL_API_MODE", + ("models", "naga", "api_url"): "NAGA_MODEL_API_URL", + ("models", "naga", "context_window_tokens"): "NAGA_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "naga", "model_name"): "NAGA_MODEL_NAME", + ( + "models", + "naga", + "reasoning_content_replay", + ): "NAGA_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "naga", + "responses_force_stateless_replay", + ): "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "naga", + "responses_tool_choice_compat", + ): "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "naga", "system_prompt_as_user"): "NAGA_MODEL_SYSTEM_PROMPT_AS_USER", + ("models", "rerank", "api_key"): "RERANK_MODEL_API_KEY", + ("models", "rerank", "api_url"): "RERANK_MODEL_API_URL", + ("models", "rerank", "context_window_tokens"): "RERANK_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "rerank", "model_name"): "RERANK_MODEL_NAME", + ("models", "security", "api_key"): "SECURITY_MODEL_API_KEY", + ("models", "security", "api_mode"): "SECURITY_MODEL_API_MODE", + ("models", "security", "api_url"): "SECURITY_MODEL_API_URL", + ( + "models", + "security", + "context_window_tokens", + ): "SECURITY_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "security", "model_name"): "SECURITY_MODEL_NAME", + ( + "models", + "security", + "reasoning_content_replay", + ): "SECURITY_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "security", + "responses_force_stateless_replay", + ): "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "security", + "responses_tool_choice_compat", + ): "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ( + "models", + "security", + "system_prompt_as_user", + ): "SECURITY_MODEL_SYSTEM_PROMPT_AS_USER", + ("models", "vision", "api_key"): "VISION_MODEL_API_KEY", + ("models", "vision", "api_mode"): "VISION_MODEL_API_MODE", + ("models", "vision", "api_url"): "VISION_MODEL_API_URL", + ("models", "vision", "context_window_tokens"): "VISION_MODEL_CONTEXT_WINDOW_TOKENS", + ("models", "vision", "model_name"): "VISION_MODEL_NAME", + ( + "models", + "vision", + "reasoning_content_replay", + ): "VISION_MODEL_REASONING_CONTENT_REPLAY", + ( + "models", + "vision", + "responses_force_stateless_replay", + ): "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + ( + "models", + "vision", + "responses_tool_choice_compat", + ): "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + ("models", "vision", "system_prompt_as_user"): "VISION_MODEL_SYSTEM_PROMPT_AS_USER", + ("onebot", "token"): "ONEBOT_TOKEN", + ("onebot", "ws_url"): "ONEBOT_WS_URL", + ("proxy", "use_proxy"): "USE_PROXY", + ("search", "searxng_url"): "SEARXNG_URL", + ("skills", "hot_reload"): "SKILLS_HOT_RELOAD", + ("skills", "intro_hash_path"): "AGENT_INTRO_HASH_PATH", + ("skills", "prefetch_tools_hide"): "PREFETCH_TOOLS_HIDE", + ("token_usage", "max_archives"): "TOKEN_USAGE_MAX_ARCHIVES", + ("token_usage", "max_size_mb"): "TOKEN_USAGE_MAX_SIZE_MB", + ("token_usage", "max_total_mb"): "TOKEN_USAGE_MAX_TOTAL_MB", + ("tools", "description_max_len"): "TOOLS_DESCRIPTION_MAX_LEN", + ("tools", "dot_delimiter"): "TOOLS_DOT_DELIMITER", + ("tools", "sanitize_verbose"): "TOOLS_SANITIZE_VERBOSE", + ("weather", "api_key"): "WEATHER_API_KEY", + ("xxapi", "api_token"): "XXAPI_API_TOKEN", +} + +# 历史/别名环境变量:不经过 _get_value 统一路径,单独在 domain_parsers 等处读取 +ENV_ALTERNATES: Final[dict[str, tuple[str, ...]]] = { + "EASTER_EGG_AGENT_CALL_MESSAGE_MODE": ("easter_egg", "agent_call_message_enabled"), + "EASTER_EGG_CALL_MESSAGE_MODE": ("easter_egg", "agent_call_message_enabled"), + "HTTP_PROXY": ("proxy", "http_proxy"), + "HTTPS_PROXY": ("proxy", "https_proxy"), +} + + +def env_key_for_path(path: tuple[str, ...]) -> str | None: + """Return primary env var for a TOML path, if registered.""" + return ENV_REGISTRY.get(path) + + +def all_env_mappings() -> dict[tuple[str, ...], str]: + """Return a copy of the primary env registry.""" + return dict(ENV_REGISTRY) diff --git a/src/Undefined/config/load_sections/__init__.py b/src/Undefined/config/load_sections/__init__.py new file mode 100644 index 00000000..c195d38c --- /dev/null +++ b/src/Undefined/config/load_sections/__init__.py @@ -0,0 +1,26 @@ +"""Config load section parsers.""" + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict +from .access import load_access +from .core import load_core +from .domains import load_domains +from .finalize import load_finalize +from .history_skills import load_history_skills +from .integrations import load_integrations +from .knowledge import load_knowledge +from .logging_tools import load_logging_tools +from .models import load_models +from .network import load_network + +__all__ = [ + "load_access", + "load_core", + "load_domains", + "load_finalize", + "load_history_skills", + "load_integrations", + "load_knowledge", + "load_logging_tools", + "load_models", + "load_network", +] diff --git a/src/Undefined/config/load_sections/access.py b/src/Undefined/config/load_sections/access.py new file mode 100644 index 00000000..e3ea89c3 --- /dev/null +++ b/src/Undefined/config/load_sections/access.py @@ -0,0 +1,86 @@ +"""Load access config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_int_list, + _coerce_str, + _get_value, +) + +logger = logging.getLogger(__name__) + + +def load_access( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + access_mode_raw = _get_value(data, ("access", "mode"), "ACCESS_MODE") + allowed_group_ids = _coerce_int_list( + _get_value(data, ("access", "allowed_group_ids"), "ALLOWED_GROUP_IDS") + ) + blocked_group_ids = _coerce_int_list( + _get_value(data, ("access", "blocked_group_ids"), "BLOCKED_GROUP_IDS") + ) + allowed_private_ids = _coerce_int_list( + _get_value(data, ("access", "allowed_private_ids"), "ALLOWED_PRIVATE_IDS") + ) + blocked_private_ids = _coerce_int_list( + _get_value(data, ("access", "blocked_private_ids"), "BLOCKED_PRIVATE_IDS") + ) + superadmin_bypass_allowlist = _coerce_bool( + _get_value( + data, + ("access", "superadmin_bypass_allowlist"), + "SUPERADMIN_BYPASS_ALLOWLIST", + ), + True, + ) + superadmin_bypass_private_blacklist = _coerce_bool( + _get_value( + data, + ("access", "superadmin_bypass_private_blacklist"), + "SUPERADMIN_BYPASS_PRIVATE_BLACKLIST", + ), + False, + ) + if access_mode_raw is None: + # 兼容旧配置:未配置 mode 时沿用历史行为(群黑名单 + 白名单联动)。 + if ( + allowed_group_ids + or blocked_group_ids + or allowed_private_ids + or blocked_private_ids + ): + access_mode = "legacy" + logger.warning( + "[配置] access.mode 未设置,已启用兼容模式(legacy)。建议显式设置为 off/blacklist/allowlist。" + ) + # 否则分支 + else: + access_mode = "off" + # 否则分支 + else: + access_mode = _coerce_str(access_mode_raw, "off").lower() + if access_mode not in {"off", "blacklist", "allowlist"}: + logger.warning( + "[配置] access.mode 非法(仅支持 off/blacklist/allowlist),已回退为 off: %s", + access_mode, + ) + access_mode = "off" + + return { + "allowed_group_ids": allowed_group_ids, + "blocked_group_ids": blocked_group_ids, + "allowed_private_ids": allowed_private_ids, + "blocked_private_ids": blocked_private_ids, + "superadmin_bypass_allowlist": superadmin_bypass_allowlist, + "superadmin_bypass_private_blacklist": superadmin_bypass_private_blacklist, + "access_mode": access_mode, + } diff --git a/src/Undefined/config/load_sections/core.py b/src/Undefined/config/load_sections/core.py new file mode 100644 index 00000000..28a8f832 --- /dev/null +++ b/src/Undefined/config/load_sections/core.py @@ -0,0 +1,183 @@ +"""Load core config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_int, + _coerce_int_list, + _coerce_str, + _get_value, +) + +logger = logging.getLogger(__name__) + + +# 加载 [core] table:机器人身份、消息开关、彩蛋、OneBot 连接 +def load_core( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + # [core] 机器人 QQ 与管理员 + bot_qq = _coerce_int(_get_value(data, ("core", "bot_qq"), "BOT_QQ"), 0) + superadmin_qq = _coerce_int( + _get_value(data, ("core", "superadmin_qq"), "SUPERADMIN_QQ"), 0 + ) + admin_qqs = _coerce_int_list(_get_value(data, ("core", "admin_qq"), "ADMIN_QQ")) + forward_proxy = _coerce_int( + _get_value(data, ("core", "forward_proxy_qq"), "FORWARD_PROXY_QQ"), + 0, + ) + forward_proxy_qq = forward_proxy if forward_proxy > 0 else None + # [core] 群聊/私聊/拍一拍处理开关 + process_every_message = _coerce_bool( + _get_value( + data, + ("core", "process_every_message"), + "PROCESS_EVERY_MESSAGE", + ), + True, + ) + process_private_message = _coerce_bool( + _get_value( + data, + ("core", "process_private_message"), + "PROCESS_PRIVATE_MESSAGE", + ), + True, + ) + process_poke_message = _coerce_bool( + _get_value( + data, + ("core", "process_poke_message"), + "PROCESS_POKE_MESSAGE", + ), + True, + ) + # [easter_egg] 关键词回复与复读彩蛋 + keyword_reply_raw = _get_value( + data, + ("easter_egg", "keyword_reply_enabled"), + "KEYWORD_REPLY_ENABLED", + ) + if keyword_reply_raw is None: + # 兼容旧配置:历史上放在 [core].keyword_reply_enabled + keyword_reply_raw = _get_value( + data, + ("core", "keyword_reply_enabled"), + None, + ) + keyword_reply_enabled = _coerce_bool(keyword_reply_raw, False) + repeat_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "repeat_enabled"), + "EASTER_EGG_REPEAT_ENABLED", + ), + False, + ) + inverted_question_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "inverted_question_enabled"), + "EASTER_EGG_INVERTED_QUESTION_ENABLED", + ), + False, + ) + repeat_threshold = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_threshold"), + "EASTER_EGG_REPEAT_THRESHOLD", + ), + 3, + ) + if repeat_threshold < 2: + repeat_threshold = 2 + if repeat_threshold > 20: + repeat_threshold = 20 + repeat_cooldown_minutes = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_cooldown_minutes"), + "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", + ), + 60, + ) + if repeat_cooldown_minutes < 0: + repeat_cooldown_minutes = 0 + context_recent_messages_limit = _coerce_int( + _get_value( + data, + ("core", "context_recent_messages_limit"), + "CONTEXT_RECENT_MESSAGES_LIMIT", + ), + 20, + ) + if context_recent_messages_limit < 0: + context_recent_messages_limit = 0 + + # [core] AI 请求与 tool_call 重试上限 + ai_request_max_retries = _coerce_int( + _get_value( + data, + ("core", "ai_request_max_retries"), + "AI_REQUEST_MAX_RETRIES", + ), + 2, + ) + if ai_request_max_retries < 0: + ai_request_max_retries = 0 + + missing_tool_call_retries = _coerce_int( + _get_value( + data, + ("core", "missing_tool_call_retries"), + "MISSING_TOOL_CALL_RETRIES", + ), + 3, + ) + if missing_tool_call_retries < 0: + missing_tool_call_retries = 0 + + nagaagent_mode_enabled = _coerce_bool( + _get_value( + data, + ("features", "nagaagent_mode_enabled"), + "NAGAAGENT_MODE_ENABLED", + ), + False, + ) + # [onebot] WebSocket 连接 + onebot_ws_url = _coerce_str( + _get_value(data, ("onebot", "ws_url"), "ONEBOT_WS_URL"), "" + ) + onebot_token = _coerce_str( + _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" + ) + + return { + "bot_qq": bot_qq, + "superadmin_qq": superadmin_qq, + "admin_qqs": admin_qqs, + "forward_proxy_qq": forward_proxy_qq, + "process_every_message": process_every_message, + "process_private_message": process_private_message, + "process_poke_message": process_poke_message, + "keyword_reply_enabled": keyword_reply_enabled, + "repeat_enabled": repeat_enabled, + "inverted_question_enabled": inverted_question_enabled, + "repeat_threshold": repeat_threshold, + "repeat_cooldown_minutes": repeat_cooldown_minutes, + "context_recent_messages_limit": context_recent_messages_limit, + "ai_request_max_retries": ai_request_max_retries, + "missing_tool_call_retries": missing_tool_call_retries, + "nagaagent_mode_enabled": nagaagent_mode_enabled, + "onebot_ws_url": onebot_ws_url, + "onebot_token": onebot_token, + } diff --git a/src/Undefined/config/load_sections/domains.py b/src/Undefined/config/load_sections/domains.py new file mode 100644 index 00000000..6ada6453 --- /dev/null +++ b/src/Undefined/config/load_sections/domains.py @@ -0,0 +1,58 @@ +"""Load domains config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..domain_parsers import ( + _parse_api_config, + _parse_cognitive_config, + _parse_memes_config, + _parse_message_batcher_config, + _parse_naga_config, + _parse_render_cache_config, +) +from ..model_parsers import ( + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, +) +from ..webui_settings import load_webui_settings + +logger = logging.getLogger(__name__) + + +def load_domains( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + # 子域配置:WebUI/API/认知/表情包/合并器/Naga/生图等,与 core/models 段解耦 + webui_settings = load_webui_settings(config_path) + api_config = _parse_api_config(data) + + cognitive = _parse_cognitive_config(data) + memes = _parse_memes_config(data) + message_batcher = _parse_message_batcher_config(data) + render_cache = _parse_render_cache_config(data) + naga = _parse_naga_config(data) + models_image_gen = _parse_image_gen_model_config(data) + models_image_edit = _parse_image_edit_model_config(data) + image_gen = _parse_image_gen_config(data) + + return { + "webui_url": webui_settings.url, + "webui_port": webui_settings.port, + "webui_password": webui_settings.password, + "api": api_config, + "cognitive": cognitive, + "memes": memes, + "message_batcher": message_batcher, + "render_cache": render_cache, + "naga": naga, + "image_gen": image_gen, + "models_image_gen": models_image_gen, + "models_image_edit": models_image_edit, + } diff --git a/src/Undefined/config/load_sections/finalize.py b/src/Undefined/config/load_sections/finalize.py new file mode 100644 index 00000000..d6da9a07 --- /dev/null +++ b/src/Undefined/config/load_sections/finalize.py @@ -0,0 +1,36 @@ +"""Finalize: validation and debug logging.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +from typing import Any + +from ..model_parsers import _log_debug_info, _verify_required_fields + + +def load_finalize(ctx: dict[str, Any], *, strict: bool = True) -> None: + # strict=True(首次启动)校验必填;热重载传 strict=False 跳过以免半写文件误杀 + if strict: + _verify_required_fields( + bot_qq=ctx["bot_qq"], + superadmin_qq=ctx["superadmin_qq"], + onebot_ws_url=ctx["onebot_ws_url"], + chat_model=ctx["chat_model"], + vision_model=ctx["vision_model"], + agent_model=ctx["agent_model"], + knowledge_enabled=ctx["knowledge_enabled"], + embedding_model=ctx["embedding_model"], + ) + + debug_keys = ( + "chat_model", + "vision_model", + "security_model", + "naga_model", + "agent_model", + "summary_model", + "grok_model", + ) + if all(key in ctx for key in debug_keys): + _log_debug_info(*(ctx[key] for key in debug_keys)) diff --git a/src/Undefined/config/load_sections/history_skills.py b/src/Undefined/config/load_sections/history_skills.py new file mode 100644 index 00000000..615580e0 --- /dev/null +++ b/src/Undefined/config/load_sections/history_skills.py @@ -0,0 +1,277 @@ +"""Load history_skills config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _coerce_str_list, + _get_value, + _normalize_queue_interval, +) + +logger = logging.getLogger(__name__) + + +def load_history_skills( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + token_usage_max_size_mb = _coerce_int( + _get_value(data, ("token_usage", "max_size_mb"), "TOKEN_USAGE_MAX_SIZE_MB"), + 5, + ) + token_usage_max_archives = _coerce_int( + _get_value(data, ("token_usage", "max_archives"), "TOKEN_USAGE_MAX_ARCHIVES"), + 30, + ) + token_usage_max_total_mb = _coerce_int( + _get_value(data, ("token_usage", "max_total_mb"), "TOKEN_USAGE_MAX_TOTAL_MB"), + 0, + ) + token_usage_archive_prune_mode = _coerce_str( + _get_value( + data, + ("token_usage", "archive_prune_mode"), + "TOKEN_USAGE_ARCHIVE_PRUNE_MODE", + ), + "delete", + ) + + history_max_records = max( + 0, + _coerce_int( + _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), + 10000, + ), + ) + history_filtered_result_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "filtered_result_limit"), + "HISTORY_FILTERED_RESULT_LIMIT", + ), + 200, + ), + ) + history_search_scan_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "search_scan_limit"), + "HISTORY_SEARCH_SCAN_LIMIT", + ), + 10000, + ), + ) + history_summary_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_fetch_limit"), + "HISTORY_SUMMARY_FETCH_LIMIT", + ), + 1000, + ), + ) + history_summary_time_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_time_fetch_limit"), + "HISTORY_SUMMARY_TIME_FETCH_LIMIT", + ), + 5000, + ), + ) + history_onebot_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "onebot_fetch_limit"), + "HISTORY_ONEBOT_FETCH_LIMIT", + ), + 10000, + ), + ) + history_group_analysis_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "group_analysis_limit"), + "HISTORY_GROUP_ANALYSIS_LIMIT", + ), + 500, + ), + ) + attachment_remote_download_max_size_mb = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "remote_download_max_size_mb"), + "ATTACHMENTS_REMOTE_DOWNLOAD_MAX_SIZE_MB", + ), + 25, + ), + ) + attachment_cache_max_total_size_mb = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "cache_max_total_size_mb"), + "ATTACHMENTS_CACHE_MAX_TOTAL_SIZE_MB", + ), + 0, + ), + ) + attachment_cache_max_records = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "cache_max_records"), + "ATTACHMENTS_CACHE_MAX_RECORDS", + ), + 2000, + ), + ) + attachment_cache_max_age_days = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "cache_max_age_days"), + "ATTACHMENTS_CACHE_MAX_AGE_DAYS", + ), + 7, + ), + ) + attachment_url_reference_max_records = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "url_reference_max_records"), + "ATTACHMENTS_URL_REFERENCE_MAX_RECORDS", + ), + 2000, + ), + ) + attachment_url_max_length = max( + 0, + _coerce_int( + _get_value( + data, + ("attachments", "url_max_length"), + "ATTACHMENTS_URL_MAX_LENGTH", + ), + 8192, + ), + ) + + skills_hot_reload = _coerce_bool( + _get_value(data, ("skills", "hot_reload"), "SKILLS_HOT_RELOAD"), True + ) + # interval/debounce 同时驱动 skills 目录扫描与 config.toml 热重载 watcher + skills_hot_reload_interval = _coerce_float( + _get_value( + data, ("skills", "hot_reload_interval"), "SKILLS_HOT_RELOAD_INTERVAL" + ), + 2.0, + ) + skills_hot_reload_interval = _normalize_queue_interval(skills_hot_reload_interval) + skills_hot_reload_debounce = _coerce_float( + _get_value( + data, ("skills", "hot_reload_debounce"), "SKILLS_HOT_RELOAD_DEBOUNCE" + ), + 0.5, + ) + skills_hot_reload_debounce = _normalize_queue_interval(skills_hot_reload_debounce) + + agent_intro_autogen_enabled = _coerce_bool( + _get_value( + data, + ("skills", "intro_autogen_enabled"), + "AGENT_INTRO_AUTOGEN_ENABLED", + ), + True, + ) + agent_intro_autogen_queue_interval = _coerce_float( + _get_value( + data, + ("skills", "intro_autogen_queue_interval"), + "AGENT_INTRO_AUTOGEN_QUEUE_INTERVAL", + ), + 1.0, + ) + agent_intro_autogen_queue_interval = _normalize_queue_interval( + agent_intro_autogen_queue_interval + ) + agent_intro_autogen_max_tokens = _coerce_int( + _get_value( + data, + ("skills", "intro_autogen_max_tokens"), + "AGENT_INTRO_AUTOGEN_MAX_TOKENS", + ), + 8192, + ) + agent_intro_hash_path = _coerce_str( + _get_value(data, ("skills", "intro_hash_path"), "AGENT_INTRO_HASH_PATH"), + ".cache/agent_intro_hashes.json", + ) + + prefetch_tools_raw = _get_value( + data, ("skills", "prefetch_tools"), "PREFETCH_TOOLS" + ) + prefetch_tools = _coerce_str_list(prefetch_tools_raw) + if not prefetch_tools and prefetch_tools_raw is None: + prefetch_tools = ["get_current_time"] + prefetch_tools_hide = _coerce_bool( + _get_value(data, ("skills", "prefetch_tools_hide"), "PREFETCH_TOOLS_HIDE"), + True, + ) + + return { + "token_usage_max_size_mb": token_usage_max_size_mb, + "token_usage_max_archives": token_usage_max_archives, + "token_usage_max_total_mb": token_usage_max_total_mb, + "token_usage_archive_prune_mode": token_usage_archive_prune_mode, + "history_max_records": history_max_records, + "history_filtered_result_limit": history_filtered_result_limit, + "history_search_scan_limit": history_search_scan_limit, + "history_summary_fetch_limit": history_summary_fetch_limit, + "history_summary_time_fetch_limit": history_summary_time_fetch_limit, + "history_onebot_fetch_limit": history_onebot_fetch_limit, + "history_group_analysis_limit": history_group_analysis_limit, + "attachment_remote_download_max_size_mb": attachment_remote_download_max_size_mb, + "attachment_cache_max_total_size_mb": attachment_cache_max_total_size_mb, + "attachment_cache_max_records": attachment_cache_max_records, + "attachment_cache_max_age_days": attachment_cache_max_age_days, + "attachment_url_reference_max_records": attachment_url_reference_max_records, + "attachment_url_max_length": attachment_url_max_length, + "skills_hot_reload": skills_hot_reload, + "skills_hot_reload_interval": skills_hot_reload_interval, + "skills_hot_reload_debounce": skills_hot_reload_debounce, + "agent_intro_autogen_enabled": agent_intro_autogen_enabled, + "agent_intro_autogen_queue_interval": agent_intro_autogen_queue_interval, + "agent_intro_autogen_max_tokens": agent_intro_autogen_max_tokens, + "agent_intro_hash_path": agent_intro_hash_path, + "prefetch_tools": prefetch_tools, + "prefetch_tools_hide": prefetch_tools_hide, + } diff --git a/src/Undefined/config/load_sections/integrations.py b/src/Undefined/config/load_sections/integrations.py new file mode 100644 index 00000000..ebbafdf7 --- /dev/null +++ b/src/Undefined/config/load_sections/integrations.py @@ -0,0 +1,275 @@ +"""Load integrations config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_int_list, + _coerce_str, + _get_value, +) + +logger = logging.getLogger(__name__) + + +def load_integrations( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + bilibili_auto_extract_enabled = _coerce_bool( + _get_value(data, ("bilibili", "auto_extract_enabled"), None), False + ) + # 功能级白名单为空时,运行时回退到全局 access 控制(见 Config.is_*_allowed) + bilibili_cookie = _coerce_str(_get_value(data, ("bilibili", "cookie"), None), "") + if not bilibili_cookie: + # 兼容旧配置项:bilibili.sessdata + bilibili_cookie = _coerce_str( + _get_value(data, ("bilibili", "sessdata"), None), "" + ) + bilibili_prefer_quality = _coerce_int( + _get_value(data, ("bilibili", "prefer_quality"), None), 80 + ) + bilibili_max_duration = _coerce_int( + _get_value(data, ("bilibili", "max_duration"), None), 600 + ) + bilibili_max_file_size = _coerce_int( + _get_value(data, ("bilibili", "max_file_size"), None), 100 + ) + bilibili_oversize_strategy = _coerce_str( + _get_value(data, ("bilibili", "oversize_strategy"), None), "downgrade" + ) + if bilibili_oversize_strategy not in ("downgrade", "info"): + bilibili_oversize_strategy = "downgrade" + bilibili_danmaku_enabled = _coerce_bool( + _get_value(data, ("bilibili", "danmaku_enabled"), None), True + ) + bilibili_danmaku_batch_size = _coerce_int( + _get_value(data, ("bilibili", "danmaku_batch_size"), None), 100 + ) + if bilibili_danmaku_batch_size <= 0: + bilibili_danmaku_batch_size = 100 + bilibili_danmaku_max_count = _coerce_int( + _get_value(data, ("bilibili", "danmaku_max_count"), None), 0 + ) + if bilibili_danmaku_max_count < 0: + bilibili_danmaku_max_count = 0 + bilibili_auto_extract_group_ids = _coerce_int_list( + _get_value(data, ("bilibili", "auto_extract_group_ids"), None) + ) + bilibili_auto_extract_private_ids = _coerce_int_list( + _get_value(data, ("bilibili", "auto_extract_private_ids"), None) + ) + + # arXiv 配置 + arxiv_auto_extract_enabled = _coerce_bool( + _get_value(data, ("arxiv", "auto_extract_enabled"), None), False + ) + arxiv_max_file_size = _coerce_int( + _get_value(data, ("arxiv", "max_file_size"), None), 100 + ) + if arxiv_max_file_size < 0: + arxiv_max_file_size = 100 + arxiv_auto_extract_group_ids = _coerce_int_list( + _get_value(data, ("arxiv", "auto_extract_group_ids"), None) + ) + arxiv_auto_extract_private_ids = _coerce_int_list( + _get_value(data, ("arxiv", "auto_extract_private_ids"), None) + ) + arxiv_auto_extract_max_items = _coerce_int( + _get_value(data, ("arxiv", "auto_extract_max_items"), None), 5 + ) + if arxiv_auto_extract_max_items <= 0: + arxiv_auto_extract_max_items = 5 + if arxiv_auto_extract_max_items > 20: + arxiv_auto_extract_max_items = 20 + arxiv_author_preview_limit = _coerce_int( + _get_value(data, ("arxiv", "author_preview_limit"), None), 20 + ) + if arxiv_author_preview_limit <= 0: + arxiv_author_preview_limit = 20 + if arxiv_author_preview_limit > 100: + arxiv_author_preview_limit = 100 + arxiv_summary_preview_chars = _coerce_int( + _get_value(data, ("arxiv", "summary_preview_chars"), None), 1000 + ) + if arxiv_summary_preview_chars <= 0: + arxiv_summary_preview_chars = 1000 + if arxiv_summary_preview_chars > 8000: + arxiv_summary_preview_chars = 8000 + + # GitHub 配置 + github_auto_extract_enabled = _coerce_bool( + _get_value(data, ("github", "auto_extract_enabled"), None), False + ) + github_request_timeout_seconds = _coerce_float( + _get_value(data, ("github", "request_timeout_seconds"), None), 10.0 + ) + if github_request_timeout_seconds <= 0: + github_request_timeout_seconds = 10.0 + if github_request_timeout_seconds > 60.0: + github_request_timeout_seconds = 60.0 + github_auto_extract_group_ids = _coerce_int_list( + _get_value(data, ("github", "auto_extract_group_ids"), None) + ) + github_auto_extract_private_ids = _coerce_int_list( + _get_value(data, ("github", "auto_extract_private_ids"), None) + ) + github_auto_extract_max_items = _coerce_int( + _get_value(data, ("github", "auto_extract_max_items"), None), 3 + ) + if github_auto_extract_max_items <= 0: + github_auto_extract_max_items = 3 + if github_auto_extract_max_items > 10: + github_auto_extract_max_items = 10 + + # Code Delivery Agent 配置 + code_delivery_enabled = _coerce_bool( + _get_value(data, ("code_delivery", "enabled"), None), True + ) + code_delivery_task_root = _coerce_str( + _get_value(data, ("code_delivery", "task_root"), None), + "data/code_delivery", + ) + code_delivery_docker_image = _coerce_str( + _get_value(data, ("code_delivery", "docker_image"), None), + "ubuntu:24.04", + ) + code_delivery_container_name_prefix = _coerce_str( + _get_value(data, ("code_delivery", "container_name_prefix"), None), + "code_delivery_", + ) + code_delivery_container_name_suffix = _coerce_str( + _get_value(data, ("code_delivery", "container_name_suffix"), None), + "_runner", + ) + code_delivery_command_timeout = _coerce_int( + _get_value(data, ("code_delivery", "default_command_timeout_seconds"), None), + 600, + ) + if code_delivery_command_timeout < 1: + code_delivery_command_timeout = 600 + code_delivery_max_command_output = _coerce_int( + _get_value(data, ("code_delivery", "max_command_output_chars"), None), + 20000, + ) + if code_delivery_max_command_output < 1: + code_delivery_max_command_output = 20000 + code_delivery_default_archive_format = _coerce_str( + _get_value(data, ("code_delivery", "default_archive_format"), None), + "zip", + ) + if code_delivery_default_archive_format not in ("zip", "tar.gz"): + code_delivery_default_archive_format = "zip" + code_delivery_max_archive_size_mb = _coerce_int( + _get_value(data, ("code_delivery", "max_archive_size_mb"), None), 200 + ) + if code_delivery_max_archive_size_mb < 1: + code_delivery_max_archive_size_mb = 200 + code_delivery_cleanup_on_finish = _coerce_bool( + _get_value(data, ("code_delivery", "cleanup_on_finish"), None), True + ) + code_delivery_cleanup_on_start = _coerce_bool( + _get_value(data, ("code_delivery", "cleanup_on_start"), None), True + ) + code_delivery_llm_max_retries = _coerce_int( + _get_value(data, ("code_delivery", "llm_max_retries_per_request"), None), + 5, + ) + if code_delivery_llm_max_retries < 0: + code_delivery_llm_max_retries = 5 + code_delivery_notify_on_llm_failure = _coerce_bool( + _get_value(data, ("code_delivery", "notify_on_llm_failure"), None), + True, + ) + code_delivery_container_memory_limit = _coerce_str( + _get_value(data, ("code_delivery", "container_memory_limit"), None), + "", + ) + code_delivery_container_cpu_limit = _coerce_str( + _get_value(data, ("code_delivery", "container_cpu_limit"), None), + "", + ) + code_delivery_command_blacklist_raw = _get_value( + data, ("code_delivery", "command_blacklist"), None + ) + if isinstance(code_delivery_command_blacklist_raw, list): + code_delivery_command_blacklist = [ + str(x) for x in code_delivery_command_blacklist_raw + ] + # 否则分支 + else: + code_delivery_command_blacklist = [] + + # messages 工具集配置 + messages_send_text_file_max_size_kb = _coerce_int( + _get_value( + data, + ("messages", "send_text_file_max_size_kb"), + "MESSAGES_SEND_TEXT_FILE_MAX_SIZE_KB", + ), + 512, + ) + if messages_send_text_file_max_size_kb <= 0: + messages_send_text_file_max_size_kb = 512 + + messages_send_url_file_max_size_mb = _coerce_int( + _get_value( + data, + ("messages", "send_url_file_max_size_mb"), + "MESSAGES_SEND_URL_FILE_MAX_SIZE_MB", + ), + 100, + ) + if messages_send_url_file_max_size_mb <= 0: + messages_send_url_file_max_size_mb = 100 + + return { + "bilibili_auto_extract_enabled": bilibili_auto_extract_enabled, + "bilibili_cookie": bilibili_cookie, + "bilibili_prefer_quality": bilibili_prefer_quality, + "bilibili_max_duration": bilibili_max_duration, + "bilibili_max_file_size": bilibili_max_file_size, + "bilibili_oversize_strategy": bilibili_oversize_strategy, + "bilibili_danmaku_enabled": bilibili_danmaku_enabled, + "bilibili_danmaku_batch_size": bilibili_danmaku_batch_size, + "bilibili_danmaku_max_count": bilibili_danmaku_max_count, + "bilibili_auto_extract_group_ids": bilibili_auto_extract_group_ids, + "bilibili_auto_extract_private_ids": bilibili_auto_extract_private_ids, + "arxiv_auto_extract_enabled": arxiv_auto_extract_enabled, + "arxiv_max_file_size": arxiv_max_file_size, + "arxiv_auto_extract_group_ids": arxiv_auto_extract_group_ids, + "arxiv_auto_extract_private_ids": arxiv_auto_extract_private_ids, + "arxiv_auto_extract_max_items": arxiv_auto_extract_max_items, + "arxiv_author_preview_limit": arxiv_author_preview_limit, + "arxiv_summary_preview_chars": arxiv_summary_preview_chars, + "github_auto_extract_enabled": github_auto_extract_enabled, + "github_request_timeout_seconds": github_request_timeout_seconds, + "github_auto_extract_group_ids": github_auto_extract_group_ids, + "github_auto_extract_private_ids": github_auto_extract_private_ids, + "github_auto_extract_max_items": github_auto_extract_max_items, + "code_delivery_enabled": code_delivery_enabled, + "code_delivery_task_root": code_delivery_task_root, + "code_delivery_docker_image": code_delivery_docker_image, + "code_delivery_container_name_prefix": code_delivery_container_name_prefix, + "code_delivery_container_name_suffix": code_delivery_container_name_suffix, + "code_delivery_command_timeout": code_delivery_command_timeout, + "code_delivery_max_command_output": code_delivery_max_command_output, + "code_delivery_default_archive_format": code_delivery_default_archive_format, + "code_delivery_max_archive_size_mb": code_delivery_max_archive_size_mb, + "code_delivery_cleanup_on_finish": code_delivery_cleanup_on_finish, + "code_delivery_cleanup_on_start": code_delivery_cleanup_on_start, + "code_delivery_llm_max_retries": code_delivery_llm_max_retries, + "code_delivery_notify_on_llm_failure": code_delivery_notify_on_llm_failure, + "code_delivery_container_memory_limit": code_delivery_container_memory_limit, + "code_delivery_container_cpu_limit": code_delivery_container_cpu_limit, + "code_delivery_command_blacklist": code_delivery_command_blacklist, + "messages_send_text_file_max_size_kb": messages_send_text_file_max_size_kb, + "messages_send_url_file_max_size_mb": messages_send_url_file_max_size_mb, + } diff --git a/src/Undefined/config/load_sections/knowledge.py b/src/Undefined/config/load_sections/knowledge.py new file mode 100644 index 00000000..266976f2 --- /dev/null +++ b/src/Undefined/config/load_sections/knowledge.py @@ -0,0 +1,122 @@ +"""Load knowledge config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, +) +from ..model_parsers import ( + _parse_embedding_model_config, + _parse_rerank_model_config, +) + +logger = logging.getLogger(__name__) + + +def load_knowledge( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + # 知识库段多数项仅读 TOML(env_key=None),避免与 embedding 模型 env 混淆 + embedding_model = _parse_embedding_model_config(data) + rerank_model = _parse_rerank_model_config(data) + + knowledge_enabled = _coerce_bool( + _get_value(data, ("knowledge", "enabled"), None), False + ) + knowledge_base_dir = _coerce_str( + _get_value(data, ("knowledge", "base_dir"), None), "knowledge" + ) + knowledge_auto_scan = _coerce_bool( + _get_value(data, ("knowledge", "auto_scan"), None), False + ) + knowledge_auto_embed = _coerce_bool( + _get_value(data, ("knowledge", "auto_embed"), None), False + ) + knowledge_scan_interval = _coerce_float( + _get_value(data, ("knowledge", "scan_interval"), None), 60.0 + ) + if knowledge_scan_interval <= 0: + knowledge_scan_interval = 60.0 + knowledge_embed_batch_size = _coerce_int( + _get_value(data, ("knowledge", "embed_batch_size"), None), 64 + ) + if knowledge_embed_batch_size <= 0: + knowledge_embed_batch_size = 64 + knowledge_chunk_size = _coerce_int( + _get_value(data, ("knowledge", "chunk_size"), None), 10 + ) + if knowledge_chunk_size <= 0: + knowledge_chunk_size = 10 + knowledge_chunk_overlap = _coerce_int( + _get_value(data, ("knowledge", "chunk_overlap"), None), 2 + ) + if knowledge_chunk_overlap < 0: + knowledge_chunk_overlap = 0 + if knowledge_chunk_overlap >= knowledge_chunk_size: + knowledge_chunk_overlap = max(0, knowledge_chunk_size - 1) + knowledge_default_top_k = _coerce_int( + _get_value(data, ("knowledge", "default_top_k"), None), 5 + ) + if knowledge_default_top_k <= 0: + knowledge_default_top_k = 5 + knowledge_enable_rerank = _coerce_bool( + _get_value(data, ("knowledge", "enable_rerank"), None), False + ) + knowledge_rerank_top_k = _coerce_int( + _get_value(data, ("knowledge", "rerank_top_k"), None), 3 + ) + if knowledge_rerank_top_k <= 0: + knowledge_rerank_top_k = 3 + if knowledge_default_top_k <= 1 and knowledge_enable_rerank: + logger.warning( + "[配置] knowledge.default_top_k=%s,无法满足 rerank_top_k < default_top_k," + "已自动禁用重排", + knowledge_default_top_k, + ) + knowledge_enable_rerank = False + if knowledge_rerank_top_k >= knowledge_default_top_k: + fallback = knowledge_default_top_k - 1 + if fallback <= 0: + fallback = 1 + knowledge_enable_rerank = False + logger.warning( + "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," + "且当前 default_top_k=%s 无法满足约束,已自动禁用重排", + knowledge_default_top_k, + ) + # 否则分支 + else: + logger.warning( + "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," + "已回退: rerank_top_k=%s -> %s (default_top_k=%s)", + knowledge_rerank_top_k, + fallback, + knowledge_default_top_k, + ) + knowledge_rerank_top_k = fallback + + return { + "embedding_model": embedding_model, + "rerank_model": rerank_model, + "knowledge_enabled": knowledge_enabled, + "knowledge_base_dir": knowledge_base_dir, + "knowledge_auto_scan": knowledge_auto_scan, + "knowledge_auto_embed": knowledge_auto_embed, + "knowledge_scan_interval": knowledge_scan_interval, + "knowledge_embed_batch_size": knowledge_embed_batch_size, + "knowledge_chunk_size": knowledge_chunk_size, + "knowledge_chunk_overlap": knowledge_chunk_overlap, + "knowledge_default_top_k": knowledge_default_top_k, + "knowledge_enable_rerank": knowledge_enable_rerank, + "knowledge_rerank_top_k": knowledge_rerank_top_k, + } diff --git a/src/Undefined/config/load_sections/logging_tools.py b/src/Undefined/config/load_sections/logging_tools.py new file mode 100644 index 00000000..f5d6c06e --- /dev/null +++ b/src/Undefined/config/load_sections/logging_tools.py @@ -0,0 +1,126 @@ +"""Load logging_tools config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +import os +import re +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_int, + _coerce_str, + _get_value, + _warn_env_fallback, +) +from ..domain_parsers import ( + _parse_easter_egg_call_mode, +) + +logger = logging.getLogger(__name__) + + +def load_logging_tools( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + log_level = _coerce_str( + _get_value(data, ("logging", "level"), "LOG_LEVEL"), "INFO" + ).upper() + log_file_path = _coerce_str( + _get_value(data, ("logging", "file_path"), "LOG_FILE_PATH"), + "logs/bot.log", + ) + log_max_size_mb = _coerce_int( + _get_value(data, ("logging", "max_size_mb"), "LOG_MAX_SIZE_MB"), 10 + ) + if log_max_size_mb <= 0: + log_max_size_mb = 10 + log_backup_count = _coerce_int( + _get_value(data, ("logging", "backup_count"), "LOG_BACKUP_COUNT"), 5 + ) + if log_backup_count < 0: + log_backup_count = 0 + log_tty_enabled = _coerce_bool( + _get_value(data, ("logging", "tty_enabled"), "LOG_TTY_ENABLED"), + False, + ) + log_thinking = _coerce_bool( + _get_value(data, ("logging", "log_thinking"), "LOG_THINKING"), True + ) + + tools_dot_delimiter = _coerce_str( + _get_value(data, ("tools", "dot_delimiter"), "TOOLS_DOT_DELIMITER"), "-_-" + ).strip() + if not tools_dot_delimiter: + tools_dot_delimiter = "-_-" + # dot_delimiter 必须满足 OpenAI 兼容的 function.name 约束。 + if "." in tools_dot_delimiter or not re.fullmatch( + r"[a-zA-Z0-9_-]+", tools_dot_delimiter + ): + logger.warning( + "[配置] tools.dot_delimiter 非法(仅允许 [a-zA-Z0-9_-] 且不能包含 '.'),已回退默认值: '-_-'(当前=%s)", + tools_dot_delimiter, + ) + tools_dot_delimiter = "-_-" + tools_description_max_len = _coerce_int( + _get_value(data, ("tools", "description_max_len"), "TOOLS_DESCRIPTION_MAX_LEN"), + 1024, + ) + tools_description_truncate_enabled = _coerce_bool( + _get_value( + data, + ("tools", "description_truncate_enabled"), + "TOOLS_DESCRIPTION_TRUNCATE_ENABLED", + ), + False, + ) + tools_sanitize_verbose = _coerce_bool( + _get_value(data, ("tools", "sanitize_verbose"), "TOOLS_SANITIZE_VERBOSE"), + False, + ) + tools_description_preview_len = _coerce_int( + _get_value( + data, + ("tools", "description_preview_len"), + "TOOLS_DESCRIPTION_PREVIEW_LEN", + ), + 160, + ) + + easter_egg_mode_raw = _get_value( + data, + ("easter_egg", "agent_call_message_enabled"), + "EASTER_EGG_AGENT_CALL_MESSAGE_ENABLED", + ) + if easter_egg_mode_raw is None: + easter_egg_mode_raw = os.getenv("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") + if easter_egg_mode_raw is not None: + _warn_env_fallback("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") + # 否则分支 + else: + easter_egg_mode_raw = os.getenv("EASTER_EGG_CALL_MESSAGE_MODE") + if easter_egg_mode_raw is not None: + _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") + + easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( + easter_egg_mode_raw + ) + + return { + "log_level": log_level, + "log_file_path": log_file_path, + "log_max_size": log_max_size_mb * 1024 * 1024, + "log_backup_count": log_backup_count, + "log_tty_enabled": log_tty_enabled, + "log_thinking": log_thinking, + "tools_dot_delimiter": tools_dot_delimiter, + "tools_description_max_len": tools_description_max_len, + "tools_description_truncate_enabled": tools_description_truncate_enabled, + "tools_sanitize_verbose": tools_sanitize_verbose, + "tools_description_preview_len": tools_description_preview_len, + "easter_egg_agent_call_message_mode": easter_egg_agent_call_message_mode, + } diff --git a/src/Undefined/config/load_sections/models.py b/src/Undefined/config/load_sections/models.py new file mode 100644 index 00000000..1a3fff14 --- /dev/null +++ b/src/Undefined/config/load_sections/models.py @@ -0,0 +1,68 @@ +"""Load models config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _get_value, +) +from ..model_parsers import ( + _parse_agent_model_config, + _parse_chat_model_config, + _parse_grok_model_config, + _parse_historian_model_config, + _parse_naga_model_config, + _parse_security_model_config, + _parse_summary_model_config, + _parse_vision_model_config, +) + +logger = logging.getLogger(__name__) + + +def load_models( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + chat_model = _parse_chat_model_config(data) + vision_model = _parse_vision_model_config(data) + security_model_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "enabled"), + "SECURITY_MODEL_ENABLED", + ), + True, + ) + # 未单独配置的模型段会回退到 chat/security/agent 等主模型 + security_model = _parse_security_model_config(data, chat_model) + naga_model = _parse_naga_model_config(data, security_model) + agent_model = _parse_agent_model_config(data) + historian_model = _parse_historian_model_config(data, agent_model) + summary_model, summary_model_configured = _parse_summary_model_config( + data, agent_model + ) + grok_model = _parse_grok_model_config(data) + + model_pool_enabled = _coerce_bool( + _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False + ) + + return { + "chat_model": chat_model, + "vision_model": vision_model, + "security_model_enabled": security_model_enabled, + "security_model": security_model, + "naga_model": naga_model, + "agent_model": agent_model, + "historian_model": historian_model, + "summary_model": summary_model, + "summary_model_configured": summary_model_configured, + "grok_model": grok_model, + "model_pool_enabled": model_pool_enabled, + } diff --git a/src/Undefined/config/load_sections/network.py b/src/Undefined/config/load_sections/network.py new file mode 100644 index 00000000..818b5555 --- /dev/null +++ b/src/Undefined/config/load_sections/network.py @@ -0,0 +1,161 @@ +"""Load network config section.""" + +from __future__ import annotations + +# 配置分段加载:按 table 解析 TOML → ctx 字段 dict + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, + _normalize_base_url, + _warn_env_fallback, +) + +logger = logging.getLogger(__name__) + + +def load_network( + data: dict[str, Any], *, config_path: Optional[Path] = None +) -> dict[str, Any]: + searxng_url = _coerce_str( + _get_value(data, ("search", "searxng_url"), "SEARXNG_URL"), "" + ) + grok_search_enabled = _coerce_bool( + _get_value( + data, + ("search", "grok_search_enabled"), + "GROK_SEARCH_ENABLED", + ), + False, + ) + + use_proxy = _coerce_bool( + _get_value(data, ("proxy", "use_proxy"), "USE_PROXY"), True + ) + http_proxy = _coerce_str( + _get_value(data, ("proxy", "http_proxy"), "http_proxy"), "" + ) + # TOML 未配置时回退标准 HTTP_PROXY 环境变量(小写键名不走 ENV_REGISTRY) + if not http_proxy: + http_proxy = _coerce_str(os.getenv("HTTP_PROXY"), "") + if http_proxy: + _warn_env_fallback("HTTP_PROXY") + https_proxy = _coerce_str( + _get_value(data, ("proxy", "https_proxy"), "https_proxy"), "" + ) + if not https_proxy: + https_proxy = _coerce_str(os.getenv("HTTPS_PROXY"), "") + if https_proxy: + _warn_env_fallback("HTTPS_PROXY") + + network_request_timeout = _coerce_float( + _get_value( + data, + ("network", "request_timeout_seconds"), + "NETWORK_REQUEST_TIMEOUT_SECONDS", + ), + 30.0, + ) + if network_request_timeout <= 0: + network_request_timeout = 480.0 + + network_request_retries = _coerce_int( + _get_value( + data, + ("network", "request_retries"), + "NETWORK_REQUEST_RETRIES", + ), + 0, + ) + if network_request_retries < 0: + network_request_retries = 0 + if network_request_retries > 5: + network_request_retries = 5 + + render_browser_max_concurrency = max( + 0, + _coerce_int( + _get_value( + data, + ("render", "browser_max_concurrency"), + "RENDER_BROWSER_MAX_CONCURRENCY", + ), + 0, + ), + ) + + api_xxapi_base_url = _normalize_base_url( + _coerce_str( + _get_value(data, ("api_endpoints", "xxapi_base_url"), "XXAPI_BASE_URL"), + "https://v2.xxapi.cn", + ), + "https://v2.xxapi.cn", + ) + api_xingzhige_base_url = _normalize_base_url( + _coerce_str( + _get_value( + data, + ("api_endpoints", "xingzhige_base_url"), + "XINGZHIGE_BASE_URL", + ), + "https://api.xingzhige.com", + ), + "https://api.xingzhige.com", + ) + api_jkyai_base_url = _normalize_base_url( + _coerce_str( + _get_value(data, ("api_endpoints", "jkyai_base_url"), "JKYAI_BASE_URL"), + "https://api.jkyai.top", + ), + "https://api.jkyai.top", + ) + api_seniverse_base_url = _normalize_base_url( + _coerce_str( + _get_value( + data, + ("api_endpoints", "seniverse_base_url"), + "SENIVERSE_BASE_URL", + ), + "https://api.seniverse.com/v3", + ), + "https://api.seniverse.com/v3", + ) + + weather_api_key = _coerce_str( + _get_value(data, ("weather", "api_key"), "WEATHER_API_KEY"), "" + ) + xxapi_api_token = _coerce_str( + _get_value(data, ("xxapi", "api_token"), "XXAPI_API_TOKEN"), "" + ) + + mcp_config_path = _coerce_str( + _get_value(data, ("mcp", "config_path"), "MCP_CONFIG_PATH"), + "config/mcp.json", + ) + + # Bilibili 配置 + return { + "searxng_url": searxng_url, + "grok_search_enabled": grok_search_enabled, + "use_proxy": use_proxy, + "http_proxy": http_proxy, + "https_proxy": https_proxy, + "network_request_timeout": network_request_timeout, + "network_request_retries": network_request_retries, + "render_browser_max_concurrency": render_browser_max_concurrency, + "api_xxapi_base_url": api_xxapi_base_url, + "api_xingzhige_base_url": api_xingzhige_base_url, + "api_jkyai_base_url": api_jkyai_base_url, + "api_seniverse_base_url": api_seniverse_base_url, + "weather_api_key": weather_api_key, + "xxapi_api_token": xxapi_api_token, + "mcp_config_path": mcp_config_path, + } diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index dc49fa38..708da82a 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -1,119 +1,22 @@ -"""配置加载逻辑""" +"""配置加载逻辑(向后兼容 shim;Config 实现见 config_class)。""" from __future__ import annotations -import logging -import os -import re -import tomllib -from dataclasses import dataclass, field as dataclass_field, fields -from pathlib import Path -from typing import Any, Optional, IO - -try: - from dotenv import load_dotenv -except Exception: # pragma: no cover - StrPath = str | os.PathLike[str] - - def load_dotenv( - dotenv_path: StrPath | None = None, - stream: IO[str] | None = None, - verbose: bool = False, - override: bool = False, - interpolate: bool = True, - encoding: str | None = "utf-8", - ) -> bool: - return False - - -from .models import ( - APIConfig, - AgentModelConfig, - ChatModelConfig, - CognitiveConfig, - EmbeddingModelConfig, - GrokModelConfig, - ImageGenConfig, - ImageGenModelConfig, - MemeConfig, - MessageBatcherConfig, - NagaConfig, - RenderCacheConfig, - RerankModelConfig, - SecurityModelConfig, - VisionModelConfig, -) -from .coercers import ( # noqa: F401 — re-exported for backward compat - _coerce_bool, - _coerce_float, - _coerce_int, - _coerce_int_list, - _coerce_str, - _coerce_str_list, - _get_model_request_params, - _get_value, - _normalize_base_url, - _normalize_queue_interval, - _normalize_str, - _warn_env_fallback, -) -from .resolvers import ( # noqa: F401 — re-exported for backward compat - _resolve_api_mode, - _resolve_reasoning_effort, - _resolve_reasoning_effort_style, - _resolve_responses_force_stateless_replay, - _resolve_responses_tool_choice_compat, - _resolve_thinking_compat_flags, -) -from .admin import ( # noqa: F401 — re-exported for backward compat - LOCAL_CONFIG_PATH, - load_local_admins, - save_local_admins, -) -from .webui_settings import ( # noqa: F401 — re-exported for backward compat +from .admin import LOCAL_CONFIG_PATH, load_local_admins, save_local_admins +from .config_class import Config, ConfigBuilder +from .toml_io import CONFIG_PATH, load_toml_data +from .webui_settings import ( DEFAULT_WEBUI_PASSWORD, DEFAULT_WEBUI_PORT, DEFAULT_WEBUI_URL, WebUISettings, load_webui_settings, ) -from .model_parsers import ( - _log_debug_info, - _merge_admins, - _parse_agent_model_config, - _parse_chat_model_config, - _parse_embedding_model_config, - _parse_grok_model_config, - _parse_historian_model_config, - _parse_image_edit_model_config, - _parse_image_gen_config, - _parse_image_gen_model_config, - _parse_naga_model_config, - _parse_rerank_model_config, - _parse_security_model_config, - _parse_summary_model_config, - _parse_vision_model_config, - _verify_required_fields, -) -from .domain_parsers import ( - _parse_api_config, - _parse_cognitive_config, - _parse_easter_egg_call_mode, - _parse_memes_config, - _parse_message_batcher_config, - _parse_naga_config, - _parse_render_cache_config, - _update_dataclass, -) - -logger = logging.getLogger(__name__) -CONFIG_PATH = Path("config.toml") - -# Re-export symbols that external modules import from this module. __all__ = [ "CONFIG_PATH", "Config", + "ConfigBuilder", "DEFAULT_WEBUI_PASSWORD", "DEFAULT_WEBUI_PORT", "DEFAULT_WEBUI_URL", @@ -124,1681 +27,3 @@ def load_dotenv( "load_webui_settings", "save_local_admins", ] - - -def _load_env() -> None: - try: - load_dotenv() - except Exception: - logger.debug("加载 .env 失败,继续使用 config.toml", exc_info=True) - - -def _build_toml_decode_hint(line: str) -> str: - hints: list[str] = [] - if "\\" in line: - hints.append( - 'Windows 路径建议用单引号(不转义)或双反斜杠,或直接用正斜杠,例如:path = \'D:\\AI\\bot\' / path = "D:\\\\AI\\\\bot" / path = "D:/AI/bot"' - ) - hints.append('多行文本请用三引号,例如:prompt = """..."""') - return ";".join(hints) - - -def _format_toml_decode_error( - path: Path, text: str, exc: tomllib.TOMLDecodeError -) -> str: - lineno: int | None = getattr(exc, "lineno", None) - colno: int | None = getattr(exc, "colno", None) - if not isinstance(lineno, int) or not isinstance(colno, int): - match = re.search(r"\(at line (\d+), column (\d+)\)", str(exc)) - if match: - lineno = int(match.group(1)) - colno = int(match.group(2)) - - if isinstance(lineno, int) and lineno > 0: - lines = text.splitlines() - line = lines[lineno - 1] if 0 <= (lineno - 1) < len(lines) else "" - caret_pos = max((colno or 1) - 1, 0) - caret = " " * min(caret_pos, len(line)) + "^" - hint = _build_toml_decode_hint(line) - location = f"line={lineno} col={colno or 1}" - return f"{exc} ({location})\n> {line}\n {caret}\n提示:{hint}" - return str(exc) - - -def load_toml_data( - config_path: Optional[Path] = None, *, strict: bool = False -) -> dict[str, Any]: - """读取 config.toml 并返回字典""" - path = config_path or CONFIG_PATH - if not path.exists(): - return {} - text = "" - try: - text = path.read_bytes().decode("utf-8-sig") - data = tomllib.loads(text) - if isinstance(data, dict): - return data - logger.warning("config.toml 内容不是对象结构") - return {} - except tomllib.TOMLDecodeError as exc: - message = _format_toml_decode_error(path, text, exc) - logger.error("config.toml 解析失败 (%s): %s", path.resolve(), message) - if strict: - raise ValueError(message) from exc - return {} - except UnicodeDecodeError as exc: - logger.error("config.toml 编码错误 (%s): %s", path.resolve(), exc) - if strict: - raise ValueError(str(exc)) from exc - return {} - except OSError as exc: - logger.error("读取 config.toml 失败: %s", exc) - if strict: - raise ValueError(str(exc)) from exc - return {} - - -@dataclass -class Config: - """应用配置""" - - bot_qq: int - superadmin_qq: int - admin_qqs: list[int] - # 访问控制模式:off / blacklist / allowlist - access_mode: str - # 访问控制(会话白名单 + 黑名单) - allowed_group_ids: list[int] - blocked_group_ids: list[int] - allowed_private_ids: list[int] - blocked_private_ids: list[int] - # 是否允许超级管理员在私聊中绕过 allowed_private_ids(仅私聊收发) - superadmin_bypass_allowlist: bool - # 是否允许超级管理员在私聊中绕过 blocked_private_ids(仅私聊收发) - superadmin_bypass_private_blacklist: bool - forward_proxy_qq: int | None - process_every_message: bool - process_private_message: bool - process_poke_message: bool - keyword_reply_enabled: bool - repeat_enabled: bool - repeat_threshold: int - repeat_cooldown_minutes: int - inverted_question_enabled: bool - context_recent_messages_limit: int - ai_request_max_retries: int - missing_tool_call_retries: int - nagaagent_mode_enabled: bool - onebot_ws_url: str - onebot_token: str - chat_model: ChatModelConfig - vision_model: VisionModelConfig - security_model_enabled: bool - security_model: SecurityModelConfig - naga_model: SecurityModelConfig - agent_model: AgentModelConfig - historian_model: AgentModelConfig - summary_model: AgentModelConfig - summary_model_configured: bool - grok_model: GrokModelConfig - model_pool_enabled: bool - log_level: str - log_file_path: str - log_max_size: int - log_backup_count: int - log_tty_enabled: bool - log_thinking: bool - tools_dot_delimiter: str - tools_description_truncate_enabled: bool - tools_description_max_len: int - tools_sanitize_verbose: bool - tools_description_preview_len: int - easter_egg_agent_call_message_mode: str - token_usage_max_size_mb: int - token_usage_max_archives: int - token_usage_max_total_mb: int - token_usage_archive_prune_mode: str - history_max_records: int - history_filtered_result_limit: int - history_search_scan_limit: int - history_summary_fetch_limit: int - history_summary_time_fetch_limit: int - history_onebot_fetch_limit: int - history_group_analysis_limit: int - attachment_remote_download_max_size_mb: int - attachment_cache_max_total_size_mb: int - attachment_cache_max_records: int - attachment_cache_max_age_days: int - attachment_url_reference_max_records: int - attachment_url_max_length: int - skills_hot_reload: bool - skills_hot_reload_interval: float - skills_hot_reload_debounce: float - agent_intro_autogen_enabled: bool - agent_intro_autogen_queue_interval: float - agent_intro_autogen_max_tokens: int - agent_intro_hash_path: str - searxng_url: str - grok_search_enabled: bool - use_proxy: bool - http_proxy: str - https_proxy: str - network_request_timeout: float - network_request_retries: int - render_browser_max_concurrency: int - api_xxapi_base_url: str - api_xingzhige_base_url: str - api_jkyai_base_url: str - api_seniverse_base_url: str - weather_api_key: str - xxapi_api_token: str - mcp_config_path: str - prefetch_tools: list[str] - prefetch_tools_hide: bool - webui_url: str - webui_port: int - webui_password: str - api: APIConfig - # Code Delivery Agent - code_delivery_enabled: bool - code_delivery_task_root: str - code_delivery_docker_image: str - code_delivery_container_name_prefix: str - code_delivery_container_name_suffix: str - code_delivery_command_timeout: int - code_delivery_max_command_output: int - code_delivery_default_archive_format: str - code_delivery_max_archive_size_mb: int - code_delivery_cleanup_on_finish: bool - code_delivery_cleanup_on_start: bool - code_delivery_llm_max_retries: int - code_delivery_notify_on_llm_failure: bool - code_delivery_container_memory_limit: str - code_delivery_container_cpu_limit: str - code_delivery_command_blacklist: list[str] - # messages 工具集 - messages_send_text_file_max_size_kb: int - messages_send_url_file_max_size_mb: int - # 嵌入模型 - embedding_model: EmbeddingModelConfig - rerank_model: RerankModelConfig - # 知识库 - knowledge_enabled: bool - knowledge_base_dir: str - knowledge_auto_scan: bool - knowledge_auto_embed: bool - knowledge_scan_interval: float - knowledge_embed_batch_size: int - knowledge_chunk_size: int - knowledge_chunk_overlap: int - knowledge_default_top_k: int - knowledge_enable_rerank: bool - knowledge_rerank_top_k: int - # Bilibili 视频提取 - bilibili_auto_extract_enabled: bool - bilibili_cookie: str - bilibili_prefer_quality: int - bilibili_max_duration: int - bilibili_max_file_size: int - bilibili_oversize_strategy: str - bilibili_danmaku_enabled: bool - bilibili_danmaku_batch_size: int - bilibili_danmaku_max_count: int - bilibili_auto_extract_group_ids: list[int] - bilibili_auto_extract_private_ids: list[int] - # arXiv 论文提取 - arxiv_auto_extract_enabled: bool - arxiv_max_file_size: int - arxiv_auto_extract_group_ids: list[int] - arxiv_auto_extract_private_ids: list[int] - arxiv_auto_extract_max_items: int - arxiv_author_preview_limit: int - arxiv_summary_preview_chars: int - # GitHub 仓库自动提取 - github_auto_extract_enabled: bool - github_request_timeout_seconds: float - github_auto_extract_group_ids: list[int] - github_auto_extract_private_ids: list[int] - github_auto_extract_max_items: int - # 认知记忆 - cognitive: CognitiveConfig - # 表情包库 - memes: MemeConfig - # 同 sender 短时多消息合并器 - message_batcher: MessageBatcherConfig - # HTML 渲染结果缓存 - render_cache: RenderCacheConfig - # Naga 集成 - naga: NagaConfig - # 生图工具配置 - image_gen: ImageGenConfig - models_image_gen: ImageGenModelConfig - models_image_edit: ImageGenModelConfig - _allowed_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _blocked_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _allowed_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _blocked_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _bilibili_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _bilibili_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _arxiv_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _arxiv_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _github_group_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - _github_private_ids_set: set[int] = dataclass_field( - default_factory=set, - init=False, - repr=False, - ) - - def __post_init__(self) -> None: - # 访问控制属于高频热路径,启动后缓存为 set 降低重复构建开销。 - normalized_mode = str(self.access_mode).strip().lower() - if normalized_mode not in {"off", "blacklist", "allowlist", "legacy"}: - normalized_mode = "off" - self.access_mode = normalized_mode - self._allowed_group_ids_set = {int(item) for item in self.allowed_group_ids} - self._blocked_group_ids_set = {int(item) for item in self.blocked_group_ids} - self._allowed_private_ids_set = {int(item) for item in self.allowed_private_ids} - self._blocked_private_ids_set = {int(item) for item in self.blocked_private_ids} - self._bilibili_group_ids_set = { - int(item) for item in self.bilibili_auto_extract_group_ids - } - self._bilibili_private_ids_set = { - int(item) for item in self.bilibili_auto_extract_private_ids - } - self._arxiv_group_ids_set = { - int(item) for item in self.arxiv_auto_extract_group_ids - } - self._arxiv_private_ids_set = { - int(item) for item in self.arxiv_auto_extract_private_ids - } - self._github_group_ids_set = { - int(item) for item in self.github_auto_extract_group_ids - } - self._github_private_ids_set = { - int(item) for item in self.github_auto_extract_private_ids - } - - @classmethod - def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Config": - """从 config.toml 和本地配置加载配置""" - _load_env() - data = load_toml_data(config_path, strict=strict) - - bot_qq = _coerce_int(_get_value(data, ("core", "bot_qq"), "BOT_QQ"), 0) - superadmin_qq = _coerce_int( - _get_value(data, ("core", "superadmin_qq"), "SUPERADMIN_QQ"), 0 - ) - admin_qqs = _coerce_int_list(_get_value(data, ("core", "admin_qq"), "ADMIN_QQ")) - forward_proxy = _coerce_int( - _get_value(data, ("core", "forward_proxy_qq"), "FORWARD_PROXY_QQ"), - 0, - ) - forward_proxy_qq = forward_proxy if forward_proxy > 0 else None - process_every_message = _coerce_bool( - _get_value( - data, - ("core", "process_every_message"), - "PROCESS_EVERY_MESSAGE", - ), - True, - ) - process_private_message = _coerce_bool( - _get_value( - data, - ("core", "process_private_message"), - "PROCESS_PRIVATE_MESSAGE", - ), - True, - ) - process_poke_message = _coerce_bool( - _get_value( - data, - ("core", "process_poke_message"), - "PROCESS_POKE_MESSAGE", - ), - True, - ) - keyword_reply_raw = _get_value( - data, - ("easter_egg", "keyword_reply_enabled"), - "KEYWORD_REPLY_ENABLED", - ) - if keyword_reply_raw is None: - # 兼容旧配置:历史上放在 [core].keyword_reply_enabled - keyword_reply_raw = _get_value( - data, - ("core", "keyword_reply_enabled"), - None, - ) - keyword_reply_enabled = _coerce_bool(keyword_reply_raw, False) - repeat_enabled = _coerce_bool( - _get_value( - data, - ("easter_egg", "repeat_enabled"), - "EASTER_EGG_REPEAT_ENABLED", - ), - False, - ) - inverted_question_enabled = _coerce_bool( - _get_value( - data, - ("easter_egg", "inverted_question_enabled"), - "EASTER_EGG_INVERTED_QUESTION_ENABLED", - ), - False, - ) - repeat_threshold = _coerce_int( - _get_value( - data, - ("easter_egg", "repeat_threshold"), - "EASTER_EGG_REPEAT_THRESHOLD", - ), - 3, - ) - if repeat_threshold < 2: - repeat_threshold = 2 - if repeat_threshold > 20: - repeat_threshold = 20 - repeat_cooldown_minutes = _coerce_int( - _get_value( - data, - ("easter_egg", "repeat_cooldown_minutes"), - "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", - ), - 60, - ) - if repeat_cooldown_minutes < 0: - repeat_cooldown_minutes = 0 - context_recent_messages_limit = _coerce_int( - _get_value( - data, - ("core", "context_recent_messages_limit"), - "CONTEXT_RECENT_MESSAGES_LIMIT", - ), - 20, - ) - if context_recent_messages_limit < 0: - context_recent_messages_limit = 0 - - ai_request_max_retries = _coerce_int( - _get_value( - data, - ("core", "ai_request_max_retries"), - "AI_REQUEST_MAX_RETRIES", - ), - 2, - ) - if ai_request_max_retries < 0: - ai_request_max_retries = 0 - - missing_tool_call_retries = _coerce_int( - _get_value( - data, - ("core", "missing_tool_call_retries"), - "MISSING_TOOL_CALL_RETRIES", - ), - 3, - ) - if missing_tool_call_retries < 0: - missing_tool_call_retries = 0 - - nagaagent_mode_enabled = _coerce_bool( - _get_value( - data, - ("features", "nagaagent_mode_enabled"), - "NAGAAGENT_MODE_ENABLED", - ), - False, - ) - onebot_ws_url = _coerce_str( - _get_value(data, ("onebot", "ws_url"), "ONEBOT_WS_URL"), "" - ) - onebot_token = _coerce_str( - _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" - ) - - embedding_model = _parse_embedding_model_config(data) - rerank_model = _parse_rerank_model_config(data) - - knowledge_enabled = _coerce_bool( - _get_value(data, ("knowledge", "enabled"), None), False - ) - knowledge_base_dir = _coerce_str( - _get_value(data, ("knowledge", "base_dir"), None), "knowledge" - ) - knowledge_auto_scan = _coerce_bool( - _get_value(data, ("knowledge", "auto_scan"), None), False - ) - knowledge_auto_embed = _coerce_bool( - _get_value(data, ("knowledge", "auto_embed"), None), False - ) - knowledge_scan_interval = _coerce_float( - _get_value(data, ("knowledge", "scan_interval"), None), 60.0 - ) - if knowledge_scan_interval <= 0: - knowledge_scan_interval = 60.0 - knowledge_embed_batch_size = _coerce_int( - _get_value(data, ("knowledge", "embed_batch_size"), None), 64 - ) - if knowledge_embed_batch_size <= 0: - knowledge_embed_batch_size = 64 - knowledge_chunk_size = _coerce_int( - _get_value(data, ("knowledge", "chunk_size"), None), 10 - ) - if knowledge_chunk_size <= 0: - knowledge_chunk_size = 10 - knowledge_chunk_overlap = _coerce_int( - _get_value(data, ("knowledge", "chunk_overlap"), None), 2 - ) - if knowledge_chunk_overlap < 0: - knowledge_chunk_overlap = 0 - knowledge_default_top_k = _coerce_int( - _get_value(data, ("knowledge", "default_top_k"), None), 5 - ) - if knowledge_default_top_k <= 0: - knowledge_default_top_k = 5 - knowledge_enable_rerank = _coerce_bool( - _get_value(data, ("knowledge", "enable_rerank"), None), False - ) - knowledge_rerank_top_k = _coerce_int( - _get_value(data, ("knowledge", "rerank_top_k"), None), 3 - ) - if knowledge_rerank_top_k <= 0: - knowledge_rerank_top_k = 3 - if knowledge_default_top_k <= 1 and knowledge_enable_rerank: - logger.warning( - "[配置] knowledge.default_top_k=%s,无法满足 rerank_top_k < default_top_k," - "已自动禁用重排", - knowledge_default_top_k, - ) - knowledge_enable_rerank = False - if knowledge_rerank_top_k >= knowledge_default_top_k: - fallback = knowledge_default_top_k - 1 - if fallback <= 0: - fallback = 1 - knowledge_enable_rerank = False - logger.warning( - "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," - "且当前 default_top_k=%s 无法满足约束,已自动禁用重排", - knowledge_default_top_k, - ) - else: - logger.warning( - "[配置] knowledge.rerank_top_k 需小于 knowledge.default_top_k," - "已回退: rerank_top_k=%s -> %s (default_top_k=%s)", - knowledge_rerank_top_k, - fallback, - knowledge_default_top_k, - ) - knowledge_rerank_top_k = fallback - - chat_model = _parse_chat_model_config(data) - vision_model = _parse_vision_model_config(data) - security_model_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "enabled"), - "SECURITY_MODEL_ENABLED", - ), - True, - ) - security_model = _parse_security_model_config(data, chat_model) - naga_model = _parse_naga_model_config(data, security_model) - agent_model = _parse_agent_model_config(data) - historian_model = _parse_historian_model_config(data, agent_model) - summary_model, summary_model_configured = _parse_summary_model_config( - data, agent_model - ) - grok_model = _parse_grok_model_config(data) - - model_pool_enabled = _coerce_bool( - _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False - ) - - superadmin_qq, admin_qqs = _merge_admins( - superadmin_qq=superadmin_qq, admin_qqs=admin_qqs - ) - - access_mode_raw = _get_value(data, ("access", "mode"), "ACCESS_MODE") - allowed_group_ids = _coerce_int_list( - _get_value(data, ("access", "allowed_group_ids"), "ALLOWED_GROUP_IDS") - ) - blocked_group_ids = _coerce_int_list( - _get_value(data, ("access", "blocked_group_ids"), "BLOCKED_GROUP_IDS") - ) - allowed_private_ids = _coerce_int_list( - _get_value(data, ("access", "allowed_private_ids"), "ALLOWED_PRIVATE_IDS") - ) - blocked_private_ids = _coerce_int_list( - _get_value(data, ("access", "blocked_private_ids"), "BLOCKED_PRIVATE_IDS") - ) - superadmin_bypass_allowlist = _coerce_bool( - _get_value( - data, - ("access", "superadmin_bypass_allowlist"), - "SUPERADMIN_BYPASS_ALLOWLIST", - ), - True, - ) - superadmin_bypass_private_blacklist = _coerce_bool( - _get_value( - data, - ("access", "superadmin_bypass_private_blacklist"), - "SUPERADMIN_BYPASS_PRIVATE_BLACKLIST", - ), - False, - ) - if access_mode_raw is None: - # 兼容旧配置:未配置 mode 时沿用历史行为(群黑名单 + 白名单联动)。 - if ( - allowed_group_ids - or blocked_group_ids - or allowed_private_ids - or blocked_private_ids - ): - access_mode = "legacy" - logger.warning( - "[配置] access.mode 未设置,已启用兼容模式(legacy)。建议显式设置为 off/blacklist/allowlist。" - ) - else: - access_mode = "off" - else: - access_mode = _coerce_str(access_mode_raw, "off").lower() - if access_mode not in {"off", "blacklist", "allowlist"}: - logger.warning( - "[配置] access.mode 非法(仅支持 off/blacklist/allowlist),已回退为 off: %s", - access_mode, - ) - access_mode = "off" - - log_level = _coerce_str( - _get_value(data, ("logging", "level"), "LOG_LEVEL"), "INFO" - ).upper() - log_file_path = _coerce_str( - _get_value(data, ("logging", "file_path"), "LOG_FILE_PATH"), - "logs/bot.log", - ) - log_max_size_mb = _coerce_int( - _get_value(data, ("logging", "max_size_mb"), "LOG_MAX_SIZE_MB"), 10 - ) - log_backup_count = _coerce_int( - _get_value(data, ("logging", "backup_count"), "LOG_BACKUP_COUNT"), 5 - ) - log_tty_enabled = _coerce_bool( - _get_value(data, ("logging", "tty_enabled"), "LOG_TTY_ENABLED"), - False, - ) - log_thinking = _coerce_bool( - _get_value(data, ("logging", "log_thinking"), "LOG_THINKING"), True - ) - - tools_dot_delimiter = _coerce_str( - _get_value(data, ("tools", "dot_delimiter"), "TOOLS_DOT_DELIMITER"), "-_-" - ).strip() - if not tools_dot_delimiter: - tools_dot_delimiter = "-_-" - # dot_delimiter 必须满足 OpenAI 兼容的 function.name 约束。 - if "." in tools_dot_delimiter or not re.fullmatch( - r"[a-zA-Z0-9_-]+", tools_dot_delimiter - ): - logger.warning( - "[配置] tools.dot_delimiter 非法(仅允许 [a-zA-Z0-9_-] 且不能包含 '.'),已回退默认值: '-_-'(当前=%s)", - tools_dot_delimiter, - ) - tools_dot_delimiter = "-_-" - tools_description_max_len = _coerce_int( - _get_value( - data, ("tools", "description_max_len"), "TOOLS_DESCRIPTION_MAX_LEN" - ), - 1024, - ) - tools_description_truncate_enabled = _coerce_bool( - _get_value( - data, - ("tools", "description_truncate_enabled"), - "TOOLS_DESCRIPTION_TRUNCATE_ENABLED", - ), - False, - ) - tools_sanitize_verbose = _coerce_bool( - _get_value(data, ("tools", "sanitize_verbose"), "TOOLS_SANITIZE_VERBOSE"), - False, - ) - tools_description_preview_len = _coerce_int( - _get_value( - data, - ("tools", "description_preview_len"), - "TOOLS_DESCRIPTION_PREVIEW_LEN", - ), - 160, - ) - - easter_egg_mode_raw = _get_value( - data, - ("easter_egg", "agent_call_message_enabled"), - "EASTER_EGG_AGENT_CALL_MESSAGE_ENABLED", - ) - if easter_egg_mode_raw is None: - easter_egg_mode_raw = os.getenv("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") - if easter_egg_mode_raw is not None: - _warn_env_fallback("EASTER_EGG_AGENT_CALL_MESSAGE_MODE") - else: - easter_egg_mode_raw = os.getenv("EASTER_EGG_CALL_MESSAGE_MODE") - if easter_egg_mode_raw is not None: - _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") - - easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( - easter_egg_mode_raw - ) - - token_usage_max_size_mb = _coerce_int( - _get_value(data, ("token_usage", "max_size_mb"), "TOKEN_USAGE_MAX_SIZE_MB"), - 5, - ) - token_usage_max_archives = _coerce_int( - _get_value( - data, ("token_usage", "max_archives"), "TOKEN_USAGE_MAX_ARCHIVES" - ), - 30, - ) - token_usage_max_total_mb = _coerce_int( - _get_value( - data, ("token_usage", "max_total_mb"), "TOKEN_USAGE_MAX_TOTAL_MB" - ), - 0, - ) - token_usage_archive_prune_mode = _coerce_str( - _get_value( - data, - ("token_usage", "archive_prune_mode"), - "TOKEN_USAGE_ARCHIVE_PRUNE_MODE", - ), - "delete", - ) - - history_max_records = max( - 0, - _coerce_int( - _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), - 10000, - ), - ) - history_filtered_result_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "filtered_result_limit"), - "HISTORY_FILTERED_RESULT_LIMIT", - ), - 200, - ), - ) - history_search_scan_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "search_scan_limit"), - "HISTORY_SEARCH_SCAN_LIMIT", - ), - 10000, - ), - ) - history_summary_fetch_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "summary_fetch_limit"), - "HISTORY_SUMMARY_FETCH_LIMIT", - ), - 1000, - ), - ) - history_summary_time_fetch_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "summary_time_fetch_limit"), - "HISTORY_SUMMARY_TIME_FETCH_LIMIT", - ), - 5000, - ), - ) - history_onebot_fetch_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "onebot_fetch_limit"), - "HISTORY_ONEBOT_FETCH_LIMIT", - ), - 10000, - ), - ) - history_group_analysis_limit = max( - 1, - _coerce_int( - _get_value( - data, - ("history", "group_analysis_limit"), - "HISTORY_GROUP_ANALYSIS_LIMIT", - ), - 500, - ), - ) - attachment_remote_download_max_size_mb = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "remote_download_max_size_mb"), - "ATTACHMENTS_REMOTE_DOWNLOAD_MAX_SIZE_MB", - ), - 25, - ), - ) - attachment_cache_max_total_size_mb = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "cache_max_total_size_mb"), - "ATTACHMENTS_CACHE_MAX_TOTAL_SIZE_MB", - ), - 0, - ), - ) - attachment_cache_max_records = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "cache_max_records"), - "ATTACHMENTS_CACHE_MAX_RECORDS", - ), - 2000, - ), - ) - attachment_cache_max_age_days = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "cache_max_age_days"), - "ATTACHMENTS_CACHE_MAX_AGE_DAYS", - ), - 7, - ), - ) - attachment_url_reference_max_records = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "url_reference_max_records"), - "ATTACHMENTS_URL_REFERENCE_MAX_RECORDS", - ), - 2000, - ), - ) - attachment_url_max_length = max( - 0, - _coerce_int( - _get_value( - data, - ("attachments", "url_max_length"), - "ATTACHMENTS_URL_MAX_LENGTH", - ), - 8192, - ), - ) - - skills_hot_reload = _coerce_bool( - _get_value(data, ("skills", "hot_reload"), "SKILLS_HOT_RELOAD"), True - ) - skills_hot_reload_interval = _coerce_float( - _get_value( - data, ("skills", "hot_reload_interval"), "SKILLS_HOT_RELOAD_INTERVAL" - ), - 2.0, - ) - skills_hot_reload_debounce = _coerce_float( - _get_value( - data, ("skills", "hot_reload_debounce"), "SKILLS_HOT_RELOAD_DEBOUNCE" - ), - 0.5, - ) - - agent_intro_autogen_enabled = _coerce_bool( - _get_value( - data, - ("skills", "intro_autogen_enabled"), - "AGENT_INTRO_AUTOGEN_ENABLED", - ), - True, - ) - agent_intro_autogen_queue_interval = _coerce_float( - _get_value( - data, - ("skills", "intro_autogen_queue_interval"), - "AGENT_INTRO_AUTOGEN_QUEUE_INTERVAL", - ), - 1.0, - ) - agent_intro_autogen_queue_interval = _normalize_queue_interval( - agent_intro_autogen_queue_interval - ) - agent_intro_autogen_max_tokens = _coerce_int( - _get_value( - data, - ("skills", "intro_autogen_max_tokens"), - "AGENT_INTRO_AUTOGEN_MAX_TOKENS", - ), - 8192, - ) - agent_intro_hash_path = _coerce_str( - _get_value(data, ("skills", "intro_hash_path"), "AGENT_INTRO_HASH_PATH"), - ".cache/agent_intro_hashes.json", - ) - - prefetch_tools_raw = _get_value( - data, ("skills", "prefetch_tools"), "PREFETCH_TOOLS" - ) - prefetch_tools = _coerce_str_list(prefetch_tools_raw) - if not prefetch_tools and prefetch_tools_raw is None: - prefetch_tools = ["get_current_time"] - prefetch_tools_hide = _coerce_bool( - _get_value(data, ("skills", "prefetch_tools_hide"), "PREFETCH_TOOLS_HIDE"), - True, - ) - - searxng_url = _coerce_str( - _get_value(data, ("search", "searxng_url"), "SEARXNG_URL"), "" - ) - grok_search_enabled = _coerce_bool( - _get_value( - data, - ("search", "grok_search_enabled"), - "GROK_SEARCH_ENABLED", - ), - False, - ) - - use_proxy = _coerce_bool( - _get_value(data, ("proxy", "use_proxy"), "USE_PROXY"), True - ) - http_proxy = _coerce_str( - _get_value(data, ("proxy", "http_proxy"), "http_proxy"), "" - ) - if not http_proxy: - http_proxy = _coerce_str(os.getenv("HTTP_PROXY"), "") - if http_proxy: - _warn_env_fallback("HTTP_PROXY") - https_proxy = _coerce_str( - _get_value(data, ("proxy", "https_proxy"), "https_proxy"), "" - ) - if not https_proxy: - https_proxy = _coerce_str(os.getenv("HTTPS_PROXY"), "") - if https_proxy: - _warn_env_fallback("HTTPS_PROXY") - - network_request_timeout = _coerce_float( - _get_value( - data, - ("network", "request_timeout_seconds"), - "NETWORK_REQUEST_TIMEOUT_SECONDS", - ), - 30.0, - ) - if network_request_timeout <= 0: - network_request_timeout = 480.0 - - network_request_retries = _coerce_int( - _get_value( - data, - ("network", "request_retries"), - "NETWORK_REQUEST_RETRIES", - ), - 0, - ) - if network_request_retries < 0: - network_request_retries = 0 - if network_request_retries > 5: - network_request_retries = 5 - - render_browser_max_concurrency = max( - 0, - _coerce_int( - _get_value( - data, - ("render", "browser_max_concurrency"), - "RENDER_BROWSER_MAX_CONCURRENCY", - ), - 0, - ), - ) - - api_xxapi_base_url = _normalize_base_url( - _coerce_str( - _get_value(data, ("api_endpoints", "xxapi_base_url"), "XXAPI_BASE_URL"), - "https://v2.xxapi.cn", - ), - "https://v2.xxapi.cn", - ) - api_xingzhige_base_url = _normalize_base_url( - _coerce_str( - _get_value( - data, - ("api_endpoints", "xingzhige_base_url"), - "XINGZHIGE_BASE_URL", - ), - "https://api.xingzhige.com", - ), - "https://api.xingzhige.com", - ) - api_jkyai_base_url = _normalize_base_url( - _coerce_str( - _get_value(data, ("api_endpoints", "jkyai_base_url"), "JKYAI_BASE_URL"), - "https://api.jkyai.top", - ), - "https://api.jkyai.top", - ) - api_seniverse_base_url = _normalize_base_url( - _coerce_str( - _get_value( - data, - ("api_endpoints", "seniverse_base_url"), - "SENIVERSE_BASE_URL", - ), - "https://api.seniverse.com/v3", - ), - "https://api.seniverse.com/v3", - ) - - weather_api_key = _coerce_str( - _get_value(data, ("weather", "api_key"), "WEATHER_API_KEY"), "" - ) - xxapi_api_token = _coerce_str( - _get_value(data, ("xxapi", "api_token"), "XXAPI_API_TOKEN"), "" - ) - - mcp_config_path = _coerce_str( - _get_value(data, ("mcp", "config_path"), "MCP_CONFIG_PATH"), - "config/mcp.json", - ) - - # Bilibili 配置 - bilibili_auto_extract_enabled = _coerce_bool( - _get_value(data, ("bilibili", "auto_extract_enabled"), None), False - ) - bilibili_cookie = _coerce_str( - _get_value(data, ("bilibili", "cookie"), None), "" - ) - if not bilibili_cookie: - # 兼容旧配置项:bilibili.sessdata - bilibili_cookie = _coerce_str( - _get_value(data, ("bilibili", "sessdata"), None), "" - ) - bilibili_prefer_quality = _coerce_int( - _get_value(data, ("bilibili", "prefer_quality"), None), 80 - ) - bilibili_max_duration = _coerce_int( - _get_value(data, ("bilibili", "max_duration"), None), 600 - ) - bilibili_max_file_size = _coerce_int( - _get_value(data, ("bilibili", "max_file_size"), None), 100 - ) - bilibili_oversize_strategy = _coerce_str( - _get_value(data, ("bilibili", "oversize_strategy"), None), "downgrade" - ) - if bilibili_oversize_strategy not in ("downgrade", "info"): - bilibili_oversize_strategy = "downgrade" - bilibili_danmaku_enabled = _coerce_bool( - _get_value(data, ("bilibili", "danmaku_enabled"), None), True - ) - bilibili_danmaku_batch_size = _coerce_int( - _get_value(data, ("bilibili", "danmaku_batch_size"), None), 100 - ) - if bilibili_danmaku_batch_size <= 0: - bilibili_danmaku_batch_size = 100 - bilibili_danmaku_max_count = _coerce_int( - _get_value(data, ("bilibili", "danmaku_max_count"), None), 0 - ) - if bilibili_danmaku_max_count < 0: - bilibili_danmaku_max_count = 0 - bilibili_auto_extract_group_ids = _coerce_int_list( - _get_value(data, ("bilibili", "auto_extract_group_ids"), None) - ) - bilibili_auto_extract_private_ids = _coerce_int_list( - _get_value(data, ("bilibili", "auto_extract_private_ids"), None) - ) - - # arXiv 配置 - arxiv_auto_extract_enabled = _coerce_bool( - _get_value(data, ("arxiv", "auto_extract_enabled"), None), False - ) - arxiv_max_file_size = _coerce_int( - _get_value(data, ("arxiv", "max_file_size"), None), 100 - ) - if arxiv_max_file_size < 0: - arxiv_max_file_size = 100 - arxiv_auto_extract_group_ids = _coerce_int_list( - _get_value(data, ("arxiv", "auto_extract_group_ids"), None) - ) - arxiv_auto_extract_private_ids = _coerce_int_list( - _get_value(data, ("arxiv", "auto_extract_private_ids"), None) - ) - arxiv_auto_extract_max_items = _coerce_int( - _get_value(data, ("arxiv", "auto_extract_max_items"), None), 5 - ) - if arxiv_auto_extract_max_items <= 0: - arxiv_auto_extract_max_items = 5 - if arxiv_auto_extract_max_items > 20: - arxiv_auto_extract_max_items = 20 - arxiv_author_preview_limit = _coerce_int( - _get_value(data, ("arxiv", "author_preview_limit"), None), 20 - ) - if arxiv_author_preview_limit <= 0: - arxiv_author_preview_limit = 20 - if arxiv_author_preview_limit > 100: - arxiv_author_preview_limit = 100 - arxiv_summary_preview_chars = _coerce_int( - _get_value(data, ("arxiv", "summary_preview_chars"), None), 1000 - ) - if arxiv_summary_preview_chars <= 0: - arxiv_summary_preview_chars = 1000 - if arxiv_summary_preview_chars > 8000: - arxiv_summary_preview_chars = 8000 - - # GitHub 配置 - github_auto_extract_enabled = _coerce_bool( - _get_value(data, ("github", "auto_extract_enabled"), None), False - ) - github_request_timeout_seconds = _coerce_float( - _get_value(data, ("github", "request_timeout_seconds"), None), 10.0 - ) - if github_request_timeout_seconds <= 0: - github_request_timeout_seconds = 10.0 - if github_request_timeout_seconds > 60.0: - github_request_timeout_seconds = 60.0 - github_auto_extract_group_ids = _coerce_int_list( - _get_value(data, ("github", "auto_extract_group_ids"), None) - ) - github_auto_extract_private_ids = _coerce_int_list( - _get_value(data, ("github", "auto_extract_private_ids"), None) - ) - github_auto_extract_max_items = _coerce_int( - _get_value(data, ("github", "auto_extract_max_items"), None), 3 - ) - if github_auto_extract_max_items <= 0: - github_auto_extract_max_items = 3 - if github_auto_extract_max_items > 10: - github_auto_extract_max_items = 10 - - # Code Delivery Agent 配置 - code_delivery_enabled = _coerce_bool( - _get_value(data, ("code_delivery", "enabled"), None), True - ) - code_delivery_task_root = _coerce_str( - _get_value(data, ("code_delivery", "task_root"), None), - "data/code_delivery", - ) - code_delivery_docker_image = _coerce_str( - _get_value(data, ("code_delivery", "docker_image"), None), - "ubuntu:24.04", - ) - code_delivery_container_name_prefix = _coerce_str( - _get_value(data, ("code_delivery", "container_name_prefix"), None), - "code_delivery_", - ) - code_delivery_container_name_suffix = _coerce_str( - _get_value(data, ("code_delivery", "container_name_suffix"), None), - "_runner", - ) - code_delivery_command_timeout = _coerce_int( - _get_value( - data, ("code_delivery", "default_command_timeout_seconds"), None - ), - 600, - ) - code_delivery_max_command_output = _coerce_int( - _get_value(data, ("code_delivery", "max_command_output_chars"), None), - 20000, - ) - code_delivery_default_archive_format = _coerce_str( - _get_value(data, ("code_delivery", "default_archive_format"), None), - "zip", - ) - if code_delivery_default_archive_format not in ("zip", "tar.gz"): - code_delivery_default_archive_format = "zip" - code_delivery_max_archive_size_mb = _coerce_int( - _get_value(data, ("code_delivery", "max_archive_size_mb"), None), 200 - ) - code_delivery_cleanup_on_finish = _coerce_bool( - _get_value(data, ("code_delivery", "cleanup_on_finish"), None), True - ) - code_delivery_cleanup_on_start = _coerce_bool( - _get_value(data, ("code_delivery", "cleanup_on_start"), None), True - ) - code_delivery_llm_max_retries = _coerce_int( - _get_value(data, ("code_delivery", "llm_max_retries_per_request"), None), - 5, - ) - code_delivery_notify_on_llm_failure = _coerce_bool( - _get_value(data, ("code_delivery", "notify_on_llm_failure"), None), - True, - ) - code_delivery_container_memory_limit = _coerce_str( - _get_value(data, ("code_delivery", "container_memory_limit"), None), - "", - ) - code_delivery_container_cpu_limit = _coerce_str( - _get_value(data, ("code_delivery", "container_cpu_limit"), None), - "", - ) - code_delivery_command_blacklist_raw = _get_value( - data, ("code_delivery", "command_blacklist"), None - ) - if isinstance(code_delivery_command_blacklist_raw, list): - code_delivery_command_blacklist = [ - str(x) for x in code_delivery_command_blacklist_raw - ] - else: - code_delivery_command_blacklist = [] - - # messages 工具集配置 - messages_send_text_file_max_size_kb = _coerce_int( - _get_value( - data, - ("messages", "send_text_file_max_size_kb"), - "MESSAGES_SEND_TEXT_FILE_MAX_SIZE_KB", - ), - 512, - ) - if messages_send_text_file_max_size_kb <= 0: - messages_send_text_file_max_size_kb = 512 - - messages_send_url_file_max_size_mb = _coerce_int( - _get_value( - data, - ("messages", "send_url_file_max_size_mb"), - "MESSAGES_SEND_URL_FILE_MAX_SIZE_MB", - ), - 100, - ) - if messages_send_url_file_max_size_mb <= 0: - messages_send_url_file_max_size_mb = 100 - - webui_settings = load_webui_settings(config_path) - api_config = _parse_api_config(data) - - cognitive = _parse_cognitive_config(data) - memes = _parse_memes_config(data) - message_batcher = _parse_message_batcher_config(data) - render_cache = _parse_render_cache_config(data) - naga = _parse_naga_config(data) - models_image_gen = _parse_image_gen_model_config(data) - models_image_edit = _parse_image_edit_model_config(data) - image_gen = _parse_image_gen_config(data) - - if strict: - _verify_required_fields( - bot_qq=bot_qq, - superadmin_qq=superadmin_qq, - onebot_ws_url=onebot_ws_url, - chat_model=chat_model, - vision_model=vision_model, - agent_model=agent_model, - knowledge_enabled=knowledge_enabled, - embedding_model=embedding_model, - ) - - _log_debug_info( - chat_model, - vision_model, - security_model, - naga_model, - agent_model, - summary_model, - grok_model, - ) - - return cls( - bot_qq=bot_qq, - superadmin_qq=superadmin_qq, - admin_qqs=admin_qqs, - access_mode=access_mode, - allowed_group_ids=allowed_group_ids, - blocked_group_ids=blocked_group_ids, - allowed_private_ids=allowed_private_ids, - blocked_private_ids=blocked_private_ids, - superadmin_bypass_allowlist=superadmin_bypass_allowlist, - superadmin_bypass_private_blacklist=superadmin_bypass_private_blacklist, - forward_proxy_qq=forward_proxy_qq, - process_every_message=process_every_message, - process_private_message=process_private_message, - process_poke_message=process_poke_message, - keyword_reply_enabled=keyword_reply_enabled, - repeat_enabled=repeat_enabled, - repeat_threshold=repeat_threshold, - repeat_cooldown_minutes=repeat_cooldown_minutes, - inverted_question_enabled=inverted_question_enabled, - context_recent_messages_limit=context_recent_messages_limit, - ai_request_max_retries=ai_request_max_retries, - missing_tool_call_retries=missing_tool_call_retries, - nagaagent_mode_enabled=nagaagent_mode_enabled, - onebot_ws_url=onebot_ws_url, - onebot_token=onebot_token, - chat_model=chat_model, - vision_model=vision_model, - security_model_enabled=security_model_enabled, - security_model=security_model, - naga_model=naga_model, - agent_model=agent_model, - historian_model=historian_model, - summary_model=summary_model, - summary_model_configured=summary_model_configured, - grok_model=grok_model, - model_pool_enabled=model_pool_enabled, - log_level=log_level, - log_file_path=log_file_path, - log_max_size=log_max_size_mb * 1024 * 1024, - log_backup_count=log_backup_count, - log_tty_enabled=log_tty_enabled, - log_thinking=log_thinking, - tools_dot_delimiter=tools_dot_delimiter, - tools_description_truncate_enabled=tools_description_truncate_enabled, - tools_description_max_len=tools_description_max_len, - tools_sanitize_verbose=tools_sanitize_verbose, - tools_description_preview_len=tools_description_preview_len, - easter_egg_agent_call_message_mode=easter_egg_agent_call_message_mode, - token_usage_max_size_mb=token_usage_max_size_mb, - token_usage_max_archives=token_usage_max_archives, - token_usage_max_total_mb=token_usage_max_total_mb, - token_usage_archive_prune_mode=token_usage_archive_prune_mode, - skills_hot_reload=skills_hot_reload, - history_max_records=history_max_records, - history_filtered_result_limit=history_filtered_result_limit, - history_search_scan_limit=history_search_scan_limit, - history_summary_fetch_limit=history_summary_fetch_limit, - history_summary_time_fetch_limit=history_summary_time_fetch_limit, - history_onebot_fetch_limit=history_onebot_fetch_limit, - history_group_analysis_limit=history_group_analysis_limit, - attachment_remote_download_max_size_mb=attachment_remote_download_max_size_mb, - attachment_cache_max_total_size_mb=attachment_cache_max_total_size_mb, - attachment_cache_max_records=attachment_cache_max_records, - attachment_cache_max_age_days=attachment_cache_max_age_days, - attachment_url_reference_max_records=attachment_url_reference_max_records, - attachment_url_max_length=attachment_url_max_length, - skills_hot_reload_interval=skills_hot_reload_interval, - skills_hot_reload_debounce=skills_hot_reload_debounce, - agent_intro_autogen_enabled=agent_intro_autogen_enabled, - agent_intro_autogen_queue_interval=agent_intro_autogen_queue_interval, - agent_intro_autogen_max_tokens=agent_intro_autogen_max_tokens, - agent_intro_hash_path=agent_intro_hash_path, - searxng_url=searxng_url, - grok_search_enabled=grok_search_enabled, - use_proxy=use_proxy, - http_proxy=http_proxy, - https_proxy=https_proxy, - network_request_timeout=network_request_timeout, - network_request_retries=network_request_retries, - render_browser_max_concurrency=render_browser_max_concurrency, - api_xxapi_base_url=api_xxapi_base_url, - api_xingzhige_base_url=api_xingzhige_base_url, - api_jkyai_base_url=api_jkyai_base_url, - api_seniverse_base_url=api_seniverse_base_url, - weather_api_key=weather_api_key, - xxapi_api_token=xxapi_api_token, - mcp_config_path=mcp_config_path, - prefetch_tools=prefetch_tools, - prefetch_tools_hide=prefetch_tools_hide, - webui_url=webui_settings.url, - webui_port=webui_settings.port, - webui_password=webui_settings.password, - api=api_config, - code_delivery_enabled=code_delivery_enabled, - code_delivery_task_root=code_delivery_task_root, - code_delivery_docker_image=code_delivery_docker_image, - code_delivery_container_name_prefix=code_delivery_container_name_prefix, - code_delivery_container_name_suffix=code_delivery_container_name_suffix, - code_delivery_command_timeout=code_delivery_command_timeout, - code_delivery_max_command_output=code_delivery_max_command_output, - code_delivery_default_archive_format=code_delivery_default_archive_format, - code_delivery_max_archive_size_mb=code_delivery_max_archive_size_mb, - code_delivery_cleanup_on_finish=code_delivery_cleanup_on_finish, - code_delivery_cleanup_on_start=code_delivery_cleanup_on_start, - code_delivery_llm_max_retries=code_delivery_llm_max_retries, - code_delivery_notify_on_llm_failure=code_delivery_notify_on_llm_failure, - code_delivery_container_memory_limit=code_delivery_container_memory_limit, - code_delivery_container_cpu_limit=code_delivery_container_cpu_limit, - code_delivery_command_blacklist=code_delivery_command_blacklist, - messages_send_text_file_max_size_kb=messages_send_text_file_max_size_kb, - messages_send_url_file_max_size_mb=messages_send_url_file_max_size_mb, - bilibili_auto_extract_enabled=bilibili_auto_extract_enabled, - bilibili_cookie=bilibili_cookie, - bilibili_prefer_quality=bilibili_prefer_quality, - bilibili_max_duration=bilibili_max_duration, - bilibili_max_file_size=bilibili_max_file_size, - bilibili_oversize_strategy=bilibili_oversize_strategy, - bilibili_danmaku_enabled=bilibili_danmaku_enabled, - bilibili_danmaku_batch_size=bilibili_danmaku_batch_size, - bilibili_danmaku_max_count=bilibili_danmaku_max_count, - bilibili_auto_extract_group_ids=bilibili_auto_extract_group_ids, - bilibili_auto_extract_private_ids=bilibili_auto_extract_private_ids, - arxiv_auto_extract_enabled=arxiv_auto_extract_enabled, - arxiv_max_file_size=arxiv_max_file_size, - arxiv_auto_extract_group_ids=arxiv_auto_extract_group_ids, - arxiv_auto_extract_private_ids=arxiv_auto_extract_private_ids, - arxiv_auto_extract_max_items=arxiv_auto_extract_max_items, - arxiv_author_preview_limit=arxiv_author_preview_limit, - arxiv_summary_preview_chars=arxiv_summary_preview_chars, - github_auto_extract_enabled=github_auto_extract_enabled, - github_request_timeout_seconds=github_request_timeout_seconds, - github_auto_extract_group_ids=github_auto_extract_group_ids, - github_auto_extract_private_ids=github_auto_extract_private_ids, - github_auto_extract_max_items=github_auto_extract_max_items, - embedding_model=embedding_model, - rerank_model=rerank_model, - knowledge_enabled=knowledge_enabled, - knowledge_base_dir=knowledge_base_dir, - knowledge_auto_scan=knowledge_auto_scan, - knowledge_auto_embed=knowledge_auto_embed, - knowledge_scan_interval=knowledge_scan_interval, - knowledge_embed_batch_size=knowledge_embed_batch_size, - knowledge_chunk_size=knowledge_chunk_size, - knowledge_chunk_overlap=knowledge_chunk_overlap, - knowledge_default_top_k=knowledge_default_top_k, - knowledge_enable_rerank=knowledge_enable_rerank, - knowledge_rerank_top_k=knowledge_rerank_top_k, - cognitive=cognitive, - memes=memes, - message_batcher=message_batcher, - render_cache=render_cache, - naga=naga, - image_gen=image_gen, - models_image_gen=models_image_gen, - models_image_edit=models_image_edit, - ) - - @property - def bilibili_sessdata(self) -> str: - """兼容旧字段名,等价于 bilibili_cookie。""" - return self.bilibili_cookie - - def allowlist_mode_enabled(self) -> bool: - """是否启用白名单限制模式。""" - - return self.access_mode in {"allowlist", "legacy"} and ( - bool(self.allowed_group_ids) or bool(self.allowed_private_ids) - ) - - def group_allowlist_enabled(self) -> bool: - """群聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" - - return bool(self.allowed_group_ids) - - def private_allowlist_enabled(self) -> bool: - """私聊白名单是否生效(显式 allowlist 模式按维度独立控制)。""" - - return bool(self.allowed_private_ids) - - def blacklist_mode_enabled(self) -> bool: - """是否启用黑名单限制模式。""" - - return self.access_mode in {"blacklist", "legacy"} and ( - bool(self.blocked_group_ids) or bool(self.blocked_private_ids) - ) - - def access_control_enabled(self) -> bool: - """是否启用访问控制。""" - - return self.allowlist_mode_enabled() or self.blacklist_mode_enabled() - - def group_access_denied_reason(self, group_id: int) -> str | None: - """群聊访问被拒绝原因。 - - 返回: - - "blacklist": 命中 access.blocked_group_ids - - "allowlist": allowlist 模式下不在 access.allowed_group_ids - - None: 允许访问 - """ - - normalized_group_id = int(group_id) - if self.access_mode == "off": - return None - if self.access_mode == "blacklist": - if normalized_group_id in self._blocked_group_ids_set: - return "blacklist" - return None - if self.access_mode == "legacy": - if normalized_group_id in self._blocked_group_ids_set: - return "blacklist" - if not self.allowlist_mode_enabled(): - return None - if normalized_group_id not in self._allowed_group_ids_set: - return "allowlist" - return None - if not self.group_allowlist_enabled(): - return None - if normalized_group_id not in self._allowed_group_ids_set: - return "allowlist" - return None - - def is_group_allowed(self, group_id: int) -> bool: - """群聊是否允许收发消息。""" - - return self.group_access_denied_reason(group_id) is None - - def private_access_denied_reason(self, user_id: int) -> str | None: - """私聊访问被拒绝原因。""" - - normalized_user_id = int(user_id) - if self.access_mode == "off": - return None - if self.access_mode == "blacklist": - if normalized_user_id not in self._blocked_private_ids_set: - return None - if ( - self.superadmin_bypass_private_blacklist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - return "blacklist" - if self.access_mode == "legacy": - if normalized_user_id in self._blocked_private_ids_set: - if ( - self.superadmin_bypass_private_blacklist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - return "blacklist" - if not self.allowlist_mode_enabled(): - return None - if ( - self.superadmin_bypass_allowlist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - if normalized_user_id not in self._allowed_private_ids_set: - return "allowlist" - return None - if not self.private_allowlist_enabled(): - return None - if ( - self.superadmin_bypass_allowlist - and normalized_user_id == int(self.superadmin_qq) - and self.superadmin_qq > 0 - ): - return None - if normalized_user_id not in self._allowed_private_ids_set: - return "allowlist" - return None - - def is_private_allowed(self, user_id: int) -> bool: - """私聊是否允许收发消息。""" - - return self.private_access_denied_reason(user_id) is None - - def is_bilibili_auto_extract_allowed_group(self, group_id: int) -> bool: - """群聊是否允许 bilibili 自动提取。""" - if self._bilibili_group_ids_set: - return int(group_id) in self._bilibili_group_ids_set - # 功能白名单为空时跟随全局 access 控制 - return self.is_group_allowed(group_id) - - def is_bilibili_auto_extract_allowed_private(self, user_id: int) -> bool: - """私聊是否允许 bilibili 自动提取。""" - if self._bilibili_private_ids_set: - return int(user_id) in self._bilibili_private_ids_set - # 功能白名单为空时跟随全局 access 控制 - return self.is_private_allowed(user_id) - - def is_arxiv_auto_extract_allowed_group(self, group_id: int) -> bool: - """群聊是否允许 arXiv 自动提取。""" - if self._arxiv_group_ids_set: - return int(group_id) in self._arxiv_group_ids_set - return self.is_group_allowed(group_id) - - def is_arxiv_auto_extract_allowed_private(self, user_id: int) -> bool: - """私聊是否允许 arXiv 自动提取。""" - if self._arxiv_private_ids_set: - return int(user_id) in self._arxiv_private_ids_set - return self.is_private_allowed(user_id) - - def is_github_auto_extract_allowed_group(self, group_id: int) -> bool: - """群聊是否允许 GitHub 仓库自动提取。""" - if self._github_group_ids_set: - return int(group_id) in self._github_group_ids_set - return self.is_group_allowed(group_id) - - def is_github_auto_extract_allowed_private(self, user_id: int) -> bool: - """私聊是否允许 GitHub 仓库自动提取。""" - if self._github_private_ids_set: - return int(user_id) in self._github_private_ids_set - return self.is_private_allowed(user_id) - - def should_process_group_message(self, is_at_bot: bool) -> bool: - """是否处理该条群消息。""" - - if self.process_every_message: - return True - return bool(is_at_bot) - - def should_process_private_message(self) -> bool: - """是否处理私聊消息回复。""" - - return bool(self.process_private_message) - - def should_process_poke_message(self) -> bool: - """是否处理拍一拍触发。""" - - return bool(self.process_poke_message) - - def get_context_recent_messages_limit(self) -> int: - """获取上下文最近历史消息条数上限。""" - - limit = int(self.context_recent_messages_limit) - if limit < 0: - return 0 - return limit - - def security_check_enabled(self) -> bool: - """是否启用安全模型检查。""" - - return bool(self.security_model_enabled) - - def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: - changes: dict[str, tuple[Any, Any]] = {} - for field in fields(self): - name = field.name - old_value = getattr(self, name) - new_value = getattr(new_config, name) - if isinstance( - old_value, - ( - ChatModelConfig, - VisionModelConfig, - SecurityModelConfig, - AgentModelConfig, - GrokModelConfig, - ), - ): - changes.update(_update_dataclass(old_value, new_value, prefix=name)) - continue - if old_value != new_value: - setattr(self, name, new_value) - changes[name] = (old_value, new_value) - return changes - - def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: - new_config = Config.load(strict=strict) - return self.update_from(new_config) - - def add_admin(self, qq: int) -> bool: - if qq in self.admin_qqs: - return False - self.admin_qqs.append(qq) - local_admins = load_local_admins() - if qq not in local_admins: - local_admins.append(qq) - save_local_admins(local_admins) - return True - - def remove_admin(self, qq: int) -> bool: - if qq == self.superadmin_qq or qq not in self.admin_qqs: - return False - self.admin_qqs.remove(qq) - local_admins = load_local_admins() - if qq in local_admins: - local_admins.remove(qq) - save_local_admins(local_admins) - return True - - def is_superadmin(self, qq: int) -> bool: - return qq == self.superadmin_qq - - def is_admin(self, qq: int) -> bool: - return qq in self.admin_qqs diff --git a/src/Undefined/config/manager.py b/src/Undefined/config/manager.py index f2c4400b..f497f253 100644 --- a/src/Undefined/config/manager.py +++ b/src/Undefined/config/manager.py @@ -30,6 +30,10 @@ def load(self, strict: bool = True) -> Config: self._config = Config.load(config_path=self.config_path, strict=strict) return self._config + def replace(self, config: Config) -> None: + """替换缓存的配置实例(库嵌入 ``set_config`` 注入时使用)。""" + self._config = config + def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: if self._config is None: self._config = Config.load(config_path=self.config_path, strict=strict) diff --git a/src/Undefined/config/parsers/__init__.py b/src/Undefined/config/parsers/__init__.py new file mode 100644 index 00000000..faf5adeb --- /dev/null +++ b/src/Undefined/config/parsers/__init__.py @@ -0,0 +1,39 @@ +"""Model configuration parsers.""" + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass +from .agent import _parse_agent_model_config +from .chat import _parse_chat_model_config +from .embedding import _parse_embedding_model_config, _parse_rerank_model_config +from .grok import _parse_grok_model_config +from .helpers import _log_debug_info, _merge_admins, _verify_required_fields +from .historian import _parse_historian_model_config +from .image import ( + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, +) +from .naga import _parse_naga_model_config +from .pool import _parse_model_pool +from .security import _parse_security_model_config +from .summary import _parse_summary_model_config +from .vision import _parse_vision_model_config + +__all__ = [ + "_log_debug_info", + "_merge_admins", + "_parse_agent_model_config", + "_parse_chat_model_config", + "_parse_embedding_model_config", + "_parse_grok_model_config", + "_parse_historian_model_config", + "_parse_image_edit_model_config", + "_parse_image_gen_config", + "_parse_image_gen_model_config", + "_parse_model_pool", + "_parse_naga_model_config", + "_parse_rerank_model_config", + "_parse_security_model_config", + "_parse_summary_model_config", + "_parse_vision_model_config", + "_verify_required_fields", +] diff --git a/src/Undefined/config/parsers/agent.py b/src/Undefined/config/parsers/agent.py new file mode 100644 index 00000000..5170d181 --- /dev/null +++ b/src/Undefined/config/parsers/agent.py @@ -0,0 +1,163 @@ +"""Agent model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + AgentModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) +from .pool import _parse_model_pool + +logger = logging.getLogger(__name__) + + +def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "agent", "queue_interval_seconds"), + "AGENT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="agent", + include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "agent", "AGENT_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "agent", "AGENT_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "prompt_cache_enabled"), + "AGENT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "reasoning_enabled"), + "AGENT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "agent", "reasoning_effort"), + "AGENT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "stream_enabled"), + "AGENT_MODEL_STREAM_ENABLED", + ), + False, + ) + context_window_tokens = _resolve_context_window_tokens( + data, "agent", "AGENT_MODEL_CONTEXT_WINDOW_TOKENS" + ) + config = AgentModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" + ), + 4096, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "agent", "thinking_enabled"), + "AGENT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "agent", "thinking_budget_tokens"), + "AGENT_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "agent", "reasoning_effort_style"), + "AGENT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "agent"), + ) + config.pool = _parse_model_pool(data, "agent", config) + return config diff --git a/src/Undefined/config/parsers/chat.py b/src/Undefined/config/parsers/chat.py new file mode 100644 index 00000000..6f1028fa --- /dev/null +++ b/src/Undefined/config/parsers/chat.py @@ -0,0 +1,165 @@ +"""Chat model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + ChatModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) +from .pool import _parse_model_pool + +logger = logging.getLogger(__name__) + + +# 解析 [models.chat]:主对话模型 API、thinking/reasoning、队列间隔与模型池 +def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: + # 该模型独立的发车间隔(秒),0=立即发车 + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "chat", "queue_interval_seconds"), + "CHAT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + # DeepSeek/兼容模型的 thinking 预算与 tool_call 互斥开关 + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="chat", + include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + # OpenAI 兼容层:chat_completions / responses 及 reasoning 回放策略 + api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "chat", "CHAT_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "chat", "CHAT_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "prompt_cache_enabled"), + "CHAT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "reasoning_enabled"), + "CHAT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "chat", "reasoning_effort"), + "CHAT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "stream_enabled"), + "CHAT_MODEL_STREAM_ENABLED", + ), + False, + ) + context_window_tokens = _resolve_context_window_tokens( + data, "chat", "CHAT_MODEL_CONTEXT_WINDOW_TOKENS" + ) + config = ChatModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), + "", + ), + context_window_tokens=context_window_tokens, + max_tokens=_coerce_int( + _get_value(data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "chat", "thinking_enabled"), + "CHAT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "chat", "thinking_budget_tokens"), + "CHAT_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "chat", "reasoning_effort_style"), + "CHAT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "chat"), + ) + config.pool = _parse_model_pool(data, "chat", config) + return config diff --git a/src/Undefined/config/parsers/embedding.py b/src/Undefined/config/parsers/embedding.py new file mode 100644 index 00000000..8c5f3aa7 --- /dev/null +++ b/src/Undefined/config/parsers/embedding.py @@ -0,0 +1,106 @@ +"""Embedding model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + EmbeddingModelConfig, + RerankModelConfig, +) +from ..resolvers import ( + _resolve_context_window_tokens, +) + +logger = logging.getLogger(__name__) + + +def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: + return EmbeddingModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" + ), + "", + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + _get_value( + data, ("models", "embedding", "queue_interval_seconds"), None + ), + 0.0, + ), + 0.0, + ), + dimensions=_coerce_int( + _get_value(data, ("models", "embedding", "dimensions"), None), 0 + ) + or None, + query_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "query_instruction"), None), "" + ), + context_window_tokens=_resolve_context_window_tokens( + data, "embedding", "EMBEDDING_MODEL_CONTEXT_WINDOW_TOKENS" + ), + document_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "document_instruction"), None), + "", + ), + request_params=_get_model_request_params(data, "embedding"), + ) + + +def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), + 0.0, + ), + 0.0, + ) + return RerankModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + context_window_tokens=_resolve_context_window_tokens( + data, "rerank", "RERANK_MODEL_CONTEXT_WINDOW_TOKENS" + ), + query_instruction=_coerce_str( + _get_value(data, ("models", "rerank", "query_instruction"), None), "" + ), + request_params=_get_model_request_params(data, "rerank"), + ) diff --git a/src/Undefined/config/parsers/grok.py b/src/Undefined/config/parsers/grok.py new file mode 100644 index 00000000..efcbbadb --- /dev/null +++ b/src/Undefined/config/parsers/grok.py @@ -0,0 +1,129 @@ +"""Grok model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + GrokModelConfig, +) +from ..resolvers import ( + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, +) + +logger = logging.getLogger(__name__) + + +def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "grok", "queue_interval_seconds"), + "GROK_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + context_window_tokens = _resolve_context_window_tokens( + data, "grok", "GROK_MODEL_CONTEXT_WINDOW_TOKENS" + ) + return GrokModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS"), + 8192, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_enabled"), + "GROK_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "grok", "thinking_budget_tokens"), + "GROK_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_include_budget"), + "GROK_MODEL_THINKING_INCLUDE_BUDGET", + ), + True, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "grok", "reasoning_effort_style"), + "GROK_MODEL_REASONING_EFFORT_STYLE", + ), + ), + prompt_cache_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "prompt_cache_enabled"), + "GROK_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ), + reasoning_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "reasoning_enabled"), + "GROK_MODEL_REASONING_ENABLED", + ), + False, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + data, + ("models", "grok", "reasoning_effort"), + "GROK_MODEL_REASONING_EFFORT", + ), + "medium", + ), + stream_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "stream_enabled"), + "GROK_MODEL_STREAM_ENABLED", + ), + False, + ), + request_params=_get_model_request_params(data, "grok"), + ) diff --git a/src/Undefined/config/parsers/helpers.py b/src/Undefined/config/parsers/helpers.py new file mode 100644 index 00000000..19c82989 --- /dev/null +++ b/src/Undefined/config/parsers/helpers.py @@ -0,0 +1,122 @@ +"""Admin merge, validation, and debug helpers for model parsers.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging + +from ..admin import load_local_admins +from ..models import ( + AgentModelConfig, + ChatModelConfig, + EmbeddingModelConfig, + GrokModelConfig, + SecurityModelConfig, + VisionModelConfig, +) + +# 合并多来源配置 +logger = logging.getLogger(__name__) + + +# 合并多来源配置 +def _merge_admins(superadmin_qq: int, admin_qqs: list[int]) -> tuple[int, list[int]]: + # admins.json 与 config.toml 的 admin_qq 合并,去重后超管必在列表中 + local_admins = load_local_admins() + all_admins = list(set(admin_qqs + local_admins)) + if superadmin_qq and superadmin_qq not in all_admins: + all_admins.append(superadmin_qq) + # 校验必填字段 + return superadmin_qq, all_admins + + +# 校验必填字段 +def _verify_required_fields( + bot_qq: int, + superadmin_qq: int, + onebot_ws_url: str, + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + agent_model: AgentModelConfig, + knowledge_enabled: bool, + embedding_model: EmbeddingModelConfig, +) -> None: + missing: list[str] = [] + if bot_qq <= 0: + missing.append("core.bot_qq") + if superadmin_qq <= 0: + missing.append("core.superadmin_qq") + if not onebot_ws_url: + missing.append("onebot.ws_url") + if not chat_model.api_url: + missing.append("models.chat.api_url") + if not chat_model.api_key: + missing.append("models.chat.api_key") + if not chat_model.model_name: + missing.append("models.chat.model_name") + if not vision_model.api_url: + missing.append("models.vision.api_url") + if not vision_model.api_key: + missing.append("models.vision.api_key") + if not vision_model.model_name: + missing.append("models.vision.model_name") + if not agent_model.api_url: + missing.append("models.agent.api_url") + if not agent_model.api_key: + missing.append("models.agent.api_key") + if not agent_model.model_name: + missing.append("models.agent.model_name") + if knowledge_enabled: + if not embedding_model.api_url: + missing.append("models.embedding.api_url") + if not embedding_model.model_name: + missing.append("models.embedding.model_name") + if missing: + # 输出调试/诊断日志 + raise ValueError(f"缺少必需配置: {', '.join(missing)}") + + +# 输出调试/诊断日志 +def _log_debug_info( + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + security_model: SecurityModelConfig, + naga_model: SecurityModelConfig, + agent_model: AgentModelConfig, + summary_model: AgentModelConfig, + grok_model: GrokModelConfig, +) -> None: + configs: list[ + tuple[ + str, + ChatModelConfig + | VisionModelConfig + | SecurityModelConfig + | AgentModelConfig + | GrokModelConfig, + ] + ] = [ + ("chat", chat_model), + ("vision", vision_model), + ("security", security_model), + ("naga", naga_model), + ("agent", agent_model), + ("summary", summary_model), + ("grok", grok_model), + ] + for name, cfg in configs: + logger.debug( + "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", + name, + cfg.model_name, + cfg.api_url, + bool(cfg.api_key), + getattr(cfg, "api_mode", "chat_completions"), + cfg.thinking_enabled, + getattr(cfg, "reasoning_enabled", False), + getattr(cfg, "reasoning_effort", "medium"), + getattr(cfg, "thinking_tool_call_compat", False), + getattr(cfg, "responses_tool_choice_compat", False), + getattr(cfg, "responses_force_stateless_replay", False), + ) diff --git a/src/Undefined/config/parsers/historian.py b/src/Undefined/config/parsers/historian.py new file mode 100644 index 00000000..74e06bf4 --- /dev/null +++ b/src/Undefined/config/parsers/historian.py @@ -0,0 +1,144 @@ +"""Historian model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + AgentModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_historian_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> AgentModelConfig: + h = data.get("models", {}).get("historian", {}) + if not isinstance(h, dict) or not h: + return fallback + queue_interval_seconds = _coerce_float( + h.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"historian": h}}, + model_name="historian", + include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "prompt_cache_enabled"), + "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + context_window_tokens = _coerce_int( + h.get("context_window_tokens"), fallback.context_window_tokens + ) + return AgentModelConfig( + api_url=_coerce_str(h.get("api_url"), fallback.api_url), + api_key=_coerce_str(h.get("api_key"), fallback.api_key), + model_name=_coerce_str(h.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + h.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort_style"), + "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=_coerce_bool( + h.get("reasoning_content_replay"), fallback.reasoning_content_replay + ), + system_prompt_as_user=_coerce_bool( + h.get("system_prompt_as_user"), fallback.system_prompt_as_user + ), + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_enabled"), + "HISTORIAN_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort"), + "HISTORIAN_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + stream_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "stream_enabled"), + "HISTORIAN_MODEL_STREAM_ENABLED", + ), + fallback.stream_enabled, + ), + request_params=merge_request_params( + fallback.request_params, + h.get("request_params"), + ), + ) diff --git a/src/Undefined/config/parsers/image.py b/src/Undefined/config/parsers/image.py new file mode 100644 index 00000000..cccf8460 --- /dev/null +++ b/src/Undefined/config/parsers/image.py @@ -0,0 +1,120 @@ +"""Image model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, +) +from ..models import ( + ImageGenConfig, + ImageGenModelConfig, +) + +logger = logging.getLogger(__name__) + + +def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_gen] 生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" + ), + "", + ), + context_window_tokens=_coerce_int( + _get_value( + data, + ("models", "image_gen", "context_window_tokens"), + None, + ), + 0, + ), + request_params=_get_model_request_params(data, "image_gen"), + ) + + +def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_edit] 参考图生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_url"), + "IMAGE_EDIT_MODEL_API_URL", + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_key"), + "IMAGE_EDIT_MODEL_API_KEY", + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, + ("models", "image_edit", "model_name"), + "IMAGE_EDIT_MODEL_NAME", + ), + "", + ), + context_window_tokens=_coerce_int( + _get_value( + data, + ("models", "image_edit", "context_window_tokens"), + None, + ), + 0, + ), + request_params=_get_model_request_params(data, "image_edit"), + ) + + +def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: + """解析 [image_gen] 生图工具配置""" + return ImageGenConfig( + provider=_coerce_str( + _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), + "xingzhige", + ), + xingzhige_size=_coerce_str( + _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" + ), + openai_size=_coerce_str( + _get_value(data, ("image_gen", "openai_size"), None), "" + ), + openai_quality=_coerce_str( + _get_value(data, ("image_gen", "openai_quality"), None), "" + ), + openai_style=_coerce_str( + _get_value(data, ("image_gen", "openai_style"), None), "" + ), + openai_timeout=_coerce_float( + _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 + ), + ) diff --git a/src/Undefined/config/parsers/naga.py b/src/Undefined/config/parsers/naga.py new file mode 100644 index 00000000..920862e0 --- /dev/null +++ b/src/Undefined/config/parsers/naga.py @@ -0,0 +1,206 @@ +"""Naga model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + SecurityModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_naga_model_config( + data: dict[str, Any], security_model: SecurityModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "naga", "queue_interval_seconds"), + "NAGA_MODEL_QUEUE_INTERVAL", + ), + security_model.queue_interval_seconds, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="naga", + include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, + "naga", + "NAGA_MODEL_REASONING_CONTENT_REPLAY", + default=security_model.reasoning_content_replay, + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, + "naga", + "NAGA_MODEL_SYSTEM_PROMPT_AS_USER", + default=security_model.system_prompt_as_user, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "prompt_cache_enabled"), + "NAGA_MODEL_PROMPT_CACHE_ENABLED", + ), + getattr(security_model, "prompt_cache_enabled", True), + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "reasoning_enabled"), + "NAGA_MODEL_REASONING_ENABLED", + ), + getattr(security_model, "reasoning_enabled", False), + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "naga", "reasoning_effort"), + "NAGA_MODEL_REASONING_EFFORT", + ), + getattr(security_model, "reasoning_effort", "medium"), + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "stream_enabled"), + "NAGA_MODEL_STREAM_ENABLED", + ), + getattr(security_model, "stream_enabled", False), + ) + + if api_url and api_key and model_name: + context_window_tokens = _resolve_context_window_tokens( + data, + "naga", + "NAGA_MODEL_CONTEXT_WINDOW_TOKENS", + default=security_model.context_window_tokens, + ) + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "max_tokens"), + "NAGA_MODEL_MAX_TOKENS", + ), + 160, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "naga", "thinking_enabled"), + "NAGA_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "thinking_budget_tokens"), + "NAGA_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "naga", "reasoning_effort_style"), + "NAGA_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "naga"), + ) + + logger.info( + "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" + ) + return SecurityModelConfig( + api_url=security_model.api_url, + api_key=security_model.api_key, + model_name=security_model.model_name, + max_tokens=security_model.max_tokens, + context_window_tokens=security_model.context_window_tokens, + queue_interval_seconds=security_model.queue_interval_seconds, + api_mode=security_model.api_mode, + thinking_enabled=security_model.thinking_enabled, + thinking_budget_tokens=security_model.thinking_budget_tokens, + thinking_include_budget=security_model.thinking_include_budget, + reasoning_effort_style=security_model.reasoning_effort_style, + thinking_tool_call_compat=security_model.thinking_tool_call_compat, + reasoning_content_replay=security_model.reasoning_content_replay, + system_prompt_as_user=security_model.system_prompt_as_user, + responses_tool_choice_compat=security_model.responses_tool_choice_compat, + responses_force_stateless_replay=security_model.responses_force_stateless_replay, + prompt_cache_enabled=security_model.prompt_cache_enabled, + reasoning_enabled=security_model.reasoning_enabled, + reasoning_effort=security_model.reasoning_effort, + stream_enabled=security_model.stream_enabled, + request_params=merge_request_params(security_model.request_params), + ) diff --git a/src/Undefined/config/parsers/pool.py b/src/Undefined/config/parsers/pool.py new file mode 100644 index 00000000..72abcce4 --- /dev/null +++ b/src/Undefined/config/parsers/pool.py @@ -0,0 +1,142 @@ +"""Model pool parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _normalize_queue_interval, + _VALID_API_MODES, +) +from ..models import AgentModelConfig, ChatModelConfig, ModelPool, ModelPoolEntry +from ..resolvers import ( + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, +) + +logger = logging.getLogger(__name__) + + +def _parse_model_pool( + data: dict[str, Any], + model_section: str, + primary_config: ChatModelConfig | AgentModelConfig, +) -> ModelPool | None: + """解析模型池配置,缺省字段继承 primary_config""" + pool_data = data.get("models", {}).get(model_section, {}).get("pool") + if not isinstance(pool_data, dict): + return None + + enabled = _coerce_bool(pool_data.get("enabled"), False) + strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() + if strategy not in ("default", "round_robin", "random"): + strategy = "default" + + raw_models = pool_data.get("models") + if not isinstance(raw_models, list): + return ModelPool(enabled=enabled, strategy=strategy) + + entries: list[ModelPoolEntry] = [] + for item in raw_models: + if not isinstance(item, dict): + continue + name = _coerce_str(item.get("model_name"), "").strip() + if not name: + continue + entries.append( + ModelPoolEntry( + api_url=_coerce_str(item.get("api_url"), primary_config.api_url), + api_key=_coerce_str(item.get("api_key"), primary_config.api_key), + model_name=name, + context_window_tokens=_coerce_int( + item.get("context_window_tokens"), + primary_config.context_window_tokens, + ), + max_tokens=_coerce_int( + item.get("max_tokens"), primary_config.max_tokens + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + item.get("queue_interval_seconds"), + primary_config.queue_interval_seconds, + ), + primary_config.queue_interval_seconds, + ), + api_mode=( + _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + ) + if _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + in _VALID_API_MODES + else primary_config.api_mode, + thinking_enabled=_coerce_bool( + item.get("thinking_enabled"), primary_config.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + item.get("thinking_budget_tokens"), + primary_config.thinking_budget_tokens, + ), + thinking_include_budget=_coerce_bool( + item.get("thinking_include_budget"), + primary_config.thinking_include_budget, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + item.get("reasoning_effort_style"), + primary_config.reasoning_effort_style, + ), + thinking_tool_call_compat=_coerce_bool( + item.get("thinking_tool_call_compat"), + primary_config.thinking_tool_call_compat, + ), + reasoning_content_replay=_coerce_bool( + item.get("reasoning_content_replay"), + primary_config.reasoning_content_replay, + ), + system_prompt_as_user=_coerce_bool( + item.get("system_prompt_as_user"), + primary_config.system_prompt_as_user, + ), + responses_tool_choice_compat=_coerce_bool( + item.get("responses_tool_choice_compat"), + primary_config.responses_tool_choice_compat, + ), + responses_force_stateless_replay=_coerce_bool( + item.get("responses_force_stateless_replay"), + primary_config.responses_force_stateless_replay, + ), + prompt_cache_enabled=_coerce_bool( + item.get("prompt_cache_enabled"), + primary_config.prompt_cache_enabled, + ), + reasoning_enabled=_coerce_bool( + item.get("reasoning_enabled"), + primary_config.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + item.get("reasoning_effort"), + primary_config.reasoning_effort, + ), + stream_enabled=_coerce_bool( + item.get("stream_enabled"), + getattr(primary_config, "stream_enabled", False), + ), + request_params=merge_request_params( + primary_config.request_params, + item.get("request_params"), + ), + ) + ) + + return ModelPool(enabled=enabled, strategy=strategy, models=entries) diff --git a/src/Undefined/config/parsers/security.py b/src/Undefined/config/parsers/security.py new file mode 100644 index 00000000..f5044a91 --- /dev/null +++ b/src/Undefined/config/parsers/security.py @@ -0,0 +1,196 @@ +"""Security model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + ChatModelConfig, + SecurityModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_security_model_config( + data: dict[str, Any], chat_model: ChatModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "security", "queue_interval_seconds"), + "SECURITY_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="security", + include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "security", "SECURITY_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "security", "SECURITY_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "prompt_cache_enabled"), + "SECURITY_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "reasoning_enabled"), + "SECURITY_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "security", "reasoning_effort"), + "SECURITY_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "stream_enabled"), + "SECURITY_MODEL_STREAM_ENABLED", + ), + False, + ) + + context_window_tokens = _resolve_context_window_tokens( + data, "security", "SECURITY_MODEL_CONTEXT_WINDOW_TOKENS" + ) + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "max_tokens"), + "SECURITY_MODEL_MAX_TOKENS", + ), + 100, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "security", "thinking_enabled"), + "SECURITY_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "thinking_budget_tokens"), + "SECURITY_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "security", "reasoning_effort_style"), + "SECURITY_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "security"), + ) + + logger.warning("未配置安全模型,将使用对话模型作为后备") + return SecurityModelConfig( + api_url=chat_model.api_url, + api_key=chat_model.api_key, + model_name=chat_model.model_name, + context_window_tokens=chat_model.context_window_tokens, + max_tokens=chat_model.max_tokens, + queue_interval_seconds=chat_model.queue_interval_seconds, + api_mode=chat_model.api_mode, + thinking_enabled=False, + thinking_budget_tokens=0, + thinking_include_budget=True, + reasoning_effort_style="openai", + thinking_tool_call_compat=chat_model.thinking_tool_call_compat, + reasoning_content_replay=chat_model.reasoning_content_replay, + system_prompt_as_user=chat_model.system_prompt_as_user, + responses_tool_choice_compat=chat_model.responses_tool_choice_compat, + responses_force_stateless_replay=chat_model.responses_force_stateless_replay, + prompt_cache_enabled=chat_model.prompt_cache_enabled, + reasoning_enabled=chat_model.reasoning_enabled, + reasoning_effort=chat_model.reasoning_effort, + stream_enabled=chat_model.stream_enabled, + request_params=merge_request_params(chat_model.request_params), + ) diff --git a/src/Undefined/config/parsers/summary.py b/src/Undefined/config/parsers/summary.py new file mode 100644 index 00000000..dd7757dc --- /dev/null +++ b/src/Undefined/config/parsers/summary.py @@ -0,0 +1,147 @@ +"""Summary model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + AgentModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_summary_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> tuple[AgentModelConfig, bool]: + s = data.get("models", {}).get("summary", {}) + if not isinstance(s, dict) or not s: + return fallback, False + queue_interval_seconds = _coerce_float( + s.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"summary": s}}, + model_name="summary", + include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "prompt_cache_enabled"), + "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + context_window_tokens = _coerce_int( + s.get("context_window_tokens"), fallback.context_window_tokens + ) + return ( + AgentModelConfig( + api_url=_coerce_str(s.get("api_url"), fallback.api_url), + api_key=_coerce_str(s.get("api_key"), fallback.api_key), + model_name=_coerce_str(s.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + s.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort_style"), + "SUMMARY_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=_coerce_bool( + s.get("reasoning_content_replay"), fallback.reasoning_content_replay + ), + system_prompt_as_user=_coerce_bool( + s.get("system_prompt_as_user"), fallback.system_prompt_as_user + ), + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_enabled"), + "SUMMARY_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort"), + "SUMMARY_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + stream_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "stream_enabled"), + "SUMMARY_MODEL_STREAM_ENABLED", + ), + fallback.stream_enabled, + ), + request_params=merge_request_params( + fallback.request_params, + s.get("request_params"), + ), + ), + True, + ) diff --git a/src/Undefined/config/parsers/vision.py b/src/Undefined/config/parsers/vision.py new file mode 100644 index 00000000..9299c980 --- /dev/null +++ b/src/Undefined/config/parsers/vision.py @@ -0,0 +1,162 @@ +"""Vision model parser.""" + +from __future__ import annotations + +# 模型配置解析:原始 dict → ChatModelConfig 等 dataclass + +import logging +from typing import Any + + +from ..coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, +) +from ..models import ( + VisionModelConfig, +) +from ..resolvers import ( + _resolve_api_mode, + _resolve_context_window_tokens, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_reasoning_content_replay, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_system_prompt_as_user, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "vision", "queue_interval_seconds"), + "VISION_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="vision", + include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + reasoning_content_replay = _resolve_reasoning_content_replay( + data, "vision", "VISION_MODEL_REASONING_CONTENT_REPLAY" + ) + system_prompt_as_user = _resolve_system_prompt_as_user( + data, "vision", "VISION_MODEL_SYSTEM_PROMPT_AS_USER" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "prompt_cache_enabled"), + "VISION_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "reasoning_enabled"), + "VISION_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "vision", "reasoning_effort"), + "VISION_MODEL_REASONING_EFFORT", + ), + "medium", + ) + stream_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "stream_enabled"), + "VISION_MODEL_STREAM_ENABLED", + ), + False, + ) + context_window_tokens = _resolve_context_window_tokens( + data, "vision", "VISION_MODEL_CONTEXT_WINDOW_TOKENS" + ) + return VisionModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "max_tokens"), + "VISION_MODEL_MAX_TOKENS", + ), + 8192, + ), + context_window_tokens=context_window_tokens, + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "vision", "thinking_enabled"), + "VISION_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "thinking_budget_tokens"), + "VISION_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "vision", "reasoning_effort_style"), + "VISION_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + reasoning_content_replay=reasoning_content_replay, + system_prompt_as_user=system_prompt_as_user, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + stream_enabled=stream_enabled, + request_params=_get_model_request_params(data, "vision"), + ) diff --git a/src/Undefined/config/toml_io.py b/src/Undefined/config/toml_io.py new file mode 100644 index 00000000..b9c96d46 --- /dev/null +++ b/src/Undefined/config/toml_io.py @@ -0,0 +1,110 @@ +"""TOML file I/O and environment bootstrap for configuration.""" + +from __future__ import annotations + +# 配置 I/O:读取 config.toml、加载 .env bootstrap、格式化解析错误 + +import logging +import os +import re +import tomllib +from pathlib import Path +from typing import Any, IO, Optional + +try: + from dotenv import load_dotenv +except Exception: # pragma: no cover + StrPath = str | os.PathLike[str] + + def load_dotenv( + dotenv_path: StrPath | None = None, + stream: IO[str] | None = None, + verbose: bool = False, + override: bool = False, + interpolate: bool = True, + encoding: str | None = "utf-8", + ) -> bool: + return False + + +logger = logging.getLogger(__name__) + +CONFIG_PATH = Path("config.toml") + +__all__ = ["CONFIG_PATH", "_load_env", "load_toml_data"] + + +def _load_env() -> None: + # dotenv 仅 bootstrap 环境变量,不覆盖已有 os.environ(TOML 仍优先于 env) + try: + load_dotenv() + except Exception: + logger.debug("加载 .env 失败,继续使用 config.toml", exc_info=True) + + +def _build_toml_decode_hint(line: str) -> str: + """根据出错行内容生成 TOML 修复提示。""" + hints: list[str] = [] + if "\\" in line: + hints.append( + 'Windows 路径建议用单引号(不转义)或双反斜杠,或直接用正斜杠,例如:path = \'D:\\AI\\bot\' / path = "D:\\\\AI\\\\bot" / path = "D:/AI/bot"' + ) + hints.append('多行文本请用三引号,例如:prompt = """..."""') + return ";".join(hints) + + +def _format_toml_decode_error( + path: Path, text: str, exc: tomllib.TOMLDecodeError +) -> str: + """将 tomllib 解析异常格式化为带行号、caret 与中文提示的可读消息。""" + lineno: int | None = getattr(exc, "lineno", None) + colno: int | None = getattr(exc, "colno", None) + if not isinstance(lineno, int) or not isinstance(colno, int): + match = re.search(r"\(at line (\d+), column (\d+)\)", str(exc)) + if match: + lineno = int(match.group(1)) + colno = int(match.group(2)) + + if isinstance(lineno, int) and lineno > 0: + lines = text.splitlines() + line = lines[lineno - 1] if 0 <= (lineno - 1) < len(lines) else "" + caret_pos = max((colno or 1) - 1, 0) + caret = " " * min(caret_pos, len(line)) + "^" + hint = _build_toml_decode_hint(line) + location = f"line={lineno} col={colno or 1}" + return f"{exc} ({location})\n> {line}\n {caret}\n提示:{hint}" + return str(exc) + + +def load_toml_data( + config_path: Optional[Path] = None, *, strict: bool = False +) -> dict[str, Any]: + """读取 config.toml 并返回字典;文件不存在时返回空 dict。""" + path = config_path or CONFIG_PATH + if not path.exists(): + return {} + text = "" + try: + # utf-8-sig 兼容带 BOM 的编辑器输出 + text = path.read_bytes().decode("utf-8-sig") + data = tomllib.loads(text) + if isinstance(data, dict): + return data + logger.warning("config.toml 内容不是对象结构") + return {} + except tomllib.TOMLDecodeError as exc: + message = _format_toml_decode_error(path, text, exc) + logger.error("config.toml 解析失败 (%s): %s", path.resolve(), message) + if strict: + raise ValueError(message) from exc + return {} + except UnicodeDecodeError as exc: + logger.error("config.toml 编码错误 (%s): %s", path.resolve(), exc) + if strict: + raise ValueError(str(exc)) from exc + return {} + except OSError as exc: + logger.error("读取 config.toml 失败: %s", exc) + if strict: + raise ValueError(str(exc)) from exc + return {} diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py deleted file mode 100644 index a793ae44..00000000 --- a/src/Undefined/handlers.py +++ /dev/null @@ -1,1399 +0,0 @@ -"""消息处理和命令分发""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass -import logging -import os -from pathlib import Path -import random -import time -from typing import Any, Coroutine, Literal - -from Undefined.attachments import ( - append_attachment_text, - build_attachment_scope, - register_message_attachments, -) -from Undefined.ai import AIClient -from Undefined.config import Config -from Undefined.faq import FAQStorage -from Undefined.rate_limit import RateLimiter -from Undefined.services.queue_manager import QueueManager -from Undefined.onebot import ( - OneBotClient, - get_message_content, - get_message_sender_id, -) -from Undefined.utils.common import ( - extract_text, - parse_message_content_for_history, - matches_xinliweiyuan, -) -from Undefined.utils.fake_at import BotNicknameCache, strip_fake_at -from Undefined.utils.history import MessageHistoryManager -from Undefined.utils.scheduler import TaskScheduler -from Undefined.utils.sender import MessageSender -from Undefined.services.security import SecurityService -from Undefined.services.command import CommandDispatcher -from Undefined.services.ai_coordinator import AICoordinator -from Undefined.services.message_batcher import MessageBatcher, make_scope -from Undefined.services.model_pool import ModelPoolService -from Undefined.skills.pipelines import PipelineRegistry -from Undefined.skills.pipelines.context import build_pipeline_context -from Undefined.utils.resources import resolve_resource_path -from Undefined.utils.queue_intervals import build_model_queue_intervals - -from Undefined.scheduled_task_storage import ScheduledTaskStorage -from Undefined.utils.logging import log_debug_json, redact_string -from Undefined.utils.coerce import safe_int - -logger = logging.getLogger(__name__) - -KEYWORD_REPLY_HISTORY_PREFIX = "[系统关键词自动回复] " -REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " - - -def _is_private_model_pool_control_text(text: str) -> bool: - return ModelPoolService.is_private_control_text(text) - - -def _format_poke_history_text(display_name: str, user_id: int) -> str: - """格式化拍一拍历史文本。""" - return f"{display_name}(暱称)[{user_id}(QQ号)] 拍了拍你。" - - -@dataclass(frozen=True) -class PrivatePokeRecord: - poke_text: str - sender_name: str - - -@dataclass(frozen=True) -class GroupPokeRecord: - poke_text: str - sender_name: str - group_name: str - sender_role: str - sender_title: str - sender_level: str - - -class MessageHandler: - """消息处理器""" - - def __init__( - self, - config: Config, - onebot: OneBotClient, - ai: AIClient, - faq_storage: FAQStorage, - task_storage: ScheduledTaskStorage, - ) -> None: - self.config = config - self.onebot = onebot - self.ai = ai - self.faq_storage = faq_storage - # 初始化工具组件 - self.history_manager = MessageHistoryManager(config.history_max_records) - self.sender = MessageSender( - onebot, - self.history_manager, - config.bot_qq, - config, - attachment_registry=getattr(ai, "attachment_registry", None), - ) - - # 初始化服务 - self.security = SecurityService(config, ai._http_client) - self.rate_limiter = RateLimiter(config) - self.queue_manager = QueueManager( - max_retries=config.ai_request_max_retries, - ) - self.queue_manager.update_model_intervals(build_model_queue_intervals(config)) - - # 设置队列管理器到 AIClient(触发 Agent 介绍生成器启动) - ai.set_queue_manager(self.queue_manager) - - self.command_dispatcher = CommandDispatcher( - config, - self.sender, - ai, - faq_storage, - onebot, - self.security, - queue_manager=self.queue_manager, - rate_limiter=self.rate_limiter, - history_manager=self.history_manager, - ) - self.ai_coordinator = AICoordinator( - config, - ai, - self.queue_manager, - self.history_manager, - self.sender, - onebot, - TaskScheduler(ai, self.sender, onebot, self.history_manager, task_storage), - self.security, - command_dispatcher=self.command_dispatcher, - ) - - # 同 sender 短时多消息合并器;coordinator 决定是否旁路 - self.message_batcher = MessageBatcher( - config.message_batcher, - flush_callback=self.ai_coordinator.handle_batched_dispatch, - ) - self.ai_coordinator.set_batcher(self.message_batcher) - - self._background_tasks: set[asyncio.Task[None]] = set() - self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} - self._bot_nickname_cache = BotNicknameCache(onebot, config.bot_qq) - self.pipeline_registry = PipelineRegistry() - self._pipelines_initialized = False - self._pipelines_init_lock = asyncio.Lock() - - # 复读功能状态(按群跟踪最近消息文本与发送者) - self._repeat_counter: dict[int, list[tuple[str, int]]] = {} - self._repeat_locks: dict[int, asyncio.Lock] = {} - # 复读冷却:group_id → {normalized_text → monotonic_timestamp} - self._repeat_cooldown: dict[int, dict[str, float]] = {} - - # 启动队列 - self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) - - async def initialize(self) -> None: - """完成需要事件循环承载的异步初始化。""" - await self.init_pipelines() - - async def init_pipelines(self) -> None: - """异步加载自动处理管线并按配置启动热重载。""" - if getattr(self, "_pipelines_initialized", False): - return - init_lock = getattr(self, "_pipelines_init_lock", None) - if init_lock is None: - init_lock = asyncio.Lock() - self._pipelines_init_lock = init_lock - async with init_lock: - if getattr(self, "_pipelines_initialized", False): - return - await self.pipeline_registry.load_items_async() - self._pipelines_initialized = True - if getattr(self.config, "skills_hot_reload", False): - self.pipeline_registry.start_hot_reload( - interval=self.config.skills_hot_reload_interval, - debounce=self.config.skills_hot_reload_debounce, - ) - - def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: - """获取或创建指定群的复读竞态保护锁。""" - lock = self._repeat_locks.get(group_id) - if lock is None: - lock = asyncio.Lock() - self._repeat_locks[group_id] = lock - return lock - - @staticmethod - def _normalize_repeat_text(text: str) -> str: - """规范化复读文本用于冷却比较(?→?)。""" - return text.replace("?", "?") - - def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: - """检查指定群的文本是否在复读冷却期内。""" - cooldown_minutes = self.config.repeat_cooldown_minutes - if cooldown_minutes <= 0: - return False - group_cd = self._repeat_cooldown.get(group_id) - if not group_cd: - return False - key = self._normalize_repeat_text(text) - last_time = group_cd.get(key) - if last_time is None: - return False - return (time.monotonic() - last_time) < cooldown_minutes * 60 - - def _record_repeat_cooldown(self, group_id: int, text: str) -> None: - """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" - cooldown_seconds = self.config.repeat_cooldown_minutes * 60 - if cooldown_seconds <= 0: - return - key = self._normalize_repeat_text(text) - group_cd = self._repeat_cooldown.setdefault(group_id, {}) - now = time.monotonic() - # 清理已过期条目 - expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] - for k in expired: - del group_cd[k] - group_cd[key] = now - - async def _annotate_meme_descriptions( - self, - attachments: list[dict[str, str]], - scope_key: str, - ) -> list[dict[str, str]]: - """为图片附件添加表情包描述(如果在表情库中找到)。 - - 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 - 最佳努力:任何失败时返回原始列表。 - """ - if not attachments: - return attachments - - ai_client = getattr(self, "ai", None) - if ai_client is None: - return attachments - - attachment_registry = getattr(ai_client, "attachment_registry", None) - if attachment_registry is None: - return attachments - - meme_service = getattr(ai_client, "_meme_service", None) - if meme_service is None or not getattr(meme_service, "enabled", False): - return attachments - - meme_store = getattr(meme_service, "_store", None) - if meme_store is None: - return attachments - - try: - # 1. 从图片附件收集唯一的 SHA256 哈希值 - uid_to_hash: dict[str, str] = {} - for att in attachments: - uid = att.get("uid", "") - if not uid.startswith("pic_"): - continue - record = attachment_registry.resolve(uid, scope_key) - if record and record.sha256: - uid_to_hash[uid] = record.sha256 - - if not uid_to_hash: - return attachments - - # 2. 批量查询:去重哈希值 - unique_hashes = set(uid_to_hash.values()) - hash_to_desc: dict[str, str] = {} - for h in unique_hashes: - meme = await meme_store.find_by_sha256(h) - if meme and meme.description: - hash_to_desc[h] = meme.description - - if not hash_to_desc: - return attachments - - # 3. 构建带注释的新列表 - result: list[dict[str, str]] = [] - for att in attachments: - uid = att.get("uid", "") - sha = uid_to_hash.get(uid, "") - desc = hash_to_desc.get(sha, "") - if desc: - new_att = dict(att) - new_att["description"] = f"[表情包] {desc}" - result.append(new_att) - else: - result.append(att) - return result - except Exception: - logger.warning("表情包自动匹配失败,跳过", exc_info=True) - return attachments - - async def _collect_message_attachments( - self, - message_content: list[dict[str, Any]], - *, - group_id: int | None = None, - user_id: int | None = None, - request_type: str, - ) -> list[dict[str, str]]: - scope_key = build_attachment_scope( - group_id=group_id, - user_id=user_id, - request_type=request_type, - ) - if not scope_key: - return [] - ai_client = getattr(self, "ai", None) - attachment_registry = ( - getattr(ai_client, "attachment_registry", None) if ai_client else None - ) - if attachment_registry is None: - return [] - onebot = getattr(self, "onebot", None) - resolve_image_url = getattr(onebot, "get_image", None) if onebot else None - result = await register_message_attachments( - registry=attachment_registry, - segments=message_content, - scope_key=scope_key, - resolve_image_url=resolve_image_url, - get_forward_messages=getattr(onebot, "get_forward_msg", None) - if onebot - else None, - ) - attachments = result.attachments - # 为图片附件添加表情包描述 - attachments = await self._annotate_meme_descriptions(attachments, scope_key) - return attachments - - def _schedule_meme_ingest( - self, - *, - attachments: list[dict[str, str]], - chat_type: str, - chat_id: int, - sender_id: int, - message_id: int | None, - scope_key: str | None, - ) -> None: - if not attachments or not scope_key: - return - meme_service = getattr(self.ai, "_meme_service", None) - if meme_service is None or not getattr(meme_service, "enabled", False): - return - self._spawn_background_task( - f"meme_ingest:{chat_type}:{chat_id}:{sender_id}:{message_id or 0}", - meme_service.enqueue_incoming_attachments( - attachments=attachments, - chat_type=chat_type, - chat_id=chat_id, - sender_id=sender_id, - message_id=message_id, - scope_key=scope_key, - ), - ) - - async def _refresh_profile_display_names( - self, - *, - sender_id: int | None = None, - sender_name: str = "", - group_id: int | None = None, - group_name: str = "", - ) -> None: - ai_client = getattr(self, "ai", None) - cognitive_service = getattr(ai_client, "_cognitive_service", None) - if not cognitive_service or not getattr(cognitive_service, "enabled", False): - return - - if sender_id and sender_name.strip(): - await cognitive_service.sync_profile_display_name( - entity_type="user", - entity_id=str(sender_id), - preferred_name=sender_name.strip(), - ) - if group_id and group_name.strip(): - await cognitive_service.sync_profile_display_name( - entity_type="group", - entity_id=str(group_id), - preferred_name=group_name.strip(), - ) - - def _can_refresh_profile_display_names(self) -> bool: - ai_client = getattr(self, "ai", None) - cognitive_service = getattr(ai_client, "_cognitive_service", None) - return bool(cognitive_service and getattr(cognitive_service, "enabled", False)) - - def _schedule_profile_display_name_refresh( - self, - *, - task_name: str, - sender_id: int | None = None, - sender_name: str = "", - group_id: int | None = None, - group_name: str = "", - ) -> None: - if not self._can_refresh_profile_display_names(): - return - - cache = getattr(self, "_profile_name_refresh_cache", None) - if cache is None: - cache = {} - self._profile_name_refresh_cache = cache - - updates: dict[str, Any] = {} - rollback: list[tuple[tuple[str, int], str | None]] = [] - - normalized_sender_name = sender_name.strip() - if sender_id and normalized_sender_name: - sender_key = ("user", int(sender_id)) - previous = cache.get(sender_key) - if previous != normalized_sender_name: - cache[sender_key] = normalized_sender_name - rollback.append((sender_key, previous)) - updates["sender_id"] = sender_id - updates["sender_name"] = normalized_sender_name - - normalized_group_name = group_name.strip() - if group_id and normalized_group_name: - group_key = ("group", int(group_id)) - previous = cache.get(group_key) - if previous != normalized_group_name: - cache[group_key] = normalized_group_name - rollback.append((group_key, previous)) - updates["group_id"] = group_id - updates["group_name"] = normalized_group_name - - if not updates: - return - - async def _run_refresh() -> None: - try: - await self._refresh_profile_display_names(**updates) - except Exception: - for key, previous in rollback: - if previous is None: - cache.pop(key, None) - else: - cache[key] = previous - raise - - self._spawn_background_task(task_name, _run_refresh()) - - async def handle_message(self, event: dict[str, Any]) -> None: - """处理收到的消息事件""" - if logger.isEnabledFor(logging.DEBUG): - log_debug_json(logger, "[事件数据]", event) - post_type = event.get("post_type", "message") - - # 处理拍一拍事件(效果同被 @) - if post_type == "notice" and event.get("notice_type") == "poke": - target_id = event.get("target_id", 0) - # 只有拍机器人才响应 - if target_id != self.config.bot_qq: - logger.debug( - "[通知] 忽略拍一拍目标非机器人: target=%s", - target_id, - ) - return - - if not self.config.should_process_poke_message(): - logger.debug("[消息策略] 已关闭拍一拍处理,忽略此次 poke 事件") - return - - poke_group_id: int = event.get("group_id", 0) - poke_sender_id: int = event.get("user_id", 0) - - # 访问控制:命中群黑名单或不满足白名单限制时忽略 - if poke_group_id == 0: - if not self.config.is_private_allowed(poke_sender_id): - private_reason = ( - self.config.private_access_denied_reason(poke_sender_id) - or "unknown" - ) - logger.debug( - "[访问控制] 忽略私聊拍一拍: user=%s reason=%s (access enabled=%s)", - poke_sender_id, - private_reason, - self.config.access_control_enabled(), - ) - return - else: - if not self.config.is_group_allowed(poke_group_id): - group_reason = ( - self.config.group_access_denied_reason(poke_group_id) - or "unknown" - ) - logger.debug( - "[访问控制] 忽略群聊拍一拍: group=%s sender=%s reason=%s (access enabled=%s)", - poke_group_id, - poke_sender_id, - group_reason, - self.config.access_control_enabled(), - ) - return - - logger.info( - "[通知] 收到拍一拍: group=%s sender=%s", - poke_group_id, - poke_sender_id, - ) - logger.debug("[通知] 拍一拍事件数据: %s", str(event)[:200]) - - if poke_group_id == 0: - private_poke = await self._record_private_poke_history( - poke_sender_id, event - ) - logger.info("[通知] 私聊拍一拍,触发私聊回复") - await self.ai_coordinator.handle_private_reply( - poke_sender_id, - private_poke.poke_text, - [], - is_poke=True, - sender_name=private_poke.sender_name, - ) - else: - group_poke = await self._record_group_poke_history( - poke_group_id, - poke_sender_id, - event, - ) - logger.info( - "[通知] 群聊拍一拍,触发群聊回复: group=%s", - poke_group_id, - ) - await self.ai_coordinator.handle_auto_reply( - poke_group_id, - poke_sender_id, - group_poke.poke_text, - [], - is_poke=True, - sender_name=group_poke.sender_name, - group_name=group_poke.group_name, - sender_role=group_poke.sender_role, - sender_title=group_poke.sender_title, - sender_level=group_poke.sender_level, - ) - return - - # 处理私聊消息 - if event.get("message_type") == "private": - private_sender_id: int = get_message_sender_id(event) - private_message_content: list[dict[str, Any]] = get_message_content(event) - trigger_message_id = event.get("message_id") - - # 访问控制:命中黑/白名单规则时忽略(不入历史、不触发任何处理) - if not self.config.is_private_allowed(private_sender_id): - private_reason = ( - self.config.private_access_denied_reason(private_sender_id) - or "unknown" - ) - logger.debug( - "[访问控制] 忽略私聊消息: user=%s reason=%s (access enabled=%s)", - private_sender_id, - private_reason, - self.config.access_control_enabled(), - ) - return - - # 获取发送者昵称 - private_sender: dict[str, Any] = event.get("sender", {}) - private_sender_nickname: str = private_sender.get("nickname", "") - - # 获取私聊用户昵称 - user_name = private_sender_nickname - if not user_name: - try: - user_info = await self.onebot.get_stranger_info(private_sender_id) - if user_info: - user_name = user_info.get("nickname", "") - except Exception as exc: - logger.warning("获取用户昵称失败: %s", exc) - - text = extract_text(private_message_content, self.config.bot_qq) - # 并行执行附件收集和历史内容解析 - private_attachments, parsed_content_raw = await asyncio.gather( - self._collect_message_attachments( - private_message_content, - user_id=private_sender_id, - request_type="private", - ), - parse_message_content_for_history( - private_message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ), - ) - safe_text = redact_string(text) - logger.info( - "[私聊消息] 发送者=%s 昵称=%s 内容=%s", - private_sender_id, - user_name or private_sender_nickname, - safe_text[:100], - ) - resolved_private_name = (user_name or private_sender_nickname or "").strip() - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_private:{private_sender_id}", - sender_id=private_sender_id, - sender_name=resolved_private_name, - ) - - # 保存私聊消息到历史记录 - parsed_content = append_attachment_text( - parsed_content_raw, private_attachments - ) - safe_parsed = redact_string(parsed_content) - logger.debug( - "[历史记录] 保存私聊: user=%s content=%s...", - private_sender_id, - safe_parsed[:50], - ) - await self.history_manager.add_private_message( - user_id=private_sender_id, - text_content=parsed_content, - display_name=private_sender_nickname, - user_name=user_name, - message_id=trigger_message_id, - attachments=private_attachments, - ) - - # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 - if private_sender_id == self.config.bot_qq: - return - - self._schedule_meme_ingest( - attachments=private_attachments, - chat_type="private", - chat_id=private_sender_id, - sender_id=private_sender_id, - message_id=safe_int(trigger_message_id), - scope_key=build_attachment_scope( - user_id=private_sender_id, - request_type="private", - ), - ) - - if not self.config.should_process_private_message(): - logger.debug( - "[消息策略] 已关闭私聊处理: user=%s", - private_sender_id, - ) - return - - if ( - getattr(self.config, "model_pool_enabled", False) - and _is_private_model_pool_control_text(text) - ) and await self.ai_coordinator.model_pool.handle_private_message( - private_sender_id, - text, - ): - return - - private_command = self.command_dispatcher.parse_command(text) - if private_command: - await self._flush_command_buffer( - scope=make_scope(user_id=private_sender_id), - sender_id=private_sender_id, - ) - await self.command_dispatcher.dispatch_private( - user_id=private_sender_id, - sender_id=private_sender_id, - command=private_command, - ) - return - - await self._run_pipelines( - target_id=private_sender_id, - target_type="private", - text=text, - message_content=private_message_content, - ) - - await self.ai_coordinator.handle_private_reply( - private_sender_id, - text, - private_message_content, - attachments=private_attachments, - sender_name=user_name, - trigger_message_id=trigger_message_id, - ) - return - - # 只处理群消息 - if event.get("message_type") != "group": - return - - group_id: int = event.get("group_id", 0) - sender_id: int = get_message_sender_id(event) - message_content: list[dict[str, Any]] = get_message_content(event) - trigger_message_id = event.get("message_id") - - # 访问控制:命中黑/白名单规则时忽略(不入历史、不触发任何处理) - if not self.config.is_group_allowed(group_id): - group_reason = self.config.group_access_denied_reason(group_id) or "unknown" - logger.debug( - "[访问控制] 忽略群消息: group=%s sender=%s reason=%s (access enabled=%s)", - group_id, - sender_id, - group_reason, - self.config.access_control_enabled(), - ) - return - - # 获取发送者信息 - group_sender: dict[str, Any] = event.get("sender", {}) - sender_card: str = group_sender.get("card", "") - sender_nickname: str = group_sender.get("nickname", "") - sender_role: str = group_sender.get("role", "member") - sender_title: str = group_sender.get("title", "") - sender_level: str = str(group_sender.get("level", "")).strip() - - # 提取文本内容 - text = extract_text(message_content, self.config.bot_qq) - safe_text = redact_string(text) - logger.info( - f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " - f"role={sender_role} | {safe_text[:100]}" - ) - - # 并行执行 3 个独立的异步操作:附件收集、群信息获取、历史内容解析 - async def _fetch_group_name() -> str: - try: - info = await self.onebot.get_group_info(group_id) - if info: - return str(info.get("group_name", "") or "") - except Exception as e: - logger.warning(f"获取群聊名失败: {e}") - return "" - - group_attachments, group_name, parsed_content_raw = await asyncio.gather( - self._collect_message_attachments( - message_content, - group_id=group_id, - request_type="group", - ), - _fetch_group_name(), - parse_message_content_for_history( - message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ), - ) - - resolved_group_sender_name = (sender_card or sender_nickname or "").strip() - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", - sender_id=sender_id, - sender_name=resolved_group_sender_name, - group_id=group_id, - group_name=str(group_name or "").strip(), - ) - - # 保存消息到历史记录 - parsed_content = append_attachment_text(parsed_content_raw, group_attachments) - safe_parsed = redact_string(parsed_content) - logger.debug( - f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." - ) - await self.history_manager.add_group_message( - group_id=group_id, - sender_id=sender_id, - text_content=parsed_content, - sender_card=sender_card, - sender_nickname=sender_nickname, - group_name=group_name, - role=sender_role, - title=sender_title, - level=sender_level, - message_id=trigger_message_id, - attachments=group_attachments, - ) - - # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 - # 同时把 bot 自身的发言写入复读计数器,使窗口中留有 bot 标记, - # 后续触发检查时会排除含 bot 的窗口,防止"bot 先发 → 用户跟发"或 - # "用户发到一半 bot 插入"等情况误触复读。 - if sender_id == self.config.bot_qq: - if self.config.repeat_enabled and text: - async with self._get_repeat_lock(group_id): - counter = self._repeat_counter.setdefault(group_id, []) - counter.append((text, sender_id)) - n = self.config.repeat_threshold - if len(counter) > n: - self._repeat_counter[group_id] = counter[-n:] - return - - self._schedule_meme_ingest( - attachments=group_attachments, - chat_type="group", - chat_id=group_id, - sender_id=sender_id, - message_id=safe_int(trigger_message_id), - scope_key=build_attachment_scope(group_id=group_id, request_type="group"), - ) - - # 检查是否 @ 了机器人(后续分流共用) - is_at_bot = self.ai_coordinator._is_at_bot(message_content) - - # 假@检测:识别 "@昵称" 纯文本形式 - # normalized_text 用于命令解析和 AI 路由,原始 text 已用于历史/日志 - is_fake_at = False - normalized_text = text - if not is_at_bot and ("@" in text or "@" in text): - nicknames = await self._bot_nickname_cache.get_nicknames(group_id) - if nicknames: - is_fake_at, normalized_text = strip_fake_at(text, nicknames) - if is_fake_at: - is_at_bot = True - logger.info( - "[假@] 识别到假@: group=%s sender=%s", - group_id, - sender_id, - ) - - # 关闭“每条消息处理”后,仅处理 @ 消息(私聊/拍一拍在其他分支中处理) - if not self.config.should_process_group_message(is_at_bot=is_at_bot): - logger.debug( - "[消息策略] 跳过群消息处理: group=%s sender=%s process_every_message=%s at_bot=%s", - group_id, - sender_id, - self.config.process_every_message, - is_at_bot, - ) - return - - # 只有被@时才处理斜杠命令(使用 normalized_text 以支持假@后的命令)。 - # 命令优先于自动处理管线,命中后不触发后续自动提取或 AI 回复。 - if is_at_bot: - command = self.command_dispatcher.parse_command(normalized_text) - if command: - await self._flush_command_buffer( - scope=make_scope(group_id=group_id), - sender_id=sender_id, - ) - await self.command_dispatcher.dispatch(group_id, sender_id, command) - return - - # 关键词自动回复:心理委员 (使用原始消息内容提取文本,保证关键词触发不受影响) - if self.config.keyword_reply_enabled and matches_xinliweiyuan(text): - rand_val = random.random() - if rand_val < 0.01: # 1% 飞起来 - message = f"[@{sender_id}] 再发让你飞起来" - logger.info("关键词回复: 再发让你飞起来") - await self.sender.send_group_message( - group_id, - message, - history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, - ) - return - elif rand_val < 0.11: # 10% 发送图片 - try: - image_path = ( - resolve_resource_path("img/xlwy.jpg").resolve().as_uri() - ) - except Exception: - image_path = Path(os.path.abspath("img/xlwy.jpg")).as_uri() - message = f"[CQ:image,file={image_path}]" - # 50% 概率 @ 发送者 - if random.random() < 0.5: - message = f"[@{sender_id}] {message}" - logger.info("关键词回复: 发送图片 xlwy.jpg") - else: # 90% 原有逻辑 - if random.random() < 0.7: - reply = "受着" - else: - reply = "那咋了" - # 50% 概率 @ 发送者 - if random.random() < 0.5: - message = f"[@{sender_id}] {reply}" - else: - message = reply - logger.info(f"关键词回复: {reply}") - # 使用 sender 发送 - await self.sender.send_group_message( - group_id, - message, - history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, - ) - return - - # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold - if self.config.repeat_enabled and text: - n = self.config.repeat_threshold - async with self._get_repeat_lock(group_id): - counter = self._repeat_counter.setdefault(group_id, []) - counter.append((text, sender_id)) - # 只保留最近 n 条 - if len(counter) > n: - self._repeat_counter[group_id] = counter[-n:] - counter = self._repeat_counter[group_id] - - if len(counter) >= n: - last_n = counter[-n:] - texts = [t for t, _ in last_n] - senders = [s for _, s in last_n] - if ( - len(set(texts)) == 1 - and len(set(senders)) == n - and self.config.bot_qq not in senders - ): - reply_text = texts[0] - # 冷却检查:同一内容在冷却期内不再复读 - if self._is_repeat_on_cooldown(group_id, reply_text): - self._repeat_counter[group_id] = [] - logger.debug( - "[复读] 冷却中跳过: group=%s text=%s", - group_id, - redact_string(reply_text)[:50], - ) - else: - if self.config.inverted_question_enabled: - stripped = reply_text.strip() - if set(stripped) <= {"?", "?"}: - reply_text = "¿" * len(stripped) - # 清空计数器防止重复触发 - self._repeat_counter[group_id] = [] - self._record_repeat_cooldown(group_id, texts[0]) - logger.info( - "[复读] 触发复读: group=%s text=%s", - group_id, - redact_string(reply_text)[:50], - ) - await self.sender.send_group_message( - group_id, - reply_text, - history_prefix=REPEAT_REPLY_HISTORY_PREFIX, - ) - return - - await self._run_pipelines( - target_id=group_id, - target_type="group", - text=text, - message_content=message_content, - ) - - # 提取文本内容 - # (已在上方提取用于日志记录) - - # 自动回复处理(使用 normalized_text 以去除假@前缀) - display_name = sender_card or sender_nickname or str(sender_id) - await self.ai_coordinator.handle_auto_reply( - group_id, - sender_id, - normalized_text, - message_content, - attachments=group_attachments, - sender_name=display_name, - group_name=group_name, - sender_role=sender_role, - sender_title=sender_title, - sender_level=sender_level, - trigger_message_id=trigger_message_id, - is_fake_at=is_fake_at, - ) - - async def _record_private_poke_history( - self, user_id: int, event: dict[str, Any] - ) -> PrivatePokeRecord: - """记录私聊拍一拍到历史。""" - sender = event.get("sender", {}) - sender_nickname = "" - if isinstance(sender, dict): - sender_nickname = str(sender.get("nickname", "")).strip() - - user_name = sender_nickname - if not user_name: - try: - user_info = await self.onebot.get_stranger_info(user_id) - if isinstance(user_info, dict): - user_name = str(user_info.get("nickname", "")).strip() - except Exception as exc: - logger.warning( - "[通知] 获取私聊拍一拍用户昵称失败: user=%s err=%s", - user_id, - exc, - ) - - resolved_sender_name = (sender_nickname or user_name).strip() - display_name = resolved_sender_name or f"QQ{user_id}" - normalized_user_name = user_name or display_name - poke_text = _format_poke_history_text(display_name, user_id) - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_private_poke:{user_id}", - sender_id=user_id, - sender_name=resolved_sender_name, - ) - - try: - await self.history_manager.add_private_message( - user_id=user_id, - text_content=poke_text, - display_name=display_name, - user_name=normalized_user_name, - ) - except Exception as exc: - logger.warning( - "[历史记录] 写入私聊拍一拍失败: user=%s err=%s", - user_id, - exc, - ) - return PrivatePokeRecord(poke_text=poke_text, sender_name=display_name) - - async def _record_group_poke_history( - self, - group_id: int, - sender_id: int, - event: dict[str, Any], - ) -> GroupPokeRecord: - """记录群聊拍一拍到历史。""" - sender = event.get("sender", {}) - sender_card = "" - sender_nickname = "" - sender_role = "member" - sender_title = "" - sender_level = "" - if isinstance(sender, dict): - sender_card = str(sender.get("card", "")).strip() - sender_nickname = str(sender.get("nickname", "")).strip() - sender_role = str(sender.get("role", "member")).strip() or "member" - sender_title = str(sender.get("title", "")).strip() - sender_level = str(sender.get("level", "")).strip() - - if not sender_card and not sender_nickname: - try: - member_info = await self.onebot.get_group_member_info( - group_id, sender_id - ) - if isinstance(member_info, dict): - sender_card = str(member_info.get("card", "")).strip() - sender_nickname = str(member_info.get("nickname", "")).strip() - sender_role = ( - str(member_info.get("role", "member")).strip() or "member" - ) - sender_title = str(member_info.get("title", "")).strip() - sender_level = str(member_info.get("level", "")).strip() - except Exception as exc: - logger.warning( - "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", - group_id, - sender_id, - exc, - ) - - group_name = "" - try: - group_info = await self.onebot.get_group_info(group_id) - if isinstance(group_info, dict): - group_name = str(group_info.get("group_name", "")).strip() - except Exception as exc: - logger.warning( - "[通知] 获取拍一拍群名失败: group=%s err=%s", - group_id, - exc, - ) - - resolved_sender_name = (sender_card or sender_nickname).strip() - resolved_group_name = group_name.strip() - display_name = resolved_sender_name or f"QQ{sender_id}" - poke_text = _format_poke_history_text(display_name, sender_id) - normalized_group_name = resolved_group_name or f"群{group_id}" - self._schedule_profile_display_name_refresh( - task_name=f"profile_name_refresh_group_poke:{group_id}:{sender_id}", - sender_id=sender_id, - sender_name=resolved_sender_name, - group_id=group_id, - group_name=resolved_group_name, - ) - - try: - await self.history_manager.add_group_message( - group_id=group_id, - sender_id=sender_id, - text_content=poke_text, - sender_card=sender_card, - sender_nickname=sender_nickname, - group_name=normalized_group_name, - role=sender_role, - title=sender_title, - level=sender_level, - ) - except Exception as exc: - logger.warning( - "[历史记录] 写入群聊拍一拍失败: group=%s sender=%s err=%s", - group_id, - sender_id, - exc, - ) - return GroupPokeRecord( - poke_text=poke_text, - sender_name=display_name, - group_name=normalized_group_name, - sender_role=sender_role, - sender_title=sender_title, - sender_level=sender_level, - ) - - async def _extract_bilibili_ids( - self, text: str, message_content: list[dict[str, Any]] - ) -> list[str]: - """从文本和消息段中提取 B 站视频 BV 号。""" - from Undefined.bilibili.parser import ( - extract_bilibili_ids, - extract_from_json_message, - ) - - bvids = await extract_bilibili_ids(text) - if not bvids: - bvids = await extract_from_json_message(message_content) - return bvids - - async def _flush_command_buffer(self, *, scope: str, sender_id: int) -> None: - batcher_config = getattr(self.config, "message_batcher", None) - if not getattr(batcher_config, "flush_on_command", False): - return - batcher = getattr(self, "message_batcher", None) - if batcher is None: - return - flushed = await batcher.flush_sender(scope, sender_id) - if not flushed: - logger.warning( - "[MessageBatcher] 命令触发 flush 当前 buffer 失败: scope=%s sender=%s", - scope, - sender_id, - ) - - async def _run_pipelines( - self, - *, - target_id: int, - target_type: Literal["group", "private"], - text: str, - message_content: list[dict[str, Any]], - ) -> bool: - """并行检测并处理所有命中的自动处理管线。""" - if not getattr(self, "_pipelines_initialized", False): - await self.init_pipelines() - context = build_pipeline_context( - self, - target_id=target_id, - target_type=target_type, - text=text, - message_content=message_content, - ) - detections = await self.pipeline_registry.run(context) - return bool(detections) - - async def apply_skills_hot_reload_config( - self, - *, - enabled: bool, - interval: float, - debounce: float, - ) -> None: - """跟随全局 skills 热重载配置更新管线。""" - if not enabled: - await self.pipeline_registry.stop_hot_reload() - logger.info("[pipelines] 热重载已随配置禁用") - return - - await self.pipeline_registry.stop_hot_reload() - self.pipeline_registry.start_hot_reload( - interval=interval, - debounce=debounce, - ) - - def _extract_arxiv_ids( - self, text: str, message_content: list[dict[str, Any]] - ) -> list[str]: - """从文本和消息段中提取 arXiv 论文 ID。""" - from Undefined.arxiv.parser import extract_arxiv_ids, extract_from_json_message - - paper_ids: list[str] = [] - seen: set[str] = set() - - for paper_id in extract_arxiv_ids(text): - if paper_id in seen: - continue - seen.add(paper_id) - paper_ids.append(paper_id) - - for paper_id in extract_from_json_message(message_content): - if paper_id in seen: - continue - seen.add(paper_id) - paper_ids.append(paper_id) - - return paper_ids - - def _extract_github_repo_ids( - self, text: str, message_content: list[dict[str, Any]] - ) -> list[str]: - """从文本和消息段中提取 GitHub 仓库 ID。""" - from Undefined.github.parser import ( - extract_from_json_message, - extract_github_repo_ids, - ) - - repo_ids: list[str] = [] - seen: set[str] = set() - - for repo_id in extract_github_repo_ids(text): - key = repo_id.lower() - if key in seen: - continue - seen.add(key) - repo_ids.append(repo_id) - - for repo_id in extract_from_json_message(message_content): - key = repo_id.lower() - if key in seen: - continue - seen.add(key) - repo_ids.append(repo_id) - - return repo_ids - - async def _handle_bilibili_extract( - self, - target_id: int, - bvids: list[str], - target_type: str, - ) -> None: - """处理 bilibili 视频自动提取和发送。""" - from Undefined.bilibili.sender import send_bilibili_video - - for bvid in bvids[:3]: # 最多同时处理 3 个 - try: - await send_bilibili_video( - video_id=bvid, - sender=self.sender, - onebot=self.onebot, - target_type=target_type, # type: ignore[arg-type] - target_id=target_id, - cookie=self.config.bilibili_cookie, - prefer_quality=self.config.bilibili_prefer_quality, - max_duration=self.config.bilibili_max_duration, - max_file_size=self.config.bilibili_max_file_size, - oversize_strategy=self.config.bilibili_oversize_strategy, - danmaku_enabled=self.config.bilibili_danmaku_enabled, - danmaku_batch_size=self.config.bilibili_danmaku_batch_size, - danmaku_max_count=self.config.bilibili_danmaku_max_count, - ) - except Exception as exc: - logger.error( - "[Bilibili] 自动提取失败 %s → %s:%s: %s", - bvid, - target_type, - target_id, - exc, - ) - try: - error_msg = f"视频提取失败: {exc}" - if target_type == "group": - await self.sender.send_group_message(target_id, error_msg) - else: - await self.sender.send_private_message(target_id, error_msg) - except Exception: - pass - - async def _handle_arxiv_extract( - self, - target_id: int, - paper_ids: list[str], - target_type: str, - ) -> None: - """处理 arXiv 论文自动提取和发送。""" - from Undefined.arxiv.sender import send_arxiv_paper - - max_items = max(1, int(self.config.arxiv_auto_extract_max_items)) - - for paper_id in paper_ids[:max_items]: - try: - result = await send_arxiv_paper( - paper_id=paper_id, - sender=self.sender, - target_type=target_type, # type: ignore[arg-type] - target_id=target_id, - max_file_size=self.config.arxiv_max_file_size, - author_preview_limit=self.config.arxiv_author_preview_limit, - summary_preview_chars=self.config.arxiv_summary_preview_chars, - context={ - "request_id": ( - f"arxiv_auto_extract:{target_type}:{target_id}:{paper_id}" - ) - }, - ) - logger.info( - "[arXiv] 自动提取完成 %s → %s:%s: %s", - paper_id, - target_type, - target_id, - result, - ) - except Exception: - logger.exception( - "[arXiv] 自动提取失败 %s → %s:%s", - paper_id, - target_type, - target_id, - ) - - async def _handle_github_extract( - self, - target_id: int, - repo_ids: list[str], - target_type: str, - ) -> None: - """处理 GitHub 仓库自动提取和发送。""" - from Undefined.github.sender import send_github_repo_card - - max_items = max( - 1, int(getattr(self.config, "github_auto_extract_max_items", 3)) - ) - request_timeout = float( - getattr(self.config, "github_request_timeout_seconds", 10.0) - ) - - for repo_id in repo_ids[:max_items]: - try: - result = await send_github_repo_card( - repo_id=repo_id, - sender=self.sender, - target_type=target_type, # type: ignore[arg-type] - target_id=target_id, - request_timeout=request_timeout, - context={ - "request_id": ( - f"github_auto_extract:{target_type}:{target_id}:{repo_id}" - ) - }, - ) - logger.info( - "[GitHub] 自动提取完成 %s → %s:%s: %s", - repo_id, - target_type, - target_id, - result, - ) - except Exception as exc: - logger.info( - "[GitHub] 自动提取跳过 %s → %s:%s: %s", - repo_id, - target_type, - target_id, - exc, - ) - - def _spawn_background_task( - self, - name: str, - coroutine: Coroutine[Any, Any, None], - ) -> None: - task = asyncio.create_task(coroutine, name=name) - self._background_tasks.add(task) - - def _finalize(done_task: asyncio.Task[None]) -> None: - self._background_tasks.discard(done_task) - try: - exc = done_task.exception() - except asyncio.CancelledError: - logger.debug("[后台任务] 已取消: %s", name) - return - if exc is not None: - logger.exception( - "[后台任务] 执行失败: name=%s", - name, - exc_info=(type(exc), exc, exc.__traceback__), - ) - - task.add_done_callback(_finalize) - - async def close(self) -> None: - """关闭消息处理器""" - logger.info("正在关闭消息处理器...") - if self._background_tasks: - logger.info( - "[后台任务] 等待自动提取任务收敛: count=%s", - len(self._background_tasks), - ) - await asyncio.gather( - *list(self._background_tasks), - return_exceptions=True, - ) - await self.pipeline_registry.stop_hot_reload() - await self.message_batcher.flush_all() - await self.ai_coordinator.queue_manager.drain() - await self.ai_coordinator.queue_manager.stop() - await self.history_manager.flush_pending_saves() - logger.info("消息处理器已关闭") diff --git a/src/Undefined/handlers/__init__.py b/src/Undefined/handlers/__init__.py new file mode 100644 index 00000000..631cda8e --- /dev/null +++ b/src/Undefined/handlers/__init__.py @@ -0,0 +1,28 @@ +"""消息处理和命令分发包。 + +聚合 ``MessageHandler`` 与各 mixin 子模块;保留 ``parse_message_content_for_history`` 等 +模块级符号供测试 monkeypatch(``import Undefined.handlers``)。 +""" + +from Undefined.handlers.message_flow import ( + KEYWORD_REPLY_HISTORY_PREFIX, + MessageHandler, +) +from Undefined.handlers.poke import GroupPokeRecord, PrivatePokeRecord +from Undefined.handlers.repeat import REPEAT_REPLY_HISTORY_PREFIX +from Undefined.utils.common import ( + extract_text, + matches_xinliweiyuan, + parse_message_content_for_history, +) + +__all__ = [ + "GroupPokeRecord", + "KEYWORD_REPLY_HISTORY_PREFIX", + "MessageHandler", + "PrivatePokeRecord", + "REPEAT_REPLY_HISTORY_PREFIX", + "extract_text", + "matches_xinliweiyuan", + "parse_message_content_for_history", +] diff --git a/src/Undefined/handlers/auto_extract.py b/src/Undefined/handlers/auto_extract.py new file mode 100644 index 00000000..8375f8cb --- /dev/null +++ b/src/Undefined/handlers/auto_extract.py @@ -0,0 +1,223 @@ +"""B 站 / arXiv / GitHub 链接自动提取 mixin。 + +从消息中解析外部资源 ID 并调用对应 sender 发送;由 ``MessageHandler`` 混入使用。 +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.onebot import OneBotClient + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +class AutoExtractMixin: + """外部资源自动提取 mixin。""" + + if TYPE_CHECKING: + config: Config + sender: MessageSender + onebot: OneBotClient + + async def _extract_bilibili_ids( + self, text: str, message_content: list[dict[str, Any]] + ) -> list[str]: + """从文本和消息段中提取 B 站视频 BV 号。""" + from Undefined.bilibili.parser import ( + extract_bilibili_ids, + extract_from_json_message, + ) + + bvids = await extract_bilibili_ids(text) + if not bvids: + bvids = await extract_from_json_message(message_content) + return list(bvids) + + def _extract_arxiv_ids( + self, text: str, message_content: list[dict[str, Any]] + ) -> list[str]: + """从文本和消息段中提取 arXiv 论文 ID。""" + from Undefined.arxiv.parser import extract_arxiv_ids, extract_from_json_message + + paper_ids: list[str] = [] + seen: set[str] = set() + + for paper_id in extract_arxiv_ids(text): + if paper_id in seen: + continue + seen.add(paper_id) + paper_ids.append(paper_id) + + for paper_id in extract_from_json_message(message_content): + if paper_id in seen: + continue + seen.add(paper_id) + paper_ids.append(paper_id) + + return paper_ids + + def _extract_github_repo_ids( + self, text: str, message_content: list[dict[str, Any]] + ) -> list[str]: + """从文本和消息段中提取 GitHub 仓库 ID。""" + from Undefined.github.parser import ( + extract_from_json_message, + extract_github_repo_ids, + ) + + repo_ids: list[str] = [] + seen: set[str] = set() + + for repo_id in extract_github_repo_ids(text): + # 仓库 ID 大小写不敏感去重 + key = repo_id.lower() + if key in seen: + continue + seen.add(key) + repo_ids.append(repo_id) + + for repo_id in extract_from_json_message(message_content): + key = repo_id.lower() + if key in seen: + continue + seen.add(key) + repo_ids.append(repo_id) + + return repo_ids + + async def _handle_bilibili_extract( + self, + target_id: int, + bvids: list[str], + target_type: str, + ) -> None: + """处理 bilibili 视频自动提取和发送。""" + from Undefined.bilibili.sender import send_bilibili_video + + for bvid in bvids[:3]: + try: + # 单条消息最多自动提取 3 个 BV + await send_bilibili_video( + video_id=bvid, + sender=self.sender, + onebot=self.onebot, + target_type=target_type, # type: ignore[arg-type] + target_id=target_id, + cookie=self.config.bilibili_cookie, + prefer_quality=self.config.bilibili_prefer_quality, + max_duration=self.config.bilibili_max_duration, + max_file_size=self.config.bilibili_max_file_size, + oversize_strategy=self.config.bilibili_oversize_strategy, + danmaku_enabled=self.config.bilibili_danmaku_enabled, + danmaku_batch_size=self.config.bilibili_danmaku_batch_size, + danmaku_max_count=self.config.bilibili_danmaku_max_count, + ) + except Exception as exc: + logger.error( + "[Bilibili] 自动提取失败 %s → %s:%s: %s", + bvid, + target_type, + target_id, + exc, + ) + try: + error_msg = f"视频提取失败: {exc}" + if target_type == "group": + await self.sender.send_group_message(target_id, error_msg) + else: + await self.sender.send_private_message(target_id, error_msg) + except Exception: + pass + + async def _handle_arxiv_extract( + self, + target_id: int, + paper_ids: list[str], + target_type: str, + ) -> None: + """处理 arXiv 论文自动提取和发送。""" + from Undefined.arxiv.sender import send_arxiv_paper + + max_items = max(1, int(self.config.arxiv_auto_extract_max_items)) + + for paper_id in paper_ids[:max_items]: + try: + result = await send_arxiv_paper( + paper_id=paper_id, + sender=self.sender, + target_type=target_type, # type: ignore[arg-type] + target_id=target_id, + max_file_size=self.config.arxiv_max_file_size, + author_preview_limit=self.config.arxiv_author_preview_limit, + summary_preview_chars=self.config.arxiv_summary_preview_chars, + context={ + "request_id": ( + f"arxiv_auto_extract:{target_type}:{target_id}:{paper_id}" + ) + }, + ) + logger.info( + "[arXiv] 自动提取完成 %s → %s:%s: %s", + paper_id, + target_type, + target_id, + result, + ) + except Exception: + logger.exception( + "[arXiv] 自动提取失败 %s → %s:%s", + paper_id, + target_type, + target_id, + ) + + async def _handle_github_extract( + self, + target_id: int, + repo_ids: list[str], + target_type: str, + ) -> None: + """处理 GitHub 仓库自动提取和发送。""" + from Undefined.github.sender import send_github_repo_card + + max_items = max( + 1, int(getattr(self.config, "github_auto_extract_max_items", 3)) + ) + request_timeout = float( + getattr(self.config, "github_request_timeout_seconds", 10.0) + ) + + for repo_id in repo_ids[:max_items]: + try: + result = await send_github_repo_card( + repo_id=repo_id, + sender=self.sender, + target_type=target_type, # type: ignore[arg-type] + target_id=target_id, + request_timeout=request_timeout, + context={ + "request_id": ( + f"github_auto_extract:{target_type}:{target_id}:{repo_id}" + ) + }, + ) + logger.info( + "[GitHub] 自动提取完成 %s → %s:%s: %s", + repo_id, + target_type, + target_id, + result, + ) + except Exception as exc: + logger.info( + "[GitHub] 自动提取跳过 %s → %s:%s: %s", + repo_id, + target_type, + target_id, + exc, + ) diff --git a/src/Undefined/handlers/message_flow.py b/src/Undefined/handlers/message_flow.py new file mode 100644 index 00000000..b9b2a5a5 --- /dev/null +++ b/src/Undefined/handlers/message_flow.py @@ -0,0 +1,845 @@ +"""消息主流程与 ``MessageHandler`` 核心实现。 + +协调私聊/群聊事件分发、附件收集、管线与 AI 回复;拍一拍、复读、自动提取由 mixin 提供。 +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from pathlib import Path +import random +from typing import Any, Coroutine, Literal + +import Undefined.handlers as handlers_module +from Undefined.attachments import ( + append_attachment_text, + build_attachment_scope, + register_message_attachments, +) +from Undefined.ai import AIClient +from Undefined.config import Config +from Undefined.faq import FAQStorage +from Undefined.handlers.auto_extract import AutoExtractMixin +from Undefined.handlers.poke import PokeMixin +from Undefined.handlers.repeat import RepeatMixin +from Undefined.onebot import ( + OneBotClient, + get_message_content, + get_message_sender_id, +) +from Undefined.rate_limit import RateLimiter +from Undefined.scheduled_task_storage import ScheduledTaskStorage +from Undefined.services.ai_coordinator import AICoordinator +from Undefined.services.command import CommandDispatcher +from Undefined.services.message_batcher import MessageBatcher, make_scope +from Undefined.services.model_pool import ModelPoolService +from Undefined.services.queue_manager import QueueManager +from Undefined.services.security import SecurityService +from Undefined.skills.pipelines import PipelineRegistry +from Undefined.skills.pipelines.context import build_pipeline_context +from Undefined.utils.coerce import safe_int +from Undefined.utils.fake_at import BotNicknameCache, strip_fake_at +from Undefined.utils.history import MessageHistoryManager +from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.queue_intervals import build_model_queue_intervals +from Undefined.utils.resources import resolve_resource_path +from Undefined.utils.scheduler import TaskScheduler +from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + +KEYWORD_REPLY_HISTORY_PREFIX = "[系统关键词自动回复] " + + +def _is_private_model_pool_control_text(text: str) -> bool: + return bool(ModelPoolService.is_private_control_text(text)) + + +class MessageHandler(PokeMixin, RepeatMixin, AutoExtractMixin): + """消息处理器。 + + 接收 OneBot 事件、写入历史并协调命令分发、自动管线与 AI 回复; + 依赖 ``AICoordinator``、``CommandDispatcher`` 与 ``MessageBatcher``。 + """ + + def __init__( + self, + config: Config, + onebot: OneBotClient, + ai: AIClient, + faq_storage: FAQStorage, + task_storage: ScheduledTaskStorage, + ) -> None: + self.config = config + self.onebot = onebot + self.ai = ai + self.faq_storage = faq_storage + self.history_manager = MessageHistoryManager(config.history_max_records) + self.sender = MessageSender( + onebot, + self.history_manager, + config.bot_qq, + config, + attachment_registry=getattr(ai, "attachment_registry", None), + ) + + self.security = SecurityService(config, ai._http_client) + self.rate_limiter = RateLimiter(config) + self.queue_manager = QueueManager( + max_retries=config.ai_request_max_retries, + ) + self.queue_manager.update_model_intervals(build_model_queue_intervals(config)) + + ai.set_queue_manager(self.queue_manager) + + self.command_dispatcher = CommandDispatcher( + config, + self.sender, + ai, + faq_storage, + onebot, + self.security, + queue_manager=self.queue_manager, + rate_limiter=self.rate_limiter, + history_manager=self.history_manager, + ) + self.ai_coordinator = AICoordinator( + config, + ai, + self.queue_manager, + self.history_manager, + self.sender, + onebot, + TaskScheduler(ai, self.sender, onebot, self.history_manager, task_storage), + self.security, + command_dispatcher=self.command_dispatcher, + ) + + self.message_batcher = MessageBatcher( + config.message_batcher, + flush_callback=self.ai_coordinator.handle_batched_dispatch, + ) + self.ai_coordinator.set_batcher(self.message_batcher) + + self._background_tasks: set[asyncio.Task[None]] = set() + self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} + self._bot_nickname_cache = BotNicknameCache(onebot, config.bot_qq) + self.pipeline_registry = PipelineRegistry() + self._pipelines_initialized = False + self._pipelines_init_lock = asyncio.Lock() + + self._repeat_counter: dict[int, list[tuple[str, int]]] = {} + self._repeat_locks: dict[int, asyncio.Lock] = {} + self._repeat_cooldown: dict[int, dict[str, float]] = {} + + self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) + + async def initialize(self) -> None: + """完成需要事件循环承载的异步初始化。""" + await self.init_pipelines() + + async def init_pipelines(self) -> None: + """异步加载自动处理管线并按配置启动热重载。""" + if getattr(self, "_pipelines_initialized", False): + return + init_lock = getattr(self, "_pipelines_init_lock", None) + if init_lock is None: + init_lock = asyncio.Lock() + self._pipelines_init_lock = init_lock + async with init_lock: + if getattr(self, "_pipelines_initialized", False): + return + await self.pipeline_registry.load_items_async() + self._pipelines_initialized = True + if getattr(self.config, "skills_hot_reload", False): + self.pipeline_registry.start_hot_reload( + interval=self.config.skills_hot_reload_interval, + debounce=self.config.skills_hot_reload_debounce, + ) + + async def _annotate_meme_descriptions( + self, + attachments: list[dict[str, str]], + scope_key: str, + ) -> list[dict[str, str]]: + """为图片附件添加表情包描述(如果在表情库中找到)。 + + 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 + 最佳努力:任何失败时返回原始列表。 + """ + if not attachments: + return attachments + + ai_client = getattr(self, "ai", None) + if ai_client is None: + return attachments + + attachment_registry = getattr(ai_client, "attachment_registry", None) + if attachment_registry is None: + return attachments + + meme_service = getattr(ai_client, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return attachments + + meme_store = getattr(meme_service, "_store", None) + if meme_store is None: + return attachments + + try: + # 仅 pic_ 前缀图片参与表情库匹配 + uid_to_hash: dict[str, str] = {} + for att in attachments: + uid = att.get("uid", "") + if not uid.startswith("pic_"): + continue + record = attachment_registry.resolve(uid, scope_key) + if record and record.sha256: + uid_to_hash[uid] = record.sha256 + + if not uid_to_hash: + return attachments + + unique_hashes = set(uid_to_hash.values()) + hash_to_desc: dict[str, str] = {} + for h in unique_hashes: + meme = await meme_store.find_by_sha256(h) + if meme and meme.description: + hash_to_desc[h] = meme.description + + if not hash_to_desc: + return attachments + + result: list[dict[str, str]] = [] + for att in attachments: + uid = att.get("uid", "") + sha = uid_to_hash.get(uid, "") + desc = hash_to_desc.get(sha, "") + if desc: + new_att = dict(att) + new_att["description"] = f"[表情包] {desc}" + result.append(new_att) + else: + result.append(att) + return result + except Exception: + logger.warning("表情包自动匹配失败,跳过", exc_info=True) + return attachments + + async def _collect_message_attachments( + self, + message_content: list[dict[str, Any]], + *, + group_id: int | None = None, + user_id: int | None = None, + request_type: str, + ) -> list[dict[str, str]]: + scope_key = build_attachment_scope( + group_id=group_id, + user_id=user_id, + request_type=request_type, + ) + if not scope_key: + return [] + ai_client = getattr(self, "ai", None) + attachment_registry = ( + getattr(ai_client, "attachment_registry", None) if ai_client else None + ) + if attachment_registry is None: + return [] + onebot = getattr(self, "onebot", None) + resolve_image_url = getattr(onebot, "get_image", None) if onebot else None + result = await register_message_attachments( + registry=attachment_registry, + segments=message_content, + scope_key=scope_key, + resolve_image_url=resolve_image_url, + get_forward_messages=getattr(onebot, "get_forward_msg", None) + if onebot + else None, + ) + attachments = result.attachments + # 命中表情库时为 AI 上下文补充 [表情包] 描述 + attachments = await self._annotate_meme_descriptions(attachments, scope_key) + return attachments + + def _schedule_meme_ingest( + self, + *, + attachments: list[dict[str, str]], + chat_type: str, + chat_id: int, + sender_id: int, + message_id: int | None, + scope_key: str | None, + ) -> None: + # 后台异步入库,不阻塞主消息处理 + if not attachments or not scope_key: + return + meme_service = getattr(self.ai, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return + self._spawn_background_task( + f"meme_ingest:{chat_type}:{chat_id}:{sender_id}:{message_id or 0}", + meme_service.enqueue_incoming_attachments( + attachments=attachments, + chat_type=chat_type, + chat_id=chat_id, + sender_id=sender_id, + message_id=message_id, + scope_key=scope_key, + ), + ) + + async def _refresh_profile_display_names( + self, + *, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: + ai_client = getattr(self, "ai", None) + cognitive_service = getattr(ai_client, "_cognitive_service", None) + if not cognitive_service or not getattr(cognitive_service, "enabled", False): + return + + if sender_id and sender_name.strip(): + await cognitive_service.sync_profile_display_name( + entity_type="user", + entity_id=str(sender_id), + preferred_name=sender_name.strip(), + ) + if group_id and group_name.strip(): + await cognitive_service.sync_profile_display_name( + entity_type="group", + entity_id=str(group_id), + preferred_name=group_name.strip(), + ) + + def _can_refresh_profile_display_names(self) -> bool: + ai_client = getattr(self, "ai", None) + cognitive_service = getattr(ai_client, "_cognitive_service", None) + return bool(cognitive_service and getattr(cognitive_service, "enabled", False)) + + def _schedule_profile_display_name_refresh( + self, + *, + task_name: str, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: + if not self._can_refresh_profile_display_names(): + return + + cache = getattr(self, "_profile_name_refresh_cache", None) + if cache is None: + cache = {} + self._profile_name_refresh_cache = cache + + updates: dict[str, Any] = {} + rollback: list[tuple[tuple[str, int], str | None]] = [] + + normalized_sender_name = sender_name.strip() + if sender_id and normalized_sender_name: + sender_key = ("user", int(sender_id)) + previous = cache.get(sender_key) + if previous != normalized_sender_name: + cache[sender_key] = normalized_sender_name + rollback.append((sender_key, previous)) + updates["sender_id"] = sender_id + updates["sender_name"] = normalized_sender_name + + normalized_group_name = group_name.strip() + if group_id and normalized_group_name: + group_key = ("group", int(group_id)) + previous = cache.get(group_key) + if previous != normalized_group_name: + cache[group_key] = normalized_group_name + rollback.append((group_key, previous)) + updates["group_id"] = group_id + updates["group_name"] = normalized_group_name + + if not updates: + return + + async def _run_refresh() -> None: + try: + await self._refresh_profile_display_names(**updates) + except Exception: + # 刷新失败时回滚内存缓存,避免脏昵称长期生效 + for key, previous in rollback: + if previous is None: + cache.pop(key, None) + else: + cache[key] = previous + raise + + self._spawn_background_task(task_name, _run_refresh()) + + async def handle_message(self, event: dict[str, Any]) -> None: + """处理收到的消息事件。""" + if logger.isEnabledFor(logging.DEBUG): + log_debug_json(logger, "[事件数据]", event) + post_type = event.get("post_type", "message") + + # 拍一拍走 notice 旁路,不进入普通消息流水线 + if post_type == "notice" and event.get("notice_type") == "poke": + await self._handle_poke_notice(event) + return + + if event.get("message_type") == "private": + await self._handle_private_message(event) + return + + if event.get("message_type") != "group": + return + + await self._handle_group_message(event) + + async def _handle_private_message(self, event: dict[str, Any]) -> None: + """处理私聊消息事件。""" + private_sender_id: int = get_message_sender_id(event) + private_message_content: list[dict[str, Any]] = get_message_content(event) + trigger_message_id = event.get("message_id") + + if not self.config.is_private_allowed(private_sender_id): + private_reason = ( + self.config.private_access_denied_reason(private_sender_id) or "unknown" + ) + logger.debug( + "[访问控制] 忽略私聊消息: user=%s reason=%s (access enabled=%s)", + private_sender_id, + private_reason, + self.config.access_control_enabled(), + ) + return + + private_sender: dict[str, Any] = event.get("sender", {}) + private_sender_nickname: str = private_sender.get("nickname", "") + + user_name = private_sender_nickname + if not user_name: + try: + user_info = await self.onebot.get_stranger_info(private_sender_id) + if user_info: + user_name = user_info.get("nickname", "") + except Exception as exc: + logger.warning("获取用户昵称失败: %s", exc) + + text = handlers_module.extract_text(private_message_content, self.config.bot_qq) + private_attachments, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + private_message_content, + user_id=private_sender_id, + request_type="private", + ), + handlers_module.parse_message_content_for_history( + private_message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + safe_text = redact_string(text) + logger.info( + "[私聊消息] 发送者=%s 昵称=%s 内容=%s", + private_sender_id, + user_name or private_sender_nickname, + safe_text[:100], + ) + resolved_private_name = (user_name or private_sender_nickname or "").strip() + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_private:{private_sender_id}", + sender_id=private_sender_id, + sender_name=resolved_private_name, + ) + + parsed_content = append_attachment_text(parsed_content_raw, private_attachments) + safe_parsed = redact_string(parsed_content) + logger.debug( + "[历史记录] 保存私聊: user=%s content=%s...", + private_sender_id, + safe_parsed[:50], + ) + await self.history_manager.add_private_message( + user_id=private_sender_id, + text_content=parsed_content, + display_name=private_sender_nickname, + user_name=user_name, + message_id=trigger_message_id, + attachments=private_attachments, + ) + + # 机器人自身消息只写历史,不触发后续自动回复/入库 + if private_sender_id == self.config.bot_qq: + return + + self._schedule_meme_ingest( + attachments=private_attachments, + chat_type="private", + chat_id=private_sender_id, + sender_id=private_sender_id, + message_id=safe_int(trigger_message_id), + scope_key=build_attachment_scope( + user_id=private_sender_id, + request_type="private", + ), + ) + + if not self.config.should_process_private_message(): + logger.debug( + "[消息策略] 已关闭私聊处理: user=%s", + private_sender_id, + ) + return + + # 多模型池控制指令优先于斜杠命令与 AI 回复 + if ( + getattr(self.config, "model_pool_enabled", False) + and _is_private_model_pool_control_text(text) + ) and await self.ai_coordinator.model_pool.handle_private_message( + private_sender_id, + text, + ): + return + + private_command = self.command_dispatcher.parse_command(text) + if private_command: + await self._flush_command_buffer( + scope=make_scope(user_id=private_sender_id), + sender_id=private_sender_id, + ) + await self.command_dispatcher.dispatch_private( + user_id=private_sender_id, + sender_id=private_sender_id, + command=private_command, + ) + return + + await self._run_pipelines( + target_id=private_sender_id, + target_type="private", + text=text, + message_content=private_message_content, + ) + + await self.ai_coordinator.handle_private_reply( + private_sender_id, + text, + private_message_content, + attachments=private_attachments, + sender_name=user_name, + trigger_message_id=trigger_message_id, + ) + + async def _handle_group_message(self, event: dict[str, Any]) -> None: + """处理群聊消息事件。""" + group_id: int = event.get("group_id", 0) + sender_id: int = get_message_sender_id(event) + message_content: list[dict[str, Any]] = get_message_content(event) + trigger_message_id = event.get("message_id") + + if not self.config.is_group_allowed(group_id): + group_reason = self.config.group_access_denied_reason(group_id) or "unknown" + logger.debug( + "[访问控制] 忽略群消息: group=%s sender=%s reason=%s (access enabled=%s)", + group_id, + sender_id, + group_reason, + self.config.access_control_enabled(), + ) + return + + group_sender: dict[str, Any] = event.get("sender", {}) + sender_card: str = group_sender.get("card", "") + sender_nickname: str = group_sender.get("nickname", "") + sender_role: str = group_sender.get("role", "member") + sender_title: str = group_sender.get("title", "") + sender_level: str = str(group_sender.get("level", "")).strip() + + text = handlers_module.extract_text(message_content, self.config.bot_qq) + safe_text = redact_string(text) + logger.info( + f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " + f"role={sender_role} | {safe_text[:100]}" + ) + + async def _fetch_group_name() -> str: + try: + info = await self.onebot.get_group_info(group_id) + if info: + return str(info.get("group_name", "") or "") + except Exception as e: + logger.warning(f"获取群聊名失败: {e}") + return "" + + group_attachments, group_name, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + message_content, + group_id=group_id, + request_type="group", + ), + _fetch_group_name(), + handlers_module.parse_message_content_for_history( + message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + + resolved_group_sender_name = (sender_card or sender_nickname or "").strip() + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", + sender_id=sender_id, + sender_name=resolved_group_sender_name, + group_id=group_id, + group_name=str(group_name or "").strip(), + ) + + parsed_content = append_attachment_text(parsed_content_raw, group_attachments) + safe_parsed = redact_string(parsed_content) + logger.debug( + f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." + ) + await self.history_manager.add_group_message( + group_id=group_id, + sender_id=sender_id, + text_content=parsed_content, + sender_card=sender_card, + sender_nickname=sender_nickname, + group_name=group_name, + role=sender_role, + title=sender_title, + level=sender_level, + message_id=trigger_message_id, + attachments=group_attachments, + ) + + # 机器人发言计入复读计数,防止 bot 复读自身 + if sender_id == self.config.bot_qq: + await self._append_bot_repeat_counter(group_id, text) + return + + self._schedule_meme_ingest( + attachments=group_attachments, + chat_type="group", + chat_id=group_id, + sender_id=sender_id, + message_id=safe_int(trigger_message_id), + scope_key=build_attachment_scope(group_id=group_id, request_type="group"), + ) + + is_at_bot = self.ai_coordinator._is_at_bot(message_content) + + # 文本 @ 未命中 CQ at 段时,尝试识别「假 @」昵称 + is_fake_at = False + normalized_text = text + if not is_at_bot and ("@" in text or "@" in text): + nicknames = await self._bot_nickname_cache.get_nicknames(group_id) + if nicknames: + is_fake_at, normalized_text = strip_fake_at(text, nicknames) + if is_fake_at: + is_at_bot = True + logger.info( + "[假@] 识别到假@: group=%s sender=%s", + group_id, + sender_id, + ) + + if not self.config.should_process_group_message(is_at_bot=is_at_bot): + logger.debug( + "[消息策略] 跳过群消息处理: group=%s sender=%s process_every_message=%s at_bot=%s", + group_id, + sender_id, + self.config.process_every_message, + is_at_bot, + ) + return + + # 斜杠命令仅在 @bot 时生效;未 @ 时不拦截普通群聊 + if is_at_bot: + command = self.command_dispatcher.parse_command(normalized_text) + if command: + await self._flush_command_buffer( + scope=make_scope(group_id=group_id), + sender_id=sender_id, + ) + await self.command_dispatcher.dispatch(group_id, sender_id, command) + return + + if self.config.keyword_reply_enabled and handlers_module.matches_xinliweiyuan( + text + ): + if await self._handle_keyword_reply(group_id, sender_id): + return + + # 复读命中则跳过管线与 AI 自动回复 + if await self._maybe_trigger_repeat(group_id, sender_id, text): + return + + await self._run_pipelines( + target_id=group_id, + target_type="group", + text=text, + message_content=message_content, + ) + + display_name = sender_card or sender_nickname or str(sender_id) + await self.ai_coordinator.handle_auto_reply( + group_id, + sender_id, + normalized_text, + message_content, + attachments=group_attachments, + sender_name=display_name, + group_name=group_name, + sender_role=sender_role, + sender_title=sender_title, + sender_level=sender_level, + trigger_message_id=trigger_message_id, + is_fake_at=is_fake_at, + ) + + async def _handle_keyword_reply(self, group_id: int, sender_id: int) -> bool: + """处理心理委员关键词自动回复;若已发送回复则返回 True。""" + rand_val = random.random() + if rand_val < 0.01: + message = f"[@{sender_id}] 再发让你飞起来" + logger.info("关键词回复: 再发让你飞起来") + await self.sender.send_group_message( + group_id, + message, + history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, + ) + return True + if rand_val < 0.11: + try: + image_path = resolve_resource_path("img/xlwy.jpg").resolve().as_uri() + except Exception: + image_path = Path(os.path.abspath("img/xlwy.jpg")).as_uri() + message = f"[CQ:image,file={image_path}]" + if random.random() < 0.5: + message = f"[@{sender_id}] {message}" + logger.info("关键词回复: 发送图片 xlwy.jpg") + else: + if random.random() < 0.7: + reply = "受着" + else: + reply = "那咋了" + if random.random() < 0.5: + message = f"[@{sender_id}] {reply}" + else: + message = reply + logger.info(f"关键词回复: {reply}") + await self.sender.send_group_message( + group_id, + message, + history_prefix=KEYWORD_REPLY_HISTORY_PREFIX, + ) + return True + + async def _flush_command_buffer(self, *, scope: str, sender_id: int) -> None: + batcher_config = getattr(self.config, "message_batcher", None) + if not getattr(batcher_config, "flush_on_command", False): + return + batcher = getattr(self, "message_batcher", None) + if batcher is None: + return + # 斜杠命令命中时强制 flush 未合并 buffer,避免命令与待发 batch 交错 + flushed = await batcher.flush_sender(scope, sender_id) + if not flushed: + logger.warning( + "[MessageBatcher] 命令触发 flush 当前 buffer 失败: scope=%s sender=%s", + scope, + sender_id, + ) + + async def _run_pipelines( + self, + *, + target_id: int, + target_type: Literal["group", "private"], + text: str, + message_content: list[dict[str, Any]], + ) -> bool: + """并行检测并处理所有命中的自动处理管线。""" + if not getattr(self, "_pipelines_initialized", False): + await self.init_pipelines() + context = build_pipeline_context( + self, + target_id=target_id, + target_type=target_type, + text=text, + message_content=message_content, + ) + detections = await self.pipeline_registry.run(context) + return bool(detections) + + async def apply_skills_hot_reload_config( + self, + *, + enabled: bool, + interval: float, + debounce: float, + ) -> None: + """跟随全局 skills 热重载配置更新管线。""" + if not enabled: + await self.pipeline_registry.stop_hot_reload() + logger.info("[pipelines] 热重载已随配置禁用") + return + + await self.pipeline_registry.stop_hot_reload() + self.pipeline_registry.start_hot_reload( + interval=interval, + debounce=debounce, + ) + + def _spawn_background_task( + self, + name: str, + coroutine: Coroutine[Any, Any, None], + ) -> None: + task = asyncio.create_task(coroutine, name=name) + self._background_tasks.add(task) + + def _finalize(done_task: asyncio.Task[None]) -> None: + self._background_tasks.discard(done_task) + try: + exc = done_task.exception() + except asyncio.CancelledError: + logger.debug("[后台任务] 已取消: %s", name) + return + if exc is not None: + logger.exception( + "[后台任务] 执行失败: name=%s", + name, + exc_info=(type(exc), exc, exc.__traceback__), + ) + + task.add_done_callback(_finalize) + + async def close(self) -> None: + """关闭消息处理器。""" + logger.info("正在关闭消息处理器...") + if self._background_tasks: + logger.info( + "[后台任务] 等待自动提取任务收敛: count=%s", + len(self._background_tasks), + ) + await asyncio.gather( + *list(self._background_tasks), + return_exceptions=True, + ) + await self.pipeline_registry.stop_hot_reload() + await self.message_batcher.flush_all() + # 关闭前排空 AI 队列并落盘历史,避免丢回复/丢记录 + await self.ai_coordinator.queue_manager.drain() + await self.ai_coordinator.queue_manager.stop() + await self.history_manager.flush_pending_saves() + logger.info("消息处理器已关闭") diff --git a/src/Undefined/handlers/poke.py b/src/Undefined/handlers/poke.py new file mode 100644 index 00000000..06ba2738 --- /dev/null +++ b/src/Undefined/handlers/poke.py @@ -0,0 +1,293 @@ +"""拍一拍(poke)通知处理 mixin。 + +负责私聊/群聊拍一拍历史写入与 AI 回复触发;由 ``MessageHandler`` 混入使用。 +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.onebot import OneBotClient + from Undefined.services.ai_coordinator import AICoordinator + from Undefined.utils.history import MessageHistoryManager + +logger = logging.getLogger(__name__) + + +def _format_poke_history_text(display_name: str, user_id: int) -> str: + """格式化拍一拍历史文本。""" + return f"{display_name}(暱称)[{user_id}(QQ号)] 拍了拍你。" + + +@dataclass(frozen=True) +class PrivatePokeRecord: + """私聊拍一拍历史记录摘要。""" + + poke_text: str + sender_name: str + + +@dataclass(frozen=True) +class GroupPokeRecord: + """群聊拍一拍历史记录摘要。""" + + poke_text: str + sender_name: str + group_name: str + sender_role: str + sender_title: str + sender_level: str + + +class PokeMixin: + """拍一拍事件处理 mixin。""" + + if TYPE_CHECKING: + config: Config + onebot: OneBotClient + ai_coordinator: AICoordinator + history_manager: MessageHistoryManager + + def _schedule_profile_display_name_refresh( + self, + *, + task_name: str, + sender_id: int | None = None, + sender_name: str = "", + group_id: int | None = None, + group_name: str = "", + ) -> None: ... + + async def _handle_poke_notice(self, event: dict[str, Any]) -> None: + """处理拍一拍通知并触发对应私聊/群聊 AI 回复。""" + # 仅处理拍机器人自身的 poke + target_id = event.get("target_id", 0) + if target_id != self.config.bot_qq: + logger.debug( + "[通知] 忽略拍一拍目标非机器人: target=%s", + target_id, + ) + return + + if not self.config.should_process_poke_message(): + logger.debug("[消息策略] 已关闭拍一拍处理,忽略此次 poke 事件") + return + + poke_group_id: int = event.get("group_id", 0) + poke_sender_id: int = event.get("user_id", 0) + + if poke_group_id == 0: + # group_id=0 表示私聊拍一拍 + if not self.config.is_private_allowed(poke_sender_id): + private_reason = ( + self.config.private_access_denied_reason(poke_sender_id) + or "unknown" + ) + logger.debug( + "[访问控制] 忽略私聊拍一拍: user=%s reason=%s (access enabled=%s)", + poke_sender_id, + private_reason, + self.config.access_control_enabled(), + ) + return + else: + if not self.config.is_group_allowed(poke_group_id): + group_reason = ( + self.config.group_access_denied_reason(poke_group_id) or "unknown" + ) + logger.debug( + "[访问控制] 忽略群聊拍一拍: group=%s sender=%s reason=%s (access enabled=%s)", + poke_group_id, + poke_sender_id, + group_reason, + self.config.access_control_enabled(), + ) + return + + logger.info( + "[通知] 收到拍一拍: group=%s sender=%s", + poke_group_id, + poke_sender_id, + ) + logger.debug("[通知] 拍一拍事件数据: %s", str(event)[:200]) + + if poke_group_id == 0: + private_poke = await self._record_private_poke_history( + poke_sender_id, event + ) + logger.info("[通知] 私聊拍一拍,触发私聊回复") + # 拍一拍旁路 MessageBatcher,直接走 mention 级队列 + await self.ai_coordinator.handle_private_reply( + poke_sender_id, + private_poke.poke_text, + [], + is_poke=True, + sender_name=private_poke.sender_name, + ) + else: + group_poke = await self._record_group_poke_history( + poke_group_id, + poke_sender_id, + event, + ) + logger.info( + "[通知] 群聊拍一拍,触发群聊回复: group=%s", + poke_group_id, + ) + await self.ai_coordinator.handle_auto_reply( + poke_group_id, + poke_sender_id, + group_poke.poke_text, + [], + is_poke=True, + sender_name=group_poke.sender_name, + group_name=group_poke.group_name, + sender_role=group_poke.sender_role, + sender_title=group_poke.sender_title, + sender_level=group_poke.sender_level, + ) + + async def _record_private_poke_history( + self, user_id: int, event: dict[str, Any] + ) -> PrivatePokeRecord: + """记录私聊拍一拍到历史。""" + sender = event.get("sender", {}) + sender_nickname = "" + if isinstance(sender, dict): + sender_nickname = str(sender.get("nickname", "")).strip() + + user_name = sender_nickname + if not user_name: + try: + user_info = await self.onebot.get_stranger_info(user_id) + if isinstance(user_info, dict): + user_name = str(user_info.get("nickname", "")).strip() + except Exception as exc: + logger.warning( + "[通知] 获取私聊拍一拍用户昵称失败: user=%s err=%s", + user_id, + exc, + ) + + resolved_sender_name = (sender_nickname or user_name).strip() + display_name = resolved_sender_name or f"QQ{user_id}" + normalized_user_name = user_name or display_name + poke_text = _format_poke_history_text(display_name, user_id) + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_private_poke:{user_id}", + sender_id=user_id, + sender_name=resolved_sender_name, + ) + + try: + await self.history_manager.add_private_message( + user_id=user_id, + text_content=poke_text, + display_name=display_name, + user_name=normalized_user_name, + ) + except Exception as exc: + logger.warning( + "[历史记录] 写入私聊拍一拍失败: user=%s err=%s", + user_id, + exc, + ) + return PrivatePokeRecord(poke_text=poke_text, sender_name=display_name) + + async def _record_group_poke_history( + self, + group_id: int, + sender_id: int, + event: dict[str, Any], + ) -> GroupPokeRecord: + """记录群聊拍一拍到历史。""" + sender = event.get("sender", {}) + sender_card = "" + sender_nickname = "" + sender_role = "member" + sender_title = "" + sender_level = "" + if isinstance(sender, dict): + sender_card = str(sender.get("card", "")).strip() + sender_nickname = str(sender.get("nickname", "")).strip() + sender_role = str(sender.get("role", "member")).strip() or "member" + sender_title = str(sender.get("title", "")).strip() + sender_level = str(sender.get("level", "")).strip() + + if not sender_card and not sender_nickname: + try: + member_info = await self.onebot.get_group_member_info( + group_id, sender_id + ) + if isinstance(member_info, dict): + sender_card = str(member_info.get("card", "")).strip() + sender_nickname = str(member_info.get("nickname", "")).strip() + sender_role = ( + str(member_info.get("role", "member")).strip() or "member" + ) + sender_title = str(member_info.get("title", "")).strip() + sender_level = str(member_info.get("level", "")).strip() + except Exception as exc: + logger.warning( + "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", + group_id, + sender_id, + exc, + ) + + group_name = "" + try: + group_info = await self.onebot.get_group_info(group_id) + if isinstance(group_info, dict): + group_name = str(group_info.get("group_name", "")).strip() + except Exception as exc: + logger.warning( + "[通知] 获取拍一拍群名失败: group=%s err=%s", + group_id, + exc, + ) + + resolved_sender_name = (sender_card or sender_nickname).strip() + resolved_group_name = group_name.strip() + display_name = resolved_sender_name or f"QQ{sender_id}" + poke_text = _format_poke_history_text(display_name, sender_id) + normalized_group_name = resolved_group_name or f"群{group_id}" + self._schedule_profile_display_name_refresh( + task_name=f"profile_name_refresh_group_poke:{group_id}:{sender_id}", + sender_id=sender_id, + sender_name=resolved_sender_name, + group_id=group_id, + group_name=resolved_group_name, + ) + + try: + await self.history_manager.add_group_message( + group_id=group_id, + sender_id=sender_id, + text_content=poke_text, + sender_card=sender_card, + sender_nickname=sender_nickname, + group_name=normalized_group_name, + role=sender_role, + title=sender_title, + level=sender_level, + ) + except Exception as exc: + logger.warning( + "[历史记录] 写入群聊拍一拍失败: group=%s sender=%s err=%s", + group_id, + sender_id, + exc, + ) + return GroupPokeRecord( + poke_text=poke_text, + sender_name=display_name, + group_name=normalized_group_name, + sender_role=sender_role, + sender_title=sender_title, + sender_level=sender_level, + ) diff --git a/src/Undefined/handlers/repeat.py b/src/Undefined/handlers/repeat.py new file mode 100644 index 00000000..5307e792 --- /dev/null +++ b/src/Undefined/handlers/repeat.py @@ -0,0 +1,146 @@ +"""群聊复读功能 mixin。 + +按群跟踪连续相同消息并在阈值满足时复读;由 ``MessageHandler`` 混入使用。 +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.utils.sender import MessageSender + +from Undefined.utils.logging import redact_string + +logger = logging.getLogger(__name__) + +REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " + + +class RepeatMixin: + """群聊复读计数与触发 mixin。""" + + if TYPE_CHECKING: + config: Config + sender: MessageSender + _repeat_counter: dict[int, list[tuple[str, int]]] + _repeat_locks: dict[int, asyncio.Lock] + _repeat_cooldown: dict[int, dict[str, float]] + + def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: + """获取或创建指定群的复读竞态保护锁。""" + lock = self._repeat_locks.get(group_id) + if lock is None: + lock = asyncio.Lock() + self._repeat_locks[group_id] = lock + return lock + + @staticmethod + def _normalize_repeat_text(text: str) -> str: + """规范化复读文本用于冷却比较(?→?)。""" + return text.replace("?", "?") + + def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: + """检查指定群的文本是否在复读冷却期内。""" + cooldown_minutes = self.config.repeat_cooldown_minutes + if cooldown_minutes <= 0: + return False + group_cd = self._repeat_cooldown.get(group_id) + if not group_cd: + return False + key = self._normalize_repeat_text(text) + last_time = group_cd.get(key) + if last_time is None: + return False + return bool((time.monotonic() - last_time) < cooldown_minutes * 60) + + def _record_repeat_cooldown(self, group_id: int, text: str) -> None: + """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + if cooldown_seconds <= 0: + return + key = self._normalize_repeat_text(text) + group_cd = self._repeat_cooldown.setdefault(group_id, {}) + now = time.monotonic() + expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] + for k in expired: + del group_cd[k] + group_cd[key] = now + + async def _append_bot_repeat_counter(self, group_id: int, text: str) -> None: + """将 bot 自身发言写入复读计数器,防止误触复读。""" + if not self.config.repeat_enabled or not text: + return + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, self.config.bot_qq)) + n = self.config.repeat_threshold + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] + + async def _maybe_trigger_repeat( + self, + group_id: int, + sender_id: int, + text: str, + ) -> bool: + """尝试触发群聊复读;若已发送复读消息则返回 True。""" + if not self.config.repeat_enabled or not text: + return False + + n = self.config.repeat_threshold + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] + counter = self._repeat_counter[group_id] + + if len(counter) < n: + return False + + last_n = counter[-n:] + texts = [t for t, _ in last_n] + senders = [s for _, s in last_n] + # 连续 n 条文本相同且来自 n 个不同发送者,且 bot 未参与 + if not ( + len(set(texts)) == 1 + and len(set(senders)) == n + and self.config.bot_qq not in senders + ): + return False + + reply_text = texts[0] + if self._is_repeat_on_cooldown(group_id, reply_text): + # 冷却期内清空计数,避免同一文本反复试探 + self._repeat_counter[group_id] = [] + logger.debug( + "[复读] 冷却中跳过: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + return False + + if self.config.inverted_question_enabled: + stripped = reply_text.strip() + # 纯问号复读时翻转成 ¿ + if set(stripped) <= {"?", "?"}: + reply_text = "¿" * len(stripped) + + self._repeat_counter[group_id] = [] + self._record_repeat_cooldown(group_id, texts[0]) + logger.info( + "[复读] 触发复读: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + await self.sender.send_group_message( + group_id, + reply_text, + history_prefix=REPEAT_REPLY_HISTORY_PREFIX, + ) + return True diff --git a/src/Undefined/memes/_image_utils.py b/src/Undefined/memes/_image_utils.py new file mode 100644 index 00000000..aff46c57 --- /dev/null +++ b/src/Undefined/memes/_image_utils.py @@ -0,0 +1,147 @@ +"""Meme 图片处理与标签归一化共享工具。""" + +from __future__ import annotations + +import math +import mimetypes +import re +from datetime import datetime +from pathlib import Path + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from PIL import Image + +from Undefined.memes.models import normalize_string_list + +_IMAGE_EXTENSIONS_BY_MIME = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/svg+xml": ".svg", +} +_TAG_SPLIT_RE = re.compile(r"[,,\n]+") + + +def now_iso() -> str: + return datetime.now().isoformat(timespec="seconds") + + +def guess_suffix(path: Path, mime_type: str) -> str: + suffix = path.suffix.lower() + if suffix: + return suffix + guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) + if guessed: + return guessed + mime_guess = mimetypes.guess_extension(mime_type or "") + if mime_guess: + return mime_guess.lower() + return ".bin" + + +def normalize_tags(raw_tags: list[str] | str | None) -> list[str]: + if raw_tags is None: + return [] + if isinstance(raw_tags, str): + parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] + return normalize_string_list(parts) + return normalize_string_list(raw_tags) + + +def is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + +# 向后兼容:旧模块级私有名仍可从 service 等路径导入 +_now_iso = now_iso +_guess_suffix = guess_suffix +_normalize_tags = normalize_tags +_is_retryable_llm_error = is_retryable_llm_error +_extract_gif_frames = extract_gif_frames +_sample_frame_indices = sample_frame_indices +_compose_grid = compose_grid + +__all__ = [ + "compose_grid", + "extract_gif_frames", + "guess_suffix", + "is_retryable_llm_error", + "normalize_tags", + "now_iso", + "sample_frame_indices", + "_compose_grid", + "_extract_gif_frames", + "_guess_suffix", + "_is_retryable_llm_error", + "_normalize_tags", + "_now_iso", + "_sample_frame_indices", +] diff --git a/src/Undefined/memes/ingest.py b/src/Undefined/memes/ingest.py new file mode 100644 index 00000000..dd157bf5 --- /dev/null +++ b/src/Undefined/memes/ingest.py @@ -0,0 +1,625 @@ +"""MemeService 入库与后台任务处理。""" + +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from dataclasses import replace +import hashlib +import logging +import mimetypes +from pathlib import Path +import shutil +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from PIL import Image + +from Undefined.memes._image_utils import ( + _compose_grid, + _extract_gif_frames, + _guess_suffix, + _is_retryable_llm_error, + _normalize_tags, + _now_iso, +) +from Undefined.memes.models import ( + IngestDigestLockEntry, + MemeRecord, + MemeSourceRecord, + build_search_text, +) + +if TYPE_CHECKING: + from Undefined.memes.store import MemeStore + from Undefined.memes.vector_store import MemeVectorStore + +logger = logging.getLogger(__name__) + + +class MemeIngestMixin: + if TYPE_CHECKING: + _ai_client: Any | None + _attachment_registry: Any | None + _ingest_digest_locks: dict[str, Any] + _ingest_digest_locks_guard: asyncio.Lock + _job_queue: Any | None + _store: MemeStore + _vector_store: MemeVectorStore + + def _blob_dir(self) -> Path: ... + def _cfg(self) -> Any: ... + def _invalidate_global_image_cache(self, uid: str) -> None: ... + def _preview_dir(self) -> Path: ... + def _queue_enabled(self) -> bool: ... + async def delete_meme(self, uid: str) -> bool: ... + def enabled(self) -> bool: ... + + async def _acquire_ingest_digest_lock(self, digest: str) -> IngestDigestLockEntry: + async with self._ingest_digest_locks_guard: + entry = self._ingest_digest_locks.get(digest) + if entry is None: + entry = IngestDigestLockEntry(lock=asyncio.Lock()) + self._ingest_digest_locks[digest] = entry + entry.users += 1 + try: + await entry.lock.acquire() + except BaseException: + await self._release_ingest_digest_lock_reference(digest, entry) + raise + return entry + + async def _release_ingest_digest_lock_reference( + self, + digest: str, + entry: IngestDigestLockEntry, + *, + release_lock: bool = False, + ) -> None: + if release_lock and entry.lock.locked(): + entry.lock.release() + async with self._ingest_digest_locks_guard: + entry.users = max(0, entry.users - 1) + current = self._ingest_digest_locks.get(digest) + if current is entry and entry.users == 0 and not entry.lock.locked(): + self._ingest_digest_locks.pop(digest, None) + + def _delete_file_if_exists(self, path: Path) -> None: + try: + path.unlink(missing_ok=True) + except OSError: + logger.debug("[memes] 删除文件失败: path=%s", path, exc_info=True) + + def _cleanup_gif_frame_files(self, uid: str) -> None: + """清理 GIF 多帧分析产生的临时帧文件 ({uid}_f{i}.png)。""" + preview_dir = self._preview_dir() + for frame_file in preview_dir.glob(f"{uid}_f*.png"): + try: + frame_file.unlink(missing_ok=True) + except OSError: + logger.debug( + "[memes] 删除帧文件失败: path=%s", frame_file, exc_info=True + ) + + async def _cleanup_meme_artifacts( + self, + *, + uid: str | None, + blob_path: Path, + preview_path: Path | None, + ) -> None: + if uid: + try: + await self._store.delete(uid) + self._invalidate_global_image_cache(uid) + except Exception: + logger.exception( + "[memes] 清理记录失败: uid=%s", + uid, + ) + try: + await self._vector_store.delete(uid) + except Exception: + logger.exception( + "[memes] 清理向量索引失败: uid=%s", + uid, + ) + await asyncio.to_thread(self._delete_file_if_exists, blob_path) + if preview_path is not None and preview_path != blob_path: + await asyncio.to_thread(self._delete_file_if_exists, preview_path) + if uid: + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + + async def enqueue_incoming_attachments( + self, + *, + attachments: list[dict[str, str]], + chat_type: str, + chat_id: int, + sender_id: int, + message_id: int | None, + scope_key: str, + ) -> None: + if not self.enabled() or not self._queue_enabled(): + return + cfg = self._cfg() + if chat_type == "group" and not bool(cfg.auto_ingest_group): + return + if chat_type == "private" and not bool(cfg.auto_ingest_private): + return + + for item in attachments: + media_type = str(item.get("media_type") or item.get("kind") or "").strip() + uid = str(item.get("uid") or "").strip() + if media_type != "image" or not uid: + continue + job = { + "request_id": f"meme_ingest_{uid}", + "kind": "ingest", + "attachment_uid": uid, + "scope_key": scope_key, + "chat_type": chat_type, + "chat_id": str(chat_id), + "sender_id": str(sender_id), + "message_id": str(message_id or ""), + "queued_at": _now_iso(), + } + queue = self._job_queue + if queue is None: + return + await queue.enqueue(job) + + async def enqueue_reanalyze(self, uid: str) -> str | None: + if not self._queue_enabled(): + return None + queue = self._job_queue + if queue is None: + return None + result = await queue.enqueue( + { + "request_id": f"meme_reanalyze_{uid}", + "kind": "reanalyze", + "uid": uid, + "queued_at": _now_iso(), + } + ) + return str(result) + + async def enqueue_reindex(self, uid: str) -> str | None: + if not self._queue_enabled(): + return None + queue = self._job_queue + if queue is None: + return None + result = await queue.enqueue( + { + "request_id": f"meme_reindex_{uid}", + "kind": "reindex", + "uid": uid, + "queued_at": _now_iso(), + } + ) + return str(result) + + async def process_job(self, job: Mapping[str, Any]) -> None: + kind = str(job.get("kind") or "").strip().lower() + if kind == "ingest": + await self._process_ingest_job(job) + return + if kind == "reanalyze": + await self._process_reanalyze_job(job) + return + if kind == "reindex": + await self._process_reindex_job(job) + return + raise ValueError(f"unsupported meme job kind: {kind}") + + async def _process_reindex_job(self, job: Mapping[str, Any]) -> None: + uid = str(job.get("uid") or "").strip() + if not uid: + return + record = await self._store.get(uid) + if record is None: + return + await self._vector_store.upsert(record) + + async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: + uid = str(job.get("uid") or "").strip() + if not uid: + return + record = await self._store.get(uid) + if record is None: + return + if self._ai_client is None: + raise RuntimeError("reanalyze requires ai_client") + analyze_path: str | list[str] = ( + record.preview_path if record.preview_path else record.blob_path + ) + # GIF 多帧模式:与 ingest 路径保持一致 + if record.is_animated: + cfg = self._cfg() + if str(getattr(cfg, "gif_analysis_mode", "grid")).lower() == "multi": + analyze_path = await self._prepare_gif_multi_frames( + Path(record.blob_path), uid + ) + try: + judgement = await self._ai_client.judge_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] judge stage failed during reanalyze: uid=%s err=%s", uid, exc + ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + return + if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + await self.delete_meme(uid) + return + try: + described = await self._ai_client.describe_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] describe stage failed during reanalyze: uid=%s err=%s", + uid, + exc, + ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + return + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + auto_description = str(described.get("description") or "").strip() + next_tags = _normalize_tags(described.get("tags")) + if not auto_description and not next_tags: + logger.warning( + "[memes] reanalyze describe failed, skip update: uid=%s", uid + ) + return + next_record = replace( + record, + auto_description=auto_description, + ocr_text="", + tags=next_tags, + search_text=build_search_text( + manual_description=record.manual_description, + auto_description=auto_description, + ocr_text="", + tags=next_tags, + aliases=record.aliases, + ), + updated_at=_now_iso(), + ) + saved = await self._store.upsert_record(next_record) + self._invalidate_global_image_cache(saved.uid) + await self._vector_store.upsert(saved) + + async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: + if self._attachment_registry is None: + raise RuntimeError("ingest requires attachment_registry") + if self._ai_client is None: + raise RuntimeError("ingest requires ai_client") + + attachment_uid = str(job.get("attachment_uid") or "").strip() + scope_key = str(job.get("scope_key") or "").strip() or None + if not attachment_uid: + return + attachment = self._attachment_registry.resolve(attachment_uid, scope_key) + if attachment is None: + raise FileNotFoundError(f"attachment uid unavailable: {attachment_uid}") + if str(attachment.media_type).lower() != "image": + return + source_path = Path(str(attachment.local_path or "")) + if not source_path.is_file(): + raise FileNotFoundError(source_path) + file_size = source_path.stat().st_size + cfg = self._cfg() + if file_size > int(cfg.max_source_image_bytes): + logger.info( + "[memes] skip oversized image: uid=%s size=%s limit=%s", + attachment_uid, + file_size, + cfg.max_source_image_bytes, + ) + return + + digest = await asyncio.to_thread(self._hash_file, source_path) + # 同一 SHA256 并发入库串行化,避免重复 AI 判定 + digest_lock_entry = await self._acquire_ingest_digest_lock(digest) + try: + existing = await self._store.find_by_sha256(digest) + if existing is not None and not Path(existing.blob_path).is_file(): + logger.warning( + "[memes] 检测到孤儿记录,删除后重新入库: uid=%s blob_path=%s", + existing.uid, + existing.blob_path, + ) + await self._cleanup_meme_artifacts( + uid=existing.uid, + blob_path=Path(existing.blob_path), + preview_path=( + Path(existing.preview_path) if existing.preview_path else None + ), + ) + existing = await self._store.find_by_sha256(digest) + if existing is not None and not Path(existing.blob_path).is_file(): + raise RuntimeError( + f"stale meme record cleanup failed: uid={existing.uid}" + ) + source = MemeSourceRecord( + uid=existing.uid if existing is not None else "", + source_type="message_attachment", + chat_type=str(job.get("chat_type") or ""), + chat_id=str(job.get("chat_id") or ""), + sender_id=str(job.get("sender_id") or ""), + message_id=str(job.get("message_id") or ""), + attachment_uid=attachment_uid, + source_url=str(attachment.source_ref or ""), + seen_at=_now_iso(), + ) + if existing is not None: + # 内容已存在:仅追加来源记录并刷新向量索引 + await self._store.add_source(replace(source, uid=existing.uid)) + await self._vector_store.upsert(existing) + return + + with Image.open(source_path) as image: + width, height = image.size + is_animated = bool(getattr(image, "is_animated", False)) + if is_animated and not bool(cfg.allow_gif): + return + + uid = await self._generate_uid() + suffix = _guess_suffix(source_path, str(attachment.mime_type or "")) + blob_path = self._blob_dir() / f"{uid}{suffix}" + cleanup_preview_path = ( + self._preview_dir() / f"{uid}.png" if is_animated else blob_path + ) + persisted_uid: str | None = None + + try: + preview_path = await self._prepare_blob_and_preview( + source_path=source_path, + target_uid=uid, + suffix=suffix, + is_animated=is_animated, + ) + if preview_path is not None: + cleanup_preview_path = preview_path + mime_type = str( + attachment.mime_type + or mimetypes.guess_type(source_path.name)[0] + or "application/octet-stream" + ) + analyze_path: str | list[str] = str( + preview_path if preview_path is not None else blob_path + ) + if ( + is_animated + and str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + == "multi" + ): + analyze_path = await self._prepare_gif_multi_frames( + source_path, uid + ) + try: + judgement = await self._ai_client.judge_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] judge stage failed, treat as non-meme: uid=%s err=%s", + uid, + exc, + ) + judgement = {"is_meme": False} + if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + # 非表情包:清理已落盘文件,不入库 + await self._cleanup_meme_artifacts( + uid=None, + blob_path=blob_path, + preview_path=cleanup_preview_path, + ) + return + + try: + described = await self._ai_client.describe_meme_image(analyze_path) + except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise + logger.exception( + "[memes] describe stage failed, drop uid=%s err=%s", uid, exc + ) + described = {"description": "", "tags": []} + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + tags = _normalize_tags(described.get("tags")) + auto_description = str(described.get("description") or "").strip() + if not auto_description and not tags: + logger.warning( + "[memes] describe stage returned empty result, drop uid=%s", uid + ) + await self._cleanup_meme_artifacts( + uid=None, + blob_path=blob_path, + preview_path=cleanup_preview_path, + ) + return + now = _now_iso() + record = MemeRecord( + uid=uid, + content_sha256=digest, + blob_path=str(blob_path), + preview_path=( + str(preview_path) if preview_path is not None else None + ), + mime_type=mime_type, + file_size=file_size, + width=width, + height=height, + is_animated=is_animated, + enabled=True, + pinned=False, + auto_description=auto_description, + manual_description="", + ocr_text="", + tags=tags, + aliases=[], + search_text=build_search_text( + manual_description="", + auto_description=auto_description, + ocr_text="", + tags=tags, + aliases=[], + ), + use_count=0, + last_used_at="", + created_at=now, + updated_at=now, + status="ready", + segment_data={"subType": "1"}, + ) + saved = await self._store.upsert_record(record) + self._invalidate_global_image_cache(saved.uid) + persisted_uid = saved.uid + await self._store.add_source(replace(source, uid=saved.uid)) + await self._vector_store.upsert(saved) + except Exception: + await self._cleanup_meme_artifacts( + uid=persisted_uid, + blob_path=blob_path, + preview_path=cleanup_preview_path, + ) + raise + finally: + await self._release_ingest_digest_lock_reference( + digest, + digest_lock_entry, + release_lock=True, + ) + await self._prune_if_needed() + + async def _prepare_blob_and_preview( + self, + *, + source_path: Path, + target_uid: str, + suffix: str, + is_animated: bool, + ) -> Path | None: + blob_path = self._blob_dir() / f"{target_uid}{suffix}" + + def _copy() -> None: + shutil.copy2(source_path, blob_path) + + await asyncio.to_thread(_copy) + if not is_animated: + return blob_path + + cfg = self._cfg() + mode = str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) + preview_path = self._preview_dir() / f"{target_uid}.png" + + def _render_preview() -> None: + frames = _extract_gif_frames(source_path, n_frames) + if mode == "multi": + frames[0].save(preview_path, format="PNG") + else: + _compose_grid(frames, preview_path) + for f in frames: + f.close() + + await asyncio.to_thread(_render_preview) + return preview_path + + async def _prepare_gif_multi_frames( + self, source_path: Path, target_uid: str + ) -> list[str]: + """multi 模式:将 GIF 各帧单独保存为 PNG,返回路径列表。""" + cfg = self._cfg() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) + preview_dir = self._preview_dir() + + def _render_frames() -> list[str]: + frames = _extract_gif_frames(source_path, n_frames) + paths: list[str] = [] + for i, frame in enumerate(frames): + p = preview_dir / f"{target_uid}_f{i}.png" + frame.save(p, format="PNG") + frame.close() + paths.append(str(p)) + return paths + + return await asyncio.to_thread(_render_frames) + + def _hash_file(self, path: Path) -> str: + hasher = hashlib.sha256() + with path.open("rb") as handle: + while True: + chunk = handle.read(1024 * 1024) + if not chunk: + break + hasher.update(chunk) + return hasher.hexdigest() + + async def _generate_uid(self) -> str: + while True: + candidate = f"pic_{uuid4().hex[:8]}" + if await self._store.get(candidate) is not None: + continue + if ( + self._attachment_registry is not None + and self._attachment_registry.get(candidate) is not None + ): + continue + return candidate + + async def _prune_if_needed(self) -> None: + stats = await self._store.stats() + cfg = self._cfg() + total_count = int(stats.get("total_count", 0)) + total_bytes = int(stats.get("total_bytes", 0)) + if total_count <= int(cfg.max_items) and total_bytes <= int( + cfg.max_total_bytes + ): + return + candidates = await self._store.list_prune_candidates() + for candidate in candidates: + if candidate.pinned: + continue + if total_count <= int(cfg.max_items) and total_bytes <= int( + cfg.max_total_bytes + ): + break + deleted = await self._store.delete(candidate.uid) + if deleted is None: + continue + self._invalidate_global_image_cache(candidate.uid) + await self._vector_store.delete(candidate.uid) + await asyncio.to_thread( + self._delete_file_if_exists, Path(deleted.blob_path) + ) + if deleted.preview_path and deleted.preview_path != deleted.blob_path: + await asyncio.to_thread( + self._delete_file_if_exists, + Path(deleted.preview_path), + ) + total_count -= 1 + total_bytes -= int(deleted.file_size) diff --git a/src/Undefined/memes/models.py b/src/Undefined/memes/models.py index 99a3c989..d4637535 100644 --- a/src/Undefined/memes/models.py +++ b/src/Undefined/memes/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass, field from typing import Any @@ -51,6 +52,12 @@ class MemeSourceRecord: seen_at: str +@dataclass +class IngestDigestLockEntry: + lock: asyncio.Lock + users: int = 0 + + @dataclass(frozen=True) class MemeSearchItem: uid: str diff --git a/src/Undefined/memes/search.py b/src/Undefined/memes/search.py new file mode 100644 index 00000000..abd2be24 --- /dev/null +++ b/src/Undefined/memes/search.py @@ -0,0 +1,471 @@ +"""MemeService 检索与列表操作。""" + +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + + +from Undefined.attachments import AttachmentRecord +from Undefined.memes._image_utils import _now_iso +from Undefined.memes.models import ( + MemeRecord, + MemeSearchItem, +) +from Undefined.utils.message_targets import resolve_message_target +from Undefined.utils.coerce import safe_int + +if TYPE_CHECKING: + import threading + + from Undefined.memes.store import MemeStore + from Undefined.memes.vector_store import MemeVectorStore + +logger = logging.getLogger(__name__) + + +class MemeSearchMixin: + if TYPE_CHECKING: + _cfg: Any + _global_image_cache: dict[str, AttachmentRecord] + _global_image_cache_lock: threading.Lock + _job_queue: Any | None + _retrieval_runtime: Any | None + _store: MemeStore + _vector_store: MemeVectorStore + + def resolve_global_image_sync(self, uid: str) -> AttachmentRecord | None: + normalized_uid = str(uid or "").strip() + if not normalized_uid: + return None + + with self._global_image_cache_lock: + cached = self._global_image_cache.get(normalized_uid) + if cached is not None: + return cached + + record = self._store.get_sync(normalized_uid) + if record is None or not record.enabled or record.status != "ready": + return None + # scope_key 留空:表情库 UID 全局可解析,不受会话 scope 限制 + attachment = AttachmentRecord( + uid=record.uid, + scope_key="", + kind="image", + media_type="image", + display_name=Path(record.blob_path).name, + source_kind="meme_library", + source_ref=Path(record.blob_path).resolve().as_uri(), + local_path=record.blob_path, + mime_type=record.mime_type, + sha256=record.content_sha256, + created_at=record.created_at, + segment_data={"subType": "1"}, + semantic_kind="meme", + description=record.description, + ) + with self._global_image_cache_lock: + self._global_image_cache[normalized_uid] = attachment + return attachment + + async def resolve_global_image(self, uid: str) -> AttachmentRecord | None: + return await asyncio.to_thread(self.resolve_global_image_sync, uid) + + async def get_meme(self, uid: str) -> dict[str, Any] | None: + record = await self._store.get(uid) + if record is None: + return None + sources = await self._store.get_sources(uid) + return { + "record": self.serialize_record(record), + "sources": [source.__dict__ for source in sources], + } + + async def get_record(self, uid: str) -> MemeRecord | None: + return await self._store.get(uid) + + def serialize_record(self, record: MemeRecord) -> dict[str, Any]: + preview_path = record.preview_path or record.blob_path + return { + "uid": record.uid, + "description": record.description, + "auto_description": record.auto_description, + "manual_description": record.manual_description, + "ocr_text": record.ocr_text, + "tags": list(record.tags), + "aliases": list(record.aliases), + "enabled": bool(record.enabled), + "pinned": bool(record.pinned), + "is_animated": bool(record.is_animated), + "mime_type": record.mime_type, + "file_size": record.file_size, + "width": record.width, + "height": record.height, + "blob_url": f"/api/v1/management/memes/{record.uid}/blob", + "preview_url": f"/api/v1/management/memes/{record.uid}/preview", + "use_count": record.use_count, + "last_used_at": record.last_used_at, + "created_at": record.created_at, + "updated_at": record.updated_at, + "status": record.status, + "search_text": record.search_text, + "preview_path": preview_path, + } + + def serialize_list_item(self, record: MemeRecord) -> dict[str, Any]: + return { + "uid": record.uid, + "description": record.description, + "enabled": bool(record.enabled), + "pinned": bool(record.pinned), + "is_animated": bool(record.is_animated), + "use_count": int(record.use_count), + "created_at": record.created_at, + "updated_at": record.updated_at, + "status": record.status, + } + + async def list_memes( + self, + *, + query: str = "", + enabled: bool | None = None, + animated: bool | None = None, + pinned: bool | None = None, + sort: str = "updated_at", + page: int = 1, + page_size: int = 50, + summary: bool = False, + ) -> dict[str, Any]: + items, total = await self._store.list_memes( + query=query, + enabled=enabled, + animated=animated, + pinned=pinned, + sort=sort, + page=page, + page_size=page_size, + ) + return { + "ok": True, + "total": total, + "window_total": total, + "total_exact": True, + "page": max(1, int(page)), + "page_size": max(1, min(200, int(page_size))), + "has_more": max(1, int(page)) * max(1, min(200, int(page_size))) < total, + "sort": str(sort or "updated_at"), + "items": [ + self.serialize_list_item(item) + if summary + else self.serialize_record(item) + for item in items + ], + } + + async def stats(self) -> dict[str, Any]: + stats = await self._store.stats() + cfg = self._cfg() + stats["max_items"] = int(cfg.max_items) + stats["max_total_bytes"] = int(cfg.max_total_bytes) + queue = self._job_queue.snapshot() if self._job_queue is not None else None + stats["queue"] = queue + return stats + + async def search_memes( + self, + query: str, + *, + query_mode: str = "hybrid", + keyword_query: str | None = None, + semantic_query: str | None = None, + top_k: int = 8, + include_disabled: bool = False, + sort: str = "relevance", + ) -> dict[str, Any]: + raw_query = str(query or "").strip() + raw_keyword_query = str(keyword_query or "").strip() + raw_semantic_query = str(semantic_query or "").strip() + normalized_mode = str(query_mode or "hybrid").strip().lower() + if normalized_mode not in {"keyword", "semantic", "hybrid"}: + normalized_mode = "hybrid" + resolved_keyword_query = raw_keyword_query or raw_query + resolved_semantic_query = raw_semantic_query or raw_query + if normalized_mode == "keyword" and not resolved_keyword_query: + return { + "ok": True, + "count": 0, + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "items": [], + } + if normalized_mode == "semantic" and not resolved_semantic_query: + return { + "ok": True, + "count": 0, + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "items": [], + } + if ( + normalized_mode == "hybrid" + and not resolved_keyword_query + and not resolved_semantic_query + ): + return { + "ok": True, + "count": 0, + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "items": [], + } + + cfg = self._cfg() + keyword_hits: list[dict[str, Any]] = [] + if normalized_mode in {"keyword", "hybrid"} and resolved_keyword_query: + keyword_hits = await self._store.search_keyword( + resolved_keyword_query, + limit=max(int(cfg.keyword_top_k), int(top_k)), + include_disabled=include_disabled, + ) + semantic_hits: list[dict[str, Any]] = [] + if normalized_mode in {"semantic", "hybrid"} and resolved_semantic_query: + semantic_hits = await self._vector_store.query( + resolved_semantic_query, + top_k=max(int(cfg.semantic_top_k), int(top_k)), + include_disabled=include_disabled, + ) + merged: dict[str, dict[str, Any]] = {} + + for item in keyword_hits: + record: MemeRecord = item["record"] + merged[record.uid] = { + "record": record, + "keyword_score": float(item.get("keyword_score", 0.0)), + "semantic_score": 0.0, + "rerank_score": None, + } + + missing_semantic_uids = [ + str(item.get("uid") or "").strip() + for item in semantic_hits + if str(item.get("uid") or "").strip() + and str(item.get("uid") or "").strip() not in merged + ] + missing_records = ( + await self._store.get_many(missing_semantic_uids) + if missing_semantic_uids + else {} + ) + + for item in semantic_hits: + uid = str(item.get("uid") or "").strip() + if not uid: + continue + existing = merged.get(uid) + if existing is None: + stored_record = missing_records.get(uid) + if stored_record is None: + continue + existing = { + "record": stored_record, + "keyword_score": 0.0, + "semantic_score": 0.0, + "rerank_score": None, + } + merged[uid] = existing + existing["semantic_score"] = max( + float(existing.get("semantic_score", 0.0)), + float(item.get("semantic_score", 0.0)), + ) + + reranker = ( + self._retrieval_runtime.ensure_reranker() + if self._retrieval_runtime is not None + else None + ) + ranked_candidates = list(merged.values()) + rerank_query = resolved_semantic_query or resolved_keyword_query + if ( + reranker is not None + and normalized_mode in {"semantic", "hybrid"} + and rerank_query + and ranked_candidates + ): + documents = [ + candidate["record"].search_text for candidate in ranked_candidates + ] + reranked = await reranker.rerank( + rerank_query, + documents, + top_n=min(len(documents), int(cfg.rerank_top_k)), + ) + for item in reranked: + try: + index = int(item.get("index")) + except (TypeError, ValueError): + continue + if index < 0 or index >= len(ranked_candidates): + continue + ranked_candidates[index]["rerank_score"] = float( + item.get("relevance_score", 0.0) or 0.0 + ) + + def _final_score(item: dict[str, Any]) -> float: + rerank_score = item.get("rerank_score") + if rerank_score is not None: + return float(rerank_score) + # hybrid 模式:keyword 与 semantic 取较高分 + return max( + float(item.get("keyword_score", 0.0)), + float(item.get("semantic_score", 0.0)), + ) + + normalized_sort = str(sort or "relevance").strip().lower() + if normalized_sort == "use_count": + ranked_candidates.sort( + key=lambda item: ( + item["record"].pinned, + item["record"].use_count, + item["record"].updated_at, + _final_score(item), + ), + reverse=True, + ) + elif normalized_sort == "created_at": + ranked_candidates.sort( + key=lambda item: ( + item["record"].pinned, + item["record"].created_at, + item["record"].updated_at, + _final_score(item), + ), + reverse=True, + ) + elif normalized_sort == "updated_at": + ranked_candidates.sort( + key=lambda item: ( + item["record"].pinned, + item["record"].updated_at, + item["record"].use_count, + _final_score(item), + ), + reverse=True, + ) + else: + ranked_candidates.sort( + key=lambda item: ( + _final_score(item), + item["record"].pinned, + item["record"].use_count, + item["record"].updated_at, + ), + reverse=True, + ) + + items: list[dict[str, Any]] = [] + for candidate in ranked_candidates[: max(1, int(top_k))]: + record = candidate["record"] + search_item = MemeSearchItem( + uid=record.uid, + description=record.description, + tags=list(record.tags), + aliases=list(record.aliases), + enabled=bool(record.enabled), + pinned=bool(record.pinned), + is_animated=record.is_animated, + created_at=record.created_at, + updated_at=record.updated_at, + score=round(_final_score(candidate), 6), + keyword_score=round(float(candidate.get("keyword_score", 0.0)), 6), + semantic_score=round(float(candidate.get("semantic_score", 0.0)), 6), + rerank_score=( + round(float(candidate["rerank_score"]), 6) + if candidate.get("rerank_score") is not None + else None + ), + use_count=record.use_count, + ) + items.append(search_item.__dict__) + return { + "ok": True, + "count": len(items), + "query_mode": normalized_mode, + "keyword_query": resolved_keyword_query, + "semantic_query": resolved_semantic_query, + "sort": normalized_sort, + "items": items, + } + + async def send_meme_by_uid(self, uid: str, context: dict[str, Any]) -> str: + record = await self._store.get(uid) + if record is None or not record.enabled or record.status != "ready": + return f"发送失败:未找到可用表情包 UID:{uid}" + + sender = context.get("sender") + if sender is None: + return "发送失败:当前上下文缺少 sender" + + tool_args = { + "target_type": context.get("target_type"), + "target_id": context.get("target_id"), + } + target, target_error = resolve_message_target(tool_args, context) + if target_error or target is None: + return f"发送失败:{target_error or '无法确定目标会话'}" + target_type, target_id = target + + local_path = Path(record.blob_path) + if not local_path.is_file(): + return f"发送失败:表情包文件不存在:{uid}" + file_uri = local_path.resolve().as_uri() + cq_message = f"[CQ:image,file={file_uri},subType=1]" + history_message = f"[图片 uid={record.uid} name={local_path.name}]" + history_attachment = await self.resolve_global_image(uid) + history_attachments = ( + [history_attachment.prompt_ref()] + if history_attachment is not None + else None + ) + + if target_type == "group": + sent_message_id = await sender.send_group_message( + int(target_id), + cq_message, + history_message=history_message, + attachments=history_attachments, + ) + else: + preferred_temp_group_id = safe_int(context.get("group_id")) or None + sent_message_id = await sender.send_private_message( + int(target_id), + cq_message, + preferred_temp_group_id=preferred_temp_group_id, + history_message=history_message, + attachments=history_attachments, + ) + + now = _now_iso() + updated_record = await self._store.increment_use(uid, now) + if updated_record is not None: + await self._vector_store.upsert(updated_record) + if sent_message_id is not None: + return f"表情包已发送(message_id={sent_message_id})" + return "表情包已发送" + + async def blob_path_for_uid( + self, uid: str, *, preview: bool = False + ) -> Path | None: + record = await self._store.get(uid) + if record is None: + return None + path_text = ( + record.preview_path if preview and record.preview_path else record.blob_path + ) + path = Path(path_text) + return path if path.is_file() else None diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 96f80abe..1f7ba719 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -2,29 +2,34 @@ import asyncio from collections.abc import Mapping -from dataclasses import dataclass, replace -from datetime import datetime +from dataclasses import replace import hashlib import logging -import math import mimetypes from pathlib import Path -import re import shutil import threading from typing import Any from uuid import uuid4 -from openai import APIConnectionError, APIStatusError, APITimeoutError from PIL import Image from Undefined.attachments import AttachmentRecord +from Undefined.memes._image_utils import ( + _compose_grid, + _extract_gif_frames, + _guess_suffix, + _is_retryable_llm_error, + _normalize_tags, + _now_iso, +) +from Undefined.memes._image_utils import _sample_frame_indices # noqa: F401 from Undefined.memes.models import ( + IngestDigestLockEntry, MemeRecord, MemeSearchItem, MemeSourceRecord, build_search_text, - normalize_string_list, ) from Undefined.memes.store import MemeStore from Undefined.memes.vector_store import MemeVectorStore @@ -34,118 +39,13 @@ logger = logging.getLogger(__name__) -_IMAGE_EXTENSIONS_BY_MIME = { - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/bmp": ".bmp", - "image/svg+xml": ".svg", -} -_TAG_SPLIT_RE = re.compile(r"[,,\n]+") - - -def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") - - -def _guess_suffix(path: Path, mime_type: str) -> str: - suffix = path.suffix.lower() - if suffix: - return suffix - guessed = _IMAGE_EXTENSIONS_BY_MIME.get(mime_type) - if guessed: - return guessed - mime_guess = mimetypes.guess_extension(mime_type or "") - if mime_guess: - return mime_guess.lower() - return ".bin" - - -def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: - if raw_tags is None: - return [] - if isinstance(raw_tags, str): - parts = [part.strip() for part in _TAG_SPLIT_RE.split(raw_tags)] - return normalize_string_list(parts) - return normalize_string_list(raw_tags) - - -def _is_retryable_llm_error(exc: Exception) -> bool: - """判断 LLM 调用异常是否应触发 worker 级重试。""" - if isinstance(exc, (APIConnectionError, APITimeoutError)): - return True - if isinstance(exc, APIStatusError): - return exc.status_code == 429 or exc.status_code >= 500 - return False - - -def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: - """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" - with Image.open(source_path) as image: - total = getattr(image, "n_frames", 1) - if total <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - n = min(n_frames, total) - if n <= 1: - image.seek(0) - return [image.convert("RGBA").copy()] - indices = _sample_frame_indices(total, n) - frames: list[Image.Image] = [] - for idx in indices: - image.seek(idx) - frames.append(image.convert("RGBA").copy()) - return frames - - -def _sample_frame_indices(total: int, n: int) -> list[int]: - """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" - if n >= total: - return list(range(total)) - if n == 1: - return [0] - if n == 2: - return [0, total - 1] - indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] - # 去重并保持顺序 - seen: set[int] = set() - result: list[int] = [] - for idx in indices: - if idx not in seen: - seen.add(idx) - result.append(idx) - return result - - -def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: - """将多帧拼接为网格图并保存为 PNG。""" - n = len(frames) - if n == 0: - return - if n == 1: - frames[0].save(output_path, format="PNG") - return - cols = math.ceil(math.sqrt(n)) - rows = math.ceil(n / cols) - fw, fh = frames[0].size - grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) - for i, frame in enumerate(frames): - resized = ( - frame.resize((fw, fh), Image.Resampling.LANCZOS) - if frame.size != (fw, fh) - else frame - ) - x = (i % cols) * fw - y = (i // cols) * fh - grid.paste(resized, (x, y)) - grid.save(output_path, format="PNG") - - -@dataclass -class _IngestDigestLockEntry: - lock: asyncio.Lock - users: int = 0 +__all__ = [ + "MemeService", + "_compose_grid", + "_extract_gif_frames", + "_is_retryable_llm_error", + "_sample_frame_indices", +] class MemeService: @@ -168,7 +68,7 @@ def __init__( self._attachment_registry = attachment_registry self._retrieval_runtime = retrieval_runtime # Serialize same-content ingest jobs within the process to avoid duplicates. - self._ingest_digest_locks: dict[str, _IngestDigestLockEntry] = {} + self._ingest_digest_locks: dict[str, IngestDigestLockEntry] = {} self._ingest_digest_locks_guard = asyncio.Lock() self._global_image_cache: dict[str, AttachmentRecord] = {} self._global_image_cache_lock = threading.Lock() @@ -196,11 +96,11 @@ def _cfg(self) -> Any: def _blob_dir(self) -> Path: return ensure_dir(Path(self._cfg().blob_dir)) - async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEntry: + async def _acquire_ingest_digest_lock(self, digest: str) -> IngestDigestLockEntry: async with self._ingest_digest_locks_guard: entry = self._ingest_digest_locks.get(digest) if entry is None: - entry = _IngestDigestLockEntry(lock=asyncio.Lock()) + entry = IngestDigestLockEntry(lock=asyncio.Lock()) self._ingest_digest_locks[digest] = entry entry.users += 1 try: @@ -213,7 +113,7 @@ async def _acquire_ingest_digest_lock(self, digest: str) -> _IngestDigestLockEnt async def _release_ingest_digest_lock_reference( self, digest: str, - entry: _IngestDigestLockEntry, + entry: IngestDigestLockEntry, *, release_lock: bool = False, ) -> None: @@ -1177,7 +1077,6 @@ def _copy() -> None: def _render_preview() -> None: frames = _extract_gif_frames(source_path, n_frames) if mode == "multi": - # multi 模式也需要生成一张预览用于存储/展示,取首帧 frames[0].save(preview_path, format="PNG") else: _compose_grid(frames, preview_path) diff --git a/src/Undefined/onebot/__init__.py b/src/Undefined/onebot/__init__.py new file mode 100644 index 00000000..14c5690a --- /dev/null +++ b/src/Undefined/onebot/__init__.py @@ -0,0 +1,16 @@ +"""OneBot WebSocket 客户端包。""" + +# 统一 re-export,供 handlers 与 sender 直接 from Undefined.onebot import ... +from Undefined.onebot.client import OneBotClient +from Undefined.onebot.message import ( + get_message_content, + get_message_sender_id, + parse_message_time, +) + +__all__ = [ + "OneBotClient", + "parse_message_time", + "get_message_sender_id", + "get_message_content", +] diff --git a/src/Undefined/onebot.py b/src/Undefined/onebot/client.py similarity index 94% rename from src/Undefined/onebot.py rename to src/Undefined/onebot/client.py index 9a53d885..ce293428 100644 --- a/src/Undefined/onebot.py +++ b/src/Undefined/onebot/client.py @@ -1,11 +1,10 @@ -"""OneBot WebSocket 客户端""" +"""OneBot v11 WebSocket 客户端实现。""" import asyncio import json import logging import time from typing import Any, Callable, Coroutine -from datetime import datetime import websockets from websockets.asyncio.client import ClientConnection @@ -20,9 +19,11 @@ def _mark_message_sent_this_turn() -> None: ctx = RequestContext.current() if ctx is None: return + # 标记本轮已向用户发出消息,供 end 工具判断是否可静默结束。 ctx.set_resource("message_sent_this_turn", True) +# OneBot v11 WebSocket 客户端 class OneBotClient: """OneBot v11 WebSocket 客户端""" @@ -61,7 +62,6 @@ def connection_status(self) -> dict[str, Any]: async def connect(self) -> None: """连接到 OneBot WebSocket""" - # 构建带 token 的 URL url = self.ws_url if self.token: separator = "&" if "?" in url else "?" @@ -126,7 +126,6 @@ async def _call_api( if logger.isEnabledFor(logging.DEBUG): log_debug_json(logger, "[OneBot请求体]", request) - # 创建 Future 等待响应 future: asyncio.Future[dict[str, Any]] = asyncio.Future() self._pending_responses[echo] = future @@ -138,7 +137,6 @@ async def _call_api( response = await asyncio.wait_for(future, timeout=480.0) duration = time.perf_counter() - start_time - # 检查响应状态 status = response.get("status") if status == "failed": retcode = response.get("retcode", -1) @@ -241,7 +239,6 @@ async def get_group_msg_history( result = await self._call_api("get_group_msg_history", params) - # 安全获取消息列表 if result is None: logger.warning("get_group_msg_history 返回 None") return [] @@ -874,51 +871,3 @@ def stop(self) -> None: """停止运行""" self._should_stop = True self._running = False - - -def parse_message_time(message: dict[str, Any]) -> datetime: - """解析消息时间。 - - 兼容秒级/毫秒级时间戳与字符串输入,异常时回退到当前时间。 - """ - - raw_timestamp = message.get("time") - - if raw_timestamp is None: - return datetime.now() - - try: - timestamp = float(raw_timestamp) - except (TypeError, ValueError): - logger.debug("[OneBot] 无法解析消息时间戳,使用当前时间: %s", raw_timestamp) - return datetime.now() - - # 13 位毫秒时间戳自动降为秒。 - if timestamp > 1_000_000_000_000: - timestamp /= 1000.0 - - if timestamp <= 0: - return datetime.now() - - try: - return datetime.fromtimestamp(timestamp) - except (OSError, OverflowError, ValueError): - logger.debug("[OneBot] 时间戳越界,使用当前时间: %s", raw_timestamp) - return datetime.now() - - -def get_message_sender_id(message: dict[str, Any]) -> int: - """获取消息发送者 QQ 号""" - sender: dict[str, Any] = message.get("sender", {}) - user_id: int = sender.get("user_id", 0) - return user_id - - -def get_message_content(message: dict[str, Any]) -> list[dict[str, Any]]: - """获取消息内容(CQ 码数组格式)""" - msg = message.get("message", []) - if isinstance(msg, str): - # 如果是字符串格式,转换为数组格式 - return [{"type": "text", "data": {"text": msg}}] - content: list[dict[str, Any]] = msg - return content diff --git a/src/Undefined/onebot/message.py b/src/Undefined/onebot/message.py new file mode 100644 index 00000000..7fe13dc5 --- /dev/null +++ b/src/Undefined/onebot/message.py @@ -0,0 +1,58 @@ +"""OneBot 消息解析辅助函数。""" + +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any + +logger = logging.getLogger(__name__) + + +def parse_message_time(message: dict[str, Any]) -> datetime: + """解析消息时间。 + + 兼容秒级/毫秒级时间戳与字符串输入,异常时回退到当前时间。 + """ + + raw_timestamp = message.get("time") + + if raw_timestamp is None: + return datetime.now() + + try: + timestamp = float(raw_timestamp) + except (TypeError, ValueError): + logger.debug("[OneBot] 无法解析消息时间戳,使用当前时间: %s", raw_timestamp) + return datetime.now() + + # 13 位毫秒时间戳自动降为秒。 + if timestamp > 1_000_000_000_000: + timestamp /= 1000.0 + + if timestamp <= 0: + return datetime.now() + + try: + return datetime.fromtimestamp(timestamp) + except (OSError, OverflowError, ValueError): + # 越界或非法 epoch 回退当前时间,避免整条消息解析失败。 + logger.debug("[OneBot] 时间戳越界,使用当前时间: %s", raw_timestamp) + return datetime.now() + + +def get_message_sender_id(message: dict[str, Any]) -> int: + """获取消息发送者 QQ 号""" + sender: dict[str, Any] = message.get("sender", {}) + user_id: int = sender.get("user_id", 0) + return user_id + + +def get_message_content(message: dict[str, Any]) -> list[dict[str, Any]]: + """获取消息内容(CQ 码数组格式)""" + msg = message.get("message", []) + if isinstance(msg, str): + # 如果是字符串格式,转换为数组格式 + return [{"type": "text", "data": {"text": msg}}] + content: list[dict[str, Any]] = msg + return content diff --git a/src/Undefined/py.typed b/src/Undefined/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index f4efaf2f..7bc32036 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -105,7 +105,7 @@ def __init__( self.security = security self.command_dispatcher = command_dispatcher self.model_pool = ModelPoolService(ai, config, sender) - # batcher 由外部(handlers.py)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 + # batcher 由外部(handlers/message_flow)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 self._batcher: MessageBatcher | None = None def set_batcher(self, batcher: MessageBatcher | None) -> None: @@ -305,7 +305,6 @@ async def _execute_auto_reply(self, request: dict[str, Any]) -> None: # 用于向 batcher 注册 inflight 任务(仅当本请求源自合并桶时生效) batcher_scope: str | None = make_scope(group_id=group_id) if group_id else None - # 创建请求上下文 async with RequestContext( request_type="group", group_id=group_id, @@ -347,7 +346,6 @@ async def send_img_cb(tid: int, mtype: str, path: str) -> None: async def send_like_cb(uid: int, times: int = 1) -> None: await self.onebot.send_like(uid, times) - # 存储资源到上下文 ai_client = self.ai memory_storage = self.ai.memory_storage runtime_config = self.ai.runtime_config @@ -442,7 +440,6 @@ async def _execute_private_reply(self, request: dict[str, Any]) -> None: trigger_message_id = request.get("trigger_message_id") batcher_scope: str | None = make_scope(user_id=user_id) - # 创建请求上下文 async with RequestContext( request_type="private", user_id=user_id, @@ -479,7 +476,6 @@ async def send_private_cb( ) -> None: await self.sender.send_private_message(uid, msg, reply_to=reply_to) - # 存储资源到上下文 ai_client = self.ai memory_storage = self.ai.memory_storage runtime_config = self.ai.runtime_config @@ -591,7 +587,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: logger.warning("[统计分析] 缺少 request_id,群=%s", group_id) return try: - # 加载提示词模板 prompt_template = _STATS_ANALYSIS_FALLBACK_PROMPT try: loaded_prompt = read_text_resource(_STATS_ANALYSIS_PROMPT_PATH).strip() @@ -615,7 +610,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: data_summary=safe_data_summary ) - # 调用 AI 进行分析 messages = [ {"role": "system", "content": "你是一位专业的数据分析师。"}, {"role": "user", "content": full_prompt}, @@ -629,7 +623,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: queue_lane=request.get("_queue_lane"), ) - # 提取分析结果 choices = result.get("choices", [{}]) if choices: content = choices[0].get("message", {}).get("content", "") @@ -647,7 +640,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: request_id, ) - # 设置分析结果(通知等待的 _handle_stats 方法) if self.command_dispatcher: self.command_dispatcher.set_stats_analysis_result( group_id, request_id, analysis @@ -655,7 +647,6 @@ async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: except Exception as exc: logger.exception("[统计分析] AI 分析失败: %s", exc) - # 出错时也通知等待,但返回空字符串 if self.command_dispatcher: self.command_dispatcher.set_stats_analysis_result( group_id, request_id, "" @@ -708,7 +699,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None return try: - # 获取提示词 from Undefined.skills.agents.intro_generator import AgentIntroGenerator agent_intro_generator = self.ai._agent_intro_generator @@ -721,7 +711,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None user_prompt, ) = await agent_intro_generator.get_intro_prompt_and_context(agent_name) - # 调用 AI 生成 messages = [ {"role": "system", "content": system_prompt or "你是一位智能助手。"}, {"role": "user", "content": user_prompt}, @@ -735,7 +724,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None queue_lane=request.get("_queue_lane"), ) - # 提取结果 choices = result.get("choices", [{}]) if choices: content = choices[0].get("message", {}).get("content", "") @@ -750,7 +738,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None request_id, ) - # 通知结果 agent_intro_generator.set_intro_generation_result( request_id, generated_content if generated_content else None ) @@ -761,7 +748,6 @@ async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None agent_name, exc, ) - # 出错时也通知,返回 None try: agent_intro_generator = self.ai._agent_intro_generator agent_intro_generator.set_intro_generation_result(request_id, None) diff --git a/src/Undefined/services/commands/bugfix.py b/src/Undefined/services/commands/bugfix.py new file mode 100644 index 00000000..f4689a3e --- /dev/null +++ b/src/Undefined/services/commands/bugfix.py @@ -0,0 +1,189 @@ +"""Bug 修复归档命令(/bugfix)的实现逻辑。 + +本模块提供 ``BugfixCommandMixin``,供 ``CommandDispatcher`` 通过多重继承组合。 +通过回溯群聊记录并调用 AI 摘要,自动生成 FAQ 归档条目。 +""" + +from __future__ import annotations + +# 斜杠命令:目录扫描注册、权限/限流/子命令路由 + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from Undefined.faq import extract_faq_title +from Undefined.onebot import ( + get_message_content, + get_message_sender_id, + parse_message_time, +) + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.faq import FAQStorage + from Undefined.onebot import OneBotClient + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +class BugfixCommandMixin: + """``/bugfix`` 命令相关方法集合,作为 ``CommandDispatcher`` 的 mixin 使用。""" + + if TYPE_CHECKING: + ai: Any + config: Config + faq_storage: FAQStorage + onebot: OneBotClient + sender: MessageSender + + async def _handle_bugfix( + self, group_id: int, admin_id: int, args: list[str] + ) -> None: + """处理 ``/bugfix`` 命令,通过分析聊天记录自动生成 FAQ 归档。""" + parsed = self._parse_bugfix_args(args) + if isinstance(parsed, str): + await self.sender.send_group_message(group_id, parsed) + return + + target_qqs, start_date, end_date, start_str, end_str = parsed + + await self.sender.send_group_message( + group_id, "🔍 正在获取对话记录进行回溯分析..." + ) + + try: + messages = await self._fetch_messages( + group_id, target_qqs, start_date, end_date + ) + if not messages: + await self.sender.send_group_message( + group_id, "❌ 未找到符合条件的对话记录。" + ) + return + + processed_text = await self._process_messages(messages) + summary = await self._obtain_bugfix_summary(group_id, processed_text) + + title = extract_faq_title(summary) + if not title or title == "未命名问题": + title = await self.ai.generate_title(summary) + + faq = await self.faq_storage.create( + group_id=group_id, + target_qq=target_qqs[0], + start_time=start_str, + end_time=end_str, + title=title, + content=summary, + ) + + result_msg = f"✅ Bug 修复分析完成!\n\n📌 FAQ ID: {faq.id}\n📋 标题: {title}\n\n{summary}" + await self.sender.send_group_message(group_id, result_msg) + + except Exception as e: + error_id = uuid4().hex[:8] + logger.exception("Bugfix 失败: error_id=%s err=%s", error_id, e) + await self.sender.send_group_message( + group_id, + f"❌ Bug 修复分析失败,请稍后重试(错误码: {error_id})", + ) + + def _parse_bugfix_args( + self, args: list[str] + ) -> tuple[list[int], datetime, datetime, str, str] | str: + """解析 ``/bugfix`` 命令的参数。""" + if len(args) < 3: + return ( + "❌ 用法: /bugfix") + lines.append(f"{indent} [QQ号|@用户2] ... <开始时间> <结束时间>\n" + "时间格式: YYYY/MM/DD/HH:MM,结束时间可用 now\n" + "示例: /bugfix 123456 2024/12/01/09:00 now" + ) + + try: + target_qqs = [int(arg) for arg in args[:-2]] + start_str, end_str_raw = args[-2], args[-1] + start_date = datetime.strptime(start_str, "%Y/%m/%d/%H:%M") + + if end_str_raw.lower() == "now": + end_date, end_str = datetime.now(), "now" + else: + end_date, end_str = ( + datetime.strptime(end_str_raw, "%Y/%m/%d/%H:%M"), + end_str_raw, + ) + + return target_qqs, start_date, end_date, start_str, end_str + except ValueError: + return "❌ 参数格式错误:QQ号应为数字或 @ 提及,时间格式应为 YYYY/MM/DD/HH:MM。" + + async def _obtain_bugfix_summary(self, group_id: int, processed_text: str) -> str: + """利用 AI 生成聊天记录的 Bug 分析摘要。""" + total_tokens = self.ai.count_tokens(processed_text) + max_tokens = self.config.chat_model.max_tokens + + if total_tokens <= max_tokens: + return str(await self.ai.summarize_chat(processed_text)) + + await self.sender.send_group_message( + group_id, f"📊 消息较长({total_tokens} tokens),正在分段处理..." + ) + chunks = self.ai.split_messages_by_tokens(processed_text, max_tokens) + summaries = [await self.ai.summarize_chat(chunk) for chunk in chunks] + return str(await self.ai.merge_summaries(summaries)) + + async def _fetch_messages( + self, + group_id: int, + target_qqs: list[int], + start_date: datetime, + end_date: datetime, + ) -> list[dict[str, Any]]: + """从 OneBot 拉取指定时间范围内目标用户的消息。""" + batch = await self.onebot.get_group_msg_history(group_id, count=2500) + if not batch: + return [] + target_qqs_set = set(target_qqs) + results = [] + for msg in batch: + msg_time = parse_message_time(msg) + if ( + start_date <= msg_time <= end_date + and get_message_sender_id(msg) in target_qqs_set + ): + # 后台循环处理队列 + results.append(msg) + return sorted(results, key=lambda m: m.get("time", 0)) + + # 后台循环处理队列 + async def _process_messages(self, messages: list[dict[str, Any]]) -> str: + """将原始 OneBot 消息序列化为 AI 可读的纯文本。""" + lines = [] + for msg in messages: + sender_id = get_message_sender_id(msg) + msg_time = parse_message_time(msg).strftime("%Y-%m-%d %H:%M:%S") + content = get_message_content(msg) + text_parts = [] + for segment in content: + seg_type, seg_data = segment.get("type", ""), segment.get("data", {}) + if seg_type == "text": + text_parts.append(seg_data.get("text", "")) + elif seg_type == "image": + file = seg_data.get("file", "") or seg_data.get("url", "") + if file: + try: + url = await self.onebot.get_image(file) + if url: + res = await self.ai.analyze_multimodal(url, "image") + text_parts.append( + f"[pic] {res.get('description', '')} {res.get('ocr_text', '')} [/pic]" + ) + except Exception: + text_parts.append("[pic]图片处理失败 [/pic]") + elif seg_type == "at": + text_parts.append(f"@{seg_data.get('qq', '')}") + if text_parts: + lines.append(f"[{msg_time}] {sender_id}: {''.join(text_parts)}") + return "\n".join(lines) diff --git a/src/Undefined/services/commands/stats.py b/src/Undefined/services/commands/stats.py new file mode 100644 index 00000000..a992f760 --- /dev/null +++ b/src/Undefined/services/commands/stats.py @@ -0,0 +1,822 @@ +"""Token 使用统计命令(/stats)的实现逻辑。 + +本模块提供 ``StatsCommandMixin``,供 ``CommandDispatcher`` 通过多重继承组合。 +群聊与私聊统计、图表生成、AI 分析队列交互均在此实现。 +""" + +from __future__ import annotations + +# 斜杠命令:目录扫描注册、权限/限流/子命令路由 + +import asyncio +import base64 +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +from Undefined.ai.queue_budget import ( + compute_queued_llm_timeout_seconds, + resolve_effective_retry_count, +) +from Undefined.token_usage_storage import TokenUsageStorage + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.onebot import OneBotClient + from Undefined.utils.history import MessageHistoryManager + from Undefined.utils.sender import MessageSender + +# 尝试导入 matplotlib(可选依赖) +plt: Any +try: + import matplotlib.pyplot as plt + + _MATPLOTLIB_AVAILABLE = True +except ImportError: + plt = None + _MATPLOTLIB_AVAILABLE = False + +logger = logging.getLogger(__name__) + +_STATS_DEFAULT_DAYS = 7 +_STATS_MIN_DAYS = 1 +_STATS_MAX_DAYS = 365 +_STATS_MODEL_TOP_N = 8 +_STATS_CALL_TYPE_TOP_N = 12 +_STATS_DATA_SUMMARY_MAX_CHARS = 12000 +_STATS_AI_FLAGS = {"--ai", "-a"} +_STATS_TIME_RANGE_RE = re.compile(r"^\d+[dwm]?$", re.IGNORECASE) + + +class StatsCommandMixin: + """``/stats`` 命令相关方法集合,作为 ``CommandDispatcher`` 的 mixin 使用。""" + + if TYPE_CHECKING: + ai: Any + config: Config + history_manager: MessageHistoryManager + onebot: OneBotClient + queue_manager: Any + sender: MessageSender + + _token_usage_storage: TokenUsageStorage + _stats_analysis_results: dict[str, str] + _stats_analysis_events: dict[str, asyncio.Event] + + def _parse_time_range(self, time_str: str) -> int: + """解析时间范围字符串,返回天数。 + + 参数: + time_str: 时间范围字符串(如 ``7d``、``1w``、``30d``)。 + + 返回: + clamp 在 ``[_STATS_MIN_DAYS, _STATS_MAX_DAYS]`` 内的天数。 + """ + if not time_str: + return _STATS_DEFAULT_DAYS + + def _clamp_days(value: int) -> int: + if value < _STATS_MIN_DAYS: + return _STATS_DEFAULT_DAYS + if value > _STATS_MAX_DAYS: + return _STATS_MAX_DAYS + return value + + time_str = time_str.lower().strip() + + if time_str.endswith("d"): + try: + return _clamp_days(int(time_str[:-1])) + except ValueError: + return _STATS_DEFAULT_DAYS + if time_str.endswith("w"): + try: + return _clamp_days(int(time_str[:-1]) * 7) + except ValueError: + return _STATS_DEFAULT_DAYS + if time_str.endswith("m"): + try: + return _clamp_days(int(time_str[:-1]) * 30) + except ValueError: + return _STATS_DEFAULT_DAYS + + try: + return _clamp_days(int(time_str)) + except ValueError: + return _STATS_DEFAULT_DAYS + + def _parse_stats_options(self, args: list[str]) -> tuple[int, bool]: + """解析 ``/stats`` 参数:时间范围 + AI 分析开关。""" + days = _STATS_DEFAULT_DAYS + enable_ai_analysis = False + picked_days = False + + for raw in args: + token = str(raw or "").strip() + if not token: + continue + lower = token.lower() + if lower in _STATS_AI_FLAGS: + enable_ai_analysis = True + continue + if not picked_days and _STATS_TIME_RANGE_RE.match(lower): + days = self._parse_time_range(lower) + picked_days = True + + return days, enable_ai_analysis + + async def _handle_stats( + self, group_id: int, sender_id: int, args: list[str] + ) -> None: + """处理群聊 ``/stats`` 命令,生成 token 使用统计图表(可选 AI 分析)。""" + if not _MATPLOTLIB_AVAILABLE: + await self.sender.send_group_message( + group_id, "❌ 缺少必要的库,无法生成图表。请安装 matplotlib。" + ) + return + + days, enable_ai_analysis = self._parse_stats_options(args) + + try: + summary = await self._token_usage_storage.get_summary(days=days) + if summary["total_calls"] == 0: + await self.sender.send_group_message( + group_id, f"📊 最近 {days} 天内无 Token 使用记录。" + ) + return + + from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir + + img_dir = ensure_dir(RENDER_CACHE_DIR) + await self._generate_line_chart(summary, img_dir, days) + await self._generate_bar_chart(summary, img_dir) + await self._generate_pie_chart(summary, img_dir) + await self._generate_stats_table(summary, img_dir) + + ai_analysis = "" + if enable_ai_analysis: + ai_analysis = await self._run_stats_ai_analysis( + scope="group", + scope_id=group_id, + sender_id=sender_id, + summary=summary, + days=days, + ) + + forward_messages = self._build_stats_forward_nodes( + summary, img_dir, days, ai_analysis + ) + await self._send_group_forward_message( + group_id, + forward_messages, + history_message=self._build_stats_history_message( + summary, + days, + ai_analysis, + ), + ) + + from Undefined.utils.cache import cleanup_cache_dir + + cleanup_cache_dir(RENDER_CACHE_DIR) + + except Exception as e: + error_id = uuid4().hex[:8] + logger.exception( + "[Stats] 生成统计图表失败: error_id=%s err=%s", error_id, e + ) + await self.sender.send_group_message( + group_id, + f"❌ 生成统计图表失败,请稍后重试(错误码: {error_id})", + ) + + async def _send_group_forward_message( + self, + group_id: int, + messages: list[dict[str, Any]], + *, + history_message: str, + ) -> None: + """发送群组合并转发消息,并在需要时写入历史记录。""" + send_forward = getattr(self.sender, "send_group_forward_message", None) + if callable(send_forward): + await send_forward(group_id, messages, history_message=history_message) + return + + await self.onebot.send_forward_msg(group_id, messages) + if self.history_manager is None: + return + text_content = history_message.strip() + if not text_content: + return + + await self.history_manager.add_group_message( + group_id=group_id, + sender_id=getattr(self.config, "bot_qq", 0), + text_content=text_content, + sender_nickname="Bot", + group_name="", + ) + + @staticmethod + def _build_stats_history_message( + summary: dict[str, Any], + days: int, + ai_analysis: str, + ) -> str: + """构建写入群聊历史的 ``/stats`` 输出摘要文本。""" + lines = [ + f"[命令输出] /stats 最近 {days} 天 Token 使用统计", + f"总调用: {summary.get('total_calls', 0)}", + f"总 Token: {summary.get('total_tokens', 0)}", + f"输入 Token: {summary.get('prompt_tokens', 0)}", + f"输出 Token: {summary.get('completion_tokens', 0)}", + ] + if ai_analysis.strip(): + lines.extend(["", "AI 分析:", ai_analysis.strip()]) + return "\n".join(lines) + + async def _handle_stats_private( + self, + user_id: int, + sender_id: int, + args: list[str], + send_message: Any = None, + *, + is_webui_session: bool = False, + ) -> None: + """处理私聊 ``/stats``(含 WebUI 虚拟私聊适配)。""" + + async def _send_private(message: str) -> None: + if send_message is not None: + await send_message(message) + else: + await self.sender.send_private_message(user_id, message) + + days, enable_ai_analysis = self._parse_stats_options(args) + try: + summary = await self._token_usage_storage.get_summary(days=days) + if summary["total_calls"] == 0: + await _send_private(f"📊 最近 {days} 天内无 Token 使用记录。") + return + + ai_analysis = "" + if enable_ai_analysis: + ai_analysis = await self._run_stats_ai_analysis( + scope="private", + scope_id=0, + sender_id=sender_id, + summary=summary, + days=days, + ) + + if not _MATPLOTLIB_AVAILABLE: + message = "❌ 缺少必要的库,无法生成图表。请安装 matplotlib。" + if is_webui_session: + message += "\n\n" + self._build_stats_summary_text(summary) + if ai_analysis: + message += f"\n\n🤖 AI 智能分析\n{ai_analysis}" + await _send_private(message) + return + + from Undefined.utils.cache import cleanup_cache_dir + from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir + + img_dir = ensure_dir(RENDER_CACHE_DIR) + await self._generate_line_chart(summary, img_dir, days) + await self._generate_bar_chart(summary, img_dir) + await self._generate_pie_chart(summary, img_dir) + await self._generate_stats_table(summary, img_dir) + + await _send_private(f"📊 最近 {days} 天的 Token 使用统计:") + for img_name in ["line_chart", "bar_chart", "pie_chart", "table"]: + img_path = img_dir / f"stats_{img_name}.png" + if img_path.exists(): + message = await self._build_private_stats_image_message( + img_path, + inline_base64=is_webui_session, + ) + await _send_private(message) + + await _send_private(self._build_stats_summary_text(summary)) + if ai_analysis: + await _send_private(f"🤖 AI 智能分析\n{ai_analysis}") + + cleanup_cache_dir(RENDER_CACHE_DIR) + except Exception as e: + error_id = uuid4().hex[:8] + logger.exception( + "[Stats] 私聊统计生成失败: error_id=%s user=%s err=%s", + error_id, + user_id, + e, + ) + await _send_private( + f"❌ 生成统计图表失败,请稍后重试(错误码: {error_id})" + ) + + async def _build_private_stats_image_message( + self, + image_path: Path, + *, + inline_base64: bool, + ) -> str: + """构建私聊统计图片的 OneBot CQ 码消息。""" + file_uri = image_path.absolute().as_uri() + if not inline_base64: + return f"[CQ:image,file={file_uri}]" + + try: + encoded = await asyncio.to_thread( + lambda: base64.b64encode(image_path.read_bytes()).decode("ascii") + ) + except Exception as exc: + logger.warning( + "[Stats] 图像 base64 编码失败,回退文件路径: path=%s err=%s", + file_uri, + exc, + ) + return f"[CQ:image,file={file_uri}]" + + return f"[CQ:image,file=base64://{encoded}]" + + async def _run_stats_ai_analysis( + self, + *, + scope: str, + scope_id: int, + sender_id: int, + summary: dict[str, Any], + days: int, + ) -> str: + """投递并等待 AI 对统计数据的分析结果。""" + if not self.queue_manager: + return "" + + data_summary = self._build_data_summary(summary, days) + request_id = uuid4().hex + analysis_event = asyncio.Event() + self._stats_analysis_events[request_id] = analysis_event + request_data = { + "type": "stats_analysis", + "group_id": scope_id, + "request_id": request_id, + "sender_id": sender_id, + "data_summary": data_summary, + "summary": summary, + "days": days, + "scope": scope, + } + receipt = await self.queue_manager.add_group_mention_request( + request_data, model_name=self.config.chat_model.model_name + ) + logger.info("[Stats] 已投递 AI 分析请求: scope=%s target=%s", scope, scope_id) + + wait_timeout = compute_queued_llm_timeout_seconds( + self.ai.runtime_config, + self.config.chat_model, + retry_count=resolve_effective_retry_count( + self.ai.runtime_config, self.queue_manager + ), + initial_wait_seconds=float( + getattr(receipt, "estimated_wait_seconds", 0.0) or 0.0 + ), + ) + try: + await asyncio.wait_for(analysis_event.wait(), timeout=wait_timeout) + ai_analysis = self._stats_analysis_results.pop(request_id, "") + logger.info( + "[Stats] 已获取 AI 分析结果: scope=%s len=%s", scope, len(ai_analysis) + ) + return ai_analysis + except asyncio.TimeoutError: + logger.warning( + "[Stats] AI 分析超时: scope=%s target=%s timeout=%.1fs", + scope, + scope_id, + wait_timeout, + ) + return "AI 分析超时,已先发送图表与汇总数据。" + finally: + self._stats_analysis_events.pop(request_id, None) + self._stats_analysis_results.pop(request_id, None) + + def _build_data_summary(self, summary: dict[str, Any], days: int) -> str: + """构建用于 AI 分析的统计数据摘要。""" + lines = [] + lines.append("📊 Token 使用综合分析数据:") + lines.append("") + + lines.append("【整体概况】") + lines.append(f"统计周期: {days} 天") + lines.append(f"总调用次数: {summary['total_calls']}") + lines.append(f"总 Token 消耗: {summary['total_tokens']:,}") + lines.append(f"平均响应时间: {summary['avg_duration']:.2f}s") + lines.append(f"涉及模型数: {len(summary['models'])}") + lines.append("") + + daily_stats = summary.get("daily_stats", {}) + if daily_stats: + dates = sorted(daily_stats.keys()) + total_daily_calls = sum(daily_stats[d]["calls"] for d in dates) + total_daily_tokens = sum(daily_stats[d]["tokens"] for d in dates) + avg_daily_calls = total_daily_calls / len(dates) if dates else 0 + avg_daily_tokens = total_daily_tokens / len(dates) if dates else 0 + + peak_day = ( + max(dates, key=lambda d: daily_stats[d]["tokens"]) if dates else "" + ) + peak_day_tokens = daily_stats[peak_day]["tokens"] if peak_day else 0 + + lines.append("【时间维度】") + lines.append(f"统计天数: {len(dates)} 天") + lines.append(f"每日平均调用: {avg_daily_calls:.1f} 次") + lines.append(f"每日平均 Token: {avg_daily_tokens:,.0f} 个") + lines.append(f"高峰日期: {peak_day} ({peak_day_tokens:,} tokens)") + lines.append("") + + models = summary.get("models", {}) + if models: + lines.append("【模型维度】") + total_tokens_all = summary["total_tokens"] + sorted_models = sorted( + models.items(), key=lambda x: x[1]["tokens"], reverse=True + ) + for model_name, model_data in sorted_models[:_STATS_MODEL_TOP_N]: + calls = model_data["calls"] + tokens = model_data["tokens"] + prompt_tokens = model_data["prompt_tokens"] + completion_tokens = model_data["completion_tokens"] + token_pct = ( + (tokens / total_tokens_all * 100) if total_tokens_all > 0 else 0 + ) + avg_per_call = tokens / calls if calls > 0 else 0 + io_ratio = completion_tokens / prompt_tokens if prompt_tokens > 0 else 0 + + lines.append(f"模型: {model_name}") + lines.append( + f" - 调用次数: {calls} ({calls / summary['total_calls'] * 100:.1f}%)" + ) + lines.append(f" - Token 消耗: {tokens:,} ({token_pct:.1f}%)") + lines.append(f" - 平均每次调用: {avg_per_call:.0f} tokens") + lines.append( + f" - 输入: {prompt_tokens:,} / 输出: {completion_tokens:,}" + ) + lines.append(f" - 输入/输出比: 1:{io_ratio:.2f}") + lines.append("") + + if len(sorted_models) > _STATS_MODEL_TOP_N: + others = sorted_models[_STATS_MODEL_TOP_N:] + others_calls = sum(int(item[1].get("calls", 0)) for item in others) + others_tokens = sum(int(item[1].get("tokens", 0)) for item in others) + others_pct = ( + (others_tokens / total_tokens_all * 100) + if total_tokens_all > 0 + else 0.0 + ) + lines.append( + f"其余 {len(others)} 个模型合计: 调用 {others_calls} 次, Token {others_tokens:,} ({others_pct:.1f}%)" + ) + lines.append("") + + call_types = summary.get("call_types", {}) + if call_types: + lines.append("【调用类型维度】") + sorted_types = sorted( + call_types.items(), key=lambda item: int(item[1]), reverse=True + ) + total_calls = max(1, int(summary.get("total_calls", 0))) + for call_type, count in sorted_types[:_STATS_CALL_TYPE_TOP_N]: + ratio = int(count) / total_calls * 100 + lines.append(f"- {call_type}: {count} 次 ({ratio:.1f}%)") + if len(sorted_types) > _STATS_CALL_TYPE_TOP_N: + rest_count = sum( + int(item[1]) for item in sorted_types[_STATS_CALL_TYPE_TOP_N:] + ) + ratio = rest_count / total_calls * 100 + lines.append( + f"- 其他 {len(sorted_types) - _STATS_CALL_TYPE_TOP_N} 类: {rest_count} 次 ({ratio:.1f}%)" + ) + lines.append("") + + prompt_tokens = summary.get("prompt_tokens", 0) + completion_tokens = summary.get("completion_tokens", 0) + total_tokens = summary.get("total_tokens", 0) + input_ratio = (prompt_tokens / total_tokens * 100) if total_tokens > 0 else 0 + output_ratio = ( + (completion_tokens / total_tokens * 100) if total_tokens > 0 else 0 + ) + output_per_input = completion_tokens / prompt_tokens if prompt_tokens > 0 else 0 + + lines.append("【效率指标】") + lines.append(f"输入 Token: {prompt_tokens:,} ({input_ratio:.1f}%)") + lines.append(f"输出 Token: {completion_tokens:,} ({output_ratio:.1f}%)") + lines.append(f"输入/输出比: 1:{output_per_input:.2f}") + lines.append("") + + if daily_stats and len(daily_stats) > 1: + lines.append("【趋势分析】") + dates = sorted(daily_stats.keys()) + first_day_tokens = daily_stats[dates[0]]["tokens"] + last_day_tokens = daily_stats[dates[-1]]["tokens"] + trend_change = ( + ((last_day_tokens - first_day_tokens) / first_day_tokens * 100) + if first_day_tokens > 0 + else 0 + ) + trend_desc = "增长" if trend_change > 0 else "下降" + lines.append( + f"总体趋势: {trend_desc} {abs(trend_change):.1f}% (从首日到末日)" + ) + lines.append("") + + summary_text = "\n".join(lines) + if len(summary_text) > _STATS_DATA_SUMMARY_MAX_CHARS: + trimmed = summary_text[: _STATS_DATA_SUMMARY_MAX_CHARS - 80].rstrip() + summary_text = ( + f"{trimmed}\n\n[数据摘要已截断,总长度 {len(summary_text)} 字符," + f"仅保留前 {_STATS_DATA_SUMMARY_MAX_CHARS} 字符]" + ) + logger.info( + "[Stats] 数据摘要过长已截断: original_len=%s max_len=%s", + len("\n".join(lines)), + _STATS_DATA_SUMMARY_MAX_CHARS, + ) + return summary_text + + def _build_stats_summary_text(self, summary: dict[str, Any]) -> str: + """构建统计结果的纯文本摘要。""" + return f"""📈 摘要汇总: +• 总调用次数: {summary["total_calls"]} +• 总消耗 Tokens: {summary["total_tokens"]:,} + └─ 输入: {summary["prompt_tokens"]:,} + └─ 输出: {summary["completion_tokens"]:,} +• 平均耗时: {summary["avg_duration"]:.2f}s +• 涉及模型数: {len(summary["models"])}""" + + def set_stats_analysis_result( + self, group_id: int, request_id: str, analysis: str + ) -> None: + """设置 AI 分析结果(由队列处理器调用)。""" + event = self._stats_analysis_events.get(request_id) + if not event: + logger.warning( + "[StatsAnalysis] 未找到等待事件,群: %s, 请求: %s", + group_id, + request_id, + ) + return + self._stats_analysis_results[request_id] = analysis + event.set() + + def _build_stats_forward_nodes( + self, + summary: dict[str, Any], + img_dir: Path, + days: int, + ai_analysis: str = "", + ) -> list[dict[str, Any]]: + """构建用于合并转发的统计图表节点列表。""" + # 对外入队 API + nodes = [] + bot_qq = str(self.config.bot_qq) + + # 对外入队 API + def add_node(content: str) -> None: + nodes.append( + { + "type": "node", + "data": {"name": "Bot", "uin": bot_qq, "content": content}, + } + ) + + add_node(f"📊 最近 {days} 天的 Token 使用统计:") + + for img_name in ["line_chart", "bar_chart", "pie_chart", "table"]: + img_path = img_dir / f"stats_{img_name}.png" + if img_path.exists(): + add_node(f"[CQ:image,file={img_path.absolute().as_uri()}]") + + add_node(self._build_stats_summary_text(summary)) + + if ai_analysis: + add_node(f"🤖 AI 智能分析\n{ai_analysis}") + + return nodes + + async def _generate_line_chart( + self, summary: dict[str, Any], img_dir: Path, days: int + ) -> None: + """生成折线图:时间趋势。""" + daily_stats = summary["daily_stats"] + if not daily_stats: + return + + dates = sorted(daily_stats.keys()) + tokens = [daily_stats[d]["tokens"] for d in dates] + prompt_tokens = [daily_stats[d]["prompt_tokens"] for d in dates] + completion_tokens = [daily_stats[d]["completion_tokens"] for d in dates] + + fig, ax = plt.subplots(figsize=(12, 7)) + + ax.plot( + dates, tokens, marker="o", linewidth=2, label="Total Token", color="#2196F3" + ) + ax.plot( + dates, + prompt_tokens, + marker="s", + linewidth=2, + label="Input Token", + color="#4CAF50", + ) + ax.plot( + dates, + completion_tokens, + marker="^", + linewidth=2, + label="Output Token", + color="#FF9800", + ) + + ax.set_title( + f"Token Usage Trend for Last {days} Days", fontsize=16, fontweight="bold" + ) + ax.set_xlabel("Date", fontsize=12) + ax.set_ylabel("Token Count", fontsize=12) + ax.legend(loc="upper left", fontsize=10) + ax.grid(True, alpha=0.3) + + plt.xticks(rotation=45, ha="right") + plt.tight_layout() + + filepath = img_dir / "stats_line_chart.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) + + async def _generate_bar_chart(self, summary: dict[str, Any], img_dir: Path) -> None: + """生成柱状图:模型对比。""" + models = summary["models"] + if not models: + return + + model_names = list(models.keys()) + tokens = [models[m]["tokens"] for m in model_names] + prompt_tokens = [models[m]["prompt_tokens"] for m in model_names] + completion_tokens = [models[m]["completion_tokens"] for m in model_names] + + fig, ax = plt.subplots(figsize=(14, 8)) + + x = range(len(model_names)) + width = 0.25 + + bars1 = ax.bar( + [i - width for i in x], + tokens, + width, + label="Total Token", + color="#2196F3", + alpha=0.8, + ) + bars2 = ax.bar( + x, + prompt_tokens, + width, + label="Input Token", + color="#4CAF50", + alpha=0.8, + ) + bars3 = ax.bar( + [i + width for i in x], + completion_tokens, + width, + label="Output Token", + color="#FF9800", + alpha=0.8, + ) + + ax.set_title("Token Usage Comparison by Model", fontsize=16, fontweight="bold") + ax.set_xlabel("Model", fontsize=12) + ax.set_ylabel("Token Count", fontsize=12) + ax.set_xticks(x) + ax.set_xticklabels(model_names, rotation=45, ha="right") + ax.legend(loc="upper right", fontsize=10) + ax.grid(True, alpha=0.3, axis="y") + + for bars in [bars1, bars2, bars3]: + for bar in bars: + height = bar.get_height() + if height > 0: + ax.text( + bar.get_x() + bar.get_width() / 2.0, + height, + f"{int(height):,}", + ha="center", + va="bottom", + fontsize=8, + ) + + plt.tight_layout() + + filepath = img_dir / "stats_bar_chart.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) + + async def _generate_pie_chart(self, summary: dict[str, Any], img_dir: Path) -> None: + """生成饼图:输入/输出比例。""" + prompt_tokens = summary["prompt_tokens"] + completion_tokens = summary["completion_tokens"] + + if prompt_tokens == 0 and completion_tokens == 0: + return + + fig, ax = plt.subplots(figsize=(12, 8)) + + labels = ["Input Token", "Output Token"] + sizes = [prompt_tokens, completion_tokens] + colors = ["#4CAF50", "#FF9800"] + explode = (0.05, 0.05) + + wedges, *_ = ax.pie( + sizes, + explode=explode, + labels=labels, + colors=colors, + autopct="%1.1f%%", + startangle=90, + textprops={"fontsize": 12}, + ) + + ax.set_title("Input/Output Token Ratio", fontsize=16, fontweight="bold", pad=20) + + ax.legend( + wedges, + [f"{labels[i]}: {sizes[i]:,}" for i in range(len(labels))], + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + fontsize=10, + ) + + plt.tight_layout() + + filepath = img_dir / "stats_pie_chart.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) + + async def _generate_stats_table( + self, summary: dict[str, Any], img_dir: Path + ) -> None: + """生成统计表格图片。""" + models = summary["models"] + if not models: + return + + model_names = list(models.keys()) + data = [] + for model in model_names: + m = models[model] + data.append( + [ + model, + m["calls"], + f"{m['tokens']:,}", + f"{m['prompt_tokens']:,}", + f"{m['completion_tokens']:,}", + ] + ) + + fig, ax = plt.subplots(figsize=(14, 9)) + ax.axis("tight") + ax.axis("off") + + table = ax.table( + cellText=data, + colLabels=["Model", "Calls", "Total Token", "Input Token", "Output Token"], + cellLoc="center", + loc="center", + ) + + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1.2, 1.5) + + for i in range(5): + table[(0, i)].set_facecolor("#2196F3") + table[(0, i)].set_text_props(weight="bold", color="white") + + for i in range(1, len(data) + 1): + for j in range(5): + if i % 2 == 0: + table[(i, j)].set_facecolor("#f0f0f0") + + ax.set_title( + "Model Usage Statistics Details", fontsize=16, fontweight="bold", pad=20 + ) + + plt.tight_layout() + + filepath = img_dir / "stats_table.png" + plt.savefig(filepath, dpi=150, bbox_inches="tight") + plt.close(fig) diff --git a/src/Undefined/services/coordinator/__init__.py b/src/Undefined/services/coordinator/__init__.py new file mode 100644 index 00000000..ee91e712 --- /dev/null +++ b/src/Undefined/services/coordinator/__init__.py @@ -0,0 +1,88 @@ +"""AI 协调器:组合群聊、私聊、合并与后台任务 mixin。""" + +from __future__ import annotations + + +import logging +from pathlib import Path +from typing import Any + +from Undefined.config import Config +from Undefined.services.coordinator.background import BackgroundMixin +from Undefined.services.coordinator.batching import BatchingMixin +from Undefined.services.coordinator.group import GroupReplyMixin +from Undefined.services.coordinator.private import PrivateReplyMixin +from Undefined.services.message_batcher import MessageBatcher +from Undefined.services.model_pool import ModelPoolService +from Undefined.services.queue_manager import QueueManager +from Undefined.services.security import SecurityService +from Undefined.utils.history import MessageHistoryManager +from Undefined.utils.scheduler import TaskScheduler +from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +class AICoordinator( + GroupReplyMixin, + PrivateReplyMixin, + BatchingMixin, + BackgroundMixin, +): + """AI 协调器,处理 AI 回复逻辑、Prompt 构建和队列管理""" + + def __init__( + self, + config: Config, + ai: Any, # AIClient + queue_manager: QueueManager, + history_manager: MessageHistoryManager, + sender: MessageSender, + onebot: Any, # OneBotClient + scheduler: TaskScheduler, + security: SecurityService, + command_dispatcher: Any = None, + ) -> None: + self.config = config + self.ai = ai + self.queue_manager = queue_manager + self.history_manager = history_manager + self.sender = sender + self.onebot = onebot + self.scheduler = scheduler + self.security = security + self.command_dispatcher = command_dispatcher + self.model_pool = ModelPoolService(ai, config, sender) + # batcher 由外部(handlers/message_flow)创建并通过 set_batcher 注入;未注入时所有消息按单条流程直送。 + self._batcher: MessageBatcher | None = None + + def set_batcher(self, batcher: MessageBatcher | None) -> None: + """注入消息合并器;传 None 等同于禁用合并。""" + self._batcher = batcher + + @property + def batcher(self) -> MessageBatcher | None: + return self._batcher + + async def _send_image(self, tid: int, mtype: str, path: str) -> None: + """发送图片或语音消息到群聊或私聊""" + import os + + if not os.path.exists(path): + return + file_uri = Path(path).resolve().as_uri() + ext = os.path.splitext(path)[1].lower() + if ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"]: + msg = f"[CQ:image,file={file_uri}]" + elif ext in [".mp3", ".wav", ".ogg", ".flac", ".m4a", ".aac"]: + msg = f"[CQ:record,file={file_uri}]" + else: + return + + try: + if mtype == "group": + await self.sender.send_group_message(tid, msg, auto_history=False) + elif mtype == "private": + await self.sender.send_private_message(tid, msg, auto_history=False) + except Exception: + logger.exception("发送媒体文件失败") diff --git a/src/Undefined/services/coordinator/background.py b/src/Undefined/services/coordinator/background.py new file mode 100644 index 00000000..b6c8cefe --- /dev/null +++ b/src/Undefined/services/coordinator/background.py @@ -0,0 +1,250 @@ +"""后台 LLM 任务:stats 分析、队列调用与 Agent 介绍生成。""" + +from __future__ import annotations + + +import logging +from typing import TYPE_CHECKING, Any, cast + +from Undefined.services.queue_manager import QUEUE_LANE_BACKGROUND +from Undefined.utils.resources import read_text_resource + +if TYPE_CHECKING: + from Undefined.ai import AIClient + from Undefined.config import Config + from Undefined.services.command import CommandDispatcher + +logger = logging.getLogger(__name__) + + +_STATS_ANALYSIS_PROMPT_PATH = "res/prompts/stats_analysis.txt" +_STATS_ANALYSIS_FALLBACK_PROMPT = ( + "你是一位专业的数据分析师。请根据以下 Token 使用统计数据提供分析:\n\n" + "{data_summary}\n\n" + "请从整体概况、趋势、模型效率、成本结构、异常点和优化建议进行总结," + "语言简洁,建议可执行。" +) + + +class BackgroundMixin: + """队列分发入口与后台 LLM 任务执行。""" + + if TYPE_CHECKING: + ai: AIClient + config: Config + command_dispatcher: CommandDispatcher + + async def _execute_auto_reply(self, request: dict[str, Any]) -> None: ... + async def _execute_private_reply(self, request: dict[str, Any]) -> None: ... + + async def execute_reply(self, request: dict[str, Any]) -> None: + """执行排队中的回复请求(由 QueueManager 分发调用) + + 参数: + request: 包含请求类型和必要元数据的请求字典 + """ + req_type = request.get("type", "unknown") + logger.debug("[执行请求] type=%s keys=%s", req_type, list(request.keys())) + batch_token = request.get("_message_batcher_token") + # 投机 pre-fire 被新消息 cancel 后,coordinator 在真正执行前跳过旧 token + if bool(getattr(batch_token, "cancelled", False)): + logger.info( + "[MessageBatcher] 跳过已取消的投机请求: type=%s scope=%s sender=%s batch=%s", + req_type, + getattr(batch_token, "scope", ""), + getattr(batch_token, "sender_id", ""), + getattr(batch_token, "batch_id", ""), + ) + return + if req_type == "auto_reply": + await self._execute_auto_reply(request) + elif req_type == "private_reply": + await self._execute_private_reply(request) + elif req_type == "stats_analysis": + await self._execute_stats_analysis(request) + elif req_type == "agent_intro_generation": + await self._execute_agent_intro_generation(request) + elif req_type in {"queued_llm_call", "background_llm_call"}: + await self._execute_queued_llm_call(request) + + async def _execute_stats_analysis(self, request: dict[str, Any]) -> None: + """执行 stats 命令的 AI 分析""" + group_id = request["group_id"] + request_id = request.get("request_id") + data_summary = request.get("data_summary", "") + + if not request_id: + logger.warning("[统计分析] 缺少 request_id,群=%s", group_id) + return + try: + prompt_template = _STATS_ANALYSIS_FALLBACK_PROMPT + try: + loaded_prompt = read_text_resource(_STATS_ANALYSIS_PROMPT_PATH).strip() + if loaded_prompt: + prompt_template = loaded_prompt + except Exception as exc: + logger.warning("[统计分析] 读取提示词失败,使用内置模板: %s", exc) + + if "{data_summary}" not in prompt_template: + logger.warning( + "[统计分析] 提示词缺少 {data_summary} 占位符,自动追加", + ) + prompt_template = f"{prompt_template}\n\n{{data_summary}}" + + safe_data_summary = str(data_summary).strip() or "暂无统计数据摘要" + try: + full_prompt = prompt_template.format(data_summary=safe_data_summary) + except Exception as exc: + logger.warning("[统计分析] 提示词渲染失败,使用回退模板: %s", exc) + full_prompt = _STATS_ANALYSIS_FALLBACK_PROMPT.format( + data_summary=safe_data_summary + ) + + messages = [ + {"role": "system", "content": "你是一位专业的数据分析师。"}, + {"role": "user", "content": full_prompt}, + ] + + result = await self.ai.submit_queued_llm_call( + model_config=self.config.chat_model, + messages=messages, + max_tokens=2048, + call_type="stats_analysis", + queue_lane=request.get("_queue_lane"), + ) + + choices = result.get("choices", [{}]) + if choices: + content = choices[0].get("message", {}).get("content", "") + analysis = content.strip() + else: + analysis = "AI 分析未能生成结果" + + if not analysis: + analysis = "AI 分析结果为空,建议稍后重试。" + + logger.info( + "[统计分析] 分析完成: group=%s length=%s request_id=%s", + group_id, + len(analysis), + request_id, + ) + + if self.command_dispatcher: + self.command_dispatcher.set_stats_analysis_result( + group_id, request_id, analysis + ) + + except Exception as exc: + logger.exception("[统计分析] AI 分析失败: %s", exc) + if self.command_dispatcher: + self.command_dispatcher.set_stats_analysis_result( + group_id, request_id, "" + ) + + async def _execute_queued_llm_call(self, request: dict[str, Any]) -> None: + """执行队列中的 LLM 子请求。""" + request_id = request.get("request_id", "") + retry_count = int(request.get("_retry_count", 0) or 0) + queue_lane = str(request.get("_queue_lane") or QUEUE_LANE_BACKGROUND) + call_type = str(request.get("call_type", "background") or "background") + try: + max_tokens_raw = request.get("max_tokens") or getattr( + request["model_config"], "max_tokens", 4096 + ) + max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else 4096 + result = await self.ai.request_model( + model_config=request["model_config"], + messages=request["messages"], + tools=request.get("tools"), + tool_choice=request.get("tool_choice", "auto"), + call_type=call_type, + max_tokens=max_tokens, + transport_state=request.get("transport_state"), + ) + self.ai.set_llm_call_result(request_id, result) + if retry_count > 0: + logger.info( + "[queued_llm_retry_success] request_id=%s call_type=%s model=%s lane=%s retry=%s", + request_id, + call_type, + getattr(request["model_config"], "model_name", "default"), + queue_lane, + retry_count, + ) + except Exception as exc: + retry_count = request.get("_retry_count", 0) + if retry_count >= self.config.ai_request_max_retries: + self.ai.set_llm_call_result(request_id, exc) + raise + + async def _execute_agent_intro_generation(self, request: dict[str, Any]) -> None: + """执行 Agent 自我介绍生成请求""" + request_id = request.get("request_id") + agent_name = request.get("agent_name") + + if not request_id or not agent_name: + logger.warning( + "[Agent介绍生成] 缺少必要参数: request_id=%s agent_name=%s", + request_id, + agent_name, + ) + return + + try: + from Undefined.skills.agents.intro_generator import AgentIntroGenerator + + agent_intro_generator = self.ai._agent_intro_generator + if not isinstance(agent_intro_generator, AgentIntroGenerator): + logger.error("[Agent介绍生成] 无法获取 AgentIntroGenerator 实例") + return + + ( + system_prompt, + user_prompt, + ) = await agent_intro_generator.get_intro_prompt_and_context(agent_name) + + messages = [ + {"role": "system", "content": system_prompt or "你是一位智能助手。"}, + {"role": "user", "content": user_prompt}, + ] + + result = await self.ai.submit_queued_llm_call( + model_config=self.ai.agent_config, + messages=messages, + max_tokens=agent_intro_generator.config.max_tokens, + call_type=f"agent:{agent_name}", + queue_lane=request.get("_queue_lane"), + ) + + choices = result.get("choices", [{}]) + if choices: + content = choices[0].get("message", {}).get("content", "") + generated_content = content.strip() + else: + generated_content = "" + + logger.info( + "[Agent介绍生成] 生成完成: agent=%s length=%s request_id=%s", + agent_name, + len(generated_content), + request_id, + ) + + agent_intro_generator.set_intro_generation_result( + request_id, generated_content if generated_content else None + ) + + except Exception as exc: + logger.exception( + "[Agent介绍生成] 生成失败: agent=%s error=%s", + agent_name, + exc, + ) + try: + agent_intro_generator = cast( + AgentIntroGenerator, self.ai._agent_intro_generator + ) + agent_intro_generator.set_intro_generation_result(request_id, None) + except Exception: + pass diff --git a/src/Undefined/services/coordinator/batching.py b/src/Undefined/services/coordinator/batching.py new file mode 100644 index 00000000..d3da2c63 --- /dev/null +++ b/src/Undefined/services/coordinator/batching.py @@ -0,0 +1,174 @@ +"""消息合并、分组 prompt 构建与队列投递。""" + +from __future__ import annotations + + +import logging +from typing import TYPE_CHECKING, Any + +from Undefined.services.coordinator.group import _GROUP_STRATEGY_FOOTER +from Undefined.services.coordinator.private import _PRIVATE_STRATEGY_FOOTER +from Undefined.services.message_batcher import BufferedMessage + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.services.message_batcher import BufferedMessage as _BufferedMessage + from Undefined.services.model_pool import ModelPoolService + from Undefined.services.queue_manager import QueueManager + +logger = logging.getLogger(__name__) + + +# BatchingMixin:MessageBatcher 回调、合并 prompt 与队列路由 +class BatchingMixin: + """MessageBatcher 回调、合并 prompt 与队列路由。""" + + if TYPE_CHECKING: + config: Config + queue_manager: QueueManager + model_pool: ModelPoolService + + def _format_group_message_segment(self, item: _BufferedMessage) -> str: ... + def _format_private_message_segment(self, item: _BufferedMessage) -> str: ... + + async def handle_batched_dispatch(self, items: list[BufferedMessage]) -> None: + """:class:`MessageBatcher` 的 flush_callback:把一批消息组装为单次请求并入队。""" + if not items: + return + await self._dispatch_grouped_request(items) + + @staticmethod + def _build_continuous_messages_note(items: list[BufferedMessage]) -> str: + """生成"连续消息说明"段。仅在 ``len(items) >= 2`` 时使用。""" + count = len(items) + first_t = items[0].arrival_time + last_t = items[-1].arrival_time + span = max(0.0, last_t - first_t) + return ( + f"\n\n 【连续消息说明】以上 {count} 条是同一用户在约 " + f"{span:.1f} 秒内连续发送的消息(按时间先后排列),代表本轮要回应的全部输入:\n" + f" - 这些 共同构成【当前输入批次】,不要把同批前几条误判为历史旧任务;" + f"批次之外的历史消息仍只作为背景,不能回溯拾荒\n" + f" - 先识别每条的意图,分清是【独立请求】还是【对前一条的修正/否定/补充/打断】\n" + f' · 若是【多个独立的不同意图/问题】(如"先帮我查 A,再翻译 B")' + f" → 每个都要回应,不要遗漏;与平时一样,可以多次 send_message 自然分发\n" + f' · 若后发是【对前发的修正/否定/补充/打断】(如"画猫" → "改成狗")' + f" → 以最后一次明确意图为准,旧的不再执行,可简短说明已采纳更新\n" + f' · 拿不准时偏向"独立请求",宁多勿漏\n' + f" - 整批在本轮一次性处理完即可,不要为同一意图重复输出(不要" + f'"中间一波、结尾再来一波"重复相同回复)\n' + f" - history 中若出现与当前轮 相同的条目,视为同一来源,不要重复处理" + ) + + def _build_grouped_prompt(self, items: list[BufferedMessage]) -> str: + """根据 BufferedMessage 列表构造合并后的完整 prompt。""" + if not items: + return "" + is_private = items[0].is_private + # prefix:拍一拍优先;否则任一 @bot + any_poke = any(it.is_poke for it in items) + any_at_bot = any(it.is_at_bot for it in items) + if any_poke: + prefix = "(用户拍了拍你) " + elif any_at_bot: + prefix = "(用户 @ 了你) " + else: + prefix = "" + + if is_private: + segments = [self._format_private_message_segment(it) for it in items] + else: + segments = [self._format_group_message_segment(it) for it in items] + body = prefix + "\n".join(segments) + if len(items) >= 2: + body += self._build_continuous_messages_note(items) + body += _GROUP_STRATEGY_FOOTER if not is_private else _PRIVATE_STRATEGY_FOOTER + return body + + async def _dispatch_grouped_request(self, items: list[BufferedMessage]) -> None: + """根据一组 BufferedMessage 决定优先级、构造 prompt 并入队。 + + 既是单条直送路径的统一出口,也是 :class:`MessageBatcher` 的 flush_callback。 + """ + if not items: + return + first = items[0] + last = items[-1] + full_question = self._build_grouped_prompt(items) + any_poke = any(it.is_poke for it in items) + any_at_bot = any(it.is_at_bot for it in items) + + if first.is_private: + user_id = first.sender_id + request_data: dict[str, Any] = { + "type": "private_reply", + "user_id": user_id, + "sender_name": first.sender_name, + "text": last.text, + "full_question": full_question, + "trigger_message_id": last.trigger_message_id, + "batched_count": len(items), + } + if first.batch_token is not None: + request_data["_message_batcher_token"] = first.batch_token + effective_config = self.model_pool.select_chat_config( + self.config.chat_model, user_id=user_id + ) + request_data["selected_model_name"] = effective_config.model_name + logger.debug( + "[私聊回复] full_question_len=%s user=%s batched=%s", + len(full_question), + user_id, + len(items), + ) + if user_id == self.config.superadmin_qq: + # 私聊超管 → 最高优先级 superadmin lane + await self.queue_manager.add_superadmin_request( + request_data, model_name=effective_config.model_name + ) + else: + await self.queue_manager.add_private_request( + request_data, model_name=effective_config.model_name + ) + return + + # 群聊:按 sender 身份与 @bot 状态选择 4 级 lane 之一 + group_id = first.group_id or 0 + sender_id = first.sender_id + request_data = { + "type": "auto_reply", + "group_id": group_id, + "sender_id": sender_id, + "sender_name": first.sender_name, + "group_name": first.group_name, + "text": last.text, + "full_question": full_question, + "is_at_bot": any_at_bot, + "trigger_message_id": last.trigger_message_id, + "batched_count": len(items), + } + if first.batch_token is not None: + request_data["_message_batcher_token"] = first.batch_token + logger.debug( + "[自动回复] full_question_len=%s group=%s sender=%s batched=%s", + len(full_question), + group_id, + sender_id, + len(items), + ) + if sender_id == self.config.superadmin_qq: + logger.info("[AI] 投递至群聊超级管理员队列 (batched=%s)", len(items)) + await self.queue_manager.add_group_superadmin_request( + request_data, model_name=self.config.chat_model.model_name + ) + elif any_at_bot: + trigger = "拍一拍" if any_poke else "@机器人" + logger.info("[AI] 触发原因: %s (batched=%s)", trigger, len(items)) + await self.queue_manager.add_group_mention_request( + request_data, model_name=self.config.chat_model.model_name + ) + else: + logger.info("[AI] 投递至普通请求队列 (batched=%s)", len(items)) + await self.queue_manager.add_group_normal_request( + request_data, model_name=self.config.chat_model.model_name + ) diff --git a/src/Undefined/services/coordinator/group.py b/src/Undefined/services/coordinator/group.py new file mode 100644 index 00000000..0dce633b --- /dev/null +++ b/src/Undefined/services/coordinator/group.py @@ -0,0 +1,405 @@ +"""群聊自动回复与 prompt 构建。""" + +from __future__ import annotations + + +import asyncio +import logging +import time +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional + +from Undefined.attachments import attachment_refs_to_xml +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.render import render_html_to_image, render_markdown_to_html +from Undefined.services.message_batcher import BufferedMessage, make_scope +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.services.message_batcher import BufferedMessage + from Undefined.services.security import SecurityService + from Undefined.utils.history import MessageHistoryManager + from Undefined.utils.scheduler import TaskScheduler + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +_GROUP_STRATEGY_FOOTER = """ + + 【回复策略 - 更克制,纯表情包才前置检索】 + 1. 如果用户 @ 了你或拍了拍你 → 【必须回复】 + 2. 如果消息中明确提到了你(根据上下文判断用户是否在叫你或维持对话流) → 【必须回复】 + 3. 如果问题明确涉及某个项目/代码/部署细节(用户明确点名或上下文明确指向) → 【酌情回复,必要时先查证再回答】 + 4. 其他技术问题 → 【酌情回复,直接按用户提到的对象回答,不要引入无关的项目名/工具名作背景】 + 5. 先判断当前输入批次(无连续消息说明时就是最后一条消息)是不是在对你说: + - 如果明显是在和别人说话 → 【不要回复】 + - 如果你不能确定是不是在和你说话 → 【默认不回复】 + - 只有明确在和你说,或多人公开讨论且对话明显开放时,才进入下一步 + 6. 群聊里的主动参与只保留给公开、开放的技术或项目讨论: + - 只在多人公开讨论代码、AI、开发工具、项目进展、技术 bug 等,且不是别人之间定向交流时,才可以【极低频参与】 + - 默认更倾向不参与;不要长篇大论,一两句点到为止;如果别人已经在深入讨论且不需要你,保持沉默 + - 轻松互动、玩梗、吐槽本身不构成参与许可;只有在你已经决定要回复,且本轮明确是纯表情包/纯反应图时,才优先考虑表情包表达 + 7. 对于已经决定要回复的场景(包括被@、被拍一拍、轻量答疑,以及少量符合条件的主动参与): + - 只有明确纯表情包回复才先检索表情包,再用 memes.send_meme_by_uid 单独发一条图片消息 + - 其他需要文字承接、解释、答疑、推进任务、确认操作或表达具体态度的场景,第一轮必须优先把必要文字回复做好并调用 send_message + - 如果确实还想补表情包,把 memes.search_memes 和 memes.send_meme_by_uid 放到文字发送后的后续响应轮次,不要阻塞首条文字回复 + - 不要发送任何敷衍消息(如'懒得掺和'、'哦'等);不想回复就直接调用 end + - 严肃、任务型、高信息密度场景少发表情包,避免打断信息传递 + - 绝不要刷屏、绝不要每条都回 + 8. 对于本来就会回复的场景(私聊、被拍一拍、被@、轻量答疑): + - 如果表情包能自然增强语气、缓和语气或让表达更像真人,也只能作为后续可选补充 + - 但不要为了发表情包而牺牲信息传递;信息密度优先时仍以文字为主 + + 简单说:像个极度安静的群友。主动插话只留给公开、开放的技术或项目讨论;明显对别人说或拿不准时就闭嘴。已经决定要回复时,除非明确是纯表情包回复,否则先把文字回复做好,表情包最后再搜。""" + + +class GroupReplyMixin: + """群聊自动回复、注入防御与群聊 prompt 格式化。""" + + if TYPE_CHECKING: + ai: Any + config: Config + history_manager: MessageHistoryManager + onebot: Any + scheduler: TaskScheduler + security: SecurityService + sender: MessageSender + + async def _dispatch_grouped_request( + self, items: list[BufferedMessage] + ) -> None: ... + async def _send_image(self, tid: int, mtype: str, path: str) -> None: ... + + async def handle_auto_reply( + self, + group_id: int, + sender_id: int, + text: str, + message_content: list[dict[str, Any]], + attachments: list[dict[str, str]] | None = None, + is_poke: bool = False, + sender_name: str = "未知用户", + group_name: str = "未知群聊", + sender_role: str = "member", + sender_title: str = "", + sender_level: str = "", + trigger_message_id: int | None = None, + is_fake_at: bool = False, + ) -> None: + """群聊自动回复入口:根据消息内容、命中情况和安全检测决定是否回复""" + is_at_bot = is_poke or is_fake_at or self._is_at_bot(message_content) + logger.debug( + "[自动回复] group=%s sender=%s at_bot=%s fake_at=%s text_len=%s", + group_id, + sender_id, + is_at_bot, + is_fake_at, + len(text), + ) + + if sender_id != self.config.superadmin_qq: + logger.debug(f"[Security] 注入检测: group={group_id}, user={sender_id}") + if await self.security.detect_injection(text, message_content): + logger.warning( + f"[Security] 检测到注入攻击: group={group_id}, user={sender_id}" + ) + await self.history_manager.modify_last_group_message( + group_id, sender_id, "<这句话检测到用户进行注入,已删除>" + ) + if is_at_bot: + await self._handle_injection_response( + group_id, text, sender_id=sender_id + ) + return + + scope = make_scope(group_id=group_id) + item = BufferedMessage( + scope=scope, + sender_id=sender_id, + text=text, + message_content=list(message_content), + attachments=list(attachments or []), + sender_name=sender_name, + arrival_time=time.time(), + is_private=False, + trigger_message_id=trigger_message_id, + is_poke=is_poke, + is_at_bot=is_at_bot, + is_fake_at=is_fake_at, + group_id=group_id, + group_name=group_name, + sender_role=sender_role, + sender_title=sender_title, + sender_level=sender_level, + ) + + # 路由:拍一拍 → 永远旁路;否则按 batcher 启用情况与 @bot 处理规则决定 + if is_poke: + await self._dispatch_grouped_request([item]) + return + + batcher = getattr(self, "_batcher", None) + if batcher is not None and batcher.is_enabled_for(is_group=True): + if is_at_bot and batcher.has_buffer(scope, sender_id): + # 已有 buffer 时再来一条 @bot:单独立即处理,不打断现有 buffer + logger.info( + "[自动回复] batch 内 @bot 旁路立即处理: group=%s sender=%s", + group_id, + sender_id, + ) + await self._dispatch_grouped_request([item]) + return + await batcher.submit(item) + return + + await self._dispatch_grouped_request([item]) + + async def _execute_auto_reply(self, request: dict[str, Any]) -> None: + group_id = request["group_id"] + sender_id = request["sender_id"] + sender_name = str(request.get("sender_name") or "未知用户") + group_name = str(request.get("group_name") or "未知群聊") + full_question = request["full_question"] + trigger_message_id = request.get("trigger_message_id") + # 用于向 batcher 注册 inflight 任务(仅当本请求源自合并桶时生效) + batcher_scope: str | None = make_scope(group_id=group_id) if group_id else None + + async with RequestContext( + request_type="group", + group_id=group_id, + sender_id=sender_id, + user_id=sender_id, + ) as ctx: + + async def send_msg_cb(message: str, reply_to: int | None = None) -> None: + await self.sender.send_group_message( + group_id, + message, + reply_to=reply_to, + history_message=message, + ) + + async def get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=self.onebot, + history_manager=self.history_manager, + bot_qq=self.config.bot_qq, + attachment_registry=getattr(self.ai, "attachment_registry", None), + group_name_hint=group_name, + ) + + async def send_private_cb( + uid: int, msg: str, reply_to: int | None = None + ) -> None: + await self.sender.send_private_message(uid, msg, reply_to=reply_to) + + async def send_img_cb(tid: int, mtype: str, path: str) -> None: + await self._send_image(tid, mtype, path) + + async def send_like_cb(uid: int, times: int = 1) -> None: + await self.onebot.send_like(uid, times) + + ai_client = self.ai + memory_storage = self.ai.memory_storage + runtime_config = self.ai.runtime_config + sender = self.sender + history_manager = self.history_manager + onebot_client = self.onebot + scheduler = self.scheduler + send_message_callback = send_msg_cb + get_recent_messages_callback = get_recent_cb + get_image_url_callback = self.onebot.get_image + get_forward_msg_callback = self.onebot.get_forward_msg + send_like_callback = send_like_cb + send_private_message_callback = send_private_cb + send_image_callback = send_img_cb + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + ctx.set_resource(key, value) + if trigger_message_id is not None: + ctx.set_resource("trigger_message_id", trigger_message_id) + if request.get("_queue_lane"): + ctx.set_resource("queue_lane", request.get("_queue_lane")) + logger.debug( + "[上下文资源] group=%s keys=%s", + group_id, + ", ".join(sorted(resources.keys())), + ) + + try: + # 把当前 task 注册到 batcher,使其有能力在新消息到达时取消本次 LLM 调用 + batcher = getattr(self, "_batcher", None) + current_task = asyncio.current_task() + registered_task: asyncio.Task[Any] | None = None + if ( + batcher is not None + and batcher_scope is not None + and current_task is not None + ): + batcher.register_inflight( + batcher_scope, sender_id, current_task, ctx + ) + registered_task = current_task + try: + await self.ai.ask( + full_question, + send_message_callback=send_msg_cb, + get_recent_messages_callback=get_recent_cb, + get_image_url_callback=self.onebot.get_image, + get_forward_msg_callback=self.onebot.get_forward_msg, + send_like_callback=send_like_cb, + sender=self.sender, + history_manager=self.history_manager, + onebot_client=self.onebot, + scheduler=self.scheduler, + extra_context={ + "render_html_to_image": render_html_to_image, + "render_markdown_to_html": render_markdown_to_html, + "group_id": group_id, + "user_id": sender_id, + "is_at_bot": bool(request.get("is_at_bot", False)), + "sender_name": sender_name, + "group_name": group_name, + }, + ) + finally: + if ( + batcher is not None + and batcher_scope is not None + and registered_task is not None + ): + batcher.unregister_inflight( + batcher_scope, sender_id, registered_task + ) + except asyncio.CancelledError: + # 投机预发送被新消息抢占取消:不写错误日志、不重试 + logger.info( + "[自动回复] 任务被取消(投机抢占): group=%s sender=%s", + group_id, + sender_id, + ) + raise + except Exception: + logger.exception("自动回复执行出错") + raise + + def _is_at_bot(self, content: list[dict[str, Any]]) -> bool: + """检查消息内容中是否包含对机器人的 @ 提问""" + for seg in content: + if seg.get("type") == "at" and str( + seg.get("data", {}).get("qq", "") + ) == str(self.config.bot_qq): + return True + return False + + async def _handle_injection_response( + self, + tid: int, + text: str, + is_private: bool = False, + sender_id: Optional[int] = None, + ) -> None: + """当检测到注入攻击时,生成并发送特定的防御性回复""" + reply = await self.security.generate_injection_response(text) + if is_private: + await self.sender.send_private_message(tid, reply, auto_history=False) + await self.history_manager.add_private_message( + tid, "<对注入消息的回复>", "Bot", "Bot" + ) + else: + msg = f"[@{sender_id}] {reply}" if sender_id else reply + await self.sender.send_group_message(tid, msg, auto_history=False) + await self.history_manager.add_group_message( + tid, self.config.bot_qq, "<对注入消息的回复>", "Bot", "" + ) + + def _format_group_message_segment(self, item: BufferedMessage) -> str: + """格式化群聊单条 `` `` 块。""" + time_str = datetime.fromtimestamp(item.arrival_time).strftime( + "%Y-%m-%d %H:%M:%S" + ) + group_name = item.group_name or "未知群聊" + location = group_name if group_name.endswith("群") else f"{group_name}群" + safe_name = escape_xml_attr(item.sender_name or "未知用户") + safe_uid = escape_xml_attr(item.sender_id) + safe_gid = escape_xml_attr(item.group_id or 0) + safe_gname = escape_xml_attr(group_name) + safe_loc = escape_xml_attr(location) + safe_role = escape_xml_attr(item.sender_role or "member") + safe_title = escape_xml_attr(item.sender_title or "") + safe_time = escape_xml_attr(time_str) + safe_text = escape_xml_text(item.text) + message_id_attr = "" + if item.trigger_message_id is not None: + message_id_attr = ( + f' message_id="{escape_xml_attr(item.trigger_message_id)}"' + ) + level_attr = ( + f' level="{escape_xml_attr(item.sender_level)}"' + if item.sender_level + else "" + ) + attachment_xml = ( + f"\n{attachment_refs_to_xml(item.attachments)}" if item.attachments else "" + ) + return ( + f' \n' + f" " + ) + + def _build_prompt( + self, + prefix: str, + name: str, + uid: int, + gid: int, + gname: str, + loc: str, + role: str, + title: str, + time_str: str, + text: str, + attachments: list[dict[str, str]] | None = None, + message_id: int | None = None, + level: str = "", + ) -> str: + """构建最终发送给 AI 的结构化 XML 消息 Prompt + + 包含回复策略提示、用户信息和原始文本内容。 + """ + safe_name = escape_xml_attr(name) + safe_uid = escape_xml_attr(uid) + safe_gid = escape_xml_attr(gid) + safe_gname = escape_xml_attr(gname) + safe_loc = escape_xml_attr(loc) + safe_role = escape_xml_attr(role) + safe_title = escape_xml_attr(title) + safe_time = escape_xml_attr(time_str) + safe_text = escape_xml_text(text) + message_id_attr = "" + if message_id is not None: + message_id_attr = f' message_id="{escape_xml_attr(message_id)}"' + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" + attachment_xml = ( + f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" + ) + return f"""{prefix}{safe_text} {attachment_xml}\n" + f"+ +{_GROUP_STRATEGY_FOOTER}""" diff --git a/src/Undefined/services/coordinator/private.py b/src/Undefined/services/coordinator/private.py new file mode 100644 index 00000000..e3a6ec74 --- /dev/null +++ b/src/Undefined/services/coordinator/private.py @@ -0,0 +1,282 @@ +"""私聊回复与私聊 prompt 格式化。""" + +from __future__ import annotations + + +import asyncio +import logging +import time +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from Undefined.attachments import ( + build_attachment_scope, + dispatch_pending_file_sends, + render_message_with_pic_placeholders, + attachment_refs_to_xml, +) +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.render import render_html_to_image, render_markdown_to_html +from Undefined.services.message_batcher import BufferedMessage, make_scope +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +if TYPE_CHECKING: + from Undefined.config import Config + from Undefined.services.message_batcher import BufferedMessage + from Undefined.services.security import SecurityService + from Undefined.utils.history import MessageHistoryManager + from Undefined.utils.scheduler import TaskScheduler + from Undefined.utils.sender import MessageSender + +logger = logging.getLogger(__name__) + + +_PRIVATE_STRATEGY_FOOTER = """ + +【私聊消息】 +这是私聊消息,用户专门来找你说话。你可以自由选择是否回复: +- 如果想回复,先调用 send_message 工具发送回复内容,然后调用 end 结束对话 +- 只有明确纯表情包回复时,才先用 memes.search_memes 查表情包,再用 memes.send_meme_by_uid 单独发图;其他场景先把文字回复做好,表情包最后再搜或不搜 +- 如果不想回复,直接调用 end 结束对话即可""" + + +class PrivateReplyMixin: + """私聊自动回复与私聊 prompt 格式化。""" + + if TYPE_CHECKING: + ai: Any + config: Config + history_manager: MessageHistoryManager + onebot: Any + scheduler: TaskScheduler + security: SecurityService + sender: MessageSender + + async def _dispatch_grouped_request( + self, items: list[BufferedMessage] + ) -> None: ... + async def _handle_injection_response( + self, + tid: int, + text: str, + is_private: bool = False, + sender_id: int | None = None, + ) -> None: ... + async def _send_image(self, tid: int, mtype: str, path: str) -> None: ... + + async def handle_private_reply( + self, + user_id: int, + text: str, + message_content: list[dict[str, Any]], + attachments: list[dict[str, str]] | None = None, + is_poke: bool = False, + sender_name: str = "未知用户", + trigger_message_id: int | None = None, + ) -> None: + """处理私聊消息入口,决定回复策略并进行安全检测""" + logger.debug("[私聊回复] user=%s text_len=%s", user_id, len(text)) + if user_id != self.config.superadmin_qq: + if await self.security.detect_injection(text, message_content): + logger.warning(f"[Security] 私聊注入攻击: user_id={user_id}") + await self.history_manager.modify_last_private_message( + user_id, "<这句话检测到用户进行注入,已删除>" + ) + await self._handle_injection_response(user_id, text, is_private=True) + return + + scope = make_scope(user_id=user_id) + item = BufferedMessage( + scope=scope, + sender_id=user_id, + text=text, + message_content=list(message_content), + attachments=list(attachments or []), + sender_name=sender_name, + arrival_time=time.time(), + is_private=True, + trigger_message_id=trigger_message_id, + is_poke=is_poke, + ) + + if is_poke: + # 拍一拍旁路 batcher,立即单条入队 + await self._dispatch_grouped_request([item]) + return + + batcher = getattr(self, "_batcher", None) + if batcher is not None and batcher.is_enabled_for(is_group=False): + await batcher.submit(item) + return + + await self._dispatch_grouped_request([item]) + + async def _execute_private_reply(self, request: dict[str, Any]) -> None: + user_id = request["user_id"] + sender_name = str(request.get("sender_name") or "未知用户") + full_question = request["full_question"] + trigger_message_id = request.get("trigger_message_id") + batcher_scope: str | None = make_scope(user_id=user_id) + + async with RequestContext( + request_type="private", + user_id=user_id, + sender_id=user_id, + ) as ctx: + + async def send_msg_cb(message: str, reply_to: int | None = None) -> None: + await self.sender.send_private_message( + user_id, message, reply_to=reply_to + ) + + async def get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=self.onebot, + history_manager=self.history_manager, + bot_qq=self.config.bot_qq, + attachment_registry=getattr(self.ai, "attachment_registry", None), + ) + + async def send_img_cb(tid: int, mtype: str, path: str) -> None: + await self._send_image(tid, mtype, path) + + async def send_like_cb(uid: int, times: int = 1) -> None: + await self.onebot.send_like(uid, times) + + async def send_private_cb( + uid: int, msg: str, reply_to: int | None = None + ) -> None: + await self.sender.send_private_message(uid, msg, reply_to=reply_to) + + ai_client = self.ai + memory_storage = self.ai.memory_storage + runtime_config = self.ai.runtime_config + sender = self.sender + history_manager = self.history_manager + onebot_client = self.onebot + scheduler = self.scheduler + send_message_callback = send_msg_cb + get_recent_messages_callback = get_recent_cb + get_image_url_callback = self.onebot.get_image + get_forward_msg_callback = self.onebot.get_forward_msg + send_like_callback = send_like_cb + send_private_message_callback = send_private_cb + send_image_callback = send_img_cb + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + ctx.set_resource(key, value) + if trigger_message_id is not None: + ctx.set_resource("trigger_message_id", trigger_message_id) + if request.get("_queue_lane"): + ctx.set_resource("queue_lane", request.get("_queue_lane")) + logger.debug( + "[上下文资源] private user=%s keys=%s", + user_id, + ", ".join(sorted(resources.keys())), + ) + + try: + batcher = getattr(self, "_batcher", None) + current_task = asyncio.current_task() + registered_task: asyncio.Task[Any] | None = None + if ( + batcher is not None + and batcher_scope is not None + and current_task is not None + ): + batcher.register_inflight(batcher_scope, user_id, current_task, ctx) + registered_task = current_task + try: + result = await self.ai.ask( + full_question, + send_message_callback=send_msg_cb, + get_recent_messages_callback=get_recent_cb, + get_image_url_callback=self.onebot.get_image, + get_forward_msg_callback=self.onebot.get_forward_msg, + send_like_callback=send_like_cb, + sender=self.sender, + history_manager=self.history_manager, + onebot_client=self.onebot, + scheduler=self.scheduler, + extra_context={ + "render_html_to_image": render_html_to_image, + "render_markdown_to_html": render_markdown_to_html, + "user_id": user_id, + "is_private_chat": True, + "sender_name": sender_name, + "selected_model_name": request.get("selected_model_name"), + }, + ) + finally: + if ( + batcher is not None + and batcher_scope is not None + and registered_task is not None + ): + batcher.unregister_inflight( + batcher_scope, user_id, registered_task + ) + if result: + scope_key = build_attachment_scope( + user_id=user_id, + request_type="private", + ) + rendered = await render_message_with_pic_placeholders( + str(result), + registry=self.ai.attachment_registry, + scope_key=scope_key, + strict=False, + ) + await self.sender.send_private_message( + user_id, + rendered.delivery_text, + history_message=rendered.history_text, + ) + await dispatch_pending_file_sends( + rendered, + sender=self.sender, + target_type="private", + target_id=user_id, + registry=self.ai.attachment_registry, + ) + except asyncio.CancelledError: + logger.info("[私聊回复] 任务被取消(投机抢占): user=%s", user_id) + raise + except Exception: + logger.exception("私聊回复执行出错") + raise + + def _format_private_message_segment(self, item: BufferedMessage) -> str: + """格式化私聊单条 ``{safe_text} {attachment_xml} +`` 块。""" + time_str = datetime.fromtimestamp(item.arrival_time).strftime( + "%Y-%m-%d %H:%M:%S" + ) + safe_name = escape_xml_attr(item.sender_name or "未知用户") + safe_uid = escape_xml_attr(item.sender_id) + safe_time = escape_xml_attr(time_str) + safe_text = escape_xml_text(item.text) + message_id_attr = "" + if item.trigger_message_id is not None: + message_id_attr = ( + f' message_id="{escape_xml_attr(item.trigger_message_id)}"' + ) + attachment_xml = ( + f"\n{attachment_refs_to_xml(item.attachments)}" if item.attachments else "" + ) + return ( + f' \n' + f" " + ) diff --git a/src/Undefined/services/message_batcher/__init__.py b/src/Undefined/services/message_batcher/__init__.py new file mode 100644 index 00000000..c9195f1f --- /dev/null +++ b/src/Undefined/services/message_batcher/__init__.py @@ -0,0 +1,48 @@ +"""同 sender 短时多消息合并器(MessageBatcher)。 + +核心目标:把同一个 sender 在短时间内连续发出的消息合并到同一轮 AI 调用, +让模型一次看到全部 ``{safe_text} {attachment_xml}\n" + f"`` 块自行决定 "独立请求 / 修正 / 打断", +避免 N 条独立 LLM 调用造成的重复回复或行为打架。 + +时序:每个 (scope, sender_id) 桶内有两条独立的"静默计时器": + +- ``T1 = window_seconds`` —— "打字静默阈值"。静默达到 T1 视为用户写完, + 这一批 batch 结束。 +- ``T2 = pre_send_seconds`` —— "投机预发送阈值",要求严格小于 T1。 + 静默到 T2 时**先把当前 batch 提前发给 LLM 抢时间**(speculative pre-fire), + 但 batch 尚未结束;T1 才决定结束。 + +新消息到来: + +- 若桶处于 ``TYPING``(尚未 pre-fire):append 后重置 T1/T2。 +- 若桶处于 ``SPECULATING``(已 pre-fire,请求已入队或 inflight 在跑): + - 检查 inflight 是否已经 "向用户发出过任何消息" + (来自 ``RequestContext.get_resource("message_sent_this_turn")``)。 + - inflight 尚未发消息 → 调 ``inflight_task.cancel()``,桶回到 TYPING; + 新消息照常 append 到原有 items 后面,T1/T2 重置。 + - inflight 已经发过消息且 ``allow_cancel_after_send=False``(默认安全)→ + 保留旧 batch 让其自然走完,新消息开新 batch(即清空当前桶后立即重新作为首条入桶)。 + - inflight 已经发过消息但开关 = True → 仍 cancel(可能造成重复发送,仅极端场景)。 + +兼容回退:当 ``pre_send_seconds <= 0`` 或 ``>= window_seconds`` 时投机模式关闭, +退化为旧版 "T1 静默到期才发车" 的行为。 +""" + +# 同 sender 短时合并:T1 结束 batch,T2 投机预发送 +from Undefined.services.message_batcher.scheduler import MessageBatcher +from Undefined.services.message_batcher.state import ( + BatchDispatchToken, + BatchPhase, + BufferedMessage, + FlushCallback, + make_scope, +) + +__all__ = [ + "BatchDispatchToken", + "BatchPhase", + "BufferedMessage", + "FlushCallback", + "MessageBatcher", + "make_scope", +] diff --git a/src/Undefined/services/message_batcher.py b/src/Undefined/services/message_batcher/scheduler.py similarity index 85% rename from src/Undefined/services/message_batcher.py rename to src/Undefined/services/message_batcher/scheduler.py index 4ad94e11..45b0cf93 100644 --- a/src/Undefined/services/message_batcher.py +++ b/src/Undefined/services/message_batcher/scheduler.py @@ -1,144 +1,28 @@ -"""同 sender 短时多消息合并器(MessageBatcher)。 - -核心目标:把同一个 sender 在短时间内连续发出的消息合并到同一轮 AI 调用, -让模型一次看到全部 `` `` 块自行决定 "独立请求 / 修正 / 打断", -避免 N 条独立 LLM 调用造成的重复回复或行为打架。 - -时序:每个 (scope, sender_id) 桶内有两条独立的"静默计时器": - -- ``T1 = window_seconds`` —— "打字静默阈值"。静默达到 T1 视为用户写完, - 这一批 batch 结束。 -- ``T2 = pre_send_seconds`` —— "投机预发送阈值",要求严格小于 T1。 - 静默到 T2 时**先把当前 batch 提前发给 LLM 抢时间**(speculative pre-fire), - 但 batch 尚未结束;T1 才决定结束。 - -新消息到来: - -- 若桶处于 ``TYPING``(尚未 pre-fire):append 后重置 T1/T2。 -- 若桶处于 ``SPECULATING``(已 pre-fire,请求已入队或 inflight 在跑): - - 检查 inflight 是否已经 "向用户发出过任何消息" - (来自 ``RequestContext.get_resource("message_sent_this_turn")``)。 - - inflight 尚未发消息 → 调 ``inflight_task.cancel()``,桶回到 TYPING; - 新消息照常 append 到原有 items 后面,T1/T2 重置。 - - inflight 已经发过消息且 ``allow_cancel_after_send=False``(默认安全)→ - 保留旧 batch 让其自然走完,新消息开新 batch(即清空当前桶后立即重新作为首条入桶)。 - - inflight 已经发过消息但开关 = True → 仍 cancel(可能造成重复发送,仅极端场景)。 - -兼容回退:当 ``pre_send_seconds <= 0`` 或 ``>= window_seconds`` 时投机模式关闭, -退化为旧版 "T1 静默到期才发车" 的行为。 -""" +"""MessageBatcher 调度与 timer 逻辑。""" from __future__ import annotations +# 同 sender 短时合并:T1 结束 batch,T2 投机预发送 + import asyncio -import enum import logging import time -from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable +from typing import Any from Undefined.config.models import MessageBatcherConfig +from Undefined.services.message_batcher.state import ( + BatchDispatchToken, + BatchPhase, + BufferedMessage, + FlushCallback, + _BatchState, + _InflightInfo, +) from Undefined.utils.coerce import was_message_sent logger = logging.getLogger(__name__) -@dataclass -class BatchDispatchToken: - """一次 batch 发车的身份令牌,用于取消已入队但尚未执行的投机请求。""" - - scope: str - sender_id: int - batch_id: int - speculative: bool = False - cancelled: bool = False - - def cancel(self) -> None: - self.cancelled = True - - -@dataclass -class BufferedMessage: - """缓冲中的单条消息上下文。""" - - scope: str - sender_id: int - text: str - message_content: list[dict[str, Any]] - attachments: list[dict[str, str]] - sender_name: str - arrival_time: float - is_private: bool - trigger_message_id: int | None = None - is_poke: bool = False - is_at_bot: bool = False - is_fake_at: bool = False - # 群聊扩展字段 - group_id: int | None = None - group_name: str = "" - sender_role: str = "member" - sender_title: str = "" - sender_level: str = "" - batch_token: BatchDispatchToken | None = None - - -FlushCallback = Callable[[list[BufferedMessage]], Awaitable[None]] -"""``flush_callback(items)``:batcher 决定 fire 时调用,调用方负责拼装 prompt 并入队执行。 - -调用约定: -- batcher 的 ``flush_callback`` **不应** 立即 await LLM 的完成, - 而是把请求扔进 QueueManager 后立即返回,真正的 LLM 任务由 coordinator 在 ``execute_reply`` - 开头调用 :meth:`MessageBatcher.register_inflight` 上报。 -- 若需要 batcher 关停时也等待 in-flight 收尾,由 :meth:`MessageBatcher.flush_all` 处理。 -""" - - -class BatchPhase(enum.Enum): - """桶状态机。""" - - TYPING = "typing" # 等待 T1/T2 静默 - SPECULATING = "speculating" # T2 已触发,请求已入队或 inflight 在跑;T1 仍未到 - FINALIZING = "finalizing" # T1 已到,等 inflight(若有)自然结束 - - -@dataclass -class _InflightInfo: - """inflight LLM 任务关联信息,由 coordinator 通过 ``register_inflight`` 上报。""" - - task: asyncio.Task[Any] - # ``RequestContext`` 引用,用于判断 ``message_sent_this_turn`` 资源 - request_context: Any = None - - -@dataclass -class _BatchState: - """单个 (scope, sender_id) 桶的状态。""" - - phase: BatchPhase = BatchPhase.TYPING - items: list[BufferedMessage] = field(default_factory=list) - first_arrival_monotonic: float = 0.0 - # T1 = window_seconds 静默 timer(决定 batch 结束) - t1_handle: asyncio.TimerHandle | None = None - # T2 = pre_send_seconds 静默 timer(决定 pre-fire);投机关闭时为 None - t2_handle: asyncio.TimerHandle | None = None - # SPECULATING 阶段记录 inflight LLM 任务(由 coordinator 通过 register_inflight 注入) - inflight: _InflightInfo | None = None - # T2 fire 时由 batcher 创建的 flush task;inflight 还未上报前用于兜底取消 - speculative_flush_task: asyncio.Task[Any] | None = None - # 当前 batch 的身份令牌;T2 入队后若又来新消息,可将旧 token 标记取消, - # coordinator 在真正执行前会跳过它。 - dispatch_token: BatchDispatchToken | None = None - - -def make_scope(*, group_id: int | None = None, user_id: int | None = None) -> str: - """构造合并 key 的 scope 字符串。""" - if group_id and group_id > 0: - return f"group:{group_id}" - if user_id is not None: - return f"private:{user_id}" - return "unknown" - - class MessageBatcher: """同 sender 短时合并器(含 T2 投机预发送)。""" @@ -187,17 +71,21 @@ def is_enabled_for(self, *, is_group: bool) -> bool: return False return cfg.group_enabled if is_group else cfg.private_enabled + # 立即触发 batch 发车 def has_buffer(self, scope: str, sender_id: int) -> bool: return (scope, sender_id) in self._buckets + # 立即触发 batch 发车 async def flush_sender(self, scope: str, sender_id: int) -> bool: return await self._handle_t1((scope, sender_id), raise_on_failure=False) @property def speculative_enabled(self) -> bool: + # 0 < pre_send < window 时启用 T2 投机预发送 cfg = self._config return 0 < cfg.pre_send_seconds < cfg.window_seconds + # 提交消息进入 (scope,sender) 合并桶并重置 T1/T2 计时器 async def submit(self, item: BufferedMessage) -> None: """提交一条消息进入合并桶。 @@ -353,6 +241,7 @@ async def submit(self, item: BufferedMessage) -> None: immediate_fire_items = self._pop_locked(key) else: # T1 delay + # fixed:从首条到达时刻起算绝对 T1;extend:每条新消息重置为 window_seconds if cfg.strategy == "fixed": target = state.first_arrival_monotonic + cfg.window_seconds t1_delay = max(0.0, target - now_mono) @@ -434,6 +323,7 @@ def register_inflight( sender_id, ) + # 注销 inflight 任务 def unregister_inflight( self, scope: str, sender_id: int, task: asyncio.Task[Any] ) -> None: diff --git a/src/Undefined/services/message_batcher/state.py b/src/Undefined/services/message_batcher/state.py new file mode 100644 index 00000000..6dbea7c4 --- /dev/null +++ b/src/Undefined/services/message_batcher/state.py @@ -0,0 +1,106 @@ +"""MessageBatcher 数据模型与 scope 工具。""" + +from __future__ import annotations + +# 同 sender 短时合并:T1 结束 batch,T2 投机预发送 + +import asyncio +import enum +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable + + +@dataclass +class BatchDispatchToken: + """一次 batch 发车的身份令牌,用于取消已入队但尚未执行的投机请求。""" + + scope: str + sender_id: int + batch_id: int + speculative: bool = False + cancelled: bool = False + + def cancel(self) -> None: + self.cancelled = True + + +@dataclass +class BufferedMessage: + """缓冲中的单条消息上下文。""" + + scope: str + sender_id: int + text: str + message_content: list[dict[str, Any]] + attachments: list[dict[str, str]] + sender_name: str + arrival_time: float + is_private: bool + trigger_message_id: int | None = None + is_poke: bool = False + is_at_bot: bool = False + is_fake_at: bool = False + # 群聊扩展字段 + group_id: int | None = None + group_name: str = "" + sender_role: str = "member" + sender_title: str = "" + sender_level: str = "" + batch_token: BatchDispatchToken | None = None + + +FlushCallback = Callable[[list[BufferedMessage]], Awaitable[None]] +"""``flush_callback(items)``:batcher 决定 fire 时调用,调用方负责拼装 prompt 并入队执行。 + +调用约定: +- batcher 的 ``flush_callback`` **不应** 立即 await LLM 的完成, + 而是把请求扔进 QueueManager 后立即返回,真正的 LLM 任务由 coordinator 在 ``execute_reply`` + 开头调用 :meth:`MessageBatcher.register_inflight` 上报。 +- 若需要 batcher 关停时也等待 in-flight 收尾,由 :meth:`MessageBatcher.flush_all` 处理。 +""" + + +class BatchPhase(enum.Enum): + """桶状态机:TYPING → SPECULATING(可选) → FINALIZING。""" + + TYPING = "typing" # 等待 T1/T2 静默,用户仍在输入 + SPECULATING = "speculating" # T2 已触发投机 pre-fire,T1 未到,batch 未结束 + FINALIZING = "finalizing" # T1 已到,等待 inflight 自然结束后再释放桶 + + +@dataclass +class _InflightInfo: + """inflight LLM 任务关联信息,由 coordinator 通过 ``register_inflight`` 上报。""" + + task: asyncio.Task[Any] + # ``RequestContext`` 引用,用于判断 ``message_sent_this_turn`` 资源 + request_context: Any = None + + +@dataclass +class _BatchState: + """单个 (scope, sender_id) 桶的状态。""" + + phase: BatchPhase = BatchPhase.TYPING + items: list[BufferedMessage] = field(default_factory=list) + first_arrival_monotonic: float = 0.0 + # T1 = window_seconds 静默 timer(决定 batch 结束) + t1_handle: asyncio.TimerHandle | None = None + # T2 = pre_send_seconds 静默 timer(决定 pre-fire);投机关闭时为 None + t2_handle: asyncio.TimerHandle | None = None + # SPECULATING 阶段记录 inflight LLM 任务(由 coordinator 通过 register_inflight 注入) + inflight: _InflightInfo | None = None + # T2 fire 时由 batcher 创建的 flush task;inflight 还未上报前用于兜底取消 + speculative_flush_task: asyncio.Task[Any] | None = None + # 当前 batch 的身份令牌;T2 入队后若又来新消息,可将旧 token 标记取消, + # coordinator 在真正执行前会跳过它。 + dispatch_token: BatchDispatchToken | None = None + + +def make_scope(*, group_id: int | None = None, user_id: int | None = None) -> str: + """构造合并 key 的 scope 字符串。""" + if group_id and group_id > 0: + return f"group:{group_id}" + if user_id is not None: + return f"private:{user_id}" + return "unknown" diff --git a/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py b/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py index aabb03ab..81ab2233 100644 --- a/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py +++ b/src/Undefined/skills/agents/code_delivery_agent/tools/read/handler.py @@ -39,6 +39,7 @@ async def _read_single_file( async with aiofiles.open( full_path, "r", encoding="utf-8", errors="replace" ) as f: + # async for 循环 async for _ in f: line_count += 1 diff --git a/src/Undefined/skills/agents/runner.py b/src/Undefined/skills/agents/runner.py deleted file mode 100644 index 9e3cc326..00000000 --- a/src/Undefined/skills/agents/runner.py +++ /dev/null @@ -1,387 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -from pathlib import Path -from typing import Any - -import aiofiles - -from Undefined.config.models import AgentModelConfig -from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY -from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry -from Undefined.skills.anthropic_skills import AnthropicSkillRegistry -from Undefined.utils.tool_calls import parse_tool_arguments - - -async def load_prompt_text(agent_dir: Path, default_prompt: str) -> str: - """从 agent 目录加载 prompt.md,缺失时返回默认提示词。""" - - prompt_path = agent_dir / "prompt.md" - if prompt_path.exists(): - async with aiofiles.open(prompt_path, "r", encoding="utf-8") as file: - return await file.read() - return default_prompt - - -def _filter_tools_for_runtime_config( - agent_name: str, - tools: list[dict[str, Any]], - runtime_config: Any | None, -) -> list[dict[str, Any]]: - if agent_name != "web_agent" or runtime_config is None: - return tools - - if bool(getattr(runtime_config, "grok_search_enabled", False)): - return tools - - filtered: list[dict[str, Any]] = [] - for tool in tools: - function = tool.get("function") if isinstance(tool, dict) else None - name = function.get("name") if isinstance(function, dict) else None - if name == "grok_search": - continue - filtered.append(tool) - return filtered - - -async def run_agent_with_tools( - *, - agent_name: str, - user_content: str, - context_messages: list[dict[str, str]] | None = None, - empty_user_content_message: str, - default_prompt: str, - context: dict[str, Any], - agent_dir: Path, - logger: logging.Logger, - max_iterations: int = 20, - tool_error_prefix: str = "错误", -) -> str: - """执行通用 Agent 循环。 - - 该方法统一处理: - - prompt 加载 - - LLM 迭代决策 - - tool call 并发执行 - - tool 结果回填 messages - """ - - if not user_content.strip(): - return empty_user_content_message - - tool_registry = AgentToolRegistry( - agent_dir / "tools", - current_agent_name=agent_name, - is_main_agent=False, - ) - tools = tool_registry.get_tools_schema() - runtime_config = context.get("runtime_config") - tools = _filter_tools_for_runtime_config(agent_name, tools, runtime_config) - - # 发现并加载 agent 私有 Anthropic Skills(可选) - agent_skills_dir = agent_dir / "anthropic_skills" - agent_skill_registry: AnthropicSkillRegistry | None = None - if agent_skills_dir.exists() and agent_skills_dir.is_dir(): - agent_skill_registry = AnthropicSkillRegistry(agent_skills_dir) - if agent_skill_registry.has_skills(): - # 将 anthropic skill tools 加入 agent 的可用工具列表 - tools = tools + agent_skill_registry.get_tools_schema() - logger.info( - "[Agent:%s] 加载了 %d 个私有 Anthropic Skills", - agent_name, - len(agent_skill_registry.get_all_skills()), - ) - - ai_client = context.get("ai_client") - if not ai_client: - return "AI client 未在上下文中提供" - - model_config_override = context.get("model_config_override") - if isinstance(model_config_override, AgentModelConfig): - agent_config = model_config_override - else: - agent_config = ai_client.agent_config - # 动态选择 agent 模型 - group_id = context.get("group_id", 0) or 0 - user_id = context.get("user_id", 0) or 0 - global_enabled = runtime_config.model_pool_enabled if runtime_config else False - agent_config = ai_client.model_selector.select_agent_config( - agent_config, - group_id=group_id, - user_id=user_id, - global_enabled=global_enabled, - ) - system_prompt = await load_prompt_text(agent_dir, default_prompt) - - # 注入 agent 私有 Anthropic Skills 元数据到 system prompt - if agent_skill_registry and agent_skill_registry.has_skills(): - skills_xml = agent_skill_registry.build_metadata_xml() - if skills_xml: - system_prompt = ( - f"{system_prompt}\n\n" - f"【可用的 Anthropic Skills】\n" - f"{skills_xml}\n\n" - f"注意:以上是你可用的 Anthropic Agent Skills。" - f"当任务与某个 skill 相关时," - f"可以调用对应的 skill tool(tool_name 字段)" - f"来获取该领域的详细指令和知识。" - ) - - agent_history = context.get("agent_history", []) - - messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] - if agent_history: - messages.extend(agent_history) - if context_messages: - messages.extend(context_messages) - messages.append({"role": "user", "content": user_content}) - transport_state: dict[str, Any] | None = None - queue_lane = context.get("queue_lane") - max_pre_tool_retries = max( - 0, int(getattr(runtime_config, "ai_request_max_retries", 0) or 0) - ) - pre_tool_failure_count = 0 - - for iteration in range(1, max_iterations + 1): - logger.debug("[Agent:%s] iteration=%s", agent_name, iteration) - message_checkpoint_len = len(messages) - transport_state_checkpoint = transport_state - try: - result = await ai_client.submit_queued_llm_call( - model_config=agent_config, - messages=messages, - max_tokens=agent_config.max_tokens, - call_type=f"agent:{agent_name}", - tools=tools if tools else None, - tool_choice="auto", - transport_state=transport_state, - queue_lane=queue_lane, - ) - except Exception as exc: - logger.exception( - "[Agent:%s] queued LLM 调用失败: lane=%s iteration=%s error=%s", - agent_name, - queue_lane, - iteration, - exc, - ) - raise RuntimeError("智能体模型请求失败") from exc - - try: - tool_execution_started = False - tool_name_map = ( - result.get("_tool_name_map") if isinstance(result, dict) else None - ) - api_to_internal: dict[str, str] = {} - if isinstance(tool_name_map, dict): - raw_api_to_internal = tool_name_map.get("api_to_internal") - if isinstance(raw_api_to_internal, dict): - api_to_internal = { - str(key): str(value) - for key, value in raw_api_to_internal.items() - } - - next_transport_state = ( - result.get("_transport_state") if isinstance(result, dict) else None - ) - transport_state = ( - next_transport_state if isinstance(next_transport_state, dict) else None - ) - - choice: dict[str, Any] = result.get("choices", [{}])[0] - message: dict[str, Any] = choice.get("message", {}) - content: str = message.get("content") or "" - reasoning_content: str | None = message.get("reasoning_content") - tool_calls: list[dict[str, Any]] = message.get("tool_calls", []) - - if content.strip() and tool_calls: - content = "" - - if not tool_calls: - return content - - assistant_message: dict[str, Any] = { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) - if isinstance(output_items, list): - assistant_message[RESPONSES_OUTPUT_ITEMS_KEY] = output_items - capture_reasoning = bool( - getattr(agent_config, "thinking_tool_call_compat", False) - ) or bool(getattr(agent_config, "reasoning_content_replay", False)) - if capture_reasoning and reasoning_content is not None: - assistant_message["reasoning_content"] = reasoning_content - messages.append(assistant_message) - - tool_tasks: list[asyncio.Future[Any]] = [] - tool_call_ids: list[str] = [] - tool_api_names: list[str] = [] - end_tool_call: dict[str, Any] | None = None - end_tool_args: dict[str, Any] = {} - - for tool_call in tool_calls: - call_id = str(tool_call.get("id", "")) - function: dict[str, Any] = tool_call.get("function", {}) - api_function_name = str(function.get("name", "")) - raw_args = function.get("arguments") - - internal_function_name = api_to_internal.get( - api_function_name, api_function_name - ) - logger.info( - "[Agent:%s] preparing tool=%s", - agent_name, - internal_function_name, - ) - - function_args = parse_tool_arguments( - raw_args, - logger=logger, - tool_name=api_function_name, - ) - - if not isinstance(function_args, dict): - function_args = {} - - # 检测 end 工具,暂存后统一处理 - if internal_function_name == "end": - if len(tool_calls) > 1: - logger.warning( - "[Agent:%s] end 与其他工具同时调用," - "将先执行其他工具,并回填 end 跳过结果", - agent_name, - ) - end_tool_call = tool_call - end_tool_args = function_args - continue - - tool_call_ids.append(call_id) - tool_api_names.append(api_function_name) - - # Anthropic Skill tool 路由 - # 工具名格式: skills ,如 skills-_-pdf-processing - skill_delimiter = ( - agent_skill_registry.dot_delimiter - if agent_skill_registry - else "-_-" - ) - is_agent_skill = internal_function_name.startswith( - f"skills{skill_delimiter}" - ) - if is_agent_skill and agent_skill_registry: - tool_tasks.append( - asyncio.ensure_future( - agent_skill_registry.execute_skill_tool( - internal_function_name, - function_args, - context, - ) - ) - ) - else: - tool_tasks.append( - asyncio.ensure_future( - tool_registry.execute_tool( - internal_function_name, - function_args, - context, - ) - ) - ) - - if tool_tasks: - tool_execution_started = True - logger.info( - "[Agent:%s] executing tools in parallel: count=%s", - agent_name, - len(tool_tasks), - ) - results = await asyncio.gather(*tool_tasks, return_exceptions=True) - - for index, tool_result in enumerate(results): - call_id = tool_call_ids[index] - api_tool_name = tool_api_names[index] - if isinstance(tool_result, Exception): - content_str = f"{tool_error_prefix}: {tool_result}" - else: - content_str = str(tool_result) - - messages.append( - { - "role": "tool", - "tool_call_id": call_id, - "name": api_tool_name, - "content": content_str, - } - ) - - # 处理 end 工具调用 - if end_tool_call: - end_call_id = str(end_tool_call.get("id", "")) - end_api_name = end_tool_call.get("function", {}).get("name", "end") - if tool_tasks: - # end 与其他工具同时调用:跳过执行,但仍回填 tool 响应, - # 避免 assistant.tool_calls 出现未配对的 tool_call_id。 - skip_content = ( - "end 与其他工具同轮调用,本轮未执行 end;" - "请根据其他工具结果继续决策。" - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": skip_content, - } - ) - logger.info( - "[Agent:%s] end 与其他工具同时调用,已回填跳过响应", - agent_name, - ) - else: - # end 单独调用,正常执行(参数已在循环中解析) - tool_execution_started = True - end_result = await tool_registry.execute_tool( - "end", end_tool_args, context - ) - messages.append( - { - "role": "tool", - "tool_call_id": end_call_id, - "name": end_api_name, - "content": str(end_result), - } - ) - pre_tool_failure_count = 0 - - except Exception as exc: - if ( - not tool_execution_started - and pre_tool_failure_count < max_pre_tool_retries - ): - pre_tool_failure_count += 1 - del messages[message_checkpoint_len:] - transport_state = transport_state_checkpoint - logger.warning( - "[Agent:%s] pre-tool 本地失败,重试当前 LLM 轮次: lane=%s retry=%s/%s iteration=%s error=%s", - agent_name, - queue_lane, - pre_tool_failure_count, - max_pre_tool_retries, - iteration, - exc, - ) - continue - logger.exception( - "[Agent:%s] 执行失败,已静默抑制: lane=%s iteration=%s error=%s", - agent_name, - queue_lane, - iteration, - exc, - ) - return "" - - return "达到最大迭代次数" diff --git a/src/Undefined/skills/agents/runner/__init__.py b/src/Undefined/skills/agents/runner/__init__.py new file mode 100644 index 00000000..09676932 --- /dev/null +++ b/src/Undefined/skills/agents/runner/__init__.py @@ -0,0 +1,12 @@ +"""Agent runner 子包:context 准备、工具执行与 LLM 迭代循环。""" + +# 对外 re-export,兼容 `from Undefined.skills.agents.runner import run_agent_with_tools` +from Undefined.skills.agents.runner.context import load_prompt_text +from Undefined.skills.agents.runner.loop import run_agent_with_tools +from Undefined.skills.agents.runner.tools import _filter_tools_for_runtime_config + +__all__ = [ + "load_prompt_text", + "run_agent_with_tools", + "_filter_tools_for_runtime_config", +] diff --git a/src/Undefined/skills/agents/runner/context.py b/src/Undefined/skills/agents/runner/context.py new file mode 100644 index 00000000..3e4e399e --- /dev/null +++ b/src/Undefined/skills/agents/runner/context.py @@ -0,0 +1,133 @@ +# Agent 运行前上下文准备(工具注册表、模型、消息链) +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import aiofiles + +from Undefined.config.models import AgentModelConfig +from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry +from Undefined.skills.anthropic_skills import AnthropicSkillRegistry + +from Undefined.skills.agents.runner.tools import _filter_tools_for_runtime_config + + +# 异步读取 agent 目录下的 prompt.md +async def load_prompt_text(agent_dir: Path, default_prompt: str) -> str: + """从 agent 目录加载 prompt.md,缺失时返回默认提示词。""" + + prompt_path = agent_dir / "prompt.md" + if prompt_path.exists(): + async with aiofiles.open(prompt_path, "r", encoding="utf-8") as file: + return await file.read() + return default_prompt + + +@dataclass +# 类:PreparedAgentRun +class PreparedAgentRun: + tool_registry: AgentToolRegistry + agent_skill_registry: AnthropicSkillRegistry | None + tools: list[dict[str, Any]] + agent_config: AgentModelConfig + messages: list[dict[str, Any]] + ai_client: Any + queue_lane: Any + max_pre_tool_retries: int + + +# 准备 Agent 运行上下文:工具、模型、消息链 +async def prepare_agent_run( + *, + agent_name: str, + user_content: str, + context_messages: list[dict[str, str]] | None, + default_prompt: str, + context: dict[str, Any], + agent_dir: Path, + logger: Any, +) -> PreparedAgentRun | str: + # 为当前 Agent 实例化私有工具注册表(含 callable agent 扫描) + tool_registry = AgentToolRegistry( + agent_dir / "tools", + current_agent_name=agent_name, + is_main_agent=False, + ) + tools = tool_registry.get_tools_schema() + runtime_config = context.get("runtime_config") + tools = _filter_tools_for_runtime_config(agent_name, tools, runtime_config) + + agent_skills_dir = agent_dir / "anthropic_skills" + agent_skill_registry: AnthropicSkillRegistry | None = None + # 可选:加载 Agent 目录下的 Anthropic Skills 并追加 tool schema + if agent_skills_dir.exists() and agent_skills_dir.is_dir(): + agent_skill_registry = AnthropicSkillRegistry(agent_skills_dir) + if agent_skill_registry.has_skills(): + tools = tools + agent_skill_registry.get_tools_schema() + logger.info( + "[Agent:%s] 加载了 %d 个私有 Anthropic Skills", + agent_name, + len(agent_skill_registry.get_all_skills()), + ) + + ai_client = context.get("ai_client") + if not ai_client: + return "AI client 未在上下文中提供" + + model_config_override = context.get("model_config_override") + if isinstance(model_config_override, AgentModelConfig): + agent_config = model_config_override + else: + agent_config = ai_client.agent_config + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + global_enabled = runtime_config.model_pool_enabled if runtime_config else False + # 多模型池:按群/私聊上下文选择 Agent 专用模型配置 + agent_config = ai_client.model_selector.select_agent_config( + agent_config, + group_id=group_id, + user_id=user_id, + global_enabled=global_enabled, + ) + system_prompt = await load_prompt_text(agent_dir, default_prompt) + + if agent_skill_registry and agent_skill_registry.has_skills(): + skills_xml = agent_skill_registry.build_metadata_xml() + if skills_xml: + system_prompt = ( + f"{system_prompt}\n\n" + f"【可用的 Anthropic Skills】\n" + f"{skills_xml}\n\n" + f"注意:以上是你可用的 Anthropic Agent Skills。" + f"当任务与某个 skill 相关时," + f"可以调用对应的 skill tool(tool_name 字段)" + f"来获取该领域的详细指令和知识。" + ) + + agent_history = context.get("agent_history", []) + + # 组装 LLM 消息链:system → agent 历史 → 上下文 → 当前用户输入 + messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}] + if agent_history: + messages.extend(agent_history) + if context_messages: + messages.extend(context_messages) + messages.append({"role": "user", "content": user_content}) + + queue_lane = context.get("queue_lane") + max_pre_tool_retries = max( + 0, int(getattr(runtime_config, "ai_request_max_retries", 0) or 0) + ) + + return PreparedAgentRun( + tool_registry=tool_registry, + agent_skill_registry=agent_skill_registry, + tools=tools, + agent_config=agent_config, + messages=messages, + ai_client=ai_client, + queue_lane=queue_lane, + max_pre_tool_retries=max_pre_tool_retries, + ) diff --git a/src/Undefined/skills/agents/runner/loop.py b/src/Undefined/skills/agents/runner/loop.py new file mode 100644 index 00000000..bc215e8b --- /dev/null +++ b/src/Undefined/skills/agents/runner/loop.py @@ -0,0 +1,182 @@ +# Agent LLM↔工具迭代循环核心 +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY +from Undefined.skills.agents.runner.context import prepare_agent_run +from Undefined.skills.agents.runner.tools import execute_assistant_tool_calls + + +# Agent 主循环:LLM 决策 → 工具执行 → 结果回填 +async def run_agent_with_tools( + *, + agent_name: str, + user_content: str, + context_messages: list[dict[str, str]] | None = None, + empty_user_content_message: str, + default_prompt: str, + context: dict[str, Any], + agent_dir: Path, + logger: logging.Logger, + max_iterations: int = 20, + tool_error_prefix: str = "错误", +) -> str: + """执行通用 Agent 循环。 + + 该方法统一处理: + - prompt 加载 + - LLM 迭代决策 + - tool call 并发执行 + - tool 结果回填 messages + """ + + # 空输入直接返回提示 + if not user_content.strip(): + return empty_user_content_message + + prepared = await prepare_agent_run( + agent_name=agent_name, + user_content=user_content, + context_messages=context_messages, + default_prompt=default_prompt, + context=context, + agent_dir=agent_dir, + logger=logger, + ) + # prepare 失败时 prepared 为错误字符串 + if isinstance(prepared, str): + return prepared + + messages = prepared.messages + transport_state: dict[str, Any] | None = None + pre_tool_failure_count = 0 + + # Agent 主循环:LLM 决策 → 工具执行 → 结果回填,直到无 tool_calls + # 迭代上限防止无限 tool 循环 + for iteration in range(1, max_iterations + 1): + logger.debug("[Agent:%s] iteration=%s", agent_name, iteration) + # 记录 checkpoint,pre-tool 失败时可回滚 messages / transport_state + message_checkpoint_len = len(messages) + transport_state_checkpoint = transport_state + try: + # 通过队列提交 LLM 请求(含 tools 与 transport 多轮状态) + result = await prepared.ai_client.submit_queued_llm_call( + model_config=prepared.agent_config, + messages=messages, + max_tokens=prepared.agent_config.max_tokens, + call_type=f"agent:{agent_name}", + tools=prepared.tools if prepared.tools else None, + tool_choice="auto", + transport_state=transport_state, + queue_lane=prepared.queue_lane, + ) + except Exception as exc: + logger.exception( + "[Agent:%s] queued LLM 调用失败: lane=%s iteration=%s error=%s", + agent_name, + prepared.queue_lane, + iteration, + exc, + ) + raise RuntimeError("智能体模型请求失败") from exc + + try: + tool_execution_started = False + tool_name_map = ( + result.get("_tool_name_map") if isinstance(result, dict) else None + ) + # API 工具名与内部 registry 名称的映射(含 dot 分隔符转换) + api_to_internal: dict[str, str] = {} + if isinstance(tool_name_map, dict): + raw_api_to_internal = tool_name_map.get("api_to_internal") + if isinstance(raw_api_to_internal, dict): + api_to_internal = { + str(key): str(value) + for key, value in raw_api_to_internal.items() + } + + next_transport_state = ( + result.get("_transport_state") if isinstance(result, dict) else None + ) + # Responses API 等多轮 transport 状态,下一轮 LLM 调用需回传 + transport_state = ( + next_transport_state if isinstance(next_transport_state, dict) else None + ) + + choice: dict[str, Any] = result.get("choices", [{}])[0] + message: dict[str, Any] = choice.get("message", {}) + content: str = message.get("content") or "" + reasoning_content: str | None = message.get("reasoning_content") + tool_calls: list[dict[str, Any]] = message.get("tool_calls", []) + + # 模型同时返回文本与工具调用时,优先走工具路径 + if content.strip() and tool_calls: + content = "" + + # 无工具调用即视为最终回复 + if not tool_calls: + return content + + # 将 assistant 消息(含 tool_calls)追加到对话历史 + assistant_message: dict[str, Any] = { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + output_items = message.get(RESPONSES_OUTPUT_ITEMS_KEY) + if isinstance(output_items, list): + assistant_message[RESPONSES_OUTPUT_ITEMS_KEY] = output_items + # 部分模型需回放 reasoning_content 以兼容 thinking + tool_call + capture_reasoning = bool( + getattr(prepared.agent_config, "thinking_tool_call_compat", False) + ) or bool(getattr(prepared.agent_config, "reasoning_content_replay", False)) + if capture_reasoning and reasoning_content is not None: + assistant_message["reasoning_content"] = reasoning_content + messages.append(assistant_message) + + # 并发执行 tool_calls,结果以 role=tool 消息回填 + tool_execution_started = await execute_assistant_tool_calls( + agent_name=agent_name, + tool_calls=tool_calls, + api_to_internal=api_to_internal, + messages=messages, + tool_registry=prepared.tool_registry, + agent_skill_registry=prepared.agent_skill_registry, + context=context, + logger=logger, + tool_error_prefix=tool_error_prefix, + ) + pre_tool_failure_count = 0 + + except Exception as exc: + # pre-tool 本地异常:在未开始执行工具前可重试当前 LLM 轮次 + if ( + not tool_execution_started + and pre_tool_failure_count < prepared.max_pre_tool_retries + ): + pre_tool_failure_count += 1 + del messages[message_checkpoint_len:] + transport_state = transport_state_checkpoint + logger.warning( + "[Agent:%s] pre-tool 本地失败,重试当前 LLM 轮次: lane=%s retry=%s/%s iteration=%s error=%s", + agent_name, + prepared.queue_lane, + pre_tool_failure_count, + prepared.max_pre_tool_retries, + iteration, + exc, + ) + continue + logger.exception( + "[Agent:%s] 执行失败,已静默抑制: lane=%s iteration=%s error=%s", + agent_name, + prepared.queue_lane, + iteration, + exc, + ) + return "" + + return "达到最大迭代次数" diff --git a/src/Undefined/skills/agents/runner/tools.py b/src/Undefined/skills/agents/runner/tools.py new file mode 100644 index 00000000..31d79177 --- /dev/null +++ b/src/Undefined/skills/agents/runner/tools.py @@ -0,0 +1,179 @@ +# Agent 工具并发调度与 end 工具特殊处理 +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT +from Undefined.skills.anthropic_skills import AnthropicSkillRegistry +from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry +from Undefined.utils.tool_calls import parse_tool_arguments + + +# 按运行时配置过滤不可用工具 schema +def _filter_tools_for_runtime_config( + agent_name: str, + tools: list[dict[str, Any]], + runtime_config: Any | None, +) -> list[dict[str, Any]]: + # web_agent 在 grok 未启用时从 schema 中剔除 grok_search + if agent_name != "web_agent" or runtime_config is None: + return tools + + if bool(getattr(runtime_config, "grok_search_enabled", False)): + return tools + + filtered: list[dict[str, Any]] = [] + for tool in tools: + function = tool.get("function") if isinstance(tool, dict) else None + name = function.get("name") if isinstance(function, dict) else None + if name == "grok_search": + continue + filtered.append(tool) + return filtered + + +# 并发执行 tool_calls 并回填 tool 消息 +async def execute_assistant_tool_calls( + *, + agent_name: str, + tool_calls: list[dict[str, Any]], + api_to_internal: dict[str, str], + messages: list[dict[str, Any]], + tool_registry: AgentToolRegistry, + agent_skill_registry: AnthropicSkillRegistry | None, + context: dict[str, Any], + logger: logging.Logger, + tool_error_prefix: str, +) -> bool: + """并发执行 assistant 的 tool_calls,回填 tool 消息。返回是否已开始工具执行。""" + + tool_tasks: list[asyncio.Future[Any]] = [] + tool_call_ids: list[str] = [] + tool_api_names: list[str] = [] + end_tool_call: dict[str, Any] | None = None + end_tool_args: dict[str, Any] = {} + tool_execution_started = False + + for tool_call in tool_calls: + call_id = str(tool_call.get("id", "")) + function: dict[str, Any] = tool_call.get("function", {}) + api_function_name = str(function.get("name", "")) + raw_args = function.get("arguments") + + internal_function_name = api_to_internal.get( + api_function_name, api_function_name + ) + logger.info( + "[Agent:%s] preparing tool=%s", + agent_name, + internal_function_name, + ) + + function_args = parse_tool_arguments( + raw_args, + logger=logger, + tool_name=api_function_name, + ) + + if not isinstance(function_args, dict): + function_args = {} + + # end 工具延后处理:若与其他工具同批调用则返回拒绝 + if internal_function_name == "end": + if len(tool_calls) > 1: + logger.warning( + "[Agent:%s] end 与其他工具同时调用," + "将先执行其他工具,end 将返回拒绝结果", + agent_name, + ) + end_tool_call = tool_call + end_tool_args = function_args + continue + + tool_call_ids.append(call_id) + tool_api_names.append(api_function_name) + + skill_delimiter = ( + agent_skill_registry.dot_delimiter if agent_skill_registry else "-_-" + ) + # Anthropic Skill 走独立 registry,其余走 AgentToolRegistry + is_agent_skill = internal_function_name.startswith(f"skills{skill_delimiter}") + if is_agent_skill and agent_skill_registry: + tool_tasks.append( + asyncio.ensure_future( + agent_skill_registry.execute_skill_tool( + internal_function_name, + function_args, + context, + ) + ) + ) + else: + tool_tasks.append( + asyncio.ensure_future( + tool_registry.execute_tool( + internal_function_name, + function_args, + context, + ) + ) + ) + + if tool_tasks: + tool_execution_started = True + logger.info( + "[Agent:%s] executing tools in parallel: count=%s", + agent_name, + len(tool_tasks), + ) + # 同轮 tool_calls 并发执行,异常转为 tool 消息内容 + results = await asyncio.gather(*tool_tasks, return_exceptions=True) + + for index, tool_result in enumerate(results): + call_id = tool_call_ids[index] + api_tool_name = tool_api_names[index] + if isinstance(tool_result, Exception): + content_str = f"{tool_error_prefix}: {tool_result}" + else: + content_str = str(tool_result) + + messages.append( + { + "role": "tool", + "tool_call_id": call_id, + "name": api_tool_name, + "content": content_str, + } + ) + + if end_tool_call: + end_call_id = str(end_tool_call.get("id", "")) + end_api_name = end_tool_call.get("function", {}).get("name", "end") + if tool_tasks: + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": END_CO_CALL_REJECT_CONTENT, + } + ) + logger.info( + "[Agent:%s] end 与其他工具同时调用,其它工具已执行,end 已回填拒绝响应", + agent_name, + ) + else: + tool_execution_started = True + end_result = await tool_registry.execute_tool("end", end_tool_args, context) + messages.append( + { + "role": "tool", + "tool_call_id": end_call_id, + "name": end_api_name, + "content": str(end_result), + } + ) + + return tool_execution_started diff --git a/src/Undefined/skills/tools/__init__.py b/src/Undefined/skills/tools/__init__.py index 2f23af88..4c316330 100644 --- a/src/Undefined/skills/tools/__init__.py +++ b/src/Undefined/skills/tools/__init__.py @@ -114,13 +114,9 @@ def _log_tools_summary(self, include_mcp: bool = True) -> None: """ tool_names = list(self._items.keys()) - # 分类工具 basic_tools, toolset_tools, mcp_tools = self._categorize_tools(tool_names) - - # 按类别分组工具集工具 toolset_by_category = self._group_toolsets_by_category(toolset_tools) - # 输出统计信息 logger.info("=" * 60) if include_mcp: logger.info("工具加载完成统计 (包含 MCP)") diff --git a/src/Undefined/skills/tools/bilibili_video/handler.py b/src/Undefined/skills/tools/bilibili_video/handler.py index 2ff07a8e..2650f9ac 100644 --- a/src/Undefined/skills/tools/bilibili_video/handler.py +++ b/src/Undefined/skills/tools/bilibili_video/handler.py @@ -23,7 +23,6 @@ def _resolve_target( return None, "target_id 必须是整数" return (target_type, target_id), None # type: ignore[return-value] - # 从上下文推断 request_type = context.get("request_type") if request_type == "group": group_id = context.get("group_id") @@ -34,7 +33,6 @@ def _resolve_target( if user_id: return ("private", int(user_id)), None - # 兜底 group_id = context.get("group_id") if group_id: return ("group", int(group_id)), None @@ -51,13 +49,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not video_id: return "video_id 不能为空" - # 解析目标 target, error = _resolve_target(args, context) if error or target is None: return f"目标解析失败: {error or '参数错误'}" target_type, target_id = target - # 获取配置 runtime_config = context.get("runtime_config") sender = context.get("sender") onebot = context.get("onebot_client") or context.get("onebot") @@ -67,7 +63,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not sender or not onebot: return "缺少必要的运行时组件(sender/onebot)" - # 读取 bilibili 配置 cookie = "" prefer_quality = 80 max_duration = 600 diff --git a/src/Undefined/skills/tools/fetch_image_uid/handler.py b/src/Undefined/skills/tools/fetch_image_uid/handler.py index 7986c811..cc84bbcc 100644 --- a/src/Undefined/skills/tools/fetch_image_uid/handler.py +++ b/src/Undefined/skills/tools/fetch_image_uid/handler.py @@ -38,7 +38,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception("fetch_image_uid 注册失败: %s", exc) return f"获取图片失败:{exc}" - # 验证是否为图片类型 mime = str(getattr(record, "mime_type", "") or "").strip().lower() if mime and not mime.startswith(_IMAGE_MIME_PREFIX): return f"URL 内容不是图片类型(检测到 {mime}),仅支持图片" diff --git a/src/Undefined/skills/tools/get_current_time/handler.py b/src/Undefined/skills/tools/get_current_time/handler.py index 1c79700a..5996b743 100644 --- a/src/Undefined/skills/tools/get_current_time/handler.py +++ b/src/Undefined/skills/tools/get_current_time/handler.py @@ -37,7 +37,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: indent=2, ) - # 默认返回 ISO 格式 return now.isoformat(timespec="seconds") @@ -51,7 +50,6 @@ def _format_text( """生成人类可读的文本格式""" lines = [] - # 公历信息 weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] weekday = weekdays[now.weekday()] tz_offset = now.strftime("%z") @@ -62,7 +60,6 @@ def _format_text( f"{now.hour:02d}:{now.minute:02d}:{now.second:02d} ({tz_str})" ) - # 农历信息 if include_lunar and lunar: year_gz = lunar.getYearInGanZhi() zodiac = lunar.getYearShengXiao() @@ -70,19 +67,15 @@ def _format_text( day_cn = lunar.getDayInChinese() lines.append(f"农历:{year_gz}年({zodiac}年) {month_cn}{day_cn}") - # 干支信息 month_gz = lunar.getMonthInGanZhi() day_gz = lunar.getDayInGanZhi() lines.append(f"干支:{year_gz}年 {month_gz}月 {day_gz}日") - # 黄历信息 if include_almanac and lunar: - # 节气 jieqi = lunar.getCurrentJieQi() if jieqi: lines.append(f"节气:{jieqi.getName()}") - # 节日 festivals = [] if solar: solar_festivals = solar.getFestivals() @@ -92,7 +85,6 @@ def _format_text( if festivals: lines.append(f"节日:{' '.join(festivals)}") - # 宜忌 yi = lunar.getDayYi() if yi: lines.append(f"宜:{' '.join(yi)}") @@ -100,7 +92,6 @@ def _format_text( if ji: lines.append(f"忌:{' '.join(ji)}") - # 冲煞 chong = lunar.getDayChongDesc() sha = lunar.getDaySha() if chong or sha: @@ -109,7 +100,6 @@ def _format_text( chong_sha += f"煞{sha}" lines.append(f"冲煞:{chong_sha}") - # 胎神 tai = lunar.getDayPositionTai() if tai: lines.append(f"胎神:{tai}") @@ -127,7 +117,6 @@ def _format_json( """生成结构化 JSON 格式""" result: Dict[str, Any] = {} - # 公历信息 weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] weekday = weekdays[now.weekday()] tz_offset = now.strftime("%z") @@ -145,7 +134,6 @@ def _format_json( "timezone": tz_str, } - # 农历信息 if include_lunar and lunar: result["lunar"] = { "year_cn": lunar.getYearInGanZhi(), @@ -159,16 +147,13 @@ def _format_json( }, } - # 黄历信息 if include_almanac and lunar: almanac: Dict[str, Any] = {} - # 节气 jieqi = lunar.getCurrentJieQi() if jieqi: almanac["solar_term"] = {"current": jieqi.getName()} - # 节日 festivals = [] if solar: solar_festivals = solar.getFestivals() @@ -178,7 +163,6 @@ def _format_json( if festivals: almanac["festivals"] = festivals - # 宜忌 yi = lunar.getDayYi() if yi: almanac["yi"] = yi @@ -186,7 +170,6 @@ def _format_json( if ji: almanac["ji"] = ji - # 冲煞 chong = lunar.getDayChongDesc() sha = lunar.getDaySha() if chong or sha: @@ -195,7 +178,6 @@ def _format_json( chong_sha += f"煞{sha}" almanac["chong"] = chong_sha - # 胎神 tai = lunar.getDayPositionTai() if tai: almanac["fetal_god"] = tai diff --git a/src/Undefined/skills/tools/get_picture/handler.py b/src/Undefined/skills/tools/get_picture/handler.py index d580dce4..5fbb6893 100644 --- a/src/Undefined/skills/tools/get_picture/handler.py +++ b/src/Undefined/skills/tools/get_picture/handler.py @@ -98,7 +98,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: device = args.get("device", "pc") fourk_type = args.get("fourk_type", "acg") - # 参数验证 if delivery not in {"embed", "send"}: return f"delivery 无效:{delivery}。仅支持 embed 或 send" @@ -117,14 +116,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if picture_type == "random4kPic" and fourk_type not in ("acg", "wallpaper"): return "4K图片类型必须是 acg(二次元)或 wallpaper(风景)" - # 构造请求参数 params: Dict[str, Any] = {"return": "json"} if picture_type == "acg": params["type"] = device elif picture_type == "random4kPic": params["type"] = fourk_type - # 创建图片保存目录 from Undefined.utils.paths import IMAGE_CACHE_DIR, ensure_dir img_dir = ensure_dir(IMAGE_CACHE_DIR) @@ -133,7 +130,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: base_url = _get_xxapi_base_url() api_url = f"{base_url}{API_PATHS[picture_type]}" - # 获取图片 success_count = 0 fail_count = 0 local_image_paths: list[str] = [] @@ -153,14 +149,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 美腿类型直接返回 JPEG 图片,不需要解析 JSON if picture_type == "meitui": - # 验证响应内容类型 content_type = response.headers.get("content-type", "") if "image" not in content_type.lower(): logger.error(f"响应不是图片格式: {content_type}") fail_count += 1 continue - # 保存图片 filename = f"{picture_type}_{uuid.uuid4().hex[:16]}.jpg" filepath = img_dir / filename filepath.write_bytes(response.content) @@ -171,13 +165,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: else: data = response.json() - # 检查响应 if data.get("code") != 200: logger.error(f"获取图片失败: {data.get('msg')}") fail_count += 1 continue - # 获取图片 URL image_url = data.get("data") if not image_url: logger.error("响应中未找到图片 URL") @@ -186,7 +178,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.info(f"图片 URL: {image_url}") - # 下载图片到本地 logger.info("正在下载图片到本地...") image_response = await request_with_retry( "GET", @@ -195,7 +186,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: context=context, ) - # 保存图片 filename = f"{picture_type}_{uuid.uuid4().hex[:16]}.jpg" filepath = img_dir / filename filepath.write_bytes(image_response.content) @@ -214,7 +204,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception(f"获取图片失败: {e}") fail_count += 1 - # 如果没有获取到任何图片 if success_count == 0: return f"获取 {TYPE_NAMES[picture_type]} 图片失败,请稍后重试" diff --git a/src/Undefined/skills/tools/get_user_info/handler.py b/src/Undefined/skills/tools/get_user_info/handler.py index e8521458..5bb34fea 100644 --- a/src/Undefined/skills/tools/get_user_info/handler.py +++ b/src/Undefined/skills/tools/get_user_info/handler.py @@ -32,7 +32,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: return "获取用户信息功能不可用(OneBot 客户端未设置)" try: - # 使用 get_stranger_info 获取详细信息 user_info = await onebot_client.get_stranger_info(user_id) if not user_info: @@ -40,12 +39,10 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: result_parts = ["【QQ用户信息】"] - # 添加头像 URL (常用 API) result_parts.append( f"头像: http://q.qlogo.cn/headimg_dl?dst_uin={user_id}&spec=640" ) - # 处理性别 sex = user_info.get("sex") if sex == "male": user_info["sex"] = "男" @@ -59,8 +56,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if value is not None and value != "": result_parts.append(f"{display_name}: {value}") - # 如果有其他字段(取决于 OneBot 实现,如 NapCat/Go-CQHttp 可能有更多) - # 我们可以尝试输出一些常见的额外字段 extra_fields = { "remark": "备注", "signature": "签名", diff --git a/src/Undefined/skills/tools/python_interpreter/handler.py b/src/Undefined/skills/tools/python_interpreter/handler.py index 002bf560..f23365db 100644 --- a/src/Undefined/skills/tools/python_interpreter/handler.py +++ b/src/Undefined/skills/tools/python_interpreter/handler.py @@ -164,7 +164,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: libraries: list[str] = args.get("libraries") or [] send_files: list[str] = args.get("send_files") or [] - # 验证库名 for lib in libraries: if not _SAFE_LIB_PATTERN.match(lib): return f"错误: 无效的库名 '{lib}'。" @@ -173,12 +172,10 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: memory = MEMORY_LIMIT_WITH_LIBS if has_libs else MEMORY_LIMIT timeout = TIMEOUT_WITH_LIBS if has_libs else TIMEOUT - # 创建宿主机临时目录,绑定挂载到容器 /tmp host_tmpdir = tempfile.mkdtemp(prefix="pyinterp_") defer_cleanup = False try: - # 验证文件路径必须绑定到容器 /tmp,并且不能逃逸宿主机临时目录 for fpath in send_files: if _resolve_output_host_path(fpath, host_tmpdir) is None: return f"错误: 输出文件路径必须位于容器 /tmp 目录内: '{fpath}'" @@ -235,7 +232,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: timeout=exec_timeout, ) - # 构建结果 parts: list[str] = [] if exit_code == 0: parts.append( @@ -246,7 +242,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: f"代码执行失败 (退出代码: {exit_code}):\n{error_output}\n{response}" ) - # 执行成功时发送文件 if send_files and exit_code == 0: file_result = await _send_output_files(send_files, host_tmpdir, context) if file_result: diff --git a/src/Undefined/skills/tools/qq_like/handler.py b/src/Undefined/skills/tools/qq_like/handler.py index 87031020..3b635f4d 100644 --- a/src/Undefined/skills/tools/qq_like/handler.py +++ b/src/Undefined/skills/tools/qq_like/handler.py @@ -20,7 +20,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if target_user_id is None: return "请提供要点赞的目标QQ号(target_user_id参数)" - # 验证参数类型 try: target_user_id = int(target_user_id) times = int(times) @@ -37,7 +36,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: return "点赞功能不可用(回调函数未设置)" try: - # 调用点赞回调 await send_like_callback(target_user_id, times) if times == 1: @@ -49,7 +47,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.exception(f"点赞失败: {e}") error_msg = str(e) - # 根据错误消息提供更友好的提示 if "SVIP 上限" in error_msg: return "点赞失败:今日给同一好友的点赞数已达SVIP上限" elif "点赞失败" in error_msg: diff --git a/src/Undefined/skills/tools/task_progress/handler.py b/src/Undefined/skills/tools/task_progress/handler.py index 668037a0..9327c70b 100644 --- a/src/Undefined/skills/tools/task_progress/handler.py +++ b/src/Undefined/skills/tools/task_progress/handler.py @@ -61,11 +61,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if not isinstance(tasks_input, list) or len(tasks_input) == 0: return "tasks 必须是非空数组" - # 从 context 获取或初始化任务列表 task_store: list[dict[str, Any]] = context.get("_task_progress", []) if action == "plan": - # 创建/替换计划 new_tasks: list[dict[str, Any]] = [] for item in tasks_input: tid_raw = item.get("id") @@ -80,7 +78,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: status = "pending" new_tasks.append({"id": tid, "description": desc, "status": status}) - # 按 id 排序 new_tasks.sort(key=lambda t: t.get("id", 0)) task_store = new_tasks context["_task_progress"] = task_store @@ -88,11 +85,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: logger.info(f"[task_progress] 创建计划,共 {len(task_store)} 个步骤") return f"计划已创建\n{_format_task_list(task_store)}" - # action == "update" if not task_store: return "还没有任务计划,请先用 plan 动作创建" - # 构建 id -> index 映射 id_to_idx = {t["id"]: i for i, t in enumerate(task_store)} updated: list[int] = [] @@ -112,7 +107,6 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: return f"不存在 id={tid} 的步骤" task_store[idx]["status"] = new_status - # 允许更新描述 new_desc = item.get("description") if new_desc: task_store[idx]["description"] = new_desc diff --git a/src/Undefined/utils/paths.py b/src/Undefined/utils/paths.py index 99124f9d..05254147 100644 --- a/src/Undefined/utils/paths.py +++ b/src/Undefined/utils/paths.py @@ -2,6 +2,8 @@ from pathlib import Path +PACKAGE_ROOT = Path(__file__).resolve().parent.parent + DATA_DIR = Path("data") CACHE_DIR = DATA_DIR / "cache" RENDER_CACHE_DIR = CACHE_DIR / "render" diff --git a/src/Undefined/utils/render_cache.py b/src/Undefined/utils/render_cache.py index 2bf2afe8..12d7ed47 100644 --- a/src/Undefined/utils/render_cache.py +++ b/src/Undefined/utils/render_cache.py @@ -362,6 +362,7 @@ async def get_render_cache() -> HtmlRenderCache: 单例的 enabled / 容量由 ``[render.cache]`` 决定; 禁用时仍返回单例对象,但所有 get/put 立即短路。 """ + # global global _cache if _cache is not None: await _cache.initialize() @@ -387,6 +388,7 @@ async def get_render_cache() -> HtmlRenderCache: async def close_render_cache() -> None: """关停时调用:刷盘并丢弃单例。""" + # global global _cache cache = _cache if cache is None: @@ -399,5 +401,6 @@ async def close_render_cache() -> None: def reset_render_cache() -> None: """仅供测试使用:丢弃单例(不刷盘),下次调用重新加载。""" + # global global _cache _cache = None diff --git a/src/Undefined/utils/resources.py b/src/Undefined/utils/resources.py index b614032f..a8ce1186 100644 --- a/src/Undefined/utils/resources.py +++ b/src/Undefined/utils/resources.py @@ -8,10 +8,11 @@ import shutil import tempfile +from Undefined.utils.paths import PACKAGE_ROOT + def _candidate_paths(relative_path: str) -> list[Path]: - module_path = Path(__file__).resolve() - package_root = module_path.parents[1] + package_root = PACKAGE_ROOT candidates = [ Path.cwd() / relative_path, # If installed from wheel, extra files may live under site-packages/ diff --git a/src/Undefined/utils/sender_helpers.py b/src/Undefined/utils/sender_helpers.py new file mode 100644 index 00000000..0138f289 --- /dev/null +++ b/src/Undefined/utils/sender_helpers.py @@ -0,0 +1,116 @@ +"""MessageSender 辅助函数。""" + +from pathlib import Path +from typing import Any +from urllib.parse import unquote, urlsplit + +from Undefined.attachments import attachment_refs_to_text + + +def _extract_message_id(result: object) -> int | None: + if not isinstance(result, dict): + return None + + message_id = result.get("message_id") + if message_id is None: + # OneBot 实现差异:message_id 可能在顶层或 data 子对象。 + data = result.get("data") + if isinstance(data, dict): + message_id = data.get("message_id") + + try: + return int(message_id) if message_id is not None else None + except (TypeError, ValueError): + return None + + +def _format_size(size_bytes: int | None) -> str: + if size_bytes is None or size_bytes < 0: + return "未知大小" + if size_bytes < 1024: + return f"{size_bytes}B" + if size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f}KB" + if size_bytes < 1024 * 1024 * 1024: + return f"{size_bytes / 1024 / 1024:.2f}MB" + return f"{size_bytes / 1024 / 1024 / 1024:.2f}GB" + + +def _build_file_history_message(file_name: str, size_bytes: int | None) -> str: + return f"[文件] {file_name} ({_format_size(size_bytes)})" + + +def _append_attachment_refs( + history_content: str, + attachments: list[dict[str, str]] | None, +) -> str: + refs_text = attachment_refs_to_text(attachments or []) + if not refs_text or refs_text in history_content: + return history_content + if not history_content: + return refs_text + return f"{history_content}\n{refs_text}" + + +def _merge_attachment_refs( + *groups: list[dict[str, str]] | None, +) -> list[dict[str, str]]: + merged: list[dict[str, str]] = [] + seen_uids: set[str] = set() + for group in groups: + for item in group or []: + uid = str(item.get("uid", "") or "").strip() + if uid and uid in seen_uids: + continue + if uid: + seen_uids.add(uid) + merged.append(item) + return merged + + +def _iter_segments_deep(value: object) -> list[dict[str, Any]]: + """递归收集消息段,用于合并转发中的本地媒体登记。""" + segments: list[dict[str, Any]] = [] + if isinstance(value, dict): + type_value = value.get("type") + data = value.get("data") + if type_value is not None and isinstance(data, dict): + segments.append(value) + content = data.get("content") + if isinstance(content, (list, dict)): + segments.extend(_iter_segments_deep(content)) + else: + for child in value.values(): + if isinstance(child, (list, dict)): + segments.extend(_iter_segments_deep(child)) + elif isinstance(value, list): + for child in value: + if isinstance(child, (list, dict)): + segments.extend(_iter_segments_deep(child)) + return segments + + +def _local_path_from_segment_source(source: Any) -> Path | None: + raw_source = str(source or "").strip() + if not raw_source: + return None + lowered = raw_source.lower() + if lowered.startswith(("http://", "https://", "base64://")): + return None + if lowered.startswith("file://"): + parsed = urlsplit(raw_source) + path = Path(unquote(parsed.path)).expanduser() + else: + path = Path(raw_source).expanduser() + if not path.is_absolute(): + path = (Path.cwd() / path).resolve() + else: + path = path.resolve() + return path if path.is_file() else None + + +def _get_file_size(file_path: str) -> int | None: + try: + return Path(file_path).stat().st_size + except OSError: + return None diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index 371c46b9..fa26c0d8 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -65,8 +65,9 @@ def _load_top_level_agent_names(root: Path) -> set[str]: def _get_local_agent_tool_names() -> set[str]: - skills_root = Path(__file__).resolve().parents[2] / "skills" - return _load_top_level_agent_names(skills_root / "agents") + from Undefined.utils.paths import PACKAGE_ROOT + + return _load_top_level_agent_names(PACKAGE_ROOT / "skills" / "agents") def _tool_invoke_proxy_timeout_seconds(tool_name: str) -> float | None: diff --git a/tests/test_ai_client_setup_paths.py b/tests/test_ai_client_setup_paths.py new file mode 100644 index 00000000..baac7c81 --- /dev/null +++ b/tests/test_ai_client_setup_paths.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import Undefined +from Undefined.services.commands.registry import CommandRegistry +from Undefined.skills.agents import AgentRegistry +from Undefined.skills.pipelines.registry import PipelineRegistry +from Undefined.skills.tools import ToolRegistry +from Undefined.utils.paths import PACKAGE_ROOT + + +def _skill_dirs(base: Path) -> set[str]: + if not base.is_dir(): + return set() + return { + item.name + for item in base.iterdir() + if item.is_dir() and (item / "config.json").exists() + } + + +def _toolset_tool_names(base: Path) -> set[str]: + if not base.is_dir(): + return set() + names: set[str] = set() + for config_path in base.rglob("config.json"): + rel = config_path.parent.relative_to(base) + if len(rel.parts) == 2: + names.add(".".join(rel.parts)) + return names + + +def test_package_root_matches_undefined_package_directory() -> None: + assert PACKAGE_ROOT == Path(Undefined.__file__).resolve().parent + + +def test_package_root_contains_skills_directories() -> None: + assert (PACKAGE_ROOT / "skills" / "tools").is_dir() + assert (PACKAGE_ROOT / "skills" / "agents").is_dir() + assert (PACKAGE_ROOT / "skills" / "anthropic_skills").is_dir() + assert (PACKAGE_ROOT / "skills" / "toolsets").is_dir() + assert (PACKAGE_ROOT / "skills" / "pipelines").is_dir() + assert (PACKAGE_ROOT / "skills" / "commands").is_dir() + + +def test_setup_wrong_path_does_not_exist() -> None: + """Regression: ai/client/setup.py used parents[1] and pointed at ai/skills.""" + wrong_root = Path(__file__).resolve().parents[2] / "src" / "Undefined" / "ai" + assert not (wrong_root / "skills" / "tools").exists() + + +def test_tool_registry_loads_all_skill_directories() -> None: + tools_dir = PACKAGE_ROOT / "skills" / "tools" + toolsets_dir = PACKAGE_ROOT / "skills" / "toolsets" + registry = ToolRegistry(tools_dir) + basic = [name for name in registry._items if "." not in name] + toolsets = [ + name for name in registry._items if "." in name and not name.startswith("mcp.") + ] + + basic_dirs = _skill_dirs(tools_dir) + toolset_names = _toolset_tool_names(toolsets_dir) + + assert len(basic) == len(basic_dirs) + assert len(toolsets) == len(toolset_names) + assert len(registry._items) == len(basic) + len(toolsets) + assert set(basic) == basic_dirs + assert set(toolsets) == toolset_names + + +def test_all_registered_tools_import_handlers() -> None: + registry = ToolRegistry(PACKAGE_ROOT / "skills" / "tools") + errors: list[str] = [] + + for name, item in sorted(registry._items.items()): + try: + registry._load_handler_for_item(item) + if item.handler is None: + errors.append(f"{name}: handler is None") + except Exception as exc: + errors.append(f"{name}: {exc}") + + assert errors == [] + + +def test_agent_registry_loads_expected_agents() -> None: + agents_dir = PACKAGE_ROOT / "skills" / "agents" + registry = AgentRegistry(agents_dir) + assert len(registry._items) == len(_skill_dirs(agents_dir)) + assert set(registry._items) == _skill_dirs(agents_dir) + + +def test_command_registry_loads_expected_commands() -> None: + commands_dir = PACKAGE_ROOT / "skills" / "commands" + registry = CommandRegistry(commands_dir) + registry.load_commands() + command_dirs = { + item.name + for item in commands_dir.iterdir() + if item.is_dir() and (item / "handler.py").exists() + } + assert len(registry._commands) == len(command_dirs) + assert set(registry._commands) == command_dirs + + +def test_pipeline_registry_loads_expected_pipelines() -> None: + async def _load() -> PipelineRegistry: + registry = PipelineRegistry(PACKAGE_ROOT / "skills" / "pipelines") + await registry.load_items_async() + return registry + + registry = asyncio.run(_load()) + assert set(registry._items) == {"arxiv", "bilibili", "github"} diff --git a/tests/test_ai_coordinator_queue_routing.py b/tests/test_ai_coordinator_queue_routing.py index 194e67c7..7a78a73f 100644 --- a/tests/test_ai_coordinator_queue_routing.py +++ b/tests/test_ai_coordinator_queue_routing.py @@ -6,8 +6,8 @@ import pytest -from Undefined.services import ai_coordinator as ai_coordinator_module from Undefined.services.ai_coordinator import AICoordinator +from Undefined.services.coordinator import group as coordinator_group_module @pytest.mark.asyncio @@ -245,7 +245,9 @@ async def _fake_ask(*_args: Any, **kwargs: Any) -> str: coordinator.scheduler = SimpleNamespace() monkeypatch.setattr( - ai_coordinator_module, "collect_context_resources", lambda _vars: {} + coordinator_group_module, + "collect_context_resources", + lambda _vars: {}, ) await coordinator._execute_auto_reply( diff --git a/tests/test_cli_startup_compat.py b/tests/test_cli_startup_compat.py new file mode 100644 index 00000000..82f0f557 --- /dev/null +++ b/tests/test_cli_startup_compat.py @@ -0,0 +1,128 @@ +"""CLI 启动兼容性与根包 import 行为测试(Phase 0 起持续有效)。""" + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path +from typing import Iterator + +import pytest + +_MINIMAL_CONFIG = """ +[onebot] +ws_url = "ws://127.0.0.1:3999" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "gpt-test" +""" + + +@pytest.fixture +def isolated_undefined_modules() -> Iterator[None]: + """测试前后清理 Undefined 相关 sys.modules 条目。""" + prefix = "Undefined" + saved = { + k: v + for k, v in sys.modules.items() + if k == prefix or k.startswith(f"{prefix}.") + } + for key in saved: + del sys.modules[key] + yield + for key in list(sys.modules): + if key == prefix or key.startswith(f"{prefix}."): + del sys.modules[key] + sys.modules.update(saved) + + +@pytest.fixture +def reset_config_singleton() -> Iterator[None]: + """重置 config 模块全局单例,避免测试间污染。""" + import Undefined.config as config_module + + saved_config = config_module._config + saved_manager = config_module._config_manager + config_module._config = None + config_module._config_manager = None + yield + config_module._config = saved_config + config_module._config_manager = saved_manager + + +def test_entry_point_undefined_main_run_importable() -> None: + from Undefined.main import run + + assert callable(run) + + +def test_entry_point_undefined_webui_run_importable() -> None: + from Undefined.webui import run + + assert callable(run) + + +def test_import_undefined_does_not_eagerly_load_onebot_or_handlers( + isolated_undefined_modules: None, +) -> None: + import Undefined # noqa: F401 + + assert "Undefined.onebot" not in sys.modules + assert "Undefined.handlers" not in sys.modules + assert "Undefined.main" not in sys.modules + + +def test_get_config_reads_config_toml_from_cwd( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + reset_config_singleton: None, +) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / "config.toml").write_text(_MINIMAL_CONFIG, encoding="utf-8") + + import Undefined.config as config_module + + config_module._config = None + config_module._config_manager = None + + cfg = config_module.get_config(strict=False) + assert cfg.onebot_ws_url == "ws://127.0.0.1:3999" + assert cfg.chat_model.model_name == "gpt-test" + + +def test_root_get_config_same_as_subpackage(isolated_undefined_modules: None) -> None: + import Undefined + + from Undefined.config import get_config as subpackage_get_config + + assert Undefined.get_config is subpackage_get_config + + +def test_import_undefined_does_not_import_webui_app( + isolated_undefined_modules: None, +) -> None: + import Undefined # noqa: F401 + + assert "Undefined.webui.app" not in sys.modules + + +def test_webui_run_deferred_import(monkeypatch: pytest.MonkeyPatch) -> None: + """webui.run 应延迟加载 app,import webui 包本身不拉起重型依赖。""" + import Undefined.webui as webui_module + + original_import = importlib.import_module + app_imported = False + + def tracking_import(name: str, package: object | None = None) -> object: + nonlocal app_imported + if name == "Undefined.webui.app": + app_imported = True + package_name = package if isinstance(package, str) else None + return original_import(name, package_name) + + monkeypatch.setattr(importlib, "import_module", tracking_import) + importlib.reload(webui_module) + + assert not app_imported + assert callable(webui_module.run) diff --git a/tests/test_config_env_only.py b/tests/test_config_env_only.py new file mode 100644 index 00000000..42049d2a --- /dev/null +++ b/tests/test_config_env_only.py @@ -0,0 +1,61 @@ +from __future__ import annotations + + +import pytest + +from Undefined.config import Config + + +def test_env_only_chat_model_fields(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ONEBOT_WS_URL", "ws://env-only:3001") + monkeypatch.setenv("CHAT_MODEL_API_URL", "https://env.example/v1") + monkeypatch.setenv("CHAT_MODEL_API_KEY", "env-key") + monkeypatch.setenv("CHAT_MODEL_NAME", "env-chat") + monkeypatch.setenv("VISION_MODEL_API_URL", "https://env.example/v1") + monkeypatch.setenv("VISION_MODEL_API_KEY", "env-key") + monkeypatch.setenv("VISION_MODEL_NAME", "env-vision") + monkeypatch.setenv("AGENT_MODEL_API_URL", "https://env.example/v1") + monkeypatch.setenv("AGENT_MODEL_API_KEY", "env-key") + monkeypatch.setenv("AGENT_MODEL_NAME", "env-agent") + + cfg = Config.from_mapping({}, strict=False) + assert cfg.onebot_ws_url == "ws://env-only:3001" + assert cfg.chat_model.model_name == "env-chat" + assert cfg.vision_model.model_name == "env-vision" + assert cfg.agent_model.model_name == "env-agent" + + +def test_env_overridden_by_mapping(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CHAT_MODEL_NAME", "from-env") + cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://x"}, + "models": { + "chat": { + "api_url": "u", + "api_key": "k", + "model_name": "from-toml", + }, + "vision": {"api_url": "u", "api_key": "k", "model_name": "v"}, + "agent": {"api_url": "u", "api_key": "k", "model_name": "a"}, + }, + }, + strict=False, + ) + assert cfg.chat_model.model_name == "from-toml" + + +def test_http_proxy_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("HTTP_PROXY", "http://127.0.0.1:7890") + cfg = Config.from_mapping( + { + "onebot": {"ws_url": "ws://x"}, + "models": { + "chat": {"api_url": "u", "api_key": "k", "model_name": "m"}, + "vision": {"api_url": "u", "api_key": "k", "model_name": "v"}, + "agent": {"api_url": "u", "api_key": "k", "model_name": "a"}, + }, + }, + strict=False, + ) + assert cfg.http_proxy == "http://127.0.0.1:7890" diff --git a/tests/test_config_env_registry.py b/tests/test_config_env_registry.py new file mode 100644 index 00000000..385f3351 --- /dev/null +++ b/tests/test_config_env_registry.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from Undefined.config.env_registry import ( + ENV_ALTERNATES, + ENV_REGISTRY, + all_env_mappings, + env_key_for_path, +) + + +def test_registry_contains_core_paths() -> None: + assert ENV_REGISTRY[("core", "bot_qq")] == "BOT_QQ" + assert ENV_REGISTRY[("onebot", "ws_url")] == "ONEBOT_WS_URL" + assert ENV_REGISTRY[("models", "chat", "api_url")] == "CHAT_MODEL_API_URL" + + +def test_env_key_for_path_lookup() -> None: + assert env_key_for_path(("models", "agent", "model_name")) == "AGENT_MODEL_NAME" + assert env_key_for_path(("nonexistent",)) is None + + +def test_all_env_mappings_is_copy() -> None: + snapshot = all_env_mappings() + snapshot[("fake",)] = "FAKE" + assert ("fake",) not in ENV_REGISTRY + + +def test_alternate_env_keys_documented() -> None: + assert "HTTP_PROXY" in ENV_ALTERNATES + assert "EASTER_EGG_CALL_MESSAGE_MODE" in ENV_ALTERNATES + + +def test_registry_has_model_context_window_entries() -> None: + assert ("models", "chat", "context_window_tokens") in ENV_REGISTRY diff --git a/tests/test_config_from_mapping.py b/tests/test_config_from_mapping.py new file mode 100644 index 00000000..2e28f806 --- /dev/null +++ b/tests/test_config_from_mapping.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from Undefined.config import Config, set_config + + +_MINIMAL_MAPPING = { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "gpt-test", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "vision-test", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "agent-test", + }, + }, +} + + +def test_from_mapping_builds_without_toml() -> None: + cfg = Config.from_mapping(_MINIMAL_MAPPING, strict=False) + assert cfg.onebot_ws_url == "ws://127.0.0.1:3001" + assert cfg.chat_model.model_name == "gpt-test" + + +def test_builder_with_mapping() -> None: + cfg = Config.builder().with_mapping(_MINIMAL_MAPPING).build(strict=False) + assert cfg.agent_model.model_name == "agent-test" + + +def test_set_config_injects_singleton(monkeypatch: pytest.MonkeyPatch) -> None: + import Undefined.config as config_pkg + + monkeypatch.setattr(config_pkg, "_config", None) + monkeypatch.setattr(config_pkg, "_config_manager", None) + cfg = Config.from_mapping(_MINIMAL_MAPPING, strict=False) + set_config(cfg) + assert config_pkg.get_config(strict=False) is cfg + assert config_pkg.get_config_manager().load(strict=False) is cfg + + +def test_from_mapping_matches_load(tmp_path: Path) -> None: + toml = """ +[onebot] +ws_url = "ws://127.0.0.1:3001" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "gpt-test" +[models.vision] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "vision-test" +[models.agent] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "agent-test" +""" + path = tmp_path / "config.toml" + path.write_text(toml, encoding="utf-8") + from_file = Config.load(path, strict=False) + from_map = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "gpt-test", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "vision-test", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "agent-test", + }, + }, + }, + strict=False, + ) + assert from_file.chat_model.model_name == from_map.chat_model.model_name + assert from_file.onebot_ws_url == from_map.onebot_ws_url diff --git a/tests/test_end_defer_co_call.py b/tests/test_end_defer_co_call.py new file mode 100644 index 00000000..70730f13 --- /dev/null +++ b/tests/test_end_defer_co_call.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from Undefined.ai.client import AIClient +from Undefined.ai.tooling import END_CO_CALL_REJECT_CONTENT +from Undefined.config.models import ChatModelConfig + + +def _build_minimal_ai_client( + *, + execute_tool: Any, + llm_responses: list[dict[str, Any]], +) -> Any: + client: Any = object.__new__(AIClient) + client.runtime_config = cast( + Any, + SimpleNamespace( + log_thinking=False, + ai_request_max_retries=0, + missing_tool_call_retries=0, + ), + ) + client._prompt_builder = cast( + Any, + SimpleNamespace( + build_messages=AsyncMock( + return_value=[{"role": "user", "content": "hello"}] + ), + end_summaries=[], + ), + ) + client.tool_manager = cast( + Any, + SimpleNamespace( + get_openai_tools=lambda: [], + execute_tool=execute_tool, + ), + ) + client._filter_tools_for_runtime_config = lambda tools: tools + client._get_runtime_config = cast(Any, lambda: client.runtime_config) + client.model_selector = cast(Any, SimpleNamespace(wait_ready=AsyncMock())) + client.chat_config = ChatModelConfig( + api_url="https://api.openai.com/v1", + api_key="sk-test", + model_name="chat-model", + max_tokens=1024, + ) + client._find_chat_config_by_name = lambda _name: client.chat_config + client._search_wrapper = None + client._end_summary_storage = cast(Any, None) + client._send_private_message_callback = None + client._send_image_callback = None + client.memory_storage = None + client._knowledge_manager = None + client._cognitive_service = None + client._meme_service = None + client._crawl4ai_capabilities = SimpleNamespace( + available=False, + error=None, + proxy_config_available=False, + ) + + submit_calls: list[list[dict[str, Any]]] = [] + + async def _submit_queued_llm_call( + *, + messages: list[dict[str, Any]], + **kwargs: Any, + ) -> dict[str, Any]: + submit_calls.append(list(messages)) + index = len(submit_calls) - 1 + if index >= len(llm_responses): + raise RuntimeError("unexpected extra llm call") + return llm_responses[index] + + client.submit_queued_llm_call = AsyncMock(side_effect=_submit_queued_llm_call) + client._submit_calls = submit_calls + return client + + +@pytest.mark.asyncio +async def test_ai_ask_rejects_end_when_co_called_with_send_message() -> None: + execute_calls: list[str] = [] + + async def _execute_tool( + name: str, args: dict[str, Any], ctx: dict[str, Any] + ) -> str: + execute_calls.append(name) + if name == "send_message": + ctx["message_sent_this_turn"] = True + return "消息已发送(message_id=1)" + if name == "end": + ctx["conversation_ended"] = True + return "对话已结束" + return "ok" + + client = _build_minimal_ai_client( + execute_tool=_execute_tool, + llm_responses=[ + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_send", + "function": { + "name": "send_message", + "arguments": '{"message":"喵"}', + }, + }, + { + "id": "call_end", + "function": { + "name": "end", + "arguments": "{}", + }, + }, + ], + } + } + ], + }, + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_end_only", + "function": { + "name": "end", + "arguments": "{}", + }, + } + ], + } + } + ], + }, + ], + ) + + result = await AIClient.ask(client, "hello") + + assert result == "" + assert execute_calls == ["send_message", "end"] + assert len(client._submit_calls) == 2 + + end_tool_messages = [ + m + for m in client._submit_calls[1] + if m.get("role") == "tool" and m.get("tool_call_id") == "call_end" + ] + assert len(end_tool_messages) == 1 + assert end_tool_messages[0]["content"] == END_CO_CALL_REJECT_CONTENT + + +@pytest.mark.asyncio +async def test_ai_ask_rejects_end_when_other_tool_failed() -> None: + execute_calls: list[str] = [] + + async def _execute_tool( + name: str, args: dict[str, Any], ctx: dict[str, Any] + ) -> str: + execute_calls.append(name) + if name == "send_message": + raise RuntimeError("send failed") + if name == "end": + ctx["conversation_ended"] = True + return "对话已结束" + return "ok" + + client = _build_minimal_ai_client( + execute_tool=_execute_tool, + llm_responses=[ + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_send", + "function": { + "name": "send_message", + "arguments": '{"message":"喵"}', + }, + }, + { + "id": "call_end", + "function": { + "name": "end", + "arguments": "{}", + }, + }, + ], + } + } + ], + }, + { + "choices": [ + { + "message": { + "content": "", + "tool_calls": [ + { + "id": "call_end_only", + "function": { + "name": "end", + "arguments": "{}", + }, + } + ], + } + } + ], + }, + ], + ) + + result = await AIClient.ask(client, "hello") + + assert result == "" + assert execute_calls == ["send_message", "end"] + assert len(client._submit_calls) == 2 + + end_tool_messages = [ + m + for m in client._submit_calls[1] + if m.get("role") == "tool" and m.get("tool_call_id") == "call_end" + ] + assert end_tool_messages[0]["content"] == END_CO_CALL_REJECT_CONTENT diff --git a/tests/test_handlers_meme_annotation.py b/tests/test_handlers_meme_annotation.py index a8ff4d62..1b35b322 100644 --- a/tests/test_handlers_meme_annotation.py +++ b/tests/test_handlers_meme_annotation.py @@ -1,4 +1,4 @@ -"""测试 handlers.py 中的表情包自动匹配功能""" +"""测试 handlers/message_flow 中的表情包自动匹配功能""" from __future__ import annotations diff --git a/tests/test_llm_request_params.py b/tests/test_llm_request_params.py index e34227b6..a917e6fa 100644 --- a/tests/test_llm_request_params.py +++ b/tests/test_llm_request_params.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib from types import SimpleNamespace from typing import Any, cast from unittest.mock import AsyncMock @@ -339,10 +340,11 @@ async def test_chat_request_auto_sets_prompt_cache_key_from_request_context() -> call_type="chat", ) + group_scope = hashlib.sha1(b"12345", usedforsecurity=False).hexdigest()[:8] assert fake_client.chat.completions.last_kwargs is not None assert ( fake_client.chat.completions.last_kwargs["prompt_cache_key"] - == "pc:gpt-test:chat:group:12345" + == f"pc:gpt-test:chat:group:{group_scope}" ) await requester._http_client.aclose() @@ -1130,6 +1132,7 @@ async def test_ai_client_request_model_prefetch_keeps_transport_count_from_calle "tool_manager", SimpleNamespace(maybe_merge_agent_tools=lambda _call_type, tools: tools), ) + setattr(client, "_filter_tools_for_runtime_config", lambda tools: tools) monkeypatch.setattr(client, "_maybe_prefetch_tools", prefetch_mock) cfg = ChatModelConfig( diff --git a/tests/test_llm_retry_suppression.py b/tests/test_llm_retry_suppression.py index 2012a693..c58a1746 100644 --- a/tests/test_llm_retry_suppression.py +++ b/tests/test_llm_retry_suppression.py @@ -8,7 +8,7 @@ import pytest -from Undefined.ai.client import AIClient +from Undefined.ai.client import AIClient, MISSING_TOOL_CALL_RETRY_HINT from Undefined.ai.queue_budget import compute_queued_llm_timeout_seconds from Undefined.config.models import ( AgentModelConfig, @@ -21,8 +21,11 @@ @pytest.mark.asyncio -async def test_ai_ask_reraises_queued_llm_error() -> None: +async def test_ai_ask_suppresses_queued_llm_error_when_retries_exhausted() -> None: client: Any = object.__new__(AIClient) + client.runtime_config = cast( + Any, SimpleNamespace(log_thinking=False, ai_request_max_retries=0) + ) client._prompt_builder = cast( Any, SimpleNamespace( @@ -34,9 +37,7 @@ async def test_ai_ask_reraises_queued_llm_error() -> None: ) client.tool_manager = cast(Any, SimpleNamespace(get_openai_tools=lambda: [])) client._filter_tools_for_runtime_config = lambda tools: tools - client._get_runtime_config = cast( - Any, lambda: cast(Any, SimpleNamespace(log_thinking=False)) - ) + client._get_runtime_config = cast(Any, lambda: client.runtime_config) client.model_selector = cast(Any, SimpleNamespace(wait_ready=AsyncMock())) client.chat_config = ChatModelConfig( api_url="https://api.openai.com/v1", @@ -60,8 +61,9 @@ async def test_ai_ask_reraises_queued_llm_error() -> None: proxy_config_available=False, ) - with pytest.raises(RuntimeError, match="boom"): - await AIClient.ask(client, "hello") + result = await AIClient.ask(client, "hello") + + assert result == "" @pytest.mark.asyncio @@ -185,13 +187,22 @@ async def test_ai_ask_limits_missing_tool_call_retries() -> None: max_tokens=1024, ) client._find_chat_config_by_name = lambda _name: client.chat_config - client.submit_queued_llm_call = AsyncMock( - side_effect=[ - {"choices": [{"message": {"content": "plain 1", "tool_calls": []}}]}, - {"choices": [{"message": {"content": "plain 2", "tool_calls": []}}]}, - {"choices": [{"message": {"content": "plain 3", "tool_calls": []}}]}, - ] - ) + llm_responses = [ + {"choices": [{"message": {"content": "plain 1", "tool_calls": []}}]}, + {"choices": [{"message": {"content": "plain 2", "tool_calls": []}}]}, + {"choices": [{"message": {"content": "plain 3", "tool_calls": []}}]}, + ] + submit_calls: list[list[dict[str, Any]]] = [] + + async def _submit_queued_llm_call( + *, + messages: list[dict[str, Any]], + **kwargs: Any, + ) -> dict[str, Any]: + submit_calls.append(list(messages)) + return llm_responses[len(submit_calls) - 1] + + client.submit_queued_llm_call = AsyncMock(side_effect=_submit_queued_llm_call) client._search_wrapper = None client._end_summary_storage = cast(Any, None) client._send_private_message_callback = None @@ -213,6 +224,13 @@ async def test_ai_ask_limits_missing_tool_call_retries() -> None: assert cast(AsyncMock, client.submit_queued_llm_call).await_count == 3 send_message.assert_awaited_once_with("plain 3") + second_call_messages = submit_calls[1] + assert second_call_messages[-2:] == [ + {"role": "assistant", "content": "plain 1"}, + {"role": "user", "content": MISSING_TOOL_CALL_RETRY_HINT}, + ] + assert "send_message" not in MISSING_TOOL_CALL_RETRY_HINT + @pytest.mark.asyncio async def test_agent_runner_reraises_queued_llm_error(tmp_path: Path) -> None: diff --git a/tests/test_package_layout.py b/tests/test_package_layout.py new file mode 100644 index 00000000..66833a0f --- /dev/null +++ b/tests/test_package_layout.py @@ -0,0 +1,32 @@ +"""打包布局与模块结构回归测试。""" + +from __future__ import annotations + +from pathlib import Path + +import Undefined + + +def test_py_typed_marker_exists() -> None: + pkg_root = Path(Undefined.__file__).resolve().parent + marker = pkg_root / "py.typed" + assert marker.is_file(), "src/Undefined/py.typed must exist for PEP 561" + + +def test_py_typed_declared_in_pyproject() -> None: + repo_root = Path(__file__).resolve().parents[1] + pyproject = (repo_root / "pyproject.toml").read_text(encoding="utf-8") + assert 'src/Undefined/py.typed" = "Undefined/py.typed"' in pyproject + + +def test_no_shadowed_monolith_modules() -> None: + """禁止 foo.py 与 foo/ 包目录并存(会导致一份实现成为不可达死代码)。""" + pkg_root = Path(Undefined.__file__).resolve().parent + violations: list[str] = [] + for path in pkg_root.rglob("*.py"): + if path.name == "__init__.py": + continue + package_dir = path.with_suffix("") + if package_dir.is_dir() and (package_dir / "__init__.py").is_file(): + violations.append(str(path.relative_to(pkg_root))) + assert violations == [], f"shadowed monolith modules: {violations}" diff --git a/tests/test_public_api_imports.py b/tests/test_public_api_imports.py new file mode 100644 index 00000000..c1b97d99 --- /dev/null +++ b/tests/test_public_api_imports.py @@ -0,0 +1,150 @@ +"""根包公共 API 与向后兼容 import 路径测试(Phase 3 API-FACADE 启用)。""" + +from __future__ import annotations + +import importlib +from typing import Any + +import pytest + +# 根包 lazy re-export 符号(见 docs/python-api.md) +_ROOT_EXPORTS: tuple[str, ...] = ( + "Config", + "get_config", + "set_config", + "AIClient", + "ToolRegistry", + "AgentRegistry", + "PipelineRegistry", + "BaseRegistry", + "AnthropicSkillRegistry", + "CognitiveService", + "KnowledgeManager", + "MemeService", + "AttachmentRegistry", + "RuntimeAPIServer", + "RuntimeAPIContext", +) + +# 拆分后须继续可用的 shim / 深层 import 路径 +_BACKWARD_COMPAT_PATHS: tuple[tuple[str, str], ...] = ( + ("Undefined.config.loader", "Config"), + ("Undefined.ai.client", "AIClient"), + ("Undefined.attachments", "AttachmentRegistry"), + ("Undefined.handlers", "MessageHandler"), + ("Undefined.onebot", "OneBotClient"), + ("Undefined.skills.tools", "ToolRegistry"), + ("Undefined.skills.agents", "AgentRegistry"), + ("Undefined.skills.pipelines.registry", "PipelineRegistry"), + ("Undefined.skills.registry", "BaseRegistry"), + ("Undefined.skills.anthropic_skills", "AnthropicSkillRegistry"), + ("Undefined.cognitive.service", "CognitiveService"), + ("Undefined.knowledge.manager", "KnowledgeManager"), + ("Undefined.memes.service", "MemeService"), + ("Undefined.api.app", "RuntimeAPIServer"), + ("Undefined.api._context", "RuntimeAPIContext"), +) + + +@pytest.mark.parametrize("symbol", _ROOT_EXPORTS) +def test_root_package_exports(symbol: str) -> None: + import Undefined + + assert hasattr(Undefined, symbol), f"Undefined.{symbol} missing from root exports" + getattr(Undefined, symbol) + + +def test_root_package_all_matches_exports() -> None: + import Undefined + + expected = {"__version__", *_ROOT_EXPORTS} + assert set(Undefined.__all__) == expected + + +def test_root_package_lazy_import_does_not_load_cli_modules() -> None: + import sys + + saved_modules = { + name: module + for name, module in sys.modules.items() + if name == "Undefined" or name.startswith("Undefined.") + } + try: + for name in saved_modules: + del sys.modules[name] + + import Undefined # noqa: F401 + + assert "Undefined.onebot" not in sys.modules + assert "Undefined.handlers" not in sys.modules + assert "Undefined.main" not in sys.modules + finally: + for name in list(sys.modules): + if ( + name == "Undefined" or name.startswith("Undefined.") + ) and name not in saved_modules: + del sys.modules[name] + sys.modules.update(saved_modules) + + +@pytest.mark.parametrize(("module_path", "symbol"), _BACKWARD_COMPAT_PATHS) +def test_backward_compat_import_path(module_path: str, symbol: str) -> None: + module = importlib.import_module(module_path) + assert hasattr(module, symbol), f"{module_path}.{symbol} missing" + getattr(module, symbol) + + +def test_set_config_not_used_by_default_get_config( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """库嵌入 set_config() 与 CLI get_config() 路径隔离(CONFIG Phase 2)。""" + import Undefined.config as config_module + + from Undefined.config import Config, set_config + + config_module._config = None + config_module._config_manager = None + + monkeypatch.chdir(tmp_path) + (tmp_path / "config.toml").write_text( + """ +[onebot] +ws_url = "ws://127.0.0.1:3999" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "from-file" +""", + encoding="utf-8", + ) + + cfg = config_module.get_config(strict=False) + assert cfg.chat_model.model_name == "from-file" + + injected = Config.from_mapping( + { + "onebot": {"ws_url": "ws://127.0.0.1:3001"}, + "models": { + "chat": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "injected", + }, + "vision": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "vision-test", + }, + "agent": { + "api_url": "https://api.example/v1", + "api_key": "sk-test", + "model_name": "agent-test", + }, + }, + }, + strict=False, + ) + set_config(injected) + assert config_module.get_config(strict=False) is injected + assert config_module.get_config_manager().load(strict=False) is injected diff --git a/uv.lock b/uv.lock index 3459b092..dd09c6b9 100644 --- a/uv.lock +++ b/uv.lock @@ -4626,7 +4626,7 @@ wheels = [ [[package]] name = "undefined-bot" -version = "3.4.2" +version = "3.5.0" source = { editable = "." } dependencies = [ { name = "aiofiles" },