diff --git a/cmd/friday/cmd_gw.go b/cmd/friday/cmd_gw.go index c998364..0c6e373 100644 --- a/cmd/friday/cmd_gw.go +++ b/cmd/friday/cmd_gw.go @@ -109,6 +109,14 @@ func (r *GatewayRunner) initCronjob(ctx context.Context, cfg *config.Config, gw cronjob.Init(cfg.Cronjob, gw.Enqueue) + // Start (which calls store.Load) MUST happen before AddJob so that + // Load's map replacement does not wipe the freshly registered built-in + // jobs. AddJob's idempotent update logic handles the case where a stale + // entry was loaded from disk. + if err := cronjob.Start(ctx); err != nil { + return err + } + s := cronjob.Default() for id, agCfg := range cfg.Agents { hbJob := cronjob.NewHeartbeatJob(id, agCfg.Workspace, 0) @@ -127,7 +135,7 @@ func (r *GatewayRunner) initCronjob(ctx context.Context, cfg *config.Config, gw } } - return cronjob.Start(ctx) + return nil } func (r *GatewayRunner) initLogger(cfg config.LoggingConfig) error { diff --git a/config.yaml.example b/config.yaml.example index b4da671..efe716e 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -68,6 +68,10 @@ agents: consolidate_every: 50 # Minimum interval between mid-conversation flushes. flush_cooldown: "2h" + # Total context window size in tokens for compaction threshold. + context_budget: 128000 + # Tokens reserved for new user message and LLM response. + reserve_tokens: 20000 # Channel definitions. Key = channel ID. channels: diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 48c16b9..4b22762 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -36,6 +36,11 @@ import ( "github.com/tgifai/friday/internal/provider" ) +const ( + defaultContextBudget = 128_000 // 128k + defaultReserveTokens = 20_000 +) + // EnqueueFunc is a callback to submit messages into the gateway pipeline. type EnqueueFunc func(ctx context.Context, msg *channel.Message) error @@ -52,6 +57,8 @@ type Agent struct { enqueue EnqueueFunc // allows agent to self-enqueue messages (set by gateway) consolidateEvery int flushCooldown time.Duration + contextBudget int + reserveTokens int toolsRegistered sync.Map // providerID → true; ensures RegisterTools is called once per provider } @@ -85,6 +92,15 @@ func NewAgent(_ context.Context, cfg config.AgentConfig) (*Agent, error) { flushCooldown = cd } + contextBudget := cfg.Session.ContextBudget + if contextBudget <= 0 { + contextBudget = defaultContextBudget + } + reserveTokens := cfg.Session.ReserveTokens + if reserveTokens <= 0 { + reserveTokens = defaultReserveTokens + } + ag := &Agent{ id: cfg.ID, name: cfg.Name, @@ -94,6 +110,8 @@ func NewAgent(_ context.Context, cfg config.AgentConfig) (*Agent, error) { skills: skill.NewRegistry(cfg.Workspace), consolidateEvery: consolidateEvery, flushCooldown: flushCooldown, + contextBudget: contextBudget, + reserveTokens: reserveTokens, } return ag, nil @@ -234,7 +252,12 @@ func (ag *Agent) ProcessMessage(ctx context.Context, msg *channel.Message) (*cha } // get or create current session - sess := ag.sessMgr.GetOrCreateFor(msg.ChannelType, msg.ChannelID, msg.ChatID) + var sess *session.Session + if msg.SessionKey != "" { + sess = ag.sessMgr.GetOrCreate(msg.SessionKey) + } else { + sess = ag.sessMgr.GetOrCreateFor(msg.ChannelType, msg.ChannelID, msg.ChatID) + } msg.SessionKey = sess.SessionKey defer func() { if err := ag.sessMgr.Save(sess); err != nil { @@ -285,6 +308,12 @@ func (ag *Agent) ProcessMessage(ctx context.Context, msg *channel.Message) (*cha // Check if session has crossed the consolidation threshold. ag.maybeEnqueueFlush(ctx, sess) + // Clear isolated cron sessions to prevent unbounded history growth. + // Isolated sessions use key prefix "cron:" (vs "agent:" for main/heartbeat). + if msg.ChannelType == channel.Type("cron") && strings.HasPrefix(msg.SessionKey, "cron:") { + sess.Clear() + } + return resp, nil } diff --git a/internal/agent/compact.go b/internal/agent/compact.go new file mode 100644 index 0000000..f840591 --- /dev/null +++ b/internal/agent/compact.go @@ -0,0 +1,251 @@ +package agent + +import ( + "context" + "strings" + + "github.com/bytedance/sonic" + "github.com/cloudwego/eino/schema" + "github.com/tgifai/friday/internal/consts" + + "github.com/tgifai/friday/internal/agent/session" + "github.com/tgifai/friday/internal/pkg/logs" + "github.com/tgifai/friday/internal/provider" +) + +const ( + preFlushMaxIterations = 3 + minKeepTurns = 2 + + flushSkipSentinel = "FLUSH_SKIP" +) + +// maybeCompact checks whether the prompt messages exceed the context budget +// and, if so, runs the compaction pipeline: pre-flush → summarize → compact. +// Returns the (possibly rebuilt) prompt messages. +func (ag *Agent) maybeCompact( + ctx context.Context, + p provider.Provider, + modelSpec *provider.ModelSpec, + sess *session.Session, + promptMsgs []*schema.Message, + userMsg *schema.Message, +) []*schema.Message { + threshold := ag.contextBudget - ag.reserveTokens + if threshold <= 0 { + return promptMsgs + } + + // Estimate without allocating a combined slice. + estimated := session.EstimateTokens(promptMsgs) + session.EstimateMessageTokens(userMsg) + if estimated <= threshold { + return promptMsgs + } + + logs.CtxInfo(ctx, "[agent:%s] compaction triggered: estimated %d tokens > threshold %d", + ag.id, estimated, threshold) + + // Step 1: Pre-flush — give LLM a chance to persist important info. + ag.runPreFlush(ctx, p, modelSpec, promptMsgs, userMsg) + + // Step 2: Calculate keepCount. + history := sess.History() + keepBudget := threshold / 2 + keepCount := calculateKeepCount(history, keepBudget) + + // Step 3: Generate summary of old messages. + oldMsgs := history + if keepCount < len(history) { + oldMsgs = history[:len(history)-keepCount] + } + + summary := ag.generateSummary(ctx, p, modelSpec, oldMsgs, threshold) + if summary == nil { + // Fallback: trim without summary. + logs.CtxWarn(ctx, "[agent:%s] summary generation failed, falling back to trim", ag.id) + summary = &schema.Message{ + Role: schema.Assistant, + Content: "[Earlier conversation history was trimmed due to context limits]", + } + } + + // Step 4: Compact the session. + sess.Compact(summary, keepCount) + logs.CtxInfo(ctx, "[agent:%s] compaction complete: kept %d messages, removed %d", + ag.id, keepCount, len(history)-keepCount) + + // Rebuild prompt messages with compacted history. + return ag.buildMessages(ctx, sess, nil, p.Type()) +} + +// runPreFlush runs a short agent loop allowing the LLM to persist important +// information before compaction. Messages from this turn are NOT saved to session. +func (ag *Agent) runPreFlush( + ctx context.Context, + p provider.Provider, + modelSpec *provider.ModelSpec, + promptMsgs []*schema.Message, + userMsg *schema.Message, +) { + flushMsgs := make([]*schema.Message, 0, len(promptMsgs)+2) + flushMsgs = append(flushMsgs, promptMsgs...) + flushMsgs = append(flushMsgs, userMsg) + flushMsgs = append(flushMsgs, &schema.Message{ + Role: schema.System, + Content: consts.PromptPreFlush, + }) + + for iter := 0; iter < preFlushMaxIterations; iter++ { + resp, err := p.Generate(ctx, modelSpec.ModelName, flushMsgs) + if err != nil { + logs.CtxWarn(ctx, "[agent:%s] pre-flush LLM call failed: %v", ag.id, err) + return + } + if resp == nil { + return + } + + // Check for skip sentinel. + if strings.Contains(resp.Content, flushSkipSentinel) { + logs.CtxDebug(ctx, "[agent:%s] pre-flush: LLM signaled FLUSH_SKIP", ag.id) + return + } + + // If LLM made tool calls, execute them. + if len(resp.ToolCalls) > 0 { + flushMsgs = append(flushMsgs, resp) + for _, call := range resp.ToolCalls { + callMsg := ag.buildToolResultMessage(ctx, &call) + flushMsgs = append(flushMsgs, callMsg) + } + continue + } + + // No tool calls, LLM is done. + return + } +} + +// buildToolResultMessage executes a tool call and returns the result as a Tool message. +// This is the shared helper used by both runLoop and runPreFlush. +func (ag *Agent) buildToolResultMessage(ctx context.Context, call *schema.ToolCall) *schema.Message { + res, callErr := ag.tools.ExecuteToolCall(ctx, call) + callMsg := &schema.Message{ + Role: schema.Tool, + ToolName: call.Function.Name, + ToolCallID: call.ID, + } + if callErr != nil { + callMsg.Content = "ERROR: " + callErr.Error() + } else { + jsonStr, marshalErr := sonic.MarshalString(res) + if marshalErr != nil || jsonStr == "" { + callMsg.Content = "{}" + } else { + callMsg.Content = jsonStr + } + } + return callMsg +} + +// generateSummary asks the LLM to summarize old messages. Returns nil on failure. +// Truncates oldMsgs to fit within tokenBudget to avoid exceeding the context window. +func (ag *Agent) generateSummary( + ctx context.Context, + p provider.Provider, + modelSpec *provider.ModelSpec, + oldMsgs []*schema.Message, + tokenBudget int, +) *schema.Message { + // Truncate oldMsgs to fit within the token budget so the summary call + // itself doesn't exceed the model's context window. + truncated := truncateToFit(oldMsgs, tokenBudget) + + summaryMsgs := make([]*schema.Message, 0, len(truncated)+1) + summaryMsgs = append(summaryMsgs, &schema.Message{ + Role: schema.System, + Content: consts.PromptSummary, + }) + summaryMsgs = append(summaryMsgs, truncated...) + + resp, err := p.Generate(ctx, modelSpec.ModelName, summaryMsgs) + if err != nil { + logs.CtxWarn(ctx, "[agent:%s] summary generation failed: %v", ag.id, err) + return nil + } + if resp == nil || strings.TrimSpace(resp.Content) == "" { + return nil + } + + return &schema.Message{ + Role: schema.Assistant, + Content: resp.Content, + } +} + +// truncateToFit returns the most recent messages from msgs that fit within +// the given token budget. Keeps messages from the tail (newest first). +func truncateToFit(msgs []*schema.Message, tokenBudget int) []*schema.Message { + total := session.EstimateTokens(msgs) + if total <= tokenBudget { + return msgs + } + // Walk from tail, accumulate until budget is exceeded. + used := 0 + start := len(msgs) + for i := len(msgs) - 1; i >= 0; i-- { + t := session.EstimateMessageTokens(msgs[i]) + if used+t > tokenBudget { + break + } + used += t + start = i + } + return msgs[start:] +} + +// calculateKeepCount determines how many recent messages to keep based on +// a token budget. Always keeps at least minKeepTurns complete turns. +func calculateKeepCount(messages []*schema.Message, tokenBudget int) int { + if len(messages) == 0 { + return 0 + } + + used := 0 + count := 0 + minKeep := findMinKeepForTurns(messages, minKeepTurns) + + for i := len(messages) - 1; i >= 0; i-- { + msgTokens := session.EstimateMessageTokens(messages[i]) + if used+msgTokens > tokenBudget && count >= minKeep { + break + } + used += msgTokens + count++ + } + + if count < minKeep { + count = minKeep + } + if count > len(messages) { + count = len(messages) + } + return count +} + +// findMinKeepForTurns returns the minimum number of messages from the tail +// needed to include at least n complete user→assistant turns. +func findMinKeepForTurns(messages []*schema.Message, n int) int { + turns := 0 + count := 0 + for i := len(messages) - 1; i >= 0; i-- { + count++ + if messages[i].Role == schema.User { + turns++ + if turns >= n { + return count + } + } + } + return count // all messages if fewer than n turns +} diff --git a/internal/agent/loop.go b/internal/agent/loop.go index ce670b9..0bf7314 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -3,6 +3,7 @@ package agent import ( "context" "fmt" + "strings" "time" "github.com/bytedance/sonic" @@ -26,11 +27,10 @@ func (ag *Agent) runLoop(ctx context.Context, p provider.Provider, modelSpec *pr // Inject session into context so CLI providers can access metadata. ctx = session.WithContext(ctx, sess) promptMsgs := ag.buildMessages(ctx, sess, msg, p.Type()) - - // Include user message in the prompt but defer session persistence - // until the loop completes successfully, preventing orphaned user - // messages when all models fail. userMsg := buildUserMessage(msg) + + // Check token budget and compact if needed. + promptMsgs = ag.maybeCompact(ctx, p, modelSpec, sess, promptMsgs, userMsg) promptMsgs = append(promptMsgs, userMsg) maxIterations := defaultMaxIterations @@ -70,22 +70,9 @@ func (ag *Agent) runLoop(ctx context.Context, p provider.Provider, modelSpec *pr msgs = append(msgs, llmResp) for _, call := range llmResp.ToolCalls { logs.CtxDebug(ctx, "[agent:%s:%d] call: %+v", ag.id, iter, call) - res, callErr := ag.tools.ExecuteToolCall(ctx, &call) - callMsg := &schema.Message{ - Role: schema.Tool, - ToolName: call.Function.Name, - ToolCallID: call.ID, - } - if callErr != nil { - logs.CtxWarn(ctx, "[agent:%s] tool %q (call_id=%s) failed: %v", ag.id, call.Function.Name, call.ID, callErr) - callMsg.Content = "ERROR: " + callErr.Error() - } else { - jsonStr, marshalErr := sonic.MarshalString(res) - if marshalErr != nil || jsonStr == "" { - callMsg.Content = "{}" - } else { - callMsg.Content = jsonStr - } + callMsg := ag.buildToolResultMessage(ctx, &call) + if callMsg.Content != "" && strings.HasPrefix(callMsg.Content, "ERROR: ") { + logs.CtxWarn(ctx, "[agent:%s] tool %q (call_id=%s) failed: %s", ag.id, call.Function.Name, call.ID, callMsg.Content) } msgs = append(msgs, callMsg) } diff --git a/internal/agent/session/compact.go b/internal/agent/session/compact.go new file mode 100644 index 0000000..01863d7 --- /dev/null +++ b/internal/agent/session/compact.go @@ -0,0 +1,48 @@ +package session + +import "github.com/cloudwego/eino/schema" + +const CompactionSummaryKey = "compaction_summary" + +// Compact replaces old messages with a summary, keeping the most recent +// keepCount messages. Only the in-memory view changes; the JSONL file +// retains full history for audit. +// +// If keepCount >= len(messages), the summary is prepended without removing +// any messages. +func (s *Session) Compact(summary *schema.Message, keepCount int) { + s.mu.Lock() + defer s.mu.Unlock() + + if summary.Extra == nil { + summary.Extra = make(map[string]any) + } + summary.Extra[CompactionSummaryKey] = true + + msgLen := len(s.messages) + if keepCount < 0 { + keepCount = 0 + } + if keepCount >= msgLen { + keepCount = msgLen + } + + // Keep only the tail. + kept := make([]*schema.Message, keepCount) + copy(kept, s.messages[msgLen-keepCount:]) + s.messages = kept + + s.summaryMsg = summary + s.compactVersion++ + + // Reset persisted state so next Save does a full rewrite. + s.persistedMsgLen = 0 + s.markMutationLocked() +} + +// HasSummary reports whether this session has an active compaction summary. +func (s *Session) HasSummary() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.summaryMsg != nil +} diff --git a/internal/agent/session/compact_test.go b/internal/agent/session/compact_test.go new file mode 100644 index 0000000..4480cf5 --- /dev/null +++ b/internal/agent/session/compact_test.go @@ -0,0 +1,120 @@ +package session + +import ( + "testing" + + "github.com/cloudwego/eino/schema" +) + +func TestSessionCompact_Basic(t *testing.T) { + sess := &Session{ + messages: make([]*schema.Message, 0, 8), + } + for i := 0; i < 10; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "msg"}) + } + + summary := &schema.Message{Role: schema.Assistant, Content: "summary of first 7 messages"} + sess.Compact(summary, 3) // keep last 3 + + history := sess.History() + if len(history) != 4 { // 1 summary + 3 recent + t.Fatalf("History() len = %d, want 4", len(history)) + } + if history[0].Content != "summary of first 7 messages" { + t.Errorf("first message should be summary, got %q", history[0].Content) + } + if !sess.HasSummary() { + t.Error("HasSummary() should return true after Compact") + } +} + +func TestSessionCompact_MsgCountUnchanged(t *testing.T) { + sess := &Session{ + messages: make([]*schema.Message, 0, 8), + } + for i := 0; i < 5; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "msg"}) + } + countBefore := sess.MsgCount() + + summary := &schema.Message{Role: schema.Assistant, Content: "summary"} + sess.Compact(summary, 2) + + if sess.MsgCount() != countBefore { + t.Errorf("MsgCount changed from %d to %d after Compact", countBefore, sess.MsgCount()) + } +} + +func TestSessionCompact_Double(t *testing.T) { + sess := &Session{ + messages: make([]*schema.Message, 0, 8), + } + for i := 0; i < 10; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "msg"}) + } + + // First compaction + sess.Compact(&schema.Message{Role: schema.Assistant, Content: "summary1"}, 5) + + // Add more messages + for i := 0; i < 5; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "new msg"}) + } + + // Second compaction — old summary is part of "old" messages + sess.Compact(&schema.Message{Role: schema.Assistant, Content: "summary2"}, 3) + + history := sess.History() + if len(history) != 4 { // 1 summary + 3 recent + t.Fatalf("History() len = %d, want 4", len(history)) + } + if history[0].Content != "summary2" { + t.Errorf("first message should be latest summary, got %q", history[0].Content) + } + if sess.compactVersion != 2 { + t.Errorf("compactVersion = %d, want 2", sess.compactVersion) + } +} + +func TestSessionCompact_KeepCountExceedsMessages(t *testing.T) { + sess := &Session{ + messages: make([]*schema.Message, 0, 8), + } + sess.Append(&schema.Message{Role: schema.User, Content: "only one"}) + + summary := &schema.Message{Role: schema.Assistant, Content: "summary"} + sess.Compact(summary, 10) // keepCount > len(messages) + + history := sess.History() + // Should just prepend summary, keep all messages + if len(history) != 2 { // 1 summary + 1 original + t.Fatalf("History() len = %d, want 2", len(history)) + } +} + +func TestSessionClear_ResetsSummary(t *testing.T) { + sess := &Session{ + messages: make([]*schema.Message, 0, 8), + } + sess.Append(&schema.Message{Role: schema.User, Content: "msg"}) + sess.Compact(&schema.Message{Role: schema.Assistant, Content: "summary"}, 0) + + sess.Clear() + + if sess.HasSummary() { + t.Error("HasSummary() should return false after Clear") + } + if len(sess.History()) != 0 { + t.Errorf("History() should be empty after Clear, got %d", len(sess.History())) + } +} + +func TestSessionHasSummary_BeforeCompact(t *testing.T) { + sess := &Session{ + messages: make([]*schema.Message, 0, 8), + } + if sess.HasSummary() { + t.Error("HasSummary() should return false before any compaction") + } +} diff --git a/internal/agent/session/integration_test.go b/internal/agent/session/integration_test.go new file mode 100644 index 0000000..de54f09 --- /dev/null +++ b/internal/agent/session/integration_test.go @@ -0,0 +1,105 @@ +package session + +import ( + "context" + "path/filepath" + "testing" + + "github.com/cloudwego/eino/schema" +) + +// TestCompactionLifecycle tests the full lifecycle: +// create → populate → compact → save → load → verify → append → save → load → verify +func TestCompactionLifecycle(t *testing.T) { + dir := t.TempDir() + storePath := filepath.Join(dir, "sessions") + store, err := newJSONLStore(storePath) + if err != nil { + t.Fatalf("newJSONLStore: %v", err) + } + ctx := context.Background() + sessKey := "agent:test:telegram:main:lifecycle" + + // 1. Create and populate. + sess := &Session{ + SessionKey: sessKey, + AgentID: "test", + messages: make([]*schema.Message, 0, 16), + } + for i := 0; i < 20; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "question"}) + sess.Append(&schema.Message{Role: schema.Assistant, Content: "answer"}) + } + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save: %v", err) + } + if sess.MsgCount() != 40 { + t.Fatalf("MsgCount = %d, want 40", sess.MsgCount()) + } + + // 2. Compact: keep last 6 messages. + summary := &schema.Message{Role: schema.Assistant, Content: "Summary: 17 Q&A rounds about various topics"} + sess.Compact(summary, 6) + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save after compact: %v", err) + } + + // 3. Load and verify. + loaded, err := store.Load(ctx, sessKey) + if err != nil { + t.Fatalf("Load: %v", err) + } + history := loaded.History() + if len(history) != 7 { // 1 summary + 6 kept + t.Fatalf("loaded History() len = %d, want 7", len(history)) + } + if !loaded.HasSummary() { + t.Error("loaded session should HasSummary()") + } + + // 4. Append new messages after loading. + loaded.Append(&schema.Message{Role: schema.User, Content: "new question"}) + loaded.Append(&schema.Message{Role: schema.Assistant, Content: "new answer"}) + if err := store.Save(ctx, loaded); err != nil { + t.Fatalf("Save after append: %v", err) + } + + // 5. Load again and verify. + loaded2, err := store.Load(ctx, sessKey) + if err != nil { + t.Fatalf("Load2: %v", err) + } + history2 := loaded2.History() + if len(history2) != 9 { // 1 summary + 6 kept + 2 new + t.Fatalf("loaded2 History() len = %d, want 9", len(history2)) + } + if history2[len(history2)-1].Content != "new answer" { + t.Errorf("last message = %q, want 'new answer'", history2[len(history2)-1].Content) + } + + // 6. Verify MsgCount survived. + if loaded2.MsgCount() != 42 { // 40 original + 2 new + t.Errorf("MsgCount = %d, want 42", loaded2.MsgCount()) + } +} + +// TestEstimateTokens_ThresholdDetection verifies the estimation is reasonable +// for deciding when compaction should trigger. +func TestEstimateTokens_ThresholdDetection(t *testing.T) { + // Build a large message list that should clearly exceed 100K tokens. + msgs := make([]*schema.Message, 0, 1000) + // Each message: 500 chars of 'x' → 500 bytes → 125 tokens. 1000 messages = ~125K tokens. + longContent := make([]byte, 500) + for i := range longContent { + longContent[i] = 'x' + } + contentStr := string(longContent) + for i := 0; i < 1000; i++ { + msgs = append(msgs, &schema.Message{Role: schema.User, Content: contentStr}) + } + + estimated := EstimateTokens(msgs) + if estimated < 100000 { + t.Errorf("EstimateTokens = %d, expected > 100000 for 500KB of content", estimated) + } +} diff --git a/internal/agent/session/session.go b/internal/agent/session/session.go index 00c0dbc..4760b38 100644 --- a/internal/agent/session/session.go +++ b/internal/agent/session/session.go @@ -35,12 +35,21 @@ type Session struct { persistedMsgLen int appendSaveCnt int + summaryMsg *schema.Message // active compaction summary (nil = never compacted) + compactVersion int // number of compactions performed + mu sync.RWMutex } func (s *Session) History() []*schema.Message { s.mu.RLock() defer s.mu.RUnlock() + if s.summaryMsg != nil { + msgs := make([]*schema.Message, 0, 1+len(s.messages)) + msgs = append(msgs, s.summaryMsg) + msgs = append(msgs, s.messages...) + return msgs + } msgs := make([]*schema.Message, len(s.messages)) copy(msgs, s.messages) return msgs @@ -53,6 +62,8 @@ func (s *Session) Clear() { s.messages = s.messages[:0] s.msgCnt = 0 s.metadata = nil + s.summaryMsg = nil + s.compactVersion = 0 s.updateTime = time.Now() s.markMutationLocked() } diff --git a/internal/agent/session/store_jsonl.go b/internal/agent/session/store_jsonl.go index 7e0396f..d65fd8b 100644 --- a/internal/agent/session/store_jsonl.go +++ b/internal/agent/session/store_jsonl.go @@ -57,6 +57,13 @@ type jsonlMessageRecord struct { Message *schema.Message `json:"msg"` } +type jsonlCompactRecord struct { + Type string `json:"_type"` + At string `json:"at"` + Version int `json:"version"` + Summary string `json:"summary"` +} + func NewJSONLManager(agentID string, workspace string) (*Manager, error) { if workspace == "" { return nil, fmt.Errorf("workspace cannot be empty") @@ -107,9 +114,14 @@ func (s *jsonlStore) Load(ctx context.Context, sessionKey string) (*Session, err scanner := bufio.NewScanner(f) scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024) + type compactState struct { + summary *schema.Message + version int + } var ( - meta *jsonlMetadataRecord - msgs = make([]*schema.Message, 0, 16) + meta *jsonlMetadataRecord + msgs = make([]*schema.Message, 0, 16) + compactMeta *compactState ) for scanner.Scan() { @@ -138,6 +150,21 @@ func (s *jsonlStore) Load(ctx context.Context, sessionKey string) (*Session, err if r.Message != nil { msgs = append(msgs, r.Message) } + case "compact": + var cr jsonlCompactRecord + if err := sonic.UnmarshalString(line, &cr); err != nil { + return nil, fmt.Errorf("parse compact record: %w", err) + } + summaryMsg := &schema.Message{ + Role: schema.Assistant, + Content: cr.Summary, + Extra: map[string]any{CompactionSummaryKey: true}, + } + compactMeta = &compactState{ + summary: summaryMsg, + version: cr.Version, + } + msgs = msgs[:0] // Reset — messages after compact record are the kept ones. default: // Ignore unknown record types for forward compatibility. } @@ -204,6 +231,11 @@ func (s *jsonlStore) Load(ctx context.Context, sessionKey string) (*Session, err } } + if compactMeta != nil { + sess.summaryMsg = compactMeta.summary + sess.compactVersion = compactMeta.version + } + return sess, nil } @@ -240,6 +272,8 @@ func (s *jsonlStore) Save(ctx context.Context, sess *Session) error { Schema: jsonlSchema, Metadata: sess.metadata, } + summaryMsg := sess.summaryMsg + compactVersion := sess.compactVersion sess.mu.RUnlock() if !dirty { @@ -280,7 +314,7 @@ func (s *jsonlStore) Save(ctx context.Context, sess *Session) error { } if needCompact { - if err := s.rewrite(path, metaLine, messages); err != nil { + if err := s.rewriteWithCompact(path, metaLine, messages, summaryMsg, compactVersion); err != nil { return err } s.markPersisted(sess, currentMsgLen, true, version) @@ -314,22 +348,57 @@ func (s *jsonlStore) markPersisted(sess *Session, msgLen int, compacted bool, ex } func (s *jsonlStore) rewrite(path string, metaLine string, messages []*schema.Message) error { + return s.rewriteWithCompact(path, metaLine, messages, nil, 0) +} + +func (s *jsonlStore) rewriteWithCompact(path string, metaLine string, messages []*schema.Message, summary *schema.Message, compactVersion int) error { tmpPath := path + ".tmp" out, err := os.Create(tmpPath) if err != nil { return fmt.Errorf("create temp session file: %w", err) } - defer func() { + cleanup := func() { _ = out.Close() - }() + _ = os.Remove(tmpPath) + } writer := bufio.NewWriter(out) - if err := writeJSONLBatch(writer, metaLine, messages); err != nil { - _ = os.Remove(tmpPath) - return err + + // Write compact record before messages if session has been compacted. + if summary != nil { + cr := jsonlCompactRecord{ + Type: "compact", + At: time.Now().Format(time.RFC3339), + Version: compactVersion, + Summary: summary.Content, + } + compactLine, marshalErr := sonic.MarshalString(cr) + if marshalErr != nil { + cleanup() + return fmt.Errorf("marshal compact record: %w", marshalErr) + } + // Write meta, compact record, then messages via writeJSONLBatch for messages only. + if _, writeErr := writer.WriteString(metaLine + "\n"); writeErr != nil { + cleanup() + return fmt.Errorf("write metadata line: %w", writeErr) + } + if _, writeErr := writer.WriteString(compactLine + "\n"); writeErr != nil { + cleanup() + return fmt.Errorf("write compact record: %w", writeErr) + } + if writeErr := writeMessages(writer, messages); writeErr != nil { + cleanup() + return writeErr + } + } else { + if writeErr := writeJSONLBatch(writer, metaLine, messages); writeErr != nil { + cleanup() + return writeErr + } } + if err := writer.Flush(); err != nil { - _ = os.Remove(tmpPath) + cleanup() return fmt.Errorf("flush session file: %w", err) } if err := out.Close(); err != nil { @@ -340,7 +409,6 @@ func (s *jsonlStore) rewrite(path string, metaLine string, messages []*schema.Me _ = os.Remove(tmpPath) return fmt.Errorf("replace session file: %w", err) } - return nil } @@ -365,6 +433,10 @@ func writeJSONLBatch(writer *bufio.Writer, metaLine string, messages []*schema.M if _, err := writer.WriteString(metaLine + "\n"); err != nil { return fmt.Errorf("write metadata line: %w", err) } + return writeMessages(writer, messages) +} + +func writeMessages(writer *bufio.Writer, messages []*schema.Message) error { for _, msg := range messages { if msg == nil { continue diff --git a/internal/agent/session/store_jsonl_test.go b/internal/agent/session/store_jsonl_test.go new file mode 100644 index 0000000..6a3a055 --- /dev/null +++ b/internal/agent/session/store_jsonl_test.go @@ -0,0 +1,139 @@ +package session + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/cloudwego/eino/schema" +) + +func TestJSONLStore_CompactPersistence(t *testing.T) { + dir := t.TempDir() + storePath := filepath.Join(dir, "sessions") + store, err := newJSONLStore(storePath) + if err != nil { + t.Fatalf("newJSONLStore: %v", err) + } + ctx := context.Background() + + sess := &Session{ + SessionKey: "agent:test:telegram:main:user1", + AgentID: "test", + messages: make([]*schema.Message, 0, 8), + } + for i := 0; i < 6; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "msg"}) + sess.Append(&schema.Message{Role: schema.Assistant, Content: "reply"}) + } + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("initial Save: %v", err) + } + + summary := &schema.Message{Role: schema.Assistant, Content: "summary of earlier conversation"} + sess.Compact(summary, 4) + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save after Compact: %v", err) + } + + loaded, err := store.Load(ctx, "agent:test:telegram:main:user1") + if err != nil { + t.Fatalf("Load: %v", err) + } + if loaded == nil { + t.Fatal("Load returned nil") + } + if !loaded.HasSummary() { + t.Error("loaded session should HasSummary()") + } + history := loaded.History() + if len(history) != 5 { + t.Fatalf("loaded History() len = %d, want 5", len(history)) + } + if history[0].Content != "summary of earlier conversation" { + t.Errorf("first message should be summary, got %q", history[0].Content) + } +} + +func TestJSONLStore_CompactThenAppend(t *testing.T) { + dir := t.TempDir() + storePath := filepath.Join(dir, "sessions") + store, err := newJSONLStore(storePath) + if err != nil { + t.Fatalf("newJSONLStore: %v", err) + } + ctx := context.Background() + + sess := &Session{ + SessionKey: "agent:test:telegram:main:user2", + AgentID: "test", + messages: make([]*schema.Message, 0, 8), + } + for i := 0; i < 4; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "old"}) + } + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save: %v", err) + } + + sess.Compact(&schema.Message{Role: schema.Assistant, Content: "summary"}, 2) + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save after compact: %v", err) + } + + sess.Append(&schema.Message{Role: schema.User, Content: "new msg"}) + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save after append: %v", err) + } + + loaded, err := store.Load(ctx, "agent:test:telegram:main:user2") + if err != nil { + t.Fatalf("Load: %v", err) + } + history := loaded.History() + if len(history) != 4 { + t.Fatalf("loaded History() len = %d, want 4", len(history)) + } + if history[len(history)-1].Content != "new msg" { + t.Errorf("last message = %q, want 'new msg'", history[len(history)-1].Content) + } +} + +func TestJSONLStore_FullHistory_Preserved(t *testing.T) { + dir := t.TempDir() + storePath := filepath.Join(dir, "sessions") + store, err := newJSONLStore(storePath) + if err != nil { + t.Fatalf("newJSONLStore: %v", err) + } + ctx := context.Background() + + sess := &Session{ + SessionKey: "agent:test:telegram:main:user3", + AgentID: "test", + messages: make([]*schema.Message, 0, 8), + } + for i := 0; i < 4; i++ { + sess.Append(&schema.Message{Role: schema.User, Content: "msg"}) + } + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save: %v", err) + } + + sess.Compact(&schema.Message{Role: schema.Assistant, Content: "summary"}, 2) + if err := store.Save(ctx, sess); err != nil { + t.Fatalf("Save after compact: %v", err) + } + + js := store.(*jsonlStore) + path := js.sessionFile("agent:test:telegram:main:user3") + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if !strings.Contains(string(data), `"_type":"compact"`) { + t.Error("JSONL should contain a compact record") + } +} diff --git a/internal/agent/session/token.go b/internal/agent/session/token.go new file mode 100644 index 0000000..26e4ac6 --- /dev/null +++ b/internal/agent/session/token.go @@ -0,0 +1,27 @@ +package session + +import "github.com/cloudwego/eino/schema" + +// EstimateMessageTokens returns a rough token count for a single message. +// Uses byte-length / 4 as a heuristic (English ~1:4, Chinese ~1:2). +// Precision is not required — this is used for threshold detection only. +func EstimateMessageTokens(msg *schema.Message) int { + if msg == nil { + return 0 + } + total := len(msg.Content) + for _, tc := range msg.ToolCalls { + total += len(tc.Function.Name) + total += len(tc.Function.Arguments) + } + return total / 4 +} + +// EstimateTokens returns a rough token count for the given messages. +func EstimateTokens(msgs []*schema.Message) int { + total := 0 + for _, msg := range msgs { + total += EstimateMessageTokens(msg) + } + return total +} diff --git a/internal/agent/session/token_test.go b/internal/agent/session/token_test.go new file mode 100644 index 0000000..ee5dc18 --- /dev/null +++ b/internal/agent/session/token_test.go @@ -0,0 +1,71 @@ +package session + +import ( + "testing" + + "github.com/cloudwego/eino/schema" +) + +func TestEstimateTokens_Empty(t *testing.T) { + if got := EstimateTokens(nil); got != 0 { + t.Errorf("EstimateTokens(nil) = %d, want 0", got) + } + if got := EstimateTokens([]*schema.Message{}); got != 0 { + t.Errorf("EstimateTokens([]) = %d, want 0", got) + } +} + +func TestEstimateTokens_TextOnly(t *testing.T) { + msgs := []*schema.Message{ + {Role: schema.User, Content: "Hello world"}, // 11 bytes → 11/4 = 2 tokens + {Role: schema.Assistant, Content: "Hi there!"}, // 9 bytes → 9/4 = 2 tokens + } + got := EstimateTokens(msgs) + // 11/4 + 9/4 = 2 + 2 = 4 (per-message integer division) + if got != 4 { + t.Errorf("EstimateTokens = %d, want 4", got) + } +} + +func TestEstimateMessageTokens(t *testing.T) { + if got := EstimateMessageTokens(nil); got != 0 { + t.Errorf("EstimateMessageTokens(nil) = %d, want 0", got) + } + msg := &schema.Message{Role: schema.User, Content: "Hello world"} // 11 bytes → 2 + if got := EstimateMessageTokens(msg); got != 2 { + t.Errorf("EstimateMessageTokens = %d, want 2", got) + } +} + +func TestEstimateTokens_WithToolCalls(t *testing.T) { + msgs := []*schema.Message{ + { + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: "call_1", + Function: schema.FunctionCall{ + Name: "file_read", + Arguments: `{"path":"/tmp/test.txt"}`, + }, + }, + }, + }, + } + got := EstimateTokens(msgs) + if got <= 0 { + t.Errorf("EstimateTokens with tool calls should be > 0, got %d", got) + } +} + +func TestEstimateTokens_Chinese(t *testing.T) { + // Chinese characters are multi-byte but we count len(string) which is bytes. + // "你好世界" = 12 bytes in UTF-8 → 12/4 = 3 tokens + msgs := []*schema.Message{ + {Role: schema.User, Content: "你好世界"}, + } + got := EstimateTokens(msgs) + if got != 3 { + t.Errorf("EstimateTokens(Chinese) = %d, want 3", got) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 74a04af..5f9d6ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -71,6 +71,8 @@ type ( TTL string `yaml:"ttl"` // e.g. "24h", session expiry after last activity ConsolidateEvery int `yaml:"consolidate_every"` // trigger memory flush every N messages (default: 50) FlushCooldown string `yaml:"flush_cooldown"` // minimum interval between flushes (default: "2h") + ContextBudget int `yaml:"context_budget"` // model context window (tokens), default 128000 + ReserveTokens int `yaml:"reserve_tokens"` // tokens reserved for new messages + reply, default 20000 } ChannelConfig struct { diff --git a/internal/consts/prompt.go b/internal/consts/prompt.go new file mode 100644 index 0000000..cd417b5 --- /dev/null +++ b/internal/consts/prompt.go @@ -0,0 +1,22 @@ +package consts + +const PromptPreFlush = `You are about to lose access to older parts of this conversation due to context window limits. Before that happens, please review the conversation history and persist any important information you want to remember: + +- Key decisions and their reasoning +- File paths and code changes made +- Unfinished tasks or pending items +- User preferences discovered in this session + +Use the file_write tool to save to memory/MEMORY.md (durable facts) or memory/daily/.md (today's events). + +If nothing important needs saving, respond with "FLUSH_SKIP".` + +const PromptSummary = `You are a helpful AI assistant tasked with summarizing conversations. Summarize the following conversation history concisely. Preserve: +- Key decisions and their reasoning +- Important file paths, function names, and code changes +- Task progress: what was completed, what remains +- Any errors encountered and how they were resolved +- User preferences and constraints mentioned + +Format as structured notes, not a narrative. Use bullet points. +Keep the summary under 2000 tokens.` diff --git a/internal/cronjob/store.go b/internal/cronjob/store.go index 1b145d6..0ad0191 100644 --- a/internal/cronjob/store.go +++ b/internal/cronjob/store.go @@ -52,7 +52,7 @@ func (s *Store) Load() error { // Heartbeat and compact jobs are always re-registered at startup // by the gateway with fresh runtime fields (Workspace, etc.). // Discard any that were accidentally persisted to avoid stale state. - if IsHeartbeatJob(j.ID) || IsCompactJob(j.ID) { + if IsHeartbeatJob(j.ID) || IsCompactJob(j.ID) || IsFlushJob(j.ID) { continue } s.jobs[j.ID] = j