Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 113 additions & 6 deletions astrbot/builtin_stars/astrbot/long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import MessageType
from astrbot.api.provider import LLMResponse, Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager

"""
Expand All @@ -17,6 +18,9 @@


class LongTermMemory:
DEFAULT_MAX_GROUP_MESSAGES = 50
DEFAULT_GROUP_ICL_TOKEN_BUDGET = 4000

def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
self.acm = acm
self.context = context
Expand All @@ -29,7 +33,19 @@ def cfg(self, event: AstrMessageEvent):
max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"])
except BaseException as e:
logger.error(e)
max_cnt = 300
max_cnt = self.DEFAULT_MAX_GROUP_MESSAGES
max_cnt = max(1, max_cnt)
try:
group_icl_token_budget = int(
cfg["provider_ltm_settings"].get(
"group_icl_token_budget",
self.DEFAULT_GROUP_ICL_TOKEN_BUDGET,
)
Comment on lines 34 to +43
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Catching BaseException here is broader than needed and may hide unexpected issues.

This except BaseException will also catch KeyboardInterrupt and SystemExit, which we usually want to propagate. Since only config parsing and int casting can fail here, narrowing to Exception (or even ValueError/TypeError) would avoid swallowing truly exceptional conditions while still handling malformed configs safely.

Suggested implementation:

        except (ValueError, TypeError, KeyError) as e:
            logger.error(e)
            max_cnt = self.DEFAULT_MAX_GROUP_MESSAGES
        except (ValueError, TypeError) as e:
            logger.error(e)
            group_icl_token_budget = self.DEFAULT_GROUP_ICL_TOKEN_BUDGET

)
except BaseException as e:
logger.error(e)
group_icl_token_budget = self.DEFAULT_GROUP_ICL_TOKEN_BUDGET
Comment on lines 34 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

捕获 BaseException 是不推荐的做法,因为它会捕获包括 SystemExitKeyboardInterrupt 在内的所有异常,这可能会导致程序无法正常退出或调试困难。建议改为捕获具体的异常(如 ValueError, KeyError)或者至少改为 Exception

Suggested change
except BaseException as e:
logger.error(e)
max_cnt = 300
max_cnt = self.DEFAULT_MAX_GROUP_MESSAGES
max_cnt = max(1, max_cnt)
try:
group_icl_token_budget = int(
cfg["provider_ltm_settings"].get(
"group_icl_token_budget",
self.DEFAULT_GROUP_ICL_TOKEN_BUDGET,
)
)
except BaseException as e:
logger.error(e)
group_icl_token_budget = self.DEFAULT_GROUP_ICL_TOKEN_BUDGET
except Exception as e:
logger.error(e)
max_cnt = self.DEFAULT_MAX_GROUP_MESSAGES
max_cnt = max(1, max_cnt)
try:
group_icl_token_budget = int(
cfg["provider_ltm_settings"].get(
"group_icl_token_budget",
self.DEFAULT_GROUP_ICL_TOKEN_BUDGET,
)
)
except Exception as e:
logger.error(e)
group_icl_token_budget = self.DEFAULT_GROUP_ICL_TOKEN_BUDGET

group_icl_token_budget = max(1, group_icl_token_budget)
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
image_caption_provider_id = cfg["provider_ltm_settings"].get(
"image_caption_provider_id"
Expand All @@ -45,6 +61,7 @@ def cfg(self, event: AstrMessageEvent):
ar_whitelist = active_reply.get("whitelist", [])
ret = {
"max_cnt": max_cnt,
"group_icl_token_budget": group_icl_token_budget,
"image_caption": image_caption,
"image_caption_prompt": image_caption_prompt,
"image_caption_provider_id": image_caption_provider_id,
Expand All @@ -56,6 +73,74 @@ def cfg(self, event: AstrMessageEvent):
}
return ret

@staticmethod
def _estimate_text_tokens(text: str) -> int:
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
Comment on lines +78 to +80
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这里的 Token 估算逻辑存在两个问题:

  1. 效率问题:使用列表推导式 [c for c in text if ...] 会在内存中创建一个完整的字符列表,对于长文本(如群聊历史)会造成不必要的内存开销。建议改用 sum(1 for c in text if ...)
  2. 准确性问题:中文字符的倍率 0.6 严重偏低。在主流模型(如 GPT-4o, DeepSeek)的 tokenizer 中,一个中文字符通常对应 1.5 到 2.0 个 tokens。使用 0.6 会导致严重低估实际消耗,从而削弱此 PR 旨在降低成本风险的效果。建议将中文字符倍率调整为 1.5 左右。
Suggested change
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
chinese_count = sum(1 for c in text if "\u4e00" <= c <= "\u9fff")
other_count = len(text) - chinese_count
return int(chinese_count * 1.5 + other_count * 0.3)


def _trim_text_to_token_budget(self, text: str, token_budget: int) -> str:
marker = "[truncated]\n"
marker_tokens = self._estimate_text_tokens(marker)
if self._estimate_text_tokens(text) <= token_budget:
return text
if token_budget <= marker_tokens:
return marker.strip()

low = 0
high = len(text)
best = ""
target_budget = token_budget - marker_tokens
while low <= high:
mid = (low + high) // 2
candidate = text[-mid:] if mid else ""
if self._estimate_text_tokens(candidate) <= target_budget:
best = candidate
low = mid + 1
else:
high = mid - 1
result = f"{marker}{best}"
while result and self._estimate_text_tokens(result) > token_budget:
result = result[:-1]
return result

def _build_chats_context(
self,
chats: list[str],
token_budget: int,
) -> tuple[str, int, int]:
separator = "\n---\n"
separator_tokens = self._estimate_text_tokens(separator)
selected: list[str] = []
total_tokens = 0

for chat in reversed(chats):
chat_tokens = self._estimate_text_tokens(chat)
extra_tokens = chat_tokens + (separator_tokens if selected else 0)
if selected and total_tokens + extra_tokens > token_budget:
break
if not selected and chat_tokens > token_budget:
trimmed = self._trim_text_to_token_budget(chat, token_budget)
return trimmed, len(chats) - 1, self._estimate_text_tokens(trimmed)
selected.append(chat)
total_tokens += extra_tokens

selected.reverse()
omitted = len(chats) - len(selected)
chats_str = separator.join(selected)
if omitted > 0:
omitted_notice = (
f"[{omitted} earlier group messages omitted due to token budget]"
)
chats_str = f"{omitted_notice}{separator}{chats_str}"
total_tokens += (
self._estimate_text_tokens(omitted_notice) + separator_tokens
)
if total_tokens > token_budget:
chats_str = self._trim_text_to_token_budget(chats_str, token_budget)
total_tokens = self._estimate_text_tokens(chats_str)
return chats_str, omitted, total_tokens

async def remove_session(self, event: AstrMessageEvent) -> int:
cnt = 0
if event.unified_msg_origin in self.session_chats:
Expand Down Expand Up @@ -125,6 +210,11 @@ async def handle_message(self, event: AstrMessageEvent) -> None:
parts.append(f" {comp.text}")
elif isinstance(comp, Image):
if cfg["image_caption"]:
logger.warning(
"Group ICL image caption is enabled. Each group image may trigger an extra multimodal request. umo=%s, provider=%s",
event.unified_msg_origin,
cfg["image_caption_provider_id"],
)
Comment on lines +213 to +217
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

handle_message 中对每张图片都输出 warning 日志会导致日志洪泛(Log Flooding),尤其是在活跃的群聊中。这会干扰管理员查看其他重要日志。建议增加一个标记位,确保每个会话或每次启动仅针对该配置提醒一次。

                        if not getattr(self, "_image_caption_warned", False):
                            logger.warning(
                                "Group ICL image caption is enabled. Each group image may trigger an extra multimodal request. umo=%s, provider=%s",
                                event.unified_msg_origin,
                                cfg["image_caption_provider_id"],
                            )
                            self._image_caption_warned = True

try:
url = comp.url if comp.url else comp.file
if not url:
Expand Down Expand Up @@ -153,9 +243,19 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non
if event.unified_msg_origin not in self.session_chats:
return

chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin])

cfg = self.cfg(event)
chats_str, omitted, estimated_tokens = self._build_chats_context(
self.session_chats[event.unified_msg_origin],
cfg["group_icl_token_budget"],
)
if omitted > 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

目前的逻辑仅在 omitted > 0(即有完整的消息被丢弃)时输出警告。如果群聊中只有一条消息但该消息被截断了(omitted 为 0),则不会触发警告。建议同时检查 chats_str 是否包含截断标记。

Suggested change
if omitted > 0:
if omitted > 0 or chats_str.startswith("[truncated]"):

logger.warning(
"Group ICL context truncated by token budget. umo=%s, omitted=%s, estimated_tokens=%s, budget=%s",
event.unified_msg_origin,
omitted,
estimated_tokens,
cfg["group_icl_token_budget"],
)
if cfg["enable_active_reply"]:
prompt = req.prompt
req.prompt = (
Expand All @@ -166,10 +266,17 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non
)
req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。
else:
req.system_prompt += (
"You are now in a chatroom. The chat history is as follows: \n"
req.extra_user_content_parts.append(
TextPart(
text=(
"Use the following recent group chat context only as background "
"for this request.\n"
"[Group Chat Context]\n"
"Recent group chat messages, newest messages are kept when truncated:\n"
f"{chats_str}"
)
).mark_as_temp()
)
req.system_prompt += chats_str

async def after_req_llm(
self, event: AstrMessageEvent, llm_resp: LLMResponse
Expand Down
40 changes: 40 additions & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class _ToolExecutionInterrupted(Exception):
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500
TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000
REQUEST_WARN_ESTIMATED_INPUT_TOKENS = 16_000
REQUEST_WARN_IMAGE_COUNT = 1
EMPTY_OUTPUT_RETRY_ATTEMPTS = 3
EMPTY_OUTPUT_RETRY_WAIT_MIN_S = 1
EMPTY_OUTPUT_RETRY_WAIT_MAX_S = 4
Comment on lines +114 to 117
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: The REQUEST_WARN_IMAGE_COUNT threshold name/usage is slightly confusing.

With REQUEST_WARN_IMAGE_COUNT = 1 and the check image_count > self.REQUEST_WARN_IMAGE_COUNT, the warning only fires for 2+ images. That’s a reasonable behavior, but the name suggests “warn at 1 image.” Either switch the comparison to >= if you want to warn on the first image, or rename the constant (e.g., REQUEST_WARN_IMAGE_COUNT_THRESHOLD) to better reflect the current > semantics.

Expand Down Expand Up @@ -176,6 +178,43 @@ def _get_persona_custom_error_message(self) -> str | None:
event = getattr(self.run_context.context, "event", None)
return extract_persona_custom_error_message_from_event(event)

@staticmethod
def _count_image_parts(messages: list[Message]) -> int:
count = 0
for message in messages:
if isinstance(message.content, list):
count += sum(
1 for part in message.content if isinstance(part, ImageURLPart)
)
return count

def _log_request_cost_preflight(self) -> None:
estimated_input_tokens = EstimateTokenCounter().count_tokens(
self.run_context.messages
)
image_count = self._count_image_parts(self.run_context.messages)
logger.debug(
"LLM request preflight. provider=%s, model=%s, estimated_input_tokens=%s, image_count=%s",
self.provider.provider_config.get("id", ""),
self.provider.get_model(),
estimated_input_tokens,
image_count,
)
if estimated_input_tokens >= self.REQUEST_WARN_ESTIMATED_INPUT_TOKENS:
logger.warning(
"LLM request has high estimated input tokens. provider=%s, model=%s, estimated_input_tokens=%s",
self.provider.provider_config.get("id", ""),
self.provider.get_model(),
estimated_input_tokens,
)
if image_count > self.REQUEST_WARN_IMAGE_COUNT:
logger.warning(
"LLM request contains multiple images. provider=%s, model=%s, image_count=%s",
self.provider.provider_config.get("id", ""),
self.provider.get_model(),
image_count,
)

async def _complete_with_assistant_response(self, llm_resp: LLMResponse) -> None:
"""Finalize the current step as a plain assistant response with no tool calls."""
self.final_llm_resp = llm_resp
Expand Down Expand Up @@ -711,6 +750,7 @@ async def step(self):
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]")
self._log_request_cost_preflight()

async for llm_response in self._iter_llm_responses_with_fallback():
if llm_response.is_chunk:
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ class MainAgentBuildConfig:
"""The number of most recent turns to keep during llm_compress strategy."""
llm_compress_provider_id: str = ""
"""The provider ID for the LLM used in context compression."""
max_context_length: int = -1
max_context_length: int = 30
"""The maximum number of turns to keep in context. -1 means no limit.
This enforce max turns before compression"""
dequeue_context_length: int = 1
dequeue_context_length: int = 10
"""The number of oldest turns to remove when context length limit is reached."""
fallback_max_context_tokens: int = 128000
"""Fallback max context tokens. When max_context_tokens is 0 and the model is not in LLM_METADATAS, use this value."""
Expand Down
15 changes: 12 additions & 3 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@
),
"llm_compress_keep_recent": 6,
"llm_compress_provider_id": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"max_context_length": 30,
"dequeue_context_length": 10,
"streaming_response": False,
"show_tool_use_status": False,
"show_tool_call_result": False,
Expand Down Expand Up @@ -217,7 +217,8 @@
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"group_message_max_cnt": 50,
"group_icl_token_budget": 4000,
"image_caption": False,
"image_caption_provider_id": "",
"active_reply": {
Expand Down Expand Up @@ -2862,6 +2863,9 @@
"group_message_max_cnt": {
"type": "int",
},
"group_icl_token_budget": {
"type": "int",
},
"image_caption": {
"type": "bool",
},
Expand Down Expand Up @@ -4079,6 +4083,11 @@
"description": "最大消息数量",
"type": "int",
},
"provider_ltm_settings.group_icl_token_budget": {
"description": "群聊上下文 Token 预算",
"type": "int",
"hint": "每次 LLM 请求注入的群聊上下文近似 token 上限。降低该值可减少费用并降低缓存失效影响。",
},
"provider_ltm_settings.image_caption": {
"description": "自动理解图片",
"type": "bool",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,19 @@ async def process(
and not event.platform_meta.support_streaming_message
)

system_prompt_before_hooks = req.system_prompt or ""
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
if reset_coro:
reset_coro.close()
return
system_prompt_after_hooks = req.system_prompt or ""
if system_prompt_after_hooks != system_prompt_before_hooks:
logger.warning(
"LLM system prompt was modified by request hooks. umo=%s, before_chars=%s, after_chars=%s",
event.unified_msg_origin,
len(system_prompt_before_hooks),
len(system_prompt_after_hooks),
)

# apply reset
if reset_coro:
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/test_long_term_memory_cost_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest

from astrbot.api.provider import ProviderRequest
from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory
from astrbot.core.agent.message import (
Message,
TextPart,
dump_messages_with_checkpoints,
)


class DummyContext:
def get_config(self, umo=None):
return {
"provider_settings": {
"image_caption_prompt": "Describe the image.",
},
"provider_ltm_settings": {
"group_message_max_cnt": 50,
"group_icl_token_budget": 30,
"image_caption": False,
"image_caption_provider_id": "",
"active_reply": {
"enable": False,
"method": "possibility_reply",
"possibility_reply": 0.1,
"whitelist": [],
},
},
}


class DummyEvent:
unified_msg_origin = "group:test"


@pytest.mark.asyncio
async def test_group_icl_uses_user_context_part_instead_of_system_prompt():
ltm = LongTermMemory(None, DummyContext())
ltm.session_chats[DummyEvent.unified_msg_origin] = [
"[alice/10:00:00]: old message",
"[bob/10:01:00]: recent message",
]
req = ProviderRequest(prompt="hello", system_prompt="base system")

await ltm.on_req_llm(DummyEvent(), req)

assert "old message" not in req.system_prompt
assert "recent message" not in req.system_prompt
assert len(req.extra_user_content_parts) == 1
part = req.extra_user_content_parts[0]
assert isinstance(part, TextPart)
assert part._no_save is True
assert "only as background" in part.text
assert "[Group Chat Context]" in part.text
assert "recent message" in part.text


@pytest.mark.asyncio
async def test_group_icl_context_is_not_persisted_in_history():
ltm = LongTermMemory(None, DummyContext())
ltm.session_chats[DummyEvent.unified_msg_origin] = [
"[bob/10:01:00]: recent message",
]
req = ProviderRequest(prompt="hello", system_prompt="base system")

await ltm.on_req_llm(DummyEvent(), req)
message = Message.model_validate(await req.assemble_context())
dumped = dump_messages_with_checkpoints([message])

assert "hello" in str(dumped)
assert "[Group Chat Context]" not in str(dumped)
assert "recent message" not in str(dumped)


def test_group_icl_context_respects_token_budget():
ltm = LongTermMemory(None, DummyContext())
chats = [f"[user{i}/10:00:0{i}]: " + ("x" * 80) for i in range(5)]
budget = 30

chats_str, omitted, estimated_tokens = ltm._build_chats_context(chats, budget)

assert omitted > 0
assert chats_str
assert estimated_tokens <= budget
Loading