Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions internal/agent/budget.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,31 @@ func (b *BudgetManager) Track(promptTokens, completionTokens int64) BudgetStatus

b.promptTokens += promptTokens
b.completionTokens += completionTokens

inputCost := float64(promptTokens) * b.pricing.InputPer1M / 1_000_000
outputCost := float64(completionTokens) * b.pricing.OutputPer1M / 1_000_000
b.totalCost += inputCost + outputCost
b.totalCost += b.costLocked(promptTokens, completionTokens, 0)

return b.statusLocked()
}

// costLocked computes the USD cost of the given token counts, charging the
// cached (cache-read) subset of the prompt at the discounted CacheRead rate when
// the registry has it (else at the full input rate). Caller must hold the lock.
func (b *BudgetManager) costLocked(promptTokens, completionTokens, cachedTokens int64) float64 {
if cachedTokens < 0 {
cachedTokens = 0
}
if cachedTokens > promptTokens {
cachedTokens = promptTokens
}
cacheRate := b.pricing.CacheReadPer1M
if cacheRate <= 0 {
cacheRate = b.pricing.InputPer1M // no discount data: bill cached at full price
}
inputCost := float64(promptTokens-cachedTokens)*b.pricing.InputPer1M/1_000_000 +
float64(cachedTokens)*cacheRate/1_000_000
outputCost := float64(completionTokens) * b.pricing.OutputPer1M / 1_000_000
return inputCost + outputCost
}

// Check returns the current budget status and whether the budget has been exceeded.
func (b *BudgetManager) Check() (BudgetStatus, bool) {
b.mu.RLock()
Expand Down Expand Up @@ -150,17 +167,26 @@ func (m *budgetMiddleware) AfterModelRewriteState(
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
var promptTokens, completionTokens int64
var sessionPrompt, sessionCompletion, sessionCached int64
if m.tokenUsage != nil {
promptTokens, completionTokens, _ = m.tokenUsage.Get()
// Per-turn delta for the per-agent-turn TOKEN cap (max_tokens_per_turn):
// runner.BeginTurn sets the baseline at turn start. Reading cumulative
// Get() here made the "per turn" cap behave as a session total.
promptTokens, completionTokens, _ = m.tokenUsage.TurnUsage()
// Session-cumulative for the COST cap (max_cost_per_session): cost must
// accumulate across turns, not reset each turn.
full := m.tokenUsage.GetFull()
sessionPrompt = int64(full.PromptTokens)
sessionCompletion = int64(full.CompletionTokens)
sessionCached = int64(full.CachedTokens)
}

// Sync budget manager with per-agent token tracker values.
// promptTokens/completionTokens drive the per-turn token cap; totalCost is the
// session-cumulative cost (cached subset billed at the cache-read rate).
m.manager.mu.Lock()
m.manager.promptTokens = promptTokens
m.manager.completionTokens = completionTokens
inputCost := float64(promptTokens) * m.manager.pricing.InputPer1M / 1_000_000
outputCost := float64(completionTokens) * m.manager.pricing.OutputPer1M / 1_000_000
m.manager.totalCost = inputCost + outputCost
m.manager.totalCost = m.manager.costLocked(sessionPrompt, sessionCompletion, sessionCached)
m.manager.mu.Unlock()

status := m.manager.Status()
Expand Down
55 changes: 55 additions & 0 deletions internal/agent/budget_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package agent

import (
"testing"

"github.com/cnjack/jcode/internal/config"
internalmodel "github.com/cnjack/jcode/internal/model"
)

// TestBudgetCost_CacheReadDiscount verifies the cached subset of the prompt is
// billed at the discounted cache-read rate when the registry has it, and at the
// full input rate otherwise (S3).
func TestBudgetCost_CacheReadDiscount(t *testing.T) {
bm := NewBudgetManager(nil, internalmodel.ModelPricing{InputPer1M: 10, OutputPer1M: 30, CacheReadPer1M: 1})
// 1000 prompt (800 cached) + 100 completion.
got := bm.costLocked(1000, 100, 800)
want := float64(200)*10/1e6 + float64(800)*1/1e6 + float64(100)*30/1e6
if got != want {
t.Errorf("discounted cost = %v, want %v", got, want)
}

// No cache pricing → cached billed at full input rate (no discount applied).
plain := NewBudgetManager(nil, internalmodel.ModelPricing{InputPer1M: 10, OutputPer1M: 30})
got = plain.costLocked(1000, 100, 800)
want = float64(1000)*10/1e6 + float64(100)*30/1e6
if got != want {
t.Errorf("no-discount cost = %v, want %v", got, want)
}

// cached clamped to prompt (never negative uncached portion).
if c := bm.costLocked(100, 0, 999); c < 0 {
t.Errorf("cached>prompt produced negative cost: %v", c)
}
}

// TestBudget_MaxTokensPerTurn verifies the per-turn cap trips on the (per-turn)
// token total it is given, not before (C5: the middleware now feeds it the
// turn delta rather than the session cumulative).
func TestBudget_MaxTokensPerTurn(t *testing.T) {
bm := NewBudgetManager(&config.BudgetConfig{MaxTokensPerTurn: 1000}, internalmodel.ModelPricing{})

bm.mu.Lock()
bm.promptTokens, bm.completionTokens = 600, 300 // 900 < 1000
bm.mu.Unlock()
if _, exceeded := bm.Check(); exceeded {
t.Error("900 tokens should not exceed a 1000 per-turn cap")
}

bm.mu.Lock()
bm.completionTokens = 500 // 1100 >= 1000
bm.mu.Unlock()
if _, exceeded := bm.Check(); !exceeded {
t.Error("1100 tokens should exceed a 1000 per-turn cap")
}
}
9 changes: 6 additions & 3 deletions internal/agent/compaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,14 @@ func (m *compactionMiddleware) BeforeModelRewriteState(
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
// Estimate current token usage from the per-agent tracker.
// Estimate current context occupancy from the per-agent tracker. Use the LAST
// call's total (GetLastTotal), NOT the cumulative prompt sum: the agent
// re-sends the whole context on every tool-loop call, so cumulative prompt
// (e.g. 20k×5=100k) would trip compaction far too early while the real window
// is still ~20k.
var currentTokens int
if m.tokenUsage != nil {
promptTokens, _, _ := m.tokenUsage.Get()
currentTokens = int(promptTokens)
currentTokens = int(m.tokenUsage.GetLastTotal())
}

if !m.strategy.ShouldCompact(currentTokens, m.contextLimit) {
Expand Down
18 changes: 13 additions & 5 deletions internal/command/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (s *interactiveState) createAgent() (*adk.ChatModelAgent, error) {
s.summCapture.Capture(summary.Content, contextN)
config.Logger().Printf("[summarization] Finalize: compacted %d context messages", contextN)
if s.agentTokenUsage != nil {
s.agentTokenUsage.Reset()
s.agentTokenUsage.ResetContext()
}
return append(systemMsgs, summary), nil
},
Expand Down Expand Up @@ -228,7 +228,8 @@ func (s *interactiveState) createAgent() (*adk.ChatModelAgent, error) {
if s.cfg.Budget != nil {
providerName, modelName := s.cfg.GetProviderModel()
inputPer1M, outputPer1M := s.registry.GetModelCost(providerName, modelName)
pricing := internalmodel.ModelPricing{InputPer1M: inputPer1M, OutputPer1M: outputPer1M}
cacheReadPer1M, _ := s.registry.GetModelCacheCost(providerName, modelName)
pricing := internalmodel.ModelPricing{InputPer1M: inputPer1M, OutputPer1M: outputPer1M, CacheReadPer1M: cacheReadPer1M}
budgetManager := agent.NewBudgetManager(s.cfg.Budget, pricing)
budgetMw := agent.NewBudgetMiddleware(budgetManager, s.agentTokenUsage, func(status agent.BudgetStatus) {
config.Logger().Printf("[budget] warning level=%d cost=%.4f", status.WarningLevel, status.EstimatedCost)
Expand All @@ -240,7 +241,7 @@ func (s *interactiveState) createAgent() (*adk.ChatModelAgent, error) {
compactionStrategy := agent.NewThresholdCompactionStrategy(compactThreshold, s.chatModel, 6)
compactionMw := agent.NewCompactionMiddleware(compactionStrategy, contextLimit, s.agentTokenUsage, func(savedTokens int) {
if s.agentTokenUsage != nil {
s.agentTokenUsage.Reset()
s.agentTokenUsage.ResetContext()
}
if s.p != nil {
s.p.Send(tui.CompactDoneMsg{OldTokens: 0, NewTokens: 0})
Expand Down Expand Up @@ -318,7 +319,7 @@ func (s *interactiveState) applyModeSwitch(newMode tui.AgentMode) {
config.Logger().Printf("[plan] agent creation failed: %v", err)
}
if s.agentTokenUsage != nil {
s.agentTokenUsage.Reset()
s.agentTokenUsage.ResetContext()
}
// Sync the TUI mode pill with the resulting unified mode (covers the
// plan-completion revert to Normal, which the user did not trigger directly).
Expand Down Expand Up @@ -592,6 +593,10 @@ func (s *interactiveState) handleConfig(cfgMsg *config.Config) {
return
}
s.chatModel = newChatModel
// Attribute subsequent usage to the newly selected model.
if s.rec != nil {
s.rec.SetModel(newModelName)
}

// Rebuild system prompt and tools to reflect config changes (e.g., SSH aliases)
if s.agentMode == tui.ModePlanning {
Expand Down Expand Up @@ -622,7 +627,7 @@ func (s *interactiveState) handleCompact() {
s.rec.RecordCompact(s.history[0].Content, oldLen-len(s.history))
}
if s.agentTokenUsage != nil {
s.agentTokenUsage.Reset()
s.agentTokenUsage.ResetContext()
}
s.p.Send(tui.CompactDoneMsg{
OldTokens: oldTokens,
Expand Down Expand Up @@ -665,6 +670,9 @@ func (s *interactiveState) handleAddModel() {
return
}
s.chatModel = newChatModel
if s.rec != nil {
s.rec.SetModel(newModelName)
}
if newAg, agErr := s.createAgent(); agErr == nil {
s.ag = newAg
}
Expand Down
Loading
Loading