diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index e08cdc5157..9383bb3fce 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -9,6 +9,7 @@ from astrbot.api.message_components import At, Image, Plain from astrbot.api.platform import MessageType from astrbot.api.provider import LLMResponse, Provider, ProviderRequest +from astrbot.core.agent.message import TextPart from astrbot.core.astrbot_config_mgr import AstrBotConfigManager """ @@ -17,6 +18,9 @@ class LongTermMemory: + DEFAULT_MAX_GROUP_MESSAGES = 50 + DEFAULT_GROUP_ICL_TOKEN_BUDGET = 4000 + def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: self.acm = acm self.context = context @@ -29,7 +33,19 @@ def cfg(self, event: AstrMessageEvent): max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"]) except BaseException as e: logger.error(e) - max_cnt = 300 + max_cnt = self.DEFAULT_MAX_GROUP_MESSAGES + max_cnt = max(1, max_cnt) + try: + group_icl_token_budget = int( + cfg["provider_ltm_settings"].get( + "group_icl_token_budget", + self.DEFAULT_GROUP_ICL_TOKEN_BUDGET, + ) + ) + except BaseException as e: + logger.error(e) + group_icl_token_budget = self.DEFAULT_GROUP_ICL_TOKEN_BUDGET + group_icl_token_budget = max(1, group_icl_token_budget) image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] image_caption_provider_id = cfg["provider_ltm_settings"].get( "image_caption_provider_id" @@ -45,6 +61,7 @@ def cfg(self, event: AstrMessageEvent): ar_whitelist = active_reply.get("whitelist", []) ret = { "max_cnt": max_cnt, + "group_icl_token_budget": group_icl_token_budget, "image_caption": image_caption, "image_caption_prompt": image_caption_prompt, "image_caption_provider_id": image_caption_provider_id, @@ -56,6 +73,74 @@ def cfg(self, event: AstrMessageEvent): } return ret + @staticmethod + def _estimate_text_tokens(text: str) -> int: + chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) + other_count = len(text) - chinese_count + return int(chinese_count * 0.6 + other_count * 0.3) + + def _trim_text_to_token_budget(self, text: str, token_budget: int) -> str: + marker = "[truncated]\n" + marker_tokens = self._estimate_text_tokens(marker) + if self._estimate_text_tokens(text) <= token_budget: + return text + if token_budget <= marker_tokens: + return marker.strip() + + low = 0 + high = len(text) + best = "" + target_budget = token_budget - marker_tokens + while low <= high: + mid = (low + high) // 2 + candidate = text[-mid:] if mid else "" + if self._estimate_text_tokens(candidate) <= target_budget: + best = candidate + low = mid + 1 + else: + high = mid - 1 + result = f"{marker}{best}" + while result and self._estimate_text_tokens(result) > token_budget: + result = result[:-1] + return result + + def _build_chats_context( + self, + chats: list[str], + token_budget: int, + ) -> tuple[str, int, int]: + separator = "\n---\n" + separator_tokens = self._estimate_text_tokens(separator) + selected: list[str] = [] + total_tokens = 0 + + for chat in reversed(chats): + chat_tokens = self._estimate_text_tokens(chat) + extra_tokens = chat_tokens + (separator_tokens if selected else 0) + if selected and total_tokens + extra_tokens > token_budget: + break + if not selected and chat_tokens > token_budget: + trimmed = self._trim_text_to_token_budget(chat, token_budget) + return trimmed, len(chats) - 1, self._estimate_text_tokens(trimmed) + selected.append(chat) + total_tokens += extra_tokens + + selected.reverse() + omitted = len(chats) - len(selected) + chats_str = separator.join(selected) + if omitted > 0: + omitted_notice = ( + f"[{omitted} earlier group messages omitted due to token budget]" + ) + chats_str = f"{omitted_notice}{separator}{chats_str}" + total_tokens += ( + self._estimate_text_tokens(omitted_notice) + separator_tokens + ) + if total_tokens > token_budget: + chats_str = self._trim_text_to_token_budget(chats_str, token_budget) + total_tokens = self._estimate_text_tokens(chats_str) + return chats_str, omitted, total_tokens + async def remove_session(self, event: AstrMessageEvent) -> int: cnt = 0 if event.unified_msg_origin in self.session_chats: @@ -125,6 +210,11 @@ async def handle_message(self, event: AstrMessageEvent) -> None: parts.append(f" {comp.text}") elif isinstance(comp, Image): if cfg["image_caption"]: + logger.warning( + "Group ICL image caption is enabled. Each group image may trigger an extra multimodal request. umo=%s, provider=%s", + event.unified_msg_origin, + cfg["image_caption_provider_id"], + ) try: url = comp.url if comp.url else comp.file if not url: @@ -153,9 +243,19 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non if event.unified_msg_origin not in self.session_chats: return - chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) - cfg = self.cfg(event) + chats_str, omitted, estimated_tokens = self._build_chats_context( + self.session_chats[event.unified_msg_origin], + cfg["group_icl_token_budget"], + ) + if omitted > 0: + logger.warning( + "Group ICL context truncated by token budget. umo=%s, omitted=%s, estimated_tokens=%s, budget=%s", + event.unified_msg_origin, + omitted, + estimated_tokens, + cfg["group_icl_token_budget"], + ) if cfg["enable_active_reply"]: prompt = req.prompt req.prompt = ( @@ -166,10 +266,17 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non ) req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 else: - req.system_prompt += ( - "You are now in a chatroom. The chat history is as follows: \n" + req.extra_user_content_parts.append( + TextPart( + text=( + "Use the following recent group chat context only as background " + "for this request.\n" + "[Group Chat Context]\n" + "Recent group chat messages, newest messages are kept when truncated:\n" + f"{chats_str}" + ) + ).mark_as_temp() ) - req.system_prompt += chats_str async def after_req_llm( self, event: AstrMessageEvent, llm_resp: LLMResponse diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 968426b8b4..f2ba3089b0 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -110,6 +110,8 @@ class _ToolExecutionInterrupted(Exception): class ToolLoopAgentRunner(BaseAgentRunner[TContext]): TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500 TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000 + REQUEST_WARN_ESTIMATED_INPUT_TOKENS = 16_000 + REQUEST_WARN_IMAGE_COUNT = 1 EMPTY_OUTPUT_RETRY_ATTEMPTS = 3 EMPTY_OUTPUT_RETRY_WAIT_MIN_S = 1 EMPTY_OUTPUT_RETRY_WAIT_MAX_S = 4 @@ -176,6 +178,43 @@ def _get_persona_custom_error_message(self) -> str | None: event = getattr(self.run_context.context, "event", None) return extract_persona_custom_error_message_from_event(event) + @staticmethod + def _count_image_parts(messages: list[Message]) -> int: + count = 0 + for message in messages: + if isinstance(message.content, list): + count += sum( + 1 for part in message.content if isinstance(part, ImageURLPart) + ) + return count + + def _log_request_cost_preflight(self) -> None: + estimated_input_tokens = EstimateTokenCounter().count_tokens( + self.run_context.messages + ) + image_count = self._count_image_parts(self.run_context.messages) + logger.debug( + "LLM request preflight. provider=%s, model=%s, estimated_input_tokens=%s, image_count=%s", + self.provider.provider_config.get("id", ""), + self.provider.get_model(), + estimated_input_tokens, + image_count, + ) + if estimated_input_tokens >= self.REQUEST_WARN_ESTIMATED_INPUT_TOKENS: + logger.warning( + "LLM request has high estimated input tokens. provider=%s, model=%s, estimated_input_tokens=%s", + self.provider.provider_config.get("id", ""), + self.provider.get_model(), + estimated_input_tokens, + ) + if image_count > self.REQUEST_WARN_IMAGE_COUNT: + logger.warning( + "LLM request contains multiple images. provider=%s, model=%s, image_count=%s", + self.provider.provider_config.get("id", ""), + self.provider.get_model(), + image_count, + ) + async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None: """Finalize the current step as a plain assistant response with no tool calls.""" self.final_llm_resp = llm_resp @@ -711,6 +750,7 @@ async def step(self): self.run_context.messages, trusted_token_usage=token_usage ) self._simple_print_message_role("[AftCompact]") + self._log_request_cost_preflight() async for llm_response in self._iter_llm_responses_with_fallback(): if llm_response.is_chunk: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index fd1a9aeb8c..f223751f61 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -152,10 +152,10 @@ class MainAgentBuildConfig: """The number of most recent turns to keep during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" - max_context_length: int = -1 + max_context_length: int = 30 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" - dequeue_context_length: int = 1 + dequeue_context_length: int = 10 """The number of oldest turns to remove when context length limit is reached.""" fallback_max_context_tokens: int = 128000 """Fallback max context tokens. When max_context_tokens is 0 and the model is not in LLM_METADATAS, use this value.""" diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 8df7ae27b0..4eb0d5180a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -130,8 +130,8 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", - "max_context_length": -1, - "dequeue_context_length": 1, + "max_context_length": 30, + "dequeue_context_length": 10, "streaming_response": False, "show_tool_use_status": False, "show_tool_call_result": False, @@ -217,7 +217,8 @@ }, "provider_ltm_settings": { "group_icl_enable": False, - "group_message_max_cnt": 300, + "group_message_max_cnt": 50, + "group_icl_token_budget": 4000, "image_caption": False, "image_caption_provider_id": "", "active_reply": { @@ -2862,6 +2863,9 @@ "group_message_max_cnt": { "type": "int", }, + "group_icl_token_budget": { + "type": "int", + }, "image_caption": { "type": "bool", }, @@ -4079,6 +4083,11 @@ "description": "最大消息数量", "type": "int", }, + "provider_ltm_settings.group_icl_token_budget": { + "description": "群聊上下文 Token 预算", + "type": "int", + "hint": "每次 LLM 请求注入的群聊上下文近似 token 上限。降低该值可减少费用并降低缓存失效影响。", + }, "provider_ltm_settings.image_caption": { "description": "自动理解图片", "type": "bool", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index fee641c192..9c4e6b0120 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -240,10 +240,19 @@ async def process( and not event.platform_meta.support_streaming_message ) + system_prompt_before_hooks = req.system_prompt or "" if await call_event_hook(event, EventType.OnLLMRequestEvent, req): if reset_coro: reset_coro.close() return + system_prompt_after_hooks = req.system_prompt or "" + if system_prompt_after_hooks != system_prompt_before_hooks: + logger.warning( + "LLM system prompt was modified by request hooks. umo=%s, before_chars=%s, after_chars=%s", + event.unified_msg_origin, + len(system_prompt_before_hooks), + len(system_prompt_after_hooks), + ) # apply reset if reset_coro: diff --git a/tests/unit/test_long_term_memory_cost_safety.py b/tests/unit/test_long_term_memory_cost_safety.py new file mode 100644 index 0000000000..2f1e3f8bb8 --- /dev/null +++ b/tests/unit/test_long_term_memory_cost_safety.py @@ -0,0 +1,85 @@ +import pytest + +from astrbot.api.provider import ProviderRequest +from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory +from astrbot.core.agent.message import ( + Message, + TextPart, + dump_messages_with_checkpoints, +) + + +class DummyContext: + def get_config(self, umo=None): + return { + "provider_settings": { + "image_caption_prompt": "Describe the image.", + }, + "provider_ltm_settings": { + "group_message_max_cnt": 50, + "group_icl_token_budget": 30, + "image_caption": False, + "image_caption_provider_id": "", + "active_reply": { + "enable": False, + "method": "possibility_reply", + "possibility_reply": 0.1, + "whitelist": [], + }, + }, + } + + +class DummyEvent: + unified_msg_origin = "group:test" + + +@pytest.mark.asyncio +async def test_group_icl_uses_user_context_part_instead_of_system_prompt(): + ltm = LongTermMemory(None, DummyContext()) + ltm.session_chats[DummyEvent.unified_msg_origin] = [ + "[alice/10:00:00]: old message", + "[bob/10:01:00]: recent message", + ] + req = ProviderRequest(prompt="hello", system_prompt="base system") + + await ltm.on_req_llm(DummyEvent(), req) + + assert "old message" not in req.system_prompt + assert "recent message" not in req.system_prompt + assert len(req.extra_user_content_parts) == 1 + part = req.extra_user_content_parts[0] + assert isinstance(part, TextPart) + assert part._no_save is True + assert "only as background" in part.text + assert "[Group Chat Context]" in part.text + assert "recent message" in part.text + + +@pytest.mark.asyncio +async def test_group_icl_context_is_not_persisted_in_history(): + ltm = LongTermMemory(None, DummyContext()) + ltm.session_chats[DummyEvent.unified_msg_origin] = [ + "[bob/10:01:00]: recent message", + ] + req = ProviderRequest(prompt="hello", system_prompt="base system") + + await ltm.on_req_llm(DummyEvent(), req) + message = Message.model_validate(await req.assemble_context()) + dumped = dump_messages_with_checkpoints([message]) + + assert "hello" in str(dumped) + assert "[Group Chat Context]" not in str(dumped) + assert "recent message" not in str(dumped) + + +def test_group_icl_context_respects_token_budget(): + ltm = LongTermMemory(None, DummyContext()) + chats = [f"[user{i}/10:00:0{i}]: " + ("x" * 80) for i in range(5)] + budget = 30 + + chats_str, omitted, estimated_tokens = ltm._build_chats_context(chats, budget) + + assert omitted > 0 + assert chats_str + assert estimated_tokens <= budget