diff --git a/internal/agent/budget.go b/internal/agent/budget.go index bc877cc..32d44dd 100644 --- a/internal/agent/budget.go +++ b/internal/agent/budget.go @@ -66,14 +66,31 @@ func (b *BudgetManager) Track(promptTokens, completionTokens int64) BudgetStatus b.promptTokens += promptTokens b.completionTokens += completionTokens - - inputCost := float64(promptTokens) * b.pricing.InputPer1M / 1_000_000 - outputCost := float64(completionTokens) * b.pricing.OutputPer1M / 1_000_000 - b.totalCost += inputCost + outputCost + b.totalCost += b.costLocked(promptTokens, completionTokens, 0) return b.statusLocked() } +// costLocked computes the USD cost of the given token counts, charging the +// cached (cache-read) subset of the prompt at the discounted CacheRead rate when +// the registry has it (else at the full input rate). Caller must hold the lock. +func (b *BudgetManager) costLocked(promptTokens, completionTokens, cachedTokens int64) float64 { + if cachedTokens < 0 { + cachedTokens = 0 + } + if cachedTokens > promptTokens { + cachedTokens = promptTokens + } + cacheRate := b.pricing.CacheReadPer1M + if cacheRate <= 0 { + cacheRate = b.pricing.InputPer1M // no discount data: bill cached at full price + } + inputCost := float64(promptTokens-cachedTokens)*b.pricing.InputPer1M/1_000_000 + + float64(cachedTokens)*cacheRate/1_000_000 + outputCost := float64(completionTokens) * b.pricing.OutputPer1M / 1_000_000 + return inputCost + outputCost +} + // Check returns the current budget status and whether the budget has been exceeded. func (b *BudgetManager) Check() (BudgetStatus, bool) { b.mu.RLock() @@ -150,17 +167,26 @@ func (m *budgetMiddleware) AfterModelRewriteState( mc *adk.ModelContext, ) (context.Context, *adk.ChatModelAgentState, error) { var promptTokens, completionTokens int64 + var sessionPrompt, sessionCompletion, sessionCached int64 if m.tokenUsage != nil { - promptTokens, completionTokens, _ = m.tokenUsage.Get() + // Per-turn delta for the per-agent-turn TOKEN cap (max_tokens_per_turn): + // runner.BeginTurn sets the baseline at turn start. Reading cumulative + // Get() here made the "per turn" cap behave as a session total. + promptTokens, completionTokens, _ = m.tokenUsage.TurnUsage() + // Session-cumulative for the COST cap (max_cost_per_session): cost must + // accumulate across turns, not reset each turn. + full := m.tokenUsage.GetFull() + sessionPrompt = int64(full.PromptTokens) + sessionCompletion = int64(full.CompletionTokens) + sessionCached = int64(full.CachedTokens) } - // Sync budget manager with per-agent token tracker values. + // promptTokens/completionTokens drive the per-turn token cap; totalCost is the + // session-cumulative cost (cached subset billed at the cache-read rate). m.manager.mu.Lock() m.manager.promptTokens = promptTokens m.manager.completionTokens = completionTokens - inputCost := float64(promptTokens) * m.manager.pricing.InputPer1M / 1_000_000 - outputCost := float64(completionTokens) * m.manager.pricing.OutputPer1M / 1_000_000 - m.manager.totalCost = inputCost + outputCost + m.manager.totalCost = m.manager.costLocked(sessionPrompt, sessionCompletion, sessionCached) m.manager.mu.Unlock() status := m.manager.Status() diff --git a/internal/agent/budget_test.go b/internal/agent/budget_test.go new file mode 100644 index 0000000..52e1baa --- /dev/null +++ b/internal/agent/budget_test.go @@ -0,0 +1,55 @@ +package agent + +import ( + "testing" + + "github.com/cnjack/jcode/internal/config" + internalmodel "github.com/cnjack/jcode/internal/model" +) + +// TestBudgetCost_CacheReadDiscount verifies the cached subset of the prompt is +// billed at the discounted cache-read rate when the registry has it, and at the +// full input rate otherwise (S3). +func TestBudgetCost_CacheReadDiscount(t *testing.T) { + bm := NewBudgetManager(nil, internalmodel.ModelPricing{InputPer1M: 10, OutputPer1M: 30, CacheReadPer1M: 1}) + // 1000 prompt (800 cached) + 100 completion. + got := bm.costLocked(1000, 100, 800) + want := float64(200)*10/1e6 + float64(800)*1/1e6 + float64(100)*30/1e6 + if got != want { + t.Errorf("discounted cost = %v, want %v", got, want) + } + + // No cache pricing → cached billed at full input rate (no discount applied). + plain := NewBudgetManager(nil, internalmodel.ModelPricing{InputPer1M: 10, OutputPer1M: 30}) + got = plain.costLocked(1000, 100, 800) + want = float64(1000)*10/1e6 + float64(100)*30/1e6 + if got != want { + t.Errorf("no-discount cost = %v, want %v", got, want) + } + + // cached clamped to prompt (never negative uncached portion). + if c := bm.costLocked(100, 0, 999); c < 0 { + t.Errorf("cached>prompt produced negative cost: %v", c) + } +} + +// TestBudget_MaxTokensPerTurn verifies the per-turn cap trips on the (per-turn) +// token total it is given, not before (C5: the middleware now feeds it the +// turn delta rather than the session cumulative). +func TestBudget_MaxTokensPerTurn(t *testing.T) { + bm := NewBudgetManager(&config.BudgetConfig{MaxTokensPerTurn: 1000}, internalmodel.ModelPricing{}) + + bm.mu.Lock() + bm.promptTokens, bm.completionTokens = 600, 300 // 900 < 1000 + bm.mu.Unlock() + if _, exceeded := bm.Check(); exceeded { + t.Error("900 tokens should not exceed a 1000 per-turn cap") + } + + bm.mu.Lock() + bm.completionTokens = 500 // 1100 >= 1000 + bm.mu.Unlock() + if _, exceeded := bm.Check(); !exceeded { + t.Error("1100 tokens should exceed a 1000 per-turn cap") + } +} diff --git a/internal/agent/compaction.go b/internal/agent/compaction.go index e93c51b..e751cc4 100644 --- a/internal/agent/compaction.go +++ b/internal/agent/compaction.go @@ -179,11 +179,14 @@ func (m *compactionMiddleware) BeforeModelRewriteState( state *adk.ChatModelAgentState, mc *adk.ModelContext, ) (context.Context, *adk.ChatModelAgentState, error) { - // Estimate current token usage from the per-agent tracker. + // Estimate current context occupancy from the per-agent tracker. Use the LAST + // call's total (GetLastTotal), NOT the cumulative prompt sum: the agent + // re-sends the whole context on every tool-loop call, so cumulative prompt + // (e.g. 20k×5=100k) would trip compaction far too early while the real window + // is still ~20k. var currentTokens int if m.tokenUsage != nil { - promptTokens, _, _ := m.tokenUsage.Get() - currentTokens = int(promptTokens) + currentTokens = int(m.tokenUsage.GetLastTotal()) } if !m.strategy.ShouldCompact(currentTokens, m.contextLimit) { diff --git a/internal/command/interactive.go b/internal/command/interactive.go index 12fadac..88ddea4 100644 --- a/internal/command/interactive.go +++ b/internal/command/interactive.go @@ -185,7 +185,7 @@ func (s *interactiveState) createAgent() (*adk.ChatModelAgent, error) { s.summCapture.Capture(summary.Content, contextN) config.Logger().Printf("[summarization] Finalize: compacted %d context messages", contextN) if s.agentTokenUsage != nil { - s.agentTokenUsage.Reset() + s.agentTokenUsage.ResetContext() } return append(systemMsgs, summary), nil }, @@ -228,7 +228,8 @@ func (s *interactiveState) createAgent() (*adk.ChatModelAgent, error) { if s.cfg.Budget != nil { providerName, modelName := s.cfg.GetProviderModel() inputPer1M, outputPer1M := s.registry.GetModelCost(providerName, modelName) - pricing := internalmodel.ModelPricing{InputPer1M: inputPer1M, OutputPer1M: outputPer1M} + cacheReadPer1M, _ := s.registry.GetModelCacheCost(providerName, modelName) + pricing := internalmodel.ModelPricing{InputPer1M: inputPer1M, OutputPer1M: outputPer1M, CacheReadPer1M: cacheReadPer1M} budgetManager := agent.NewBudgetManager(s.cfg.Budget, pricing) budgetMw := agent.NewBudgetMiddleware(budgetManager, s.agentTokenUsage, func(status agent.BudgetStatus) { config.Logger().Printf("[budget] warning level=%d cost=%.4f", status.WarningLevel, status.EstimatedCost) @@ -240,7 +241,7 @@ func (s *interactiveState) createAgent() (*adk.ChatModelAgent, error) { compactionStrategy := agent.NewThresholdCompactionStrategy(compactThreshold, s.chatModel, 6) compactionMw := agent.NewCompactionMiddleware(compactionStrategy, contextLimit, s.agentTokenUsage, func(savedTokens int) { if s.agentTokenUsage != nil { - s.agentTokenUsage.Reset() + s.agentTokenUsage.ResetContext() } if s.p != nil { s.p.Send(tui.CompactDoneMsg{OldTokens: 0, NewTokens: 0}) @@ -318,7 +319,7 @@ func (s *interactiveState) applyModeSwitch(newMode tui.AgentMode) { config.Logger().Printf("[plan] agent creation failed: %v", err) } if s.agentTokenUsage != nil { - s.agentTokenUsage.Reset() + s.agentTokenUsage.ResetContext() } // Sync the TUI mode pill with the resulting unified mode (covers the // plan-completion revert to Normal, which the user did not trigger directly). @@ -592,6 +593,10 @@ func (s *interactiveState) handleConfig(cfgMsg *config.Config) { return } s.chatModel = newChatModel + // Attribute subsequent usage to the newly selected model. + if s.rec != nil { + s.rec.SetModel(newModelName) + } // Rebuild system prompt and tools to reflect config changes (e.g., SSH aliases) if s.agentMode == tui.ModePlanning { @@ -622,7 +627,7 @@ func (s *interactiveState) handleCompact() { s.rec.RecordCompact(s.history[0].Content, oldLen-len(s.history)) } if s.agentTokenUsage != nil { - s.agentTokenUsage.Reset() + s.agentTokenUsage.ResetContext() } s.p.Send(tui.CompactDoneMsg{ OldTokens: oldTokens, @@ -665,6 +670,9 @@ func (s *interactiveState) handleAddModel() { return } s.chatModel = newChatModel + if s.rec != nil { + s.rec.SetModel(newModelName) + } if newAg, agErr := s.createAgent(); agErr == nil { s.ag = newAg } diff --git a/internal/command/web.go b/internal/command/web.go index 887abdc..ec12f2d 100644 --- a/internal/command/web.go +++ b/internal/command/web.go @@ -4,8 +4,11 @@ import ( "context" "encoding/json" "fmt" + "os" "os/signal" "path/filepath" + "sync" + "sync/atomic" "syscall" "time" @@ -21,6 +24,7 @@ import ( "github.com/cnjack/jcode/internal/channel/ble" "github.com/cnjack/jcode/internal/config" "github.com/cnjack/jcode/internal/handler" + "github.com/cnjack/jcode/internal/mode" internalmodel "github.com/cnjack/jcode/internal/model" weixin "github.com/cnjack/jcode/internal/pkg/weixin" "github.com/cnjack/jcode/internal/prompts" @@ -94,14 +98,10 @@ func runWebServer(port int, host string, openBrowser bool) error { pwd := util.GetWorkDir() platform := util.GetSystemInfo() - envInfo := util.CollectEnvInfo(pwd) skillLoader := skills.NewLoaderWithDisabled(cfg.DisabledSkills) skillLoader.ScanProjectSkills(pwd) - systemPrompt := prompts.GetSystemPrompt(platform, pwd, "local", envInfo, skillLoader.Descriptions()) - planPrompt := prompts.GetPlanSystemPrompt(platform, pwd, "local", envInfo) - var providerName, modelName string if !needsSetup { providerName, modelName = cfg.GetProviderModel() @@ -114,47 +114,32 @@ func runWebServer(port int, host string, openBrowser bool) error { registry := internalmodel.NewModelRegistryWithConfig(cfg) - env := tools.NewEnv(pwd, platform) - bgManager := tools.NewBackgroundManager(env) - rec, _ := session.NewRecorder(pwd, providerName, modelName) - - // Shared token tracker for usage display (goal status, reminders, token - // updates). - agentTokenUsage := &internalmodel.TokenUsage{} - - // Load MCP tools. mcpTools is reassigned by reloadMCPTools (below) so the - // agent picks up server add/edit/delete/login without a restart — the - // buildAllTools closure reads this variable on each agent rebuild. - var mcpTools []tool.BaseTool + // Load MCP tools. mcpToolsPtr is swapped atomically by reloadMCPTools so a new + // task (built concurrently by buildWebTask) always reads a consistent slice + // header without a data race on hot-reload. + var mcpToolsPtr atomic.Pointer[[]tool.BaseTool] var initialMCPStatuses []tools.MCPStatus if len(cfg.MCPServers) > 0 { - mcpTools, initialMCPStatuses = tools.LoadMCPTools(ctx, cfg.MCPServers) + mt, statuses := tools.LoadMCPTools(ctx, cfg.MCPServers) + mcpToolsPtr.Store(&mt) + initialMCPStatuses = statuses } - - // reloadMCPTools re-establishes connections from the given server map and - // swaps in the fresh tool set. The Server rebuilds the agent afterwards and - // uses the returned statuses for the management UI. reloadMCPTools := func(servers map[string]*config.MCPServer) ([]tools.MCPStatus, error) { nt, statuses := tools.LoadMCPTools(ctx, servers) - mcpTools = nt + mcpToolsPtr.Store(&nt) return statuses, nil } - planStore := tools.NewPlanStore() - startupMode := resolveStartupMode(cfg, false) - approvalState := runner.NewApprovalStateWithMode(pwd, startupMode) - // Create WebHandler early so subagent tool can emit events through it. - webHandler := handler.NewWebHandler() + // Langfuse tracer (shared across tasks). + var langfuseTracer *telemetry.LangfuseTracer + if cfg.Telemetry != nil && cfg.Telemetry.Langfuse != nil { + langfuseTracer = telemetry.NewLangfuseTracer(cfg.Telemetry.Langfuse) + } - // Wrap handler with NotifyingHandler for WeChat push notifications. - // Callbacks check wechatClient.State() before sending, so this is safe - // even when the channel is disabled or not yet configured. - var finalHandler handler.AgentEventHandler + // WeChat client + shared push notifiers (process-level, reused by every task). wechatClient := weixin.NewClient() - - // Auto-enable if credentials exist and channel.web_enabled is true. if cfg.Channel != nil && cfg.Channel.WebEnabled && wechatClient.State() == channel.StateDisabled { if err := wechatClient.Enable(); err != nil { config.Logger().Printf("[wechat] web auto-enable failed: %v", err) @@ -162,163 +147,48 @@ func runWebServer(port int, host string, openBrowser bool) error { config.Logger().Printf("[wechat] web auto-enabled") } } - - // Always wrap with NotifyingHandler — the user can enable via the UI toggle. - notifyingH := handler.NewNotifyingHandler(webHandler, 10*time.Second) - notifyingH.SetApprovalNotifier(func(toolName, toolArgs string) { - if wechatClient.State() == channel.StateEnabled { - if err := wechatClient.SendText(channel.ApprovalMessage(toolName, toolArgs, "Please check the web interface")); err != nil { - config.Logger().Printf("[wechat] failed to send approval notification: %v", err) - } - } - }) - notifyingH.SetDoneNotifier(func(summary string, err error) { - if wechatClient.State() == channel.StateEnabled { - if sendErr := wechatClient.SendText(channel.DoneMessage(summary, err)); sendErr != nil { - config.Logger().Printf("[wechat] failed to send done notification: %v", sendErr) - } - } - }) - - // Register WeChat as a notifier for working/idle status pushes. - notifyingH.AddNotifier(channel.NewChannelNotifier(wechatClient)) - - // Register BLE notifier if enabled (lazy connect — will auto-discover JCODE-* devices). + var sharedBLE *ble.Notifier if cfg.Channel != nil && cfg.Channel.BLEEnabled { - bleNotifier := ble.New() - notifyingH.AddNotifier(bleNotifier) + sharedBLE = ble.New() } - finalHandler = notifyingH - - // Langfuse tracer. - var langfuseTracer *telemetry.LangfuseTracer - if cfg.Telemetry != nil && cfg.Telemetry.Langfuse != nil { - langfuseTracer = telemetry.NewLangfuseTracer(cfg.Telemetry.Langfuse) - } - - buildAllTools := func(cm model.ToolCallingChatModel) []tool.BaseTool { - all := []tool.BaseTool{ - env.NewReadTool(), env.NewEditTool(), env.NewWriteTool(), - env.NewExecuteTool(bgManager), env.NewGrepTool(), - env.NewTodoWriteTool(), env.NewTodoReadTool(), - env.NewGoalSetTool(), env.NewGoalGetTool(), env.NewGoalUpdateTool(), - env.NewSwitchEnvTool(), - env.NewCheckBackgroundTool(bgManager), - env.NewSubagentTool(&tools.SubagentDeps{ - ChatModel: cm, - Recorder: rec, - Notifier: func(name, agentType string, done bool, result string, err error) { - webHandler.OnSubagentEvent(name, agentType, done, result, err) - }, - ProgressFn: func(agentName, event, toolName, detail string) { - webHandler.OnSubagentProgress(agentName, event, toolName, detail) - }, - }), - tools.NewAskUserTool(&tools.AskUserDeps{ - BatchRequestFn: webHandler.RequestAskUser, - }), - skills.NewLoadSkillTool(skillLoader), - } - return append(all, mcpTools...) - } - - // buildPlanTools mirrors the TUI/ACP read-only plan tool set: no edit/write, - // no background execute, no subagent/team. This is what makes web Plan mode a - // real read-only mode rather than just a prompt prefix. - buildPlanTools := func() []tool.BaseTool { - return []tool.BaseTool{ - env.NewReadTool(), - env.NewExecuteTool(nil), - env.NewGrepTool(), - env.NewTodoWriteTool(), env.NewTodoReadTool(), - tools.NewAskUserTool(&tools.AskUserDeps{ - BatchRequestFn: webHandler.RequestAskUser, - }), - } - } - - // makeAgent assembles the middleware stack and tools for a given chat model - // and plan flag, then builds the agent. Plan mode swaps to the read-only - // prompt + tool set; Approval/Full access share the full set (they differ only on - // the approval axis, carried by approvalState). This is the cheap per-mode - // assembly — it does NOT rebuild the chat model (mirrors ACP's makeAgent). - makeAgent := func(cm model.ToolCallingChatModel, ctxLimit int, planMode bool) (*adk.ChatModelAgent, error) { - var middlewares []adk.AgentMiddleware - if langfuseTracer != nil { - middlewares = append(middlewares, langfuseTracer.AgentMiddleware()) - } - - var handlers []adk.ChatModelAgentMiddleware - - compactThreshold := cfg.CompactionThreshold() - reductionThreshold := compactThreshold - 0.15 - if reductionThreshold < 0.1 { - reductionThreshold = compactThreshold * 0.8 - } - - summMw, err := summarization.New(ctx, &summarization.Config{ - Model: cm, - Trigger: &summarization.TriggerCondition{ - ContextTokens: int(float64(ctxLimit) * compactThreshold), - }, - TranscriptFilePath: filepath.Join(config.ConfigDir(), "transcript.txt"), + // makeNotifyingHandler wraps a fresh per-task WebHandler with the shared push + // notifiers (WeChat + BLE) so a backgrounded task can still surface + // approval/done/working notifications without stealing UI focus. + makeNotifyingHandler := func(wh *handler.WebHandler) *handler.NotifyingHandler { + nh := handler.NewNotifyingHandler(wh, 10*time.Second) + nh.SetApprovalNotifier(func(toolName, toolArgs string) { + if wechatClient.State() == channel.StateEnabled { + if err := wechatClient.SendText(channel.ApprovalMessage(toolName, toolArgs, "Please check the web interface")); err != nil { + config.Logger().Printf("[wechat] failed to send approval notification: %v", err) + } + } }) - if err == nil { - handlers = append(handlers, summMw) - } - - reductionBackend := &agent.LocalReductionBackend{RootDir: config.ConfigDir()} - reductionMw, err := reduction.New(ctx, &reduction.Config{ - Backend: reductionBackend, - RootDir: filepath.Join(config.ConfigDir(), "reduction"), - MaxLengthForTrunc: 50000, - MaxTokensForClear: int64(float64(ctxLimit) * reductionThreshold), - ReadFileToolName: "read", - ToolConfig: map[string]*reduction.ToolReductionConfig{ - "read": {SkipClear: true}, - }, + nh.SetDoneNotifier(func(summary string, err error) { + if wechatClient.State() == channel.StateEnabled { + if sendErr := wechatClient.SendText(channel.DoneMessage(summary, err)); sendErr != nil { + config.Logger().Printf("[wechat] failed to send done notification: %v", sendErr) + } + } }) - if err == nil { - handlers = append(handlers, reductionMw) - } - - reminderMw := agent.NewReminderMiddleware(agent.ReminderConfig{ - TodoStore: env.TodoStore, - GoalStore: env.GoalStore, - PlanStore: planStore, - EnvLabel: "local", - IsRemote: env.IsRemote(), - ContextLimit: ctxLimit, - }, agentTokenUsage) - handlers = append(handlers, reminderMw) - - prompt := systemPrompt - toolList := buildAllTools(cm) - if planMode { - prompt = planPrompt - toolList = buildPlanTools() + nh.AddNotifier(channel.NewChannelNotifier(wechatClient)) + if sharedBLE != nil { + nh.AddNotifier(sharedBLE) } - return agent.NewAgent(ctx, cm, toolList, prompt, approvalState.RequestApproval, middlewares, handlers) + return nh } - // currentCM / currentCtxLimit cache the live chat model so a mode switch can - // re-assemble the agent without re-resolving config or rebuilding the model. - // currentPlanMode preserves the tool/prompt axis across model switches. - var currentCM model.ToolCallingChatModel - var currentCtxLimit int - currentPlanMode := startupMode.IsPlan() - - createAgent := func(prov, mod string) (*adk.ChatModelAgent, error) { - // Resolve provider config. - // Reload config to pick up any new providers added via setup. + // newChatModel resolves a provider/model into a live chat model + context + // limit. Shared because it has no per-task state — each task gets its own + // model instance from it. + newChatModel := func(prov, mod string) (model.ToolCallingChatModel, int, error) { currentCfg, err := config.LoadConfig() if err != nil { - return nil, fmt.Errorf("config error: %w", err) + return nil, 0, fmt.Errorf("config error: %w", err) } provCfg := currentCfg.GetProviders()[prov] if provCfg == nil { - return nil, fmt.Errorf("provider %q not configured", prov) + return nil, 0, fmt.Errorf("provider %q not configured", prov) } bURL := provCfg.BaseURL if bURL == "" { @@ -328,188 +198,344 @@ func runWebServer(port int, host string, openBrowser bool) error { Model: mod, APIKey: provCfg.APIKey, BaseURL: bURL, }) if err != nil { - return nil, fmt.Errorf("create model %s/%s: %w", prov, mod, err) + return nil, 0, fmt.Errorf("create model %s/%s: %w", prov, mod, err) } - ctxLimit := internalmodel.ResolveContextLimit(registry, currentCfg, prov, mod) - - currentCM = cm - currentCtxLimit = ctxLimit - return makeAgent(cm, ctxLimit, currentPlanMode) + return cm, ctxLimit, nil } - // rebuildForMode re-assembles the agent for a mode change, reusing the live - // chat model when available (cheap) and only swapping prompt + tools. - rebuildForMode := func(planMode bool) (*adk.ChatModelAgent, error) { - currentPlanMode = planMode - if currentCM == nil { - return createAgent(providerName, modelName) + // buildWebTask is the per-task engine factory. It produces a fully ISOLATED + // set of run state — its own env, background manager, recorder, token tracker, + // approval state, plan store, and event handler — so concurrent tasks never + // share mutable execution state. exec != nil binds the task to a remote SSH + // target instead of a local pwd. taskID != "" resumes an existing session. + buildWebTask := func(taskID, taskPwd, modeStr string, exec *tools.SSHExecutor) (*web.EngineConfig, error) { + startMode := startupMode + if modeStr != "" { + startMode = mode.Parse(modeStr) } - return makeAgent(currentCM, currentCtxLimit, planMode) - } - var ag *adk.ChatModelAgent - var agentErr error - if !needsSetup { - ag, agentErr = createAgent(providerName, modelName) - if agentErr != nil { - return fmt.Errorf("error creating agent: %w", agentErr) + // Fresh execution environment for this task only. + tenv := tools.NewEnv(taskPwd, platform) + promptPlatform := platform + envLabel := "local" + projectKey := taskPwd + var taskEnvInfo *util.EnvInfo + // Per-task skill loader: project skills are scanned into THIS task's loader + // so concurrent tasks in different projects don't bleed each other's project + // skills into one shared accumulator. (The process-wide skillLoader stays + // for the path-agnostic slash/list/toggle management UI.) + taskLoader := skills.NewLoaderWithDisabled(cfg.DisabledSkills) + if exec != nil { + tenv.SetSSH(exec, taskPwd) + promptPlatform = exec.Platform() + envLabel = fmt.Sprintf("%s (pwd: %s)", exec.Label(), taskPwd) + projectKey = remote.ProjectLabel(exec, taskPwd) + } else { + taskLoader.ScanProjectSkills(taskPwd) + taskEnvInfo = util.CollectEnvInfo(taskPwd) } - } - switchProject := func(newPwd string) (*adk.ChatModelAgent, *session.Recorder, error) { - // 0. Close any live remote SSH connection we're switching away from. - if prev, ok := env.Exec.(*tools.SSHExecutor); ok { - defer func() { _ = prev.Close() }() + tbg := tools.NewBackgroundManager(tenv) + trec, _ := session.NewRecorder(projectKey, providerName, modelName) + if taskID != "" && trec != nil { + trec.SetUUID(taskID) + } + ttok := &internalmodel.TokenUsage{} + tplan := tools.NewPlanStore() + tappr := runner.NewApprovalStateWithMode(taskPwd, startMode) + twh := handler.NewWebHandler() + tnotify := makeNotifyingHandler(twh) + tappr.SetHandler(tnotify) + + // Wire THIS task's todo/goal stores to THIS task's recorder + handler, so + // todos persist on resume and goal changes reach the task's UI and session + // file. (Each engine is built by this factory, including the bootstrap, so + // this replaces the old single bootstrap-only wiring in NewServer.) + if tenv.TodoStore != nil { + tenv.TodoStore.OnUpdate = func(items []tools.TodoItem) { + if trec == nil || !trec.HasRecording() { + return + } + snap := make([]session.TodoSnapshotItem, len(items)) + for i, it := range items { + snap[i] = session.TodoSnapshotItem{ID: it.ID, Title: it.Title, Status: string(it.Status)} + } + trec.RecordTodoSnapshot(snap) + } + } + if tenv.GoalStore != nil { + tenv.GoalStore.OnUpdate = func(g *tools.Goal) { + if trec != nil && trec.HasRecording() { + tools.GoalRecorderHook(trec)(g) + } + twh.Emit("goal_update", g) + } } - // 1. Update env working directory (all tools share the same *Env). - env.ResetToLocal(newPwd, platform) - - // 2. Update approval state workpath. - approvalState.SetWorkpath(newPwd) - - // 3. Re-scan project skills from the new directory. - skillLoader.ScanProjectSkills(newPwd) - - // 4. Re-render system prompt with the new pwd context. - // Since createAgent closure captures the `systemPrompt` variable, - // updating it here means createAgent will use the new value. - envInfo = util.CollectEnvInfo(newPwd) - systemPrompt = prompts.GetSystemPrompt(platform, newPwd, "local", envInfo, skillLoader.Descriptions()) - - // 5. Update the outer pwd variable (captured by createAgent closure's env). - pwd = newPwd - - // 6. Close old recorder, create new one scoped to the new project. - if rec != nil { - rec.Close() + // Per-task system/plan prompts (rendered for this task's pwd). + skillDescs := taskLoader.Descriptions() + var systemPrompt, planPrompt string + if exec != nil { + systemPrompt = prompts.GetSystemPrompt(promptPlatform, taskPwd, envLabel, nil, skillDescs) + planPrompt = prompts.GetPlanSystemPrompt(promptPlatform, taskPwd, envLabel, nil) + } else { + systemPrompt = prompts.GetSystemPrompt(platform, taskPwd, "local", taskEnvInfo, skillDescs) + planPrompt = prompts.GetPlanSystemPrompt(platform, taskPwd, "local", taskEnvInfo) } - newRec, _ := session.NewRecorder(newPwd, providerName, modelName) - rec = newRec - // 7. Rebuild the agent with updated prompt. - newAg, err := createAgent(providerName, modelName) - if err != nil { - return nil, nil, err + buildAllTools := func(cm model.ToolCallingChatModel) []tool.BaseTool { + all := []tool.BaseTool{ + tenv.NewReadTool(), tenv.NewEditTool(), tenv.NewWriteTool(), + tenv.NewExecuteTool(tbg), tenv.NewGrepTool(), + tenv.NewTodoWriteTool(), tenv.NewTodoReadTool(), + tenv.NewGoalSetTool(), tenv.NewGoalGetTool(), tenv.NewGoalUpdateTool(), + tenv.NewSwitchEnvTool(), + tenv.NewCheckBackgroundTool(tbg), + tenv.NewSubagentTool(&tools.SubagentDeps{ + ChatModel: cm, + Recorder: trec, + Notifier: func(name, agentType string, done bool, result string, err error) { + twh.OnSubagentEvent(name, agentType, done, result, err) + }, + ProgressFn: func(agentName, event, toolName, detail string) { + twh.OnSubagentProgress(agentName, event, toolName, detail) + }, + }), + tools.NewAskUserTool(&tools.AskUserDeps{ + BatchRequestFn: twh.RequestAskUser, + }), + skills.NewLoadSkillTool(taskLoader), + } + if mt := mcpToolsPtr.Load(); mt != nil { + all = append(all, (*mt)...) + } + return all } - return newAg, newRec, nil - } + buildPlanTools := func() []tool.BaseTool { + return []tool.BaseTool{ + tenv.NewReadTool(), + tenv.NewExecuteTool(nil), + tenv.NewGrepTool(), + tenv.NewTodoWriteTool(), tenv.NewTodoReadTool(), + tools.NewAskUserTool(&tools.AskUserDeps{ + BatchRequestFn: twh.RequestAskUser, + }), + } + } - // switchToRemote mirrors switchProject but binds the shared env to a remote - // SSH executor instead of a local path. It reuses the SAME agent/recorder - // rebuild sequence so the agent, system prompt and session recorder stay - // consistent with the local switch path. - switchToRemote := func(executor *tools.SSHExecutor, remotePwd string) (*adk.ChatModelAgent, *session.Recorder, error) { - // 0. Close the previous live remote SSH connection (if switching - // remote→remote); switching from local has nothing to close. - if prev, ok := env.Exec.(*tools.SSHExecutor); ok && prev != executor { - defer func() { _ = prev.Close() }() + // Per-task compaction paths — transcript + reduction must be task-scoped or + // concurrent summarization across tasks would corrupt a shared file. + taskUUID := "task" + if trec != nil { + taskUUID = trec.UUID() } + transcriptPath := filepath.Join(config.ConfigDir(), "transcripts", taskUUID+".txt") + reductionRoot := filepath.Join(config.ConfigDir(), "reduction", taskUUID) + _ = os.MkdirAll(filepath.Dir(transcriptPath), 0o755) + _ = os.MkdirAll(reductionRoot, 0o755) + + makeAgent := func(cm model.ToolCallingChatModel, ctxLimit int, planMode bool) (*adk.ChatModelAgent, error) { + var middlewares []adk.AgentMiddleware //nolint:staticcheck // langfuseTracer.AgentMiddleware()/agent.NewAgent still use the deprecated type + if langfuseTracer != nil { + middlewares = append(middlewares, langfuseTracer.AgentMiddleware()) + } - // 1. Point the shared env at the remote SSH executor. - env.SetSSH(executor, remotePwd) - remotePlatform := executor.Platform() + var handlers []adk.ChatModelAgentMiddleware - // 2. Approval state now governs the remote working directory. - approvalState.SetWorkpath(remotePwd) + compactThreshold := cfg.CompactionThreshold() + reductionThreshold := compactThreshold - 0.15 + if reductionThreshold < 0.1 { + reductionThreshold = compactThreshold * 0.8 + } - // 3. Re-render the system prompt with the remote env label + platform. - // Project skills are scanned from the LOCAL fs, so keep the existing - // descriptions rather than rescanning against the remote path. - envLabel := fmt.Sprintf("%s (pwd: %s)", executor.Label(), remotePwd) - systemPrompt = prompts.GetSystemPrompt(remotePlatform, remotePwd, envLabel, nil, skillLoader.Descriptions()) + summMw, err := summarization.New(ctx, &summarization.Config{ + Model: cm, + Trigger: &summarization.TriggerCondition{ + ContextTokens: int(float64(ctxLimit) * compactThreshold), + }, + TranscriptFilePath: transcriptPath, + }) + if err == nil { + handlers = append(handlers, summMw) + } - // 4. Update the captured pwd (env already points at the remote target). - pwd = remotePwd + reductionBackend := &agent.LocalReductionBackend{RootDir: reductionRoot} + reductionMw, err := reduction.New(ctx, &reduction.Config{ + Backend: reductionBackend, + RootDir: reductionRoot, + MaxLengthForTrunc: 50000, + MaxTokensForClear: int64(float64(ctxLimit) * reductionThreshold), + ReadFileToolName: "read", + ToolConfig: map[string]*reduction.ToolReductionConfig{ + "read": {SkipClear: true}, + }, + }) + if err == nil { + handlers = append(handlers, reductionMw) + } - // 5. Recorder scoped to a host-qualified project key so a remote path - // does not collide with a local path of the same name in the tree. - projectKey := remote.ProjectLabel(executor, remotePwd) - if rec != nil { - rec.Close() + reminderMw := agent.NewReminderMiddleware(agent.ReminderConfig{ + TodoStore: tenv.TodoStore, + GoalStore: tenv.GoalStore, + PlanStore: tplan, + EnvLabel: "local", + IsRemote: tenv.IsRemote(), + ContextLimit: ctxLimit, + }, ttok) + handlers = append(handlers, reminderMw) + + prompt := systemPrompt + toolList := buildAllTools(cm) + if planMode { + prompt = planPrompt + toolList = buildPlanTools() + } + return agent.NewAgent(ctx, cm, toolList, prompt, tappr.RequestApproval, middlewares, handlers) } - newRec, _ := session.NewRecorder(projectKey, providerName, modelName) - rec = newRec - // 6. Rebuild the agent with the updated remote prompt. - newAg, err := createAgent(providerName, modelName) - if err != nil { - return nil, nil, err + // Per-task chat-model cache so a model/mode switch rebuilds only this task. + // cmMu serializes the cache against breakdownFn (a GET handler) reading it. + var cmMu sync.Mutex + var currentCM model.ToolCallingChatModel + var currentCtxLimit int + currentPlanMode := startMode.IsPlan() + + createAgent := func(prov, mod string) (*adk.ChatModelAgent, error) { + cm, ctxLimit, err := newChatModel(prov, mod) + if err != nil { + return nil, err + } + cmMu.Lock() + plan := currentPlanMode + cmMu.Unlock() + ag, err := makeAgent(cm, ctxLimit, plan) + if err != nil { + return nil, err // don't poison the cache with a model whose agent failed to build + } + cmMu.Lock() + currentCM = cm + currentCtxLimit = ctxLimit + cmMu.Unlock() + return ag, nil } - return newAg, newRec, nil - } - // breakdownFn estimates how the live agent's context window is partitioned - // across system prompt / built-in tools / MCP tools / skills. It reads the - // captured assembly variables (systemPrompt, mcpTools, currentCM, skillLoader) - // by reference, so project switches and MCP reloads are reflected without any - // cache to invalidate. Built-in tools = all tools minus MCP tools. - breakdownFn := func() usage.ContextBreakdown { - var b usage.ContextBreakdown - skillDesc := skillLoader.Descriptions() - b.SkillsTokens = usage.Estimate(skillDesc) - // Skills are injected into the system prompt, so subtract to avoid - // double-counting them in the system-prompt bucket. - b.SystemPromptTokens = usage.Estimate(systemPrompt) - b.SkillsTokens - if b.SystemPromptTokens < 0 { - b.SystemPromptTokens = 0 - } - for _, mt := range mcpTools { - b.MCPToolsTokens += estimateToolTokens(ctx, mt) + rebuildForMode := func(planMode bool) (*adk.ChatModelAgent, error) { + cmMu.Lock() + currentPlanMode = planMode + cm, ctxLimit := currentCM, currentCtxLimit + cmMu.Unlock() + if cm == nil { + return createAgent(providerName, modelName) + } + return makeAgent(cm, ctxLimit, planMode) } - if currentCM != nil { - total := 0 - for _, at := range buildAllTools(currentCM) { - total += estimateToolTokens(ctx, at) + + breakdownFn := func() usage.ContextBreakdown { + var b usage.ContextBreakdown + skillDesc := taskLoader.Descriptions() + b.SkillsTokens = usage.Estimate(skillDesc) + b.SystemPromptTokens = usage.Estimate(systemPrompt) - b.SkillsTokens + if b.SystemPromptTokens < 0 { + b.SystemPromptTokens = 0 } - b.SystemToolsTokens = total - b.MCPToolsTokens - if b.SystemToolsTokens < 0 { - b.SystemToolsTokens = 0 + if mt := mcpToolsPtr.Load(); mt != nil { + for _, t := range *mt { + b.MCPToolsTokens += estimateToolTokens(ctx, t) + } + } + cmMu.Lock() + cm := currentCM + cmMu.Unlock() + if cm != nil { + total := 0 + for _, at := range buildAllTools(cm) { + total += estimateToolTokens(ctx, at) + } + b.SystemToolsTokens = total - b.MCPToolsTokens + if b.SystemToolsTokens < 0 { + b.SystemToolsTokens = 0 + } + } + return b + } + + var ag *adk.ChatModelAgent + if !needsSetup { + var err error + ag, err = createAgent(providerName, modelName) + if err != nil { + return nil, fmt.Errorf("error creating agent: %w", err) } } - return b + + return &web.EngineConfig{ + TaskID: taskID, + Pwd: taskPwd, + Mode: startMode.String(), + ProviderName: providerName, + ModelName: modelName, + Agent: ag, + Env: tenv, + TodoStore: tenv.TodoStore, + Recorder: trec, + TokenUsage: ttok, + ApprovalState: tappr, + Handler: twh, + EventHandler: tnotify, + BreakdownFn: breakdownFn, + CreateAgent: createAgent, + RebuildForMode: rebuildForMode, + }, nil + } + + // Bootstrap engine for the initial task. + bootEC, err := buildWebTask("", pwd, startupMode.String(), nil) + if err != nil { + return err } + bootNotifying, _ := bootEC.EventHandler.(*handler.NotifyingHandler) srv := web.NewServer(&web.ServerConfig{ - Port: port, - Host: host, - OpenBrowser: openBrowser, - Pwd: pwd, - Version: Version, - Agent: ag, - CreateAgent: createAgent, - RebuildForMode: rebuildForMode, + Port: port, + Host: host, + OpenBrowser: openBrowser, + Pwd: pwd, + Version: Version, + Agent: bootEC.Agent, + CreateAgent: bootEC.CreateAgent, + RebuildForMode: bootEC.RebuildForMode, + NewEngine: func(taskID, taskPwd, modeStr string) (*web.EngineConfig, error) { + return buildWebTask(taskID, taskPwd, modeStr, nil) + }, + NewRemoteEngine: func(taskID string, exec *tools.SSHExecutor, remotePwd, modeStr string) (*web.EngineConfig, error) { + return buildWebTask(taskID, remotePwd, modeStr, exec) + }, InitialMode: startupMode.String(), - SwitchProject: switchProject, - SwitchToRemote: switchToRemote, - TodoStore: env.TodoStore, - Recorder: rec, + TodoStore: bootEC.TodoStore, + Recorder: bootEC.Recorder, Tracer: langfuseTracer, - Env: env, + Env: bootEC.Env, ProviderName: providerName, ModelName: modelName, Config: cfg, Registry: registry, - ApprovalState: approvalState, + ApprovalState: bootEC.ApprovalState, SkillLoader: skillLoader, ReloadMCP: reloadMCPTools, InitialMCPStatuses: initialMCPStatuses, WechatClient: wechatClient, - WebHandler: webHandler, - EventHandler: finalHandler, + WebHandler: bootEC.Handler, + EventHandler: bootEC.EventHandler, NeedsSetup: needsSetup, - TokenUsage: agentTokenUsage, - ContextBreakdownFn: breakdownFn, + TokenUsage: bootEC.TokenUsage, + ContextBreakdownFn: bootEC.BreakdownFn, }) - // Set handler for approval routing. - // If WeChat channel wraps the handler, use the wrapping handler for notifications. - approvalState.SetHandler(finalHandler) - - // Set up inbound WeChat message handler now that srv exists. - // Always register regardless of WebEnabled — the user can enable via the UI. + // Set up inbound WeChat message handler now that srv exists. Always register + // regardless of WebEnabled — the user can enable via the UI. Inbound messages + // target the active task (no task_id channel). wechatClient.SetOnMessage(func(from, text string) { if wechatClient.State() != channel.StateEnabled { return // channel disabled, silently ignore @@ -521,11 +547,11 @@ func runWebServer(port int, host string, openBrowser bool) error { } }) - // Clean up WeChat on shutdown. + // Clean up WeChat + shared notifiers on shutdown. defer func() { - // Close all notifiers (BLE, etc.) - notifyingH.CloseNotifiers() - + if bootNotifying != nil { + bootNotifying.CloseNotifiers() + } if wechatClient.State() == channel.StateEnabled { // Best-effort, don't block shutdown go func() { _ = wechatClient.SendText(channel.GoodbyeMessage(time.Now())) }() @@ -538,13 +564,9 @@ func runWebServer(port int, host string, openBrowser bool) error { return fmt.Errorf("server error: %w", err) } - if rec != nil { - rec.Close() - } + srv.CloseAllEngines() if langfuseTracer != nil { langfuseTracer.Flush() } return nil } - -// Ensure unused imports are used (some may be used only indirectly). diff --git a/internal/model/chatmodel.go b/internal/model/chatmodel.go index 1f99e40..210a47d 100644 --- a/internal/model/chatmodel.go +++ b/internal/model/chatmodel.go @@ -38,8 +38,19 @@ type TokenUsage struct { lastCached int64 lastReasoning int64 lastCacheWrite int64 - byModel map[string]int64 - mu sync.RWMutex + // cacheSeen is set (sticky) once the provider returns a prompt_tokens_details + // object, so CacheObserved can report "caching supported" even on a 0-hit + // turn. Cleared by Reset (a session boundary), never by ResetContext. + cacheSeen int64 + // turnBase* snapshot the cumulative counters at the start of an agent turn + // (BeginTurn) so per-turn budgets measure THIS turn's consumption, not the + // whole session's. ResetContext deliberately leaves these untouched so a + // mid-turn compaction does not corrupt the per-turn measurement. + turnBasePrompt int64 + turnBaseCompletion int64 + turnBaseCached int64 + byModel map[string]int64 + mu sync.RWMutex } // AddParams carries one API call's token usage. Using a struct keeps the @@ -51,6 +62,11 @@ type AddParams struct { Cached int Reasoning int CacheWrite int + // CacheDetailsPresent is true when the provider returned a + // prompt_tokens_details object at all (even with cached_tokens:0), letting + // CacheObserved tell "supports caching, 0 hits" apart from "never reports + // caching". See https://platform.openai.com/docs/guides/prompt-caching. + CacheDetailsPresent bool } // TokenUsageDetail holds a token usage snapshot for tracing/observability and @@ -99,6 +115,9 @@ func (t *TokenUsage) Add(p AddParams) { atomic.StoreInt64(&t.lastCached, int64(p.Cached)) atomic.StoreInt64(&t.lastReasoning, int64(p.Reasoning)) atomic.StoreInt64(&t.lastCacheWrite, int64(p.CacheWrite)) + if p.CacheDetailsPresent || p.Cached > 0 { + atomic.StoreInt64(&t.cacheSeen, 1) + } } // Get returns the current token usage @@ -158,10 +177,14 @@ func (t *TokenUsage) CacheHitRate() float64 { } } -// CacheObserved reports whether any cache-read tokens have been seen, used to -// distinguish "cache hit rate is 0%" from "this provider never reports caching". +// CacheObserved reports whether the provider has reported cache details (a +// prompt_tokens_details object) — used to distinguish "cache hit rate is 0%" +// from "this provider never reports caching". It is true on the first turn that +// carries cache details even when cached_tokens is 0, and stays true for the +// session (cleared only by Reset). The CachedTokens>0 fallback keeps it correct +// for older snapshots recorded before the presence flag existed. func (t *TokenUsage) CacheObserved() bool { - return atomic.LoadInt64(&t.CachedTokens) > 0 + return atomic.LoadInt64(&t.cacheSeen) > 0 || atomic.LoadInt64(&t.CachedTokens) > 0 } // Reset resets the token tracker @@ -179,11 +202,59 @@ func (t *TokenUsage) Reset() { atomic.StoreInt64(&t.lastCached, 0) atomic.StoreInt64(&t.lastReasoning, 0) atomic.StoreInt64(&t.lastCacheWrite, 0) + atomic.StoreInt64(&t.cacheSeen, 0) + atomic.StoreInt64(&t.turnBasePrompt, 0) + atomic.StoreInt64(&t.turnBaseCompletion, 0) + atomic.StoreInt64(&t.turnBaseCached, 0) t.mu.Lock() t.byModel = nil t.mu.Unlock() } +// ResetContext clears only the "current context occupancy" snapshot (the last +// API call's per-call values), leaving the cumulative consumption ledger, the +// cache-support flag, the per-model breakdown, and the per-turn baseline +// intact. Call this after a compaction/summarization shrinks the live context: +// the context indicator should reflect the smaller window, but the session's +// accumulated spend must NOT be lost — it feeds budgets, the usage log, and +// cross-session stats. (Full Reset is for a genuine session boundary.) +func (t *TokenUsage) ResetContext() { + atomic.StoreInt64(&t.LastTotalTokens, 0) + atomic.StoreInt64(&t.lastPrompt, 0) + atomic.StoreInt64(&t.lastCompletion, 0) + atomic.StoreInt64(&t.lastCached, 0) + atomic.StoreInt64(&t.lastReasoning, 0) + atomic.StoreInt64(&t.lastCacheWrite, 0) +} + +// BeginTurn snapshots the cumulative counters as the baseline for the current +// agent turn so TurnUsage reports only this turn's delta. Called at the start of +// every runner turn. +func (t *TokenUsage) BeginTurn() { + atomic.StoreInt64(&t.turnBasePrompt, atomic.LoadInt64(&t.PromptTokens)) + atomic.StoreInt64(&t.turnBaseCompletion, atomic.LoadInt64(&t.CompletionTokens)) + atomic.StoreInt64(&t.turnBaseCached, atomic.LoadInt64(&t.CachedTokens)) +} + +// TurnUsage returns this turn's consumption (cumulative minus the BeginTurn +// baseline). Each value is clamped at 0 so a mid-turn Reset (which zeroes the +// cumulative and the baseline together) can never yield a negative delta. +func (t *TokenUsage) TurnUsage() (prompt, completion, cached int64) { + prompt = atomic.LoadInt64(&t.PromptTokens) - atomic.LoadInt64(&t.turnBasePrompt) + completion = atomic.LoadInt64(&t.CompletionTokens) - atomic.LoadInt64(&t.turnBaseCompletion) + cached = atomic.LoadInt64(&t.CachedTokens) - atomic.LoadInt64(&t.turnBaseCached) + if prompt < 0 { + prompt = 0 + } + if completion < 0 { + completion = 0 + } + if cached < 0 { + cached = 0 + } + return +} + // AddByModel adds token usage attributed to a specific model name. func (t *TokenUsage) AddByModel(model string, prompt, completion, total int) { if model == "" { @@ -213,8 +284,9 @@ func (t *TokenUsage) GetByModel() map[string]int64 { // ModelPricing contains cost information for a model. type ModelPricing struct { - InputPer1M float64 // cost per 1M input tokens - OutputPer1M float64 // cost per 1M output tokens + InputPer1M float64 // cost per 1M input tokens + OutputPer1M float64 // cost per 1M output tokens + CacheReadPer1M float64 // cost per 1M cache-read (cached input) tokens; 0 ⇒ no discount data, fall back to InputPer1M } // ModelInfo contains information about a model @@ -288,6 +360,7 @@ func extractUsage(u openai.Usage) AddParams { } if u.PromptTokensDetails != nil { p.Cached = u.PromptTokensDetails.CachedTokens + p.CacheDetailsPresent = true } if u.CompletionTokensDetails != nil { p.Reasoning = u.CompletionTokensDetails.ReasoningTokens diff --git a/internal/model/registry.go b/internal/model/registry.go index bd3a992..957ade4 100644 --- a/internal/model/registry.go +++ b/internal/model/registry.go @@ -239,6 +239,16 @@ func (r *ModelRegistry) GetModelCost(providerID, modelID string) (inputPer1M, ou return m.Cost.Input, m.Cost.Output } +// GetModelCacheCost returns the cache-read and cache-write prices (USD per 1M +// tokens) for a model, or 0 when the registry has no cache pricing for it. +func (r *ModelRegistry) GetModelCacheCost(providerID, modelID string) (cacheReadPer1M, cacheWritePer1M float64) { + _, m, ok := r.LookupModel(providerID, modelID) + if !ok || m == nil || m.Cost == nil { + return 0, 0 + } + return m.Cost.CacheRead, m.Cost.CacheWrite +} + // GetProviderAPI returns the API base URL for a provider from the registry. func (r *ModelRegistry) GetProviderAPI(providerID string) string { prov := r.GetProvider(providerID) diff --git a/internal/model/token_usage_test.go b/internal/model/token_usage_test.go index a411309..42fbe90 100644 --- a/internal/model/token_usage_test.go +++ b/internal/model/token_usage_test.go @@ -70,6 +70,68 @@ func TestTokenUsageDetail_Minus(t *testing.T) { } } +func TestTokenUsage_ResetContext_PreservesLedger(t *testing.T) { + u := &TokenUsage{} + u.Add(AddParams{Prompt: 1000, Completion: 200, Total: 1200, Cached: 800, CacheDetailsPresent: true}) + u.Add(AddParams{Prompt: 500, Completion: 100, Total: 600}) + u.AddByModel("m", 1500, 300, 1800) + + u.ResetContext() + + // Cumulative ledger must survive a context reset (the compaction case). + if got := u.GetFull(); got.PromptTokens != 1500 || got.CompletionTokens != 300 || got.TotalTokens != 1800 || got.CallCount != 2 { + t.Errorf("ResetContext wiped the cumulative ledger: %+v", got) + } + if !u.CacheObserved() { + t.Errorf("ResetContext should not clear the cache-support flag") + } + if u.GetByModel() == nil { + t.Errorf("ResetContext should not clear the per-model breakdown") + } + // Only the current-occupancy snapshot is cleared. + if got := u.GetLastTotal(); got != 0 { + t.Errorf("GetLastTotal after ResetContext = %d, want 0", got) + } + if got := u.GetLastDetail(); got.PromptTokens != 0 || got.CompletionTokens != 0 { + t.Errorf("GetLastDetail after ResetContext = %+v, want zero", got) + } +} + +func TestTokenUsage_TurnUsage(t *testing.T) { + u := &TokenUsage{} + u.Add(AddParams{Prompt: 1000, Completion: 200, Cached: 100}) + + u.BeginTurn() // baseline at turn start + + u.Add(AddParams{Prompt: 2000, Completion: 300, Cached: 500}) + u.Add(AddParams{Prompt: 2000, Completion: 100, Cached: 500}) + + p, c, cached := u.TurnUsage() + if p != 4000 || c != 400 || cached != 1000 { + t.Errorf("TurnUsage() = (%d,%d,%d), want (4000,400,1000)", p, c, cached) + } + + // A mid-turn Reset zeroes cumulative AND baseline together; the delta must + // clamp to >=0, never go negative. + u.Reset() + u.Add(AddParams{Prompt: 50, Completion: 10}) + if p, c, _ := u.TurnUsage(); p < 0 || c < 0 { + t.Errorf("TurnUsage() after Reset = (%d,%d), must not be negative", p, c) + } +} + +func TestTokenUsage_CacheObserved_DetailsButZeroHit(t *testing.T) { + u := &TokenUsage{} + // Provider reported a details object but served 0 cached tokens (cold cache). + u.Add(AddParams{Prompt: 1000, Completion: 100, Cached: 0, CacheDetailsPresent: true}) + if !u.CacheObserved() { + t.Errorf("CacheObserved() = false, want true when details present even at 0 hits") + } + if got := u.CacheHitRate(); got != 0 { + t.Errorf("CacheHitRate() = %v, want 0", got) + } +} + func TestTokenUsage_Reset(t *testing.T) { u := &TokenUsage{} u.Add(AddParams{Prompt: 100, Completion: 20, Total: 120, Cached: 80, Reasoning: 5}) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 69584cd..99add5d 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -38,11 +38,18 @@ func Run( if tokenUsage != nil { ctx = internalmodel.WithTokenTracker(ctx, tokenUsage) } - // Snapshot cumulative usage so we can record this turn's delta on completion. + // Snapshot cumulative usage so we can record this turn's delta on completion, + // and mark the per-turn baseline so the budget middleware measures THIS turn. var startSnap internalmodel.TokenUsageDetail if tokenUsage != nil { startSnap = tokenUsage.GetFull() + tokenUsage.BeginTurn() } + // Persist this turn's token delta on EVERY exit path — success, user + // cancellation, or a model error — so a turn that already made billable calls + // before stopping is not dropped from the usage log. recordUsageTurn is + // nil-tracker and zero-delta guarded, so a no-op turn records nothing. + defer recordUsageTurn(tokenUsage, startSnap, rec) // Resolve the context limit once (config + registry lookup) and reuse it for // every live update below. ctxLimit := modelContextLimit() @@ -141,9 +148,8 @@ todoLoop: } h.OnTokenUpdate(buildTokenUsage(tracker, ctxLimit)) - // Persist this turn's token delta to the global usage log for stats. - recordUsageTurn(tokenUsage, startSnap, rec) - + // This turn's token delta is persisted by the deferred recordUsageTurn, which + // also covers the early-return (cancel/error) paths above. h.OnAgentDone(nil) return resp } @@ -384,6 +390,9 @@ func recordUsageTurn(tracker *internalmodel.TokenUsage, start internalmodel.Toke CacheWrite: delta.CacheWriteTokens, Total: delta.TotalTokens, Calls: delta.CallCount, + // Record cache *support* (details object seen), not just a positive hit, so + // the stats page can show "—" vs a real 0% even on a cold-cache turn. + CacheSeen: tracker.CacheObserved(), } if rec != nil { ev.Session = rec.UUID() diff --git a/internal/session/index_test.go b/internal/session/index_test.go index d094a0f..85f7ec2 100644 --- a/internal/session/index_test.go +++ b/internal/session/index_test.go @@ -1,6 +1,10 @@ package session -import "testing" +import ( + "fmt" + "sync" + "testing" +) // TestRecorderIndexingRequiresContent locks the contract the web server's // todo/goal OnUpdate guard relies on: a recorder that has written nothing is @@ -39,3 +43,33 @@ func TestRecorderIndexingRequiresContent(t *testing.T) { } rec.Close() } + +// TestConcurrentIndexWritesNoLostUpdate guards the indexMu serialization: many +// goroutines adding distinct sessions concurrently must ALL survive in the +// index. Without the lock, the read-modify-rename writers lose updates (and +// corrupt the shared .tmp), so the final index would hold far fewer than N. +// Run with -race to also catch the data race. +func TestConcurrentIndexWritesNoLostUpdate(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + + const project = "/proj/concurrent" + const n = 50 + + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _ = addToIndex(project, SessionMeta{UUID: fmt.Sprintf("uuid-%d", i), Project: project}) + }(i) + } + wg.Wait() + + metas, err := ListSessions(project) + if err != nil { + t.Fatalf("ListSessions: %v", err) + } + if len(metas) != n { + t.Fatalf("lost updates: got %d sessions, want %d (concurrent addToIndex without serialization)", len(metas), n) + } +} diff --git a/internal/session/session.go b/internal/session/session.go index 79e9db4..4167338 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -161,8 +161,13 @@ func NewRecorder(project, provider, model string) (*Recorder, error) { }, nil } -// UUID returns the session identifier. -func (r *Recorder) UUID() string { return r.uuid } +// UUID returns the session identifier. Locked because SetUUID can update it +// concurrently (a resumed web task swaps the recorder's UUID under r.mu). +func (r *Recorder) UUID() string { + r.mu.Lock() + defer r.mu.Unlock() + return r.uuid +} // Project returns the workspace path this recorder is scoped to. func (r *Recorder) Project() string { return r.project } @@ -170,8 +175,23 @@ func (r *Recorder) Project() string { return r.project } // Provider returns the provider the session was opened with. func (r *Recorder) Provider() string { return r.provider } -// Model returns the model the session was opened with. -func (r *Recorder) Model() string { return r.model } +// Model returns the model currently attributed to recorded usage. It is the +// model the session was opened with unless SetModel updated it after a switch. +func (r *Recorder) Model() string { + r.mu.Lock() + defer r.mu.Unlock() + return r.model +} + +// SetModel updates the model attributed to subsequently recorded usage so a +// mid-session model switch attributes new turns to the new model rather than +// the one the session was opened with. The session-start header is unchanged +// (it records the opening model). +func (r *Recorder) SetModel(model string) { + r.mu.Lock() + r.model = model + r.mu.Unlock() +} // ValidateSessionID checks that a session ID is safe for use as a filename. // It rejects empty IDs, path traversal sequences, and path separators. @@ -569,8 +589,19 @@ func (r *Recorder) writeEntry(e Entry) error { return r.file.Sync() } +// indexMu serializes the read-modify-rename writers of the shared session index +// (session.json). Atomic rename prevents torn files but NOT lost updates: two +// concurrent writers each read the same old index, mutate, and rename — last one +// wins — and they also race on the shared ".tmp" path. Once multiple web +// task Engines create/title/update/delete sessions in parallel this is a real +// lost-update + tmp-corruption hazard, so every writer takes this lock. Readers +// (ListSessions/ListAllSessions) rely on rename atomicity and stay lock-free. +var indexMu sync.Mutex + // addToIndex adds a SessionMeta to the shared index file. func addToIndex(project string, meta SessionMeta) error { + indexMu.Lock() + defer indexMu.Unlock() indexPath, err := config.SessionsIndexPath() if err != nil { return err @@ -624,6 +655,8 @@ func generateTitle(content string) string { // updateIndexTitle updates the title of a session in the shared index file. func updateIndexTitle(project, uuid, title string) error { + indexMu.Lock() + defer indexMu.Unlock() indexPath, err := config.SessionsIndexPath() if err != nil { return err @@ -680,6 +713,8 @@ func DeleteSession(project, uuid string) error { // removed when the uuid was actually found in the index, which also prevents a // crafted uuid from deleting an arbitrary file. func DeleteSessionByUUID(uuid string) (bool, error) { + indexMu.Lock() + defer indexMu.Unlock() indexPath, err := config.SessionsIndexPath() if err != nil { return false, err @@ -732,6 +767,8 @@ func DeleteSessionByUUID(uuid string) (bool, error) { } func removeFromIndex(project, uuid string) error { + indexMu.Lock() + defer indexMu.Unlock() indexPath, err := config.SessionsIndexPath() if err != nil { return err @@ -791,6 +828,8 @@ func ListSessions(project string) ([]SessionMeta, error) { // or (nil, nil) if no session with that uuid exists. uuid is only compared in // memory (never used as a path), so no path validation is required here. func UpdateSessionMeta(uuid string, mutate func(*SessionMeta)) (*SessionMeta, error) { + indexMu.Lock() + defer indexMu.Unlock() indexPath, err := config.SessionsIndexPath() if err != nil { return nil, err diff --git a/internal/usage/event.go b/internal/usage/event.go index 36799f4..e7eba27 100644 --- a/internal/usage/event.go +++ b/internal/usage/event.go @@ -33,6 +33,12 @@ type Event struct { CacheWrite int `json:"cache_write,omitempty"` Total int `json:"total"` Calls int `json:"calls,omitempty"` // API calls in this turn + // CacheSeen is true when the provider reported a prompt_tokens_details object + // during the turn (caching supported), even if cached==0. Lets stats show "—" + // vs a real 0% without conflating "unsupported" with "supported, no hit". + // Absent in events written before this field existed (defaults to false, with + // a Cached>0 fallback in Aggregate). + CacheSeen bool `json:"cache_seen,omitempty"` } // RecordEvent stamps ev with the current time and appends it to the default diff --git a/internal/usage/stats.go b/internal/usage/stats.go index 276cf37..cd87f6f 100644 --- a/internal/usage/stats.go +++ b/internal/usage/stats.go @@ -59,8 +59,12 @@ func Aggregate(events []Event, today string) Aggregated { agg := Aggregated{Days: make(map[string]*DayBucket)} byModel := map[string]int64{} byProject := map[string]int64{} + anyCacheSeen := false for _, ev := range events { + if ev.CacheSeen { + anyCacheSeen = true + } agg.Totals.Total += int64(ev.Total) agg.Totals.Prompt += int64(ev.Prompt) agg.Totals.Completion += int64(ev.Completion) @@ -90,7 +94,10 @@ func Aggregate(events []Event, today string) Aggregated { agg.ActiveDays = len(agg.Days) agg.CurrentStreak = currentStreak(agg.Days, today) agg.LongestStreak = longestStreak(agg.Days) - agg.CacheSupported = agg.Totals.Cached > 0 + // A turn that reported cache details (CacheSeen) means the provider supports + // caching even if no tokens were served from cache; the Cached>0 fallback + // keeps older events (written before CacheSeen existed) correct. + agg.CacheSupported = anyCacheSeen || agg.Totals.Cached > 0 if agg.Totals.Prompt > 0 { r := float64(agg.Totals.Cached) / float64(agg.Totals.Prompt) switch { diff --git a/internal/usage/usage_test.go b/internal/usage/usage_test.go index 060c470..8483418 100644 --- a/internal/usage/usage_test.go +++ b/internal/usage/usage_test.go @@ -156,6 +156,20 @@ func TestAggregate_NoCacheSupport(t *testing.T) { } } +func TestAggregate_CacheSeenZeroHit(t *testing.T) { + // Provider reported cache details but served 0 cached tokens: caching IS + // supported, so the stats page should not claim it isn't. + e := ev("2026-06-21", "m", "/p", 100, 80, 0) + e.CacheSeen = true + a := Aggregate([]Event{e}, "2026-06-21") + if !a.CacheSupported { + t.Error("CacheSupported = false, want true when CacheSeen even with 0 cached") + } + if a.CacheHitRate != 0 { + t.Errorf("CacheHitRate = %v, want 0", a.CacheHitRate) + } +} + func TestAggregate_Trend(t *testing.T) { a := Aggregate([]Event{ ev("2026-06-21", "m", "/p", 100, 80, 40), diff --git a/internal/web/concurrency_test.go b/internal/web/concurrency_test.go new file mode 100644 index 0000000..33f344a --- /dev/null +++ b/internal/web/concurrency_test.go @@ -0,0 +1,248 @@ +package web + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "github.com/cnjack/jcode/internal/handler" + "github.com/cnjack/jcode/internal/model" + "github.com/cnjack/jcode/internal/session" + "github.com/cnjack/jcode/internal/tools" +) + +// drain non-blockingly collects every queued WSEvent on a client's send channel. +func drain(c *WSClient) []WSEvent { + var out []WSEvent + for { + select { + case data := <-c.sendCh: + var ev WSEvent + if json.Unmarshal(data, &ev) == nil { + out = append(out, ev) + } + default: + return out + } + } +} + +func hasType(evs []WSEvent, typ string) bool { + for _, e := range evs { + if e.Type == typ { + return true + } + } + return false +} + +// TestWSClientWants locks the per-client subscription predicate that prevents a +// busy task from flooding a client viewing a different one. +func TestWSClientWants(t *testing.T) { + c := newWSClient(nil) + // Before subscribing, a client receives everything (legacy compatibility). + if !c.wants("task-A") || !c.wants("") { + t.Fatal("fresh client (subAll) should want all events") + } + c.subscribe([]string{"task-A"}) + if !c.wants("task-A") { + t.Error("subscribed task should be wanted") + } + if c.wants("task-B") { + t.Error("unsubscribed task must NOT be wanted after first subscribe") + } + if !c.wants("") { + t.Error("global events (empty task) must always be wanted") + } + c.subscribe([]string{"task-B"}) + if !c.wants("task-A") || !c.wants("task-B") { + t.Error("subscriptions are additive") + } + c.unsubscribe([]string{"task-A"}) + if c.wants("task-A") { + t.Error("unsubscribe should drop the task") + } +} + +// TestBrokerDeliversByTaskID proves the broker fans an event only to clients +// subscribed to its task (plus all clients for global events). +func TestBrokerDeliversByTaskID(t *testing.T) { + b := NewWSBroker() + viewer := newWSClient(nil) // only watching task-A + viewer.subscribe([]string{"task-A"}) + legacy := newWSClient(nil) // never subscribed → sees everything + b.mu.Lock() + b.clients[1] = viewer + b.clients[2] = legacy + b.mu.Unlock() + + b.Broadcast(WSEvent{TaskID: "task-A", Type: "a_event"}) + b.Broadcast(WSEvent{TaskID: "task-B", Type: "b_event"}) + b.Broadcast(WSEvent{Type: "global_event"}) + + got := drain(viewer) + if !hasType(got, "a_event") || !hasType(got, "global_event") { + t.Errorf("viewer should get its task + global, got %+v", got) + } + if hasType(got, "b_event") { + t.Errorf("viewer must NOT get another task's events, got %+v", got) + } + gotLegacy := drain(legacy) + if !hasType(gotLegacy, "a_event") || !hasType(gotLegacy, "b_event") || !hasType(gotLegacy, "global_event") { + t.Errorf("legacy (subAll) client should get every event, got %+v", gotLegacy) + } +} + +// TestEnginePumpStampsTaskID is the end-to-end pump test: an engine's handler +// event reaches a subscribed client tagged with that engine's task id. +func TestEnginePumpStampsTaskID(t *testing.T) { + s := &Server{Engine: &Engine{}, tasks: make(map[string]*Engine), wsBroker: NewWSBroker(), ctx: context.Background()} + h := handler.NewWebHandler() + eng := &Engine{taskID: "task-1", handler: h} + + client := newWSClient(nil) + client.subscribe([]string{"task-1"}) + s.wsBroker.mu.Lock() + s.wsBroker.clients[1] = client + s.wsBroker.mu.Unlock() + + _ = s.registerEngine(eng) // starts the per-engine pump + h.OnAgentText("hello from task 1") + + deadline := time.After(2 * time.Second) + for { + select { + case data := <-client.sendCh: + var ev WSEvent + _ = json.Unmarshal(data, &ev) + if ev.Type == "agent_text" { + if ev.TaskID != "task-1" { + t.Fatalf("event task_id = %q, want task-1", ev.TaskID) + } + return + } + case <-deadline: + t.Fatal("timed out waiting for the pumped agent_text event") + } + } +} + +// stubFactoryServer builds a Server whose newEngine factory produces fully +// isolated (but agent-less) engines, so engine lifecycle/routing can be tested +// without a live model. +func stubFactoryServer(t *testing.T) *Server { + t.Helper() + t.Setenv("HOME", t.TempDir()) + s := &Server{Engine: &Engine{}, tasks: make(map[string]*Engine), wsBroker: NewWSBroker(), ctx: context.Background()} + s.newEngine = func(taskID, pwd, modeStr string) (*EngineConfig, error) { + rec, _ := session.NewRecorder(pwd, "prov", "model") + if taskID != "" && rec != nil { + rec.SetUUID(taskID) + } + return &EngineConfig{ + TaskID: taskID, + Pwd: pwd, + Mode: modeStr, + Env: tools.NewEnv(pwd, "darwin/arm64"), + Recorder: rec, + TokenUsage: &model.TokenUsage{}, + Handler: handler.NewWebHandler(), + }, nil + } + return s +} + +// TestEngineIsolationAndRouting verifies concurrently-built engines are fully +// isolated (distinct env/recorder/token tracker) and individually routable. +func TestEngineIsolationAndRouting(t *testing.T) { + s := stubFactoryServer(t) + + const n = 16 + engines := make([]*Engine, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + eng, err := s.buildLocalEngine("", fmt.Sprintf("/proj/%d", i), "build") + if err != nil { + t.Errorf("buildLocalEngine: %v", err) + return + } + engines[i] = eng + }(i) + } + wg.Wait() + + seenEnv := map[*tools.Env]bool{} + seenTok := map[*model.TokenUsage]bool{} + seenID := map[string]bool{} + for i, eng := range engines { + if eng == nil { + t.Fatalf("engine %d not built", i) + } + if seenEnv[eng.env] { + t.Errorf("engine %d shares an env with another task", i) + } + if seenTok[eng.tokenUsage] { + t.Errorf("engine %d shares a token tracker with another task", i) + } + if seenID[eng.taskID] { + t.Errorf("engine %d has a duplicate task id %q", i, eng.taskID) + } + seenEnv[eng.env] = true + seenTok[eng.tokenUsage] = true + seenID[eng.taskID] = true + + // Each engine must be routable by its id, and not collide with others. + if got := s.resolveEngine(eng.taskID); got != eng { + t.Errorf("resolveEngine(%q) returned the wrong engine", eng.taskID) + } + } + if got := s.resolveEngine("does-not-exist"); got != nil { + t.Errorf("resolveEngine(unknown) = %v, want nil", got) + } +} + +// TestPerTaskGateIndependence proves one running task does not block another: +// the busy flag is per-engine, not global. +func TestPerTaskGateIndependence(t *testing.T) { + a := &Engine{taskID: "a"} + b := &Engine{taskID: "b"} + + if !a.running.CompareAndSwap(false, true) { + t.Fatal("task a should acquire its own gate") + } + // a is busy; b must still be able to start. + if !b.running.CompareAndSwap(false, true) { + t.Fatal("task b must NOT be blocked by task a running") + } + // a cannot double-start while busy. + if a.running.CompareAndSwap(false, true) { + t.Fatal("task a must not start twice concurrently") + } + a.running.Store(false) + if !a.running.CompareAndSwap(false, true) { + t.Fatal("task a should restart after finishing") + } +} + +// TestDeleteEngineTeardown verifies a non-active task can be torn down (removed +// from the map, pump stopped) without disturbing the others. +func TestDeleteEngineTeardown(t *testing.T) { + s := stubFactoryServer(t) + keep, _ := s.buildLocalEngine("", "/proj/keep", "build") + drop, _ := s.buildLocalEngine("", "/proj/drop", "build") + + s.deleteEngine(drop.taskID) + + if got := s.resolveEngine(drop.taskID); got != nil { + t.Error("deleted engine should no longer resolve") + } + if got := s.resolveEngine(keep.taskID); got != keep { + t.Error("the other engine must remain routable") + } +} diff --git a/internal/web/engine.go b/internal/web/engine.go new file mode 100644 index 0000000..2c352a5 --- /dev/null +++ b/internal/web/engine.go @@ -0,0 +1,464 @@ +package web + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/eino/adk" + + "github.com/cnjack/jcode/internal/handler" + "github.com/cnjack/jcode/internal/model" + "github.com/cnjack/jcode/internal/runner" + "github.com/cnjack/jcode/internal/session" + "github.com/cnjack/jcode/internal/tools" + "github.com/cnjack/jcode/internal/usage" +) + +// Engine is the per-task run state of the web server — one independent top-level +// session (a "task"). It holds exactly the fields that were Server singletons in +// the single-active design: the agent, its conversation history, recorder, +// per-task token tracker, approval axis, working dir, and event handler. +// +// The refactor toward concurrent tasks proceeds in stages: +// - INC-3 (this): the fields move out of Server into Engine, Server embeds one +// bootstrap *Engine, and field promotion keeps every existing s. +// reference compiling and behaving identically (concurrency stays 1). +// - INC-5: handlers resolve an Engine by task_id (Server.tasks) instead of +// relying on the promoted bootstrap. +// - INC-7+: the global run gate is removed so multiple Engines run at once; +// `running` becomes the per-task busy flag and each Engine gets its own ctx. +// +// Kernel types (session.Recorder, tools.Env, runner.ApprovalState, +// handler.WebHandler, model.TokenUsage) are reused unchanged — parallelism is +// achieved by instantiating one set of them per Engine, never by editing them. +type Engine struct { + // emu guards this engine's mutable run state (history, recorder, runCancel, + // sessionSnapshot). Named emu — not mu — so it does not collide with the + // promoted Server.mu when Engine is embedded in Server. + emu sync.Mutex + + // taskID is the task identity (== the recorder's session UUID once a message + // has been recorded). Empty for a freshly created, not-yet-messaged engine. + taskID string + + // pwd is the task's working directory, bound at creation. In the target model + // it is immutable for the task's lifetime; "switching project" creates a new + // Engine rather than mutating this one's env in place. + pwd string + + // --- run state (guarded today by Server.mu; gains its own lock in a later + // increment once Server.mu's single-run role is gone) --- + agent *adk.ChatModelAgent + history []adk.Message + running atomic.Bool // per-task busy flag (was the global Server.running gate) + runCancel context.CancelFunc + + // --- per-task model / mode axis --- + providerName string + modelName string + mode string // "build" / "plan" / "full_access" + + // --- per-task execution context --- + env *tools.Env // fresh per task; todo/goal/bg hang off it + todoStore *tools.TodoStore + approvalState *runner.ApprovalState + + // --- per-task accounting + event emission --- + recorder *session.Recorder + tokenUsage *model.TokenUsage + sessionSnapshot string // git tree hash at run start, for session-scoped diffs + handler *handler.WebHandler + eventHandler handler.AgentEventHandler // runner handler (may wrap handler in a NotifyingHandler) + breakdownFn func() usage.ContextBreakdown + + // per-task agent rebuild closures (bound to THIS task's env/model/prompt), so + // a model or mode switch rebuilds only this task's agent. + createAgent func(providerName, modelName string) (*adk.ChatModelAgent, error) + rebuildForMode func(planMode bool) (*adk.ChatModelAgent, error) + + // pumpCancel stops this engine's event-forwarding goroutine on teardown. + pumpCancel context.CancelFunc +} + +// EngineConfig carries the per-task pieces a factory (command.buildWebTask) +// produces for one task. The web package owns Engine's unexported fields, so the +// command package hands them over through this exported struct and the server +// assembles the Engine via newEngine. +type EngineConfig struct { + TaskID string + Pwd string + Mode string + ProviderName string + ModelName string + Agent *adk.ChatModelAgent + Env *tools.Env + TodoStore *tools.TodoStore + Recorder *session.Recorder + TokenUsage *model.TokenUsage + ApprovalState *runner.ApprovalState + Handler *handler.WebHandler + EventHandler handler.AgentEventHandler + BreakdownFn func() usage.ContextBreakdown + CreateAgent func(providerName, modelName string) (*adk.ChatModelAgent, error) + RebuildForMode func(planMode bool) (*adk.ChatModelAgent, error) +} + +// newEngine assembles an *Engine from the factory-produced config. The engine's +// identity (taskID) is its recorder's session UUID unless an explicit resume id +// was supplied. +func newEngine(c *EngineConfig) *Engine { + tu := c.TokenUsage + if tu == nil { + tu = &model.TokenUsage{} + } + taskID := c.TaskID + if taskID == "" && c.Recorder != nil { + taskID = c.Recorder.UUID() + } + return &Engine{ + taskID: taskID, + pwd: c.Pwd, + mode: c.Mode, + providerName: c.ProviderName, + modelName: c.ModelName, + agent: c.Agent, + env: c.Env, + todoStore: c.TodoStore, + recorder: c.Recorder, + tokenUsage: tu, + approvalState: c.ApprovalState, + handler: c.Handler, + eventHandler: c.EventHandler, + breakdownFn: c.BreakdownFn, + createAgent: c.CreateAgent, + rebuildForMode: c.RebuildForMode, + } +} + +// activeEngine returns the currently-foregrounded engine (the embedded bootstrap +// pointer), read under s.mu so it never tears against setActiveEngine's swap. +// Legacy non-task-routed handlers MUST go through this accessor (and the +// emu-locked Engine helpers below) instead of bare promoted s. reads. +func (s *Server) activeEngine() *Engine { + s.mu.RLock() + defer s.mu.RUnlock() + return s.Engine +} + +// --- emu-guarded accessors for an engine's MUTABLE run-state fields (agent, +// recorder, mode, provider, model). Immutable-after-build fields (pwd, env, +// todoStore, handler, tokenUsage, approvalState, breakdownFn) may be read +// directly off an *Engine snapshot obtained via activeEngine/resolveEngine. --- + +// activePwd returns the foreground engine's working directory (or "") via the +// s.mu-guarded accessor, so file/exec/git/pty handlers never read the swapped +// s.Engine pointer bare. +func (s *Server) activePwd() string { + if eng := s.activeEngine(); eng != nil { + return eng.pwd + } + return "" +} + +// activeHandler returns the foreground engine's WebHandler (or nil). +func (s *Server) activeHandler() *handler.WebHandler { + if eng := s.activeEngine(); eng != nil { + return eng.handler + } + return nil +} + +// activeMode returns the foreground engine's mode (or ""). +func (s *Server) activeMode() string { + if eng := s.activeEngine(); eng != nil { + return eng.curMode() + } + return "" +} + +// modelSnapshot returns the engine's provider/model/mode under emu. +func (e *Engine) modelSnapshot() (provider, model, modeStr string) { + e.emu.Lock() + defer e.emu.Unlock() + return e.providerName, e.modelName, e.mode +} + +// curMode returns the engine's mode under emu. +func (e *Engine) curMode() string { + e.emu.Lock() + defer e.emu.Unlock() + return e.mode +} + +// recUUID returns the engine recorder's UUID (or "") under emu. +func (e *Engine) recUUID() string { + e.emu.Lock() + defer e.emu.Unlock() + if e.recorder == nil { + return "" + } + return e.recorder.UUID() +} + +// applyModelSwitch swaps the engine's agent + provider/model under emu and +// re-tags the recorder with the new model. +func (e *Engine) applyModelSwitch(ag *adk.ChatModelAgent, provider, model string) { + e.emu.Lock() + defer e.emu.Unlock() + e.agent = ag + e.providerName = provider + e.modelName = model + if e.recorder != nil { + e.recorder.SetModel(model) + } +} + +// applyModeSwitch sets the engine's mode and (optionally) its rebuilt agent. +func (e *Engine) applyModeSwitch(modeStr string, ag *adk.ChatModelAgent) { + e.emu.Lock() + defer e.emu.Unlock() + e.mode = modeStr + if ag != nil { + e.agent = ag + } +} + +// setAgent swaps just the agent under emu (MCP reload, skill toggle, setup). +func (e *Engine) setAgent(ag *adk.ChatModelAgent) { + e.emu.Lock() + defer e.emu.Unlock() + e.agent = ag +} + +// resolveEngine returns the engine for taskID, or the active engine when taskID +// is empty (legacy / no-task_id callers). Returns nil when taskID is unknown. +func (s *Server) resolveEngine(taskID string) *Engine { + if taskID == "" { + return s.activeEngine() + } + s.tasksMu.RLock() + eng := s.tasks[taskID] + s.tasksMu.RUnlock() + if eng == nil { + // The active engine may not be in the map yet under its session UUID + // (a brand-new chat whose recorder UUID the client already knows). + if a := s.activeEngine(); a != nil && a.taskID == taskID { + return a + } + } + return eng +} + +// maxLiveEngines bounds the number of concurrently-live task engines so a +// client cannot mint unbounded engines (fd/goroutine/agent accumulation). +const maxLiveEngines = 64 + +var errTooManyTasks = fmt.Errorf("too many concurrent tasks") + +// registerEngine adds eng to the tasks map (keyed by task id), publishes its +// pump-cancel under tasksMu (so teardown observes it), and starts its event +// pump. Idempotent for an already-registered engine. Returns errTooManyTasks if +// registering a NEW engine would exceed the live-engine cap. +func (s *Server) registerEngine(eng *Engine) error { + if eng == nil || eng.taskID == "" { + return nil + } + s.tasksMu.Lock() + _, existed := s.tasks[eng.taskID] + if !existed && len(s.tasks) >= maxLiveEngines { + s.tasksMu.Unlock() + return errTooManyTasks + } + var pumpCtx context.Context + if !existed { + base := s.ctx + if base == nil { + base = context.Background() + } + ctx, cancel := context.WithCancel(base) + eng.pumpCancel = cancel // published under tasksMu; teardown reads it after a tasksMu acquisition + pumpCtx = ctx + } + s.tasks[eng.taskID] = eng + s.tasksMu.Unlock() + if pumpCtx != nil { + s.startPump(pumpCtx, eng) + } + return nil +} + +// startPump forwards eng's handler events to the WS broker, stamped with the +// engine's task id, until ctx is cancelled (teardown) or the channel closes. +// Each engine gets its own pump so concurrent tasks never serialize on one +// forwarding goroutine. +func (s *Server) startPump(ctx context.Context, eng *Engine) { + events := eng.handler.Events() + go func() { + for { + select { + case <-ctx.Done(): + return + case ev, ok := <-events: + if !ok { + return + } + s.wsBroker.Broadcast(WSEvent{TaskID: eng.taskID, Type: ev.Event, Data: ev.Data}) + } + } + }() +} + +// buildLocalEngine creates and registers a fresh local task engine. +func (s *Server) buildLocalEngine(taskID, pwd, modeStr string) (*Engine, error) { + if s.newEngine == nil { + return nil, fmt.Errorf("task creation is not supported") + } + ec, err := s.newEngine(taskID, pwd, modeStr) + if err != nil { + return nil, err + } + eng := newEngine(ec) + // Inherit the foreground task's current model selection rather than reverting + // to the startup default (the factory bakes in startup provider/model). + if cur := s.activeEngine(); cur != nil && eng.createAgent != nil { + prov, mdl, _ := cur.modelSnapshot() + if prov != "" && (prov != eng.providerName || mdl != eng.modelName) { + if ag, agErr := eng.createAgent(prov, mdl); agErr == nil { + eng.applyModelSwitch(ag, prov, mdl) + } + } + } + if err := s.registerEngine(eng); err != nil { + eng.teardown() + return nil, err + } + return eng, nil +} + +// buildRemoteEngine creates and registers a fresh remote (SSH) task engine. +func (s *Server) buildRemoteEngine(taskID string, exec *tools.SSHExecutor, remotePwd, modeStr string) (*Engine, error) { + if s.newRemoteEngine == nil { + return nil, fmt.Errorf("remote task creation is not supported") + } + ec, err := s.newRemoteEngine(taskID, exec, remotePwd, modeStr) + if err != nil { + return nil, err + } + eng := newEngine(ec) + if err := s.registerEngine(eng); err != nil { + eng.teardown() + return nil, err + } + return eng, nil +} + +// setActiveEngine makes eng the foreground engine (the one the promoted legacy +// handlers operate on). It reclaims the OUTGOING engine only when it is an unused +// throwaway (idle and never recorded) — e.g. a "new chat" the user navigated +// away from without typing — so the new-chat/switch-project path doesn't leak an +// engine each time. A running or already-recorded task is kept (real background +// work). The pointer swap is guarded by s.mu, matching activeEngine's read. +func (s *Server) setActiveEngine(eng *Engine) { + if eng == nil { + return + } + _ = s.registerEngine(eng) + s.mu.Lock() + prev := s.Engine + s.Engine = eng + s.mu.Unlock() + if prev != nil && prev != eng && !prev.running.Load() { + reclaim := false + prev.emu.Lock() + if prev.recorder == nil || !prev.recorder.HasRecording() { + reclaim = true + } + prev.emu.Unlock() + if reclaim { + s.deleteEngine(prev.taskID) + } + } +} + +// deleteEngine removes a task engine from the map and tears it down (stops its +// pump, cancels its run, closes its recorder). The active engine is never +// deleted out from under the foreground; callers guard against that. +func (s *Server) deleteEngine(taskID string) { + s.tasksMu.Lock() + eng := s.tasks[taskID] + delete(s.tasks, taskID) + s.tasksMu.Unlock() + if eng != nil { + eng.teardown() + } +} + +// teardown stops the engine's event pump, cancels any in-flight run, waits for +// it to drain, then closes the recorder — so the recorder is never closed under +// a live writer (which would truncate the session file). Shared resources (push +// notifiers, registry, skill loader, MCP clients) are owned by the Server and +// deliberately left untouched. Callers reach teardown only after a tasksMu +// acquisition (deleteEngine/CloseAllEngines), which establishes happens-before +// with registerEngine's pumpCancel publication. +func (e *Engine) teardown() { + if e.pumpCancel != nil { + e.pumpCancel() + } + e.emu.Lock() + c := e.runCancel + e.emu.Unlock() + if c != nil { + c() + } + // Best-effort wait for the run goroutine to flip running=false so its final + // RecordAssistant lands before we close. + for i := 0; i < 200 && e.running.Load(); i++ { + time.Sleep(5 * time.Millisecond) + } + if e.recorder != nil { + e.recorder.Close() + } +} + +// setTaskStatus broadcasts a global task_status event (so every client's sidebar +// can mark the task running/idle live) and best-effort persists Status + +// UpdatedAt so recency survives a reload. The broadcast carries the task id in +// its DATA (not the envelope TaskID) so it is delivered to all clients, not +// filtered to the task's subscribers. +func (s *Server) setTaskStatus(eng *Engine, running bool) { + if eng == nil || eng.taskID == "" { + return + } + status := "idle" + if running { + status = "running" + } + s.wsBroker.Broadcast(WSEvent{Type: "task_status", Data: map[string]any{ + "task_id": eng.taskID, + "running": running, + "status": status, + }}) + go func(id, st string) { + _, _ = session.UpdateSessionMeta(id, func(m *session.SessionMeta) { + m.Status = st + m.UpdatedAt = time.Now().Format(time.RFC3339) + }) + }(eng.taskID, status) +} + +// CloseAllEngines tears down every live engine. Called on server shutdown. +func (s *Server) CloseAllEngines() { + s.tasksMu.Lock() + engines := make([]*Engine, 0, len(s.tasks)) + for _, e := range s.tasks { + engines = append(engines, e) + } + s.tasks = make(map[string]*Engine) + s.tasksMu.Unlock() + for _, e := range engines { + e.teardown() + } +} diff --git a/internal/web/git.go b/internal/web/git.go index 0fe9b85..3eb2a75 100644 --- a/internal/web/git.go +++ b/internal/web/git.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io" "net/http" + "os" "os/exec" "strings" ) @@ -15,7 +16,7 @@ import ( func (s *Server) handleGitBranches(w http.ResponseWriter, r *http.Request) { listCmd := exec.CommandContext(r.Context(), "git", "for-each-ref", "--format=%(refname:short)", "--sort=-committerdate", "refs/heads") - listCmd.Dir = s.pwd + listCmd.Dir = s.activePwd() out, err := listCmd.Output() if err != nil { writeJSON(w, http.StatusOK, map[string]any{"current": "", "branches": []string{}}) @@ -25,7 +26,7 @@ func (s *Server) handleGitBranches(w http.ResponseWriter, r *http.Request) { // `branch --show-current` reports the unborn branch of a fresh repo (e.g. // "main"), where `rev-parse --abbrev-ref HEAD` would just say "HEAD". curCmd := exec.CommandContext(r.Context(), "git", "branch", "--show-current") - curCmd.Dir = s.pwd + curCmd.Dir = s.activePwd() curOut, _ := curCmd.Output() current := strings.TrimSpace(string(curOut)) @@ -58,20 +59,38 @@ func (s *Server) handleGitBranches(w http.ResponseWriter, r *http.Request) { // handleGitCheckout switches to an existing branch, or creates and checks out a // new branch when create=true. It refuses while the agent is running (a branch -// switch rewrites the working tree under a live task) and surfaces git's own -// error verbatim (e.g. "Your local changes would be overwritten") rather than -// forcing a destructive checkout — the user decides how to resolve a dirty tree. +// switch rewrites the working tree under a live task). +// +// When a plain switch would clobber uncommitted work, git aborts +// non-destructively. Rather than surfacing that raw error we report it as a +// recoverable result (HTTP 200, blocked:true, plus the at-risk files) so the UI +// can ask the user how to proceed instead of dead-ending on an error. The +// client then retries with an explicit strategy: +// - "stash": `git stash push -u` the changes first, then switch (recoverable +// via `git stash pop`). +// - "force": `git checkout -f`, discarding the local changes. +// +// Any failure after a strategy was chosen is genuine and returned as an error. func (s *Server) handleGitCheckout(w http.ResponseWriter, r *http.Request) { - if s.running.Load() { + // Capture the active engine ONCE so the running guard and the checkout target + // the same task's repo even if the active engine is swapped concurrently. + eng := s.activeEngine() + if eng == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "no active task"}) + return + } + if eng.running.Load() { writeJSON(w, http.StatusConflict, map[string]string{ "error": "agent is running — stop it before switching branch", }) return } + dir := eng.pwd var req struct { - Branch string `json:"branch"` - Create bool `json:"create"` + Branch string `json:"branch"` + Create bool `json:"create"` + Strategy string `json:"strategy"` // "" (plain) | "stash" | "force" } if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) @@ -83,26 +102,95 @@ func (s *Server) handleGitCheckout(w http.ResponseWriter, r *http.Request) { return } + // Force stable English git output so the block detection below doesn't depend + // on the host locale. We present our own UI copy, so this C-locale text is + // only used internally, never shown to the user. + env := append(os.Environ(), "LC_ALL=C", "LANG=C") + + // "stash" strategy: tuck the working changes (including untracked files) away + // before switching so nothing is lost. A genuine stash failure is fatal. + stashed := false + if req.Strategy == "stash" { + stashCmd := exec.CommandContext(r.Context(), "git", "stash", "push", "-u", + "-m", "jcode: auto-stash before switching to "+branch) + stashCmd.Dir = dir + stashCmd.Env = env + stashOut, stashErr := stashCmd.CombinedOutput() + if stashErr != nil { + msg := strings.TrimSpace(string(stashOut)) + if msg == "" { + msg = stashErr.Error() + } + writeJSON(w, http.StatusConflict, map[string]string{"error": msg}) + return + } + stashed = !strings.Contains(string(stashOut), "No local changes to save") + } + args := []string{"checkout"} + if req.Strategy == "force" { + args = append(args, "-f") + } if req.Create { args = append(args, "-b") } args = append(args, branch) cmd := exec.CommandContext(r.Context(), "git", args...) - cmd.Dir = s.pwd + cmd.Dir = dir + cmd.Env = env out, err := cmd.CombinedOutput() if err != nil { msg := strings.TrimSpace(string(out)) if msg == "" { msg = err.Error() } + // A plain switch aborted by uncommitted work is recoverable: report it as + // such (the working tree is untouched) so the UI can offer stash/discard. + if req.Strategy == "" && checkoutBlockedByLocalChanges(msg) { + writeJSON(w, http.StatusOK, map[string]any{ + "branch": "", + "blocked": true, + "message": msg, + "files": parseOverwriteFiles(msg), + }) + return + } writeJSON(w, http.StatusConflict, map[string]string{"error": msg}) return } curCmd := exec.CommandContext(r.Context(), "git", "branch", "--show-current") - curCmd.Dir = s.pwd + curCmd.Dir = dir curOut, _ := curCmd.Output() - writeJSON(w, http.StatusOK, map[string]any{"branch": strings.TrimSpace(string(curOut))}) + writeJSON(w, http.StatusOK, map[string]any{ + "branch": strings.TrimSpace(string(curOut)), + "stashed": stashed, + }) +} + +// checkoutBlockedByLocalChanges reports whether a failed `git checkout` aborted +// because uncommitted work (modified or untracked files) would be overwritten — +// the recoverable case the UI resolves by stashing or discarding. Matched +// against C-locale git output. +func checkoutBlockedByLocalChanges(msg string) bool { + m := strings.ToLower(msg) + return strings.Contains(m, "would be overwritten by checkout") || + strings.Contains(m, "would be overwritten by merge") || + strings.Contains(m, "please commit your changes or stash them") +} + +// parseOverwriteFiles pulls the tab-indented paths git lists between the +// "would be overwritten" header and the trailing "Please commit…/Aborting" +// lines, so the UI can show exactly which files are at risk. +func parseOverwriteFiles(msg string) []string { + files := make([]string, 0, 8) + for _, line := range strings.Split(msg, "\n") { + if strings.HasPrefix(line, "\t") { + if f := strings.TrimSpace(line); f != "" { + files = append(files, f) + } + } + } + return files } diff --git a/internal/web/goal_test.go b/internal/web/goal_test.go index b5b005d..98e4fad 100644 --- a/internal/web/goal_test.go +++ b/internal/web/goal_test.go @@ -15,8 +15,10 @@ import ( // (no network listener, no agent), so the HTTP layer can be tested in-process. func newGoalTestServer() *Server { return &Server{ - env: tools.NewEnv("/tmp", "darwin/arm64"), - handler: handler.NewWebHandler(), + Engine: &Engine{ + env: tools.NewEnv("/tmp", "darwin/arm64"), + handler: handler.NewWebHandler(), + }, } } diff --git a/internal/web/mode_test.go b/internal/web/mode_test.go index b97de06..234bb14 100644 --- a/internal/web/mode_test.go +++ b/internal/web/mode_test.go @@ -19,13 +19,15 @@ import ( // recording the planMode flag every agent rebuild is asked for. func newModeTestServer(rebuilt *[]bool) *Server { return &Server{ - wsBroker: NewWSBroker(), - approvalState: runner.NewApprovalStateWithMode("/tmp", mode.Approval), - mode: "approval", - rebuildForMode: func(planMode bool) (*adk.ChatModelAgent, error) { - *rebuilt = append(*rebuilt, planMode) - return nil, nil + Engine: &Engine{ + approvalState: runner.NewApprovalStateWithMode("/tmp", mode.Approval), + mode: "approval", + rebuildForMode: func(planMode bool) (*adk.ChatModelAgent, error) { + *rebuilt = append(*rebuilt, planMode) + return nil, nil + }, }, + wsBroker: NewWSBroker(), } } diff --git a/internal/web/pty.go b/internal/web/pty.go index 24486e0..eb0da3a 100644 --- a/internal/web/pty.go +++ b/internal/web/pty.go @@ -17,9 +17,10 @@ import ( // ptySession represents a running PTY session. type ptySession struct { - id string - cmd *exec.Cmd - ptmx *os.File + id string + ownerID string // task id that created it, so a project/remote switch only closes its own + cmd *exec.Cmd + ptmx *os.File } // ptyManager manages PTY sessions. @@ -39,8 +40,8 @@ var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } -// create starts a new PTY session and returns its ID. -func (m *ptyManager) create(workDir string) (string, error) { +// create starts a new PTY session owned by ownerID and returns its ID. +func (m *ptyManager) create(workDir, ownerID string) (string, error) { shell := os.Getenv("SHELL") if shell == "" { shell = "/bin/sh" @@ -57,7 +58,7 @@ func (m *ptyManager) create(workDir string) (string, error) { m.mu.Lock() m.nextID++ id := fmt.Sprintf("pty_%d", m.nextID) - sess := &ptySession{id: id, cmd: cmd, ptmx: ptmx} + sess := &ptySession{id: id, ownerID: ownerID, cmd: cmd, ptmx: ptmx} m.sessions[id] = sess m.mu.Unlock() @@ -104,7 +105,7 @@ func (m *ptyManager) kill(id string) { } } -// closeAll terminates all PTY sessions. +// closeAll terminates all PTY sessions (server shutdown). func (m *ptyManager) closeAll() { m.mu.Lock() sessions := make([]*ptySession, 0, len(m.sessions)) @@ -119,6 +120,27 @@ func (m *ptyManager) closeAll() { } } +// closeForTask terminates only the PTY sessions owned by taskID, leaving other +// concurrent tasks' terminals alive. An empty taskID matches nothing. +func (m *ptyManager) closeForTask(taskID string) { + if taskID == "" { + return + } + m.mu.Lock() + var sessions []*ptySession + for id, s := range m.sessions { + if s.ownerID == taskID { + sessions = append(sessions, s) + delete(m.sessions, id) + } + } + m.mu.Unlock() + for _, s := range sessions { + _ = s.cmd.Process.Kill() + _ = s.ptmx.Close() + } +} + // serveWS handles the WebSocket connection for a PTY session. // Data flows: PTY stdout → WebSocket → client, client → WebSocket → PTY stdin. func (m *ptyManager) serveWS(w http.ResponseWriter, r *http.Request, id string) { diff --git a/internal/web/remote.go b/internal/web/remote.go index 6c8dcd9..f6d595a 100644 --- a/internal/web/remote.go +++ b/internal/web/remote.go @@ -185,11 +185,9 @@ func (s *Server) handleRemoteListDir(w http.ResponseWriter, r *http.Request) { // remote executor at the chosen directory and rebuilds the agent (same path as // a local project switch). func (s *Server) handleRemoteBind(w http.ResponseWriter, r *http.Request) { - if s.running.Load() { - writeJSON(w, http.StatusConflict, map[string]string{"error": "agent is running, cannot switch workspace"}) - return - } - if s.switchToRemote == nil { + // No running gate: binding a remote workspace builds a NEW engine; the + // previous task keeps running in the background. + if s.newRemoteEngine == nil { writeJSON(w, http.StatusNotImplemented, map[string]string{"error": "remote workspaces are not supported"}) return } @@ -211,25 +209,25 @@ func (s *Server) handleRemoteBind(w http.ResponseWriter, r *http.Request) { remotePwd = remote.DiscoverPwd(r.Context(), pc.exec, "/root") } - // Tear down local PTYs (they belonged to the previous workspace). - s.ptyMgr.closeAll() - - ag, rec, err := s.switchToRemote(pc.exec, remotePwd) + // Snapshot the outgoing task once, then build the new engine BEFORE tearing + // anything down — a failed bind must not disrupt the current task's PTYs. + prevTaskID, curMode := "", "" + if cur := s.activeEngine(); cur != nil { + prevTaskID, curMode = cur.taskID, cur.curMode() + } + eng, err := s.buildRemoteEngine("", pc.exec, remotePwd, curMode) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": fmt.Sprintf("failed to bind remote workspace: %v", err)}) return } + s.ptyMgr.closeForTask(prevTaskID) // outgoing task's PTYs only; others keep theirs + s.setActiveEngine(eng) label := remote.ProjectLabel(pc.exec, remotePwd) - s.mu.Lock() - s.pwd = remotePwd - s.agent = ag - s.recorder = rec - s.history = nil - s.mu.Unlock() - - s.todoStore.Update(nil) + if eng.todoStore != nil { + eng.todoStore.Update(nil) + } // Ownership of the executor has transferred to the live env; remove the // pending entry WITHOUT closing it. diff --git a/internal/web/server.go b/internal/web/server.go index b63ac1e..4ac1bfe 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -17,7 +17,6 @@ import ( "sort" "strings" "sync" - "sync/atomic" "time" "github.com/cloudwego/eino/adk" @@ -40,51 +39,48 @@ import ( // Server is the jcode web server. type Server struct { + // Engine is the bootstrap/active task's run state. Its fields (agent, history, + // recorder, pwd, env, tokenUsage, approvalState, handler, …) are PROMOTED onto + // Server, so existing s. accesses resolve to s.Engine. while + // there is a single active task. Per-task routing (Server.tasks) supersedes + // this promotion in a later increment. Always non-nil after NewServer. + *Engine + + // tasks holds every live task engine keyed by task id (session UUID). Wired in + // the routing increment; the bootstrap Engine above is the current de-facto + // entry until then. + tasks map[string]*Engine + tasksMu sync.RWMutex + port int host string openBrowser bool - pwd string - handler *handler.WebHandler wsBroker *WSBroker - mu sync.RWMutex - agent *adk.ChatModelAgent - history []adk.Message - running atomic.Bool - - // Cancel function for the currently running agent, protected by mu. - runCancel context.CancelFunc + // mu guards the shared-server maps and, during the single-active transition, + // the bootstrap Engine's run state (the role that moves to a per-Engine lock + // once tasks truly run in parallel). + mu sync.RWMutex // Server-level context (from Start), used for background agent work. ctx context.Context - // Active model info. - providerName string - modelName string - mode string // "build" or "plan" - // Dependencies set during initialization. - todoStore *tools.TodoStore - recorder *session.Recorder - tracer *telemetry.LangfuseTracer - env *tools.Env - cfg *config.Config - registry *model.ModelRegistry - - // createAgent rebuilds the agent (after config changes). - // Accepts provider and model names so the caller can switch models. - createAgent func(providerName, modelName string) (*adk.ChatModelAgent, error) - - // rebuildForMode re-assembles the agent for a session-mode change, swapping - // the tool/prompt axis (Plan = read-only) while reusing the live chat model. - rebuildForMode func(planMode bool) (*adk.ChatModelAgent, error) - - // switchProject changes the working directory and rebuilds the agent. - switchProject func(newPwd string) (*adk.ChatModelAgent, *session.Recorder, error) - - // switchToRemote binds the shared env to a remote SSH executor and rebuilds - // the agent (mirrors switchProject for remote targets). - switchToRemote func(executor *tools.SSHExecutor, remotePwd string) (*adk.ChatModelAgent, *session.Recorder, error) + tracer *telemetry.LangfuseTracer + cfg *config.Config + cfgMu sync.Mutex // serializes read-modify-write SaveConfig from concurrent handlers + registry *model.ModelRegistry + + // newEngine builds a fresh, fully-isolated task engine (its own env, agent, + // recorder, handler, approval state) at the given pwd/mode. This is how a new + // concurrent task — or a "switch project" — gets its run state without + // mutating any other task's. taskID is non-empty when resuming an existing + // session. nil in setup mode. + newEngine func(taskID, pwd, mode string) (*EngineConfig, error) + + // newRemoteEngine is newEngine's remote sibling: it builds a task engine bound + // to an SSH executor instead of a local pwd. + newRemoteEngine func(taskID string, executor *tools.SSHExecutor, remotePwd, mode string) (*EngineConfig, error) // remoteConns holds SSH connections established by the remote-connect wizard // that have not yet been bound to the live env (keyed by connection id). @@ -93,9 +89,6 @@ type Server struct { // PTY manager for terminal sessions. ptyMgr *ptyManager - // approvalState controls whether tool calls require approval. - approvalState *runner.ApprovalState - // skillLoader provides skill listing for slash commands. skillLoader *skills.Loader @@ -111,32 +104,17 @@ type Server struct { // mcpLogins tracks in-progress/finished OAuth logins per server name. Guarded by mu. mcpLogins map[string]*mcpLoginState - // sessionSnapshot holds the git tree hash at the start of an agent run, - // used to compute session-scoped diffs (agent changes only). - sessionSnapshot string - // wechatClient is the optional WeChat channel client. wechatClient channel.Channel - // eventHandler is the handler passed to the runner — may be a NotifyingHandler - // wrapping the WebHandler, or the WebHandler itself. - eventHandler handler.AgentEventHandler - // needsSetup is true when no providers are configured. The server starts in // setup mode and exposes setup API endpoints while blocking chat operations. needsSetup bool version string - // tokenUsage tracks per-call token totals for the agent runs, used for - // usage display (goal status, token updates). - tokenUsage *model.TokenUsage - // usageStore backs the global usage-statistics endpoint. nil falls back to // usage.Default(); tests inject a temp-dir store. usageStore *usage.Store - - // breakdownFn computes the live context-window breakdown for the active task. - breakdownFn func() usage.ContextBreakdown } // ServerConfig holds the configuration for creating a new Server. @@ -149,9 +127,9 @@ type ServerConfig struct { Agent *adk.ChatModelAgent CreateAgent func(providerName, modelName string) (*adk.ChatModelAgent, error) RebuildForMode func(planMode bool) (*adk.ChatModelAgent, error) - InitialMode string // unified startup mode string ("approval"/"plan"/"full_access") - SwitchProject func(newPwd string) (*adk.ChatModelAgent, *session.Recorder, error) - SwitchToRemote func(executor *tools.SSHExecutor, remotePwd string) (*adk.ChatModelAgent, *session.Recorder, error) + NewEngine func(taskID, pwd, mode string) (*EngineConfig, error) // factory for new concurrent task engines (local) + NewRemoteEngine func(taskID string, executor *tools.SSHExecutor, remotePwd, mode string) (*EngineConfig, error) // remote sibling of NewEngine + InitialMode string // unified startup mode string ("approval"/"plan"/"full_access") TodoStore *tools.TodoStore Recorder *session.Recorder Tracer *telemetry.LangfuseTracer @@ -182,97 +160,70 @@ func NewServer(cfg *ServerConfig) *Server { if cfg.EventHandler != nil { eh = cfg.EventHandler } - s := &Server{ - port: cfg.Port, - host: cfg.Host, - openBrowser: cfg.OpenBrowser, + // The bootstrap Engine carries the per-task run state of the initial session. + boot := &Engine{ pwd: cfg.Pwd, - version: cfg.Version, handler: h, - wsBroker: NewWSBroker(), agent: cfg.Agent, - createAgent: cfg.CreateAgent, - rebuildForMode: cfg.RebuildForMode, - switchProject: cfg.SwitchProject, - switchToRemote: cfg.SwitchToRemote, - remoteConns: newRemoteConnRegistry(), todoStore: cfg.TodoStore, recorder: cfg.Recorder, - tracer: cfg.Tracer, env: cfg.Env, providerName: cfg.ProviderName, modelName: cfg.ModelName, mode: mode.Parse(cfg.InitialMode).String(), - cfg: cfg.Config, - registry: cfg.Registry, - ptyMgr: newPTYManager(), approvalState: cfg.ApprovalState, - skillLoader: cfg.SkillLoader, - reloadMCP: cfg.ReloadMCP, - mcpStatuses: make(map[string]tools.MCPStatus), - mcpLogins: make(map[string]*mcpLoginState), - wechatClient: cfg.WechatClient, eventHandler: eh, - needsSetup: cfg.NeedsSetup, tokenUsage: cfg.TokenUsage, breakdownFn: cfg.ContextBreakdownFn, + createAgent: cfg.CreateAgent, + rebuildForMode: cfg.RebuildForMode, + } + if boot.tokenUsage == nil { + boot.tokenUsage = &model.TokenUsage{} } - if s.tokenUsage == nil { - s.tokenUsage = &model.TokenUsage{} + // The engine's identity is its recorder's session UUID; this is the task_id + // stamped on its events and the key in the tasks map. + if boot.taskID == "" && boot.recorder != nil { + boot.taskID = boot.recorder.UUID() } + s := &Server{ + Engine: boot, + tasks: make(map[string]*Engine), + port: cfg.Port, + host: cfg.Host, + openBrowser: cfg.OpenBrowser, + version: cfg.Version, + wsBroker: NewWSBroker(), + newEngine: cfg.NewEngine, + newRemoteEngine: cfg.NewRemoteEngine, + remoteConns: newRemoteConnRegistry(), + tracer: cfg.Tracer, + cfg: cfg.Config, + registry: cfg.Registry, + ptyMgr: newPTYManager(), + skillLoader: cfg.SkillLoader, + reloadMCP: cfg.ReloadMCP, + mcpStatuses: make(map[string]tools.MCPStatus), + mcpLogins: make(map[string]*mcpLoginState), + wechatClient: cfg.WechatClient, + needsSetup: cfg.NeedsSetup, + } + // The bootstrap engine is registered (and its pump started) in Start, once + // s.ctx exists. for _, st := range cfg.InitialMCPStatuses { s.mcpStatuses[st.Name] = st } - // Wire TodoStore → session recording. - // The callback always accesses s.recorder (protected by s.mu) so that - // handleNewSession / handleSwitchProject correctly use the latest recorder. - if cfg.TodoStore != nil { - cfg.TodoStore.OnUpdate = func(items []tools.TodoItem) { - s.mu.RLock() - r := s.recorder - s.mu.RUnlock() - // Only record into a session that already has real content (a user - // message created the file). Otherwise an ambient todo reset — e.g. - // clearing the previous session's todos when starting fresh — would - // be the first write, creating + indexing a phantom empty session. - if r != nil && r.HasRecording() { - snapItems := make([]session.TodoSnapshotItem, len(items)) - for i, it := range items { - snapItems[i] = session.TodoSnapshotItem{ - ID: it.ID, Title: it.Title, Status: string(it.Status), - } - } - r.RecordTodoSnapshot(snapItems) - } - } - } - - // Wire GoalStore → session recording, mirroring the TodoStore wiring above. - if cfg.Env != nil && cfg.Env.GoalStore != nil { - cfg.Env.GoalStore.OnUpdate = func(g *tools.Goal) { - s.mu.RLock() - r := s.recorder - s.mu.RUnlock() - // Same guard as the todo hook: a goal change must not be the first - // write that creates + indexes an otherwise-empty session (e.g. - // clearing the previous session's goal on reset). Always emit to the - // UI, but only persist once the session has real content. - if r != nil && r.HasRecording() { - tools.GoalRecorderHook(r)(g) - } - if s.handler != nil { - s.handler.Emit("goal_update", g) - } - } - } + // TodoStore/GoalStore → recorder/handler wiring is done PER TASK in the engine + // factory (command.buildWebTask), so each engine binds its OWN recorder and + // handler. (The bootstrap engine is built by that same factory.) return s } // Handler returns the underlying WebHandler for external wiring (e.g. approval routing). func (s *Server) Handler() *handler.WebHandler { - return s.handler + return s.activeHandler() } // Start starts the web server. Blocks until context is cancelled. @@ -331,6 +282,7 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("GET /api/slash-commands", s.handleSlashCommands) mux.HandleFunc("GET /api/browse", s.handleBrowse) mux.HandleFunc("POST /api/project/switch", s.handleSwitchProject) + mux.HandleFunc("POST /api/project/validate", s.handleValidatePaths) mux.HandleFunc("POST /api/pty", s.handleCreatePTY) mux.HandleFunc("GET /api/pty", s.handleListPTY) mux.HandleFunc("DELETE /api/pty/{id}", s.handleKillPTY) @@ -378,8 +330,11 @@ func (s *Server) Start(ctx context.Context) error { Handler: corsHandler, } - // Forward WebHandler events to WebSocket clients. - go s.forwardEvents() + // Register the bootstrap engine (adds it to the tasks map + starts its event + // pump). New task engines register themselves on creation. + if s.Engine != nil { + _ = s.registerEngine(s.Engine) + } // Graceful shutdown on context cancellation. go func() { @@ -412,24 +367,16 @@ func (s *Server) Start(ctx context.Context) error { return nil } -// forwardEvents reads from the WebHandler event channel and broadcasts to WebSocket clients. -func (s *Server) forwardEvents() { - for ev := range s.handler.Events() { - s.wsBroker.Broadcast(WSEvent{ - Type: ev.Event, - Data: ev.Data, - }) - } -} - // --- API Handlers --- -// currentModelSupportsImage checks if the currently selected model supports image input. -func (s *Server) currentModelSupportsImage() bool { - if s.registry == nil { +// currentModelSupportsImage checks if the given engine's selected model supports +// image input. +func (s *Server) currentModelSupportsImage(eng *Engine) bool { + if s.registry == nil || eng == nil { return false } - _, m, ok := s.registry.LookupModel(s.providerName, s.modelName) + provider, mdl, _ := eng.modelSnapshot() + _, m, ok := s.registry.LookupModel(provider, mdl) if !ok || m == nil || m.Modalities == nil { return false } @@ -442,18 +389,17 @@ func (s *Server) currentModelSupportsImage() bool { } func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { - s.mu.RLock() - sessionID := "" - if s.recorder != nil { - sessionID = s.recorder.UUID() - } - s.mu.RUnlock() + eng := s.activeEngine() - if s.needsSetup { + if s.needsSetup || eng == nil { + pwd := "" + if eng != nil { + pwd = eng.pwd + } writeJSON(w, http.StatusOK, map[string]any{ "status": "needs_setup", "version": s.version, - "pwd": s.pwd, + "pwd": pwd, "provider": "", "model": "", "mode": "build", @@ -464,53 +410,61 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { return } + provider, mdl, modeStr := eng.modelSnapshot() writeJSON(w, http.StatusOK, map[string]any{ "status": "ok", "version": s.version, - "pwd": s.pwd, - "provider": s.providerName, - "model": s.modelName, - "mode": s.mode, - "session_id": sessionID, - "running": s.running.Load(), - "image_support": s.currentModelSupportsImage(), + "pwd": eng.pwd, + "provider": provider, + "model": mdl, + "mode": modeStr, + "session_id": eng.recUUID(), + "running": eng.running.Load(), + "image_support": s.currentModelSupportsImage(eng), }) } func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { - full := s.tokenUsage.GetFull() + eng := s.activeEngine() + if eng == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "no active task"}) + return + } + full := eng.tokenUsage.GetFull() + provider, mdl, modeStr := eng.modelSnapshot() writeJSON(w, http.StatusOK, map[string]any{ - "running": s.running.Load(), + "running": eng.running.Load(), "ws_clients": s.wsBroker.ClientCount(), - "pwd": s.pwd, - "provider": s.providerName, - "model": s.modelName, - "mode": s.mode, + "pwd": eng.pwd, + "provider": provider, + "model": mdl, + "mode": modeStr, // Live token snapshot so a client reconnecting between turns can render // the context bar + cache hit rate without waiting for the next // token_update WS event. total_tokens = current context occupancy. "token": map[string]any{ - "total_tokens": s.tokenUsage.GetLastTotal(), + "total_tokens": eng.tokenUsage.GetLastTotal(), "prompt_tokens": full.PromptTokens, "completion_tokens": full.CompletionTokens, "cached_tokens": full.CachedTokens, "reasoning_tokens": full.ReasoningTokens, "cache_write_tokens": full.CacheWriteTokens, "call_count": full.CallCount, - "cache_hit_rate": s.tokenUsage.CacheHitRate(), - "cache_supported": s.tokenUsage.CacheObserved(), - "model_context_limit": s.currentModelContextLimit(), + "cache_hit_rate": eng.tokenUsage.CacheHitRate(), + "cache_supported": eng.tokenUsage.CacheObserved(), + "model_context_limit": s.currentModelContextLimit(eng), }, }) } -// currentModelContextLimit resolves the context window of the currently +// currentModelContextLimit resolves the context window of the given engine's // selected model, or 0 if unknown. -func (s *Server) currentModelContextLimit() int { - if s.registry == nil || s.cfg == nil { +func (s *Server) currentModelContextLimit(eng *Engine) int { + if s.registry == nil || s.cfg == nil || eng == nil { return 0 } - return model.ResolveContextLimit(s.registry, s.cfg, s.providerName, s.modelName) + provider, mdl, _ := eng.modelSnapshot() + return model.ResolveContextLimit(s.registry, s.cfg, provider, mdl) } // handleWorkspace returns lightweight git workspace info (branch + dirty) for @@ -523,12 +477,12 @@ func (s *Server) handleWorkspace(w http.ResponseWriter, r *http.Request) { // initialised repo with no commits still reports its unborn branch (e.g. // "main") instead of the literal "HEAD". branchCmd := exec.CommandContext(r.Context(), "git", "branch", "--show-current") - branchCmd.Dir = s.pwd + branchCmd.Dir = s.activePwd() branchOut, _ := branchCmd.Output() branch := strings.TrimSpace(string(branchOut)) statusCmd := exec.CommandContext(r.Context(), "git", "status", "--porcelain") - statusCmd.Dir = s.pwd + statusCmd.Dir = s.activePwd() statusOut, _ := statusCmd.Output() dirty := strings.TrimSpace(string(statusOut)) != "" @@ -544,42 +498,60 @@ func (s *Server) handleChat(w http.ResponseWriter, r *http.Request) { return } - // Use CompareAndSwap to atomically check and set running, preventing - // two concurrent requests from both entering submitMessage. - if !s.running.CompareAndSwap(false, true) { - writeJSON(w, http.StatusConflict, map[string]string{ - "error": "agent is already processing a request", - }) - return - } - // running is now true; submitMessage will proceed without re-setting it. - var req struct { Message string `json:"message"` Images []chatImage `json:"images,omitempty"` // optional: base64-encoded images Mode string `json:"mode,omitempty"` // "build" or "plan" - SessionID string `json:"session_id,omitempty"` // optional: continue existing session + SessionID string `json:"session_id,omitempty"` // optional: the task (session) to run } if err := json.NewDecoder(io.LimitReader(r.Body, 20<<20)).Decode(&req); err != nil { - s.running.Store(false) writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) return } if strings.TrimSpace(req.Message) == "" { - s.running.Store(false) writeJSON(w, http.StatusBadRequest, map[string]string{"error": "message is required"}) return } - mode := req.Mode - if mode == "" { - mode = s.mode + modeStr := req.Mode + if modeStr == "" { + modeStr = s.activeMode() } - sessionID := s.submitMessage(req.Message, mode, "", req.SessionID, req.Images) + // Resolve (or lazily create) the engine for this task. Different tasks run + // concurrently; the per-task running flag only blocks double-running the SAME + // task. + eng, err := s.engineForChat(req.SessionID, modeStr) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + if !eng.running.CompareAndSwap(false, true) { + writeJSON(w, http.StatusConflict, map[string]string{ + "error": "this task is already processing a request", + }) + return + } + + sessionID := s.submitMessage(eng, req.Message, modeStr, "", req.SessionID, req.Images) writeJSON(w, http.StatusAccepted, map[string]string{"status": "processing", "session_id": sessionID}) } +// engineForChat resolves the engine a chat request targets. An empty task id (or +// one matching the active task) uses the active engine; a known live task uses +// its engine; an unknown id lazily spins up a fresh engine for it (a new task or +// the first message of a not-yet-live task), rooted at the active task's pwd. +func (s *Server) engineForChat(taskID, modeStr string) (*Engine, error) { + if eng := s.resolveEngine(taskID); eng != nil { + return eng, nil + } + pwd := "" + if a := s.activeEngine(); a != nil { + pwd = a.pwd + } + return s.buildLocalEngine(taskID, pwd, modeStr) +} + // chatImage represents a base64-encoded image in a chat request. type chatImage struct { Data string `json:"data"` // base64 data (without data: prefix) @@ -589,10 +561,14 @@ type chatImage struct { // SubmitMessage submits a message for agent processing from an external source // (e.g. WeChat inbound message). Returns false if the agent is busy. func (s *Server) SubmitMessage(message, source string) bool { - if !s.running.CompareAndSwap(false, true) { + eng := s.activeEngine() + if eng == nil { return false } - s.submitMessage(message, s.mode, source, "", nil) + if !eng.running.CompareAndSwap(false, true) { + return false + } + s.submitMessage(eng, message, eng.curMode(), source, "", nil) return true } @@ -602,9 +578,9 @@ func (s *Server) SubmitMessage(message, source string) bool { // continuity — if the current recorder has a different UUID, resume the // correct session instead of creating a new one. // images is an optional list of base64-encoded images to include in the message. -// The caller must have already set s.running to true (via CompareAndSwap). +// The caller must have already set eng.running to true (via CompareAndSwap). // Returns the session_id of the recorder used. -func (s *Server) submitMessage(message, mode, source, sessionID string, images []chatImage) string { +func (s *Server) submitMessage(eng *Engine, message, mode, source, sessionID string, images []chatImage) string { // Slash command rewrite: if the original message starts with "/", check for // skill slash commands and rewrite to load_skill instruction (same pattern as // ACP/TUI). This must happen BEFORE the plan-mode prefix is applied, otherwise @@ -639,7 +615,7 @@ func (s *Server) submitMessage(message, mode, source, sessionID string, images [ // Emit user_message event for external sources (e.g. WeChat) so web clients see it. // Web-originated messages are already added by the frontend's sendMessage(). if source != "" { - s.handler.Emit("user_message", map[string]string{ + eng.handler.Emit("user_message", map[string]string{ "content": message, "source": source, }) @@ -648,23 +624,23 @@ func (s *Server) submitMessage(message, mode, source, sessionID string, images [ // Ensure a recorder exists (lazy creation on first message). // If the client provided a session_id and the current recorder differs, // resume the client's session to prevent creating a duplicate. - s.mu.Lock() - if s.recorder == nil { - rec, _ := session.NewRecorder(s.pwd, s.providerName, s.modelName) + eng.emu.Lock() + if eng.recorder == nil { + rec, _ := session.NewRecorder(eng.pwd, eng.providerName, eng.modelName) if sessionID != "" { rec.SetUUID(sessionID) } - s.recorder = rec - } else if sessionID != "" && s.recorder.UUID() != sessionID { + eng.recorder = rec + } else if sessionID != "" && eng.recorder.UUID() != sessionID { // Client is continuing a session that doesn't match the current recorder. // Resume the client's session to keep all messages together. - s.recorder.Close() - rec, _ := session.NewRecorder(s.pwd, s.providerName, s.modelName) + eng.recorder.Close() + rec, _ := session.NewRecorder(eng.pwd, eng.providerName, eng.modelName) rec.SetUUID(sessionID) - s.recorder = rec + eng.recorder = rec } - recorder := s.recorder - s.mu.Unlock() + recorder := eng.recorder + eng.emu.Unlock() // Record user message. if recorder != nil { @@ -707,35 +683,38 @@ func (s *Server) submitMessage(message, mode, source, sessionID string, images [ userMsg = schema.UserMessage(agentMsg) } - s.mu.Lock() - s.history = append(s.history, userMsg) - history := make([]adk.Message, len(s.history)) - copy(history, s.history) - agent := s.agent - s.mu.Unlock() + eng.emu.Lock() + eng.history = append(eng.history, userMsg) + history := make([]adk.Message, len(eng.history)) + copy(history, eng.history) + agent := eng.agent + eng.emu.Unlock() - // Stream response via WebSocket — run agent in background. + // Stream response via WebSocket — run agent in background. Each task derives + // its own cancellable context so /stop cancels only that task. runCtx, runCancel := context.WithCancel(s.ctx) - s.mu.Lock() - s.runCancel = runCancel - s.mu.Unlock() + eng.emu.Lock() + eng.runCancel = runCancel + eng.emu.Unlock() go func() { + s.setTaskStatus(eng, true) defer func() { - s.running.Store(false) - s.mu.Lock() - s.runCancel = nil - s.mu.Unlock() + eng.running.Store(false) + eng.emu.Lock() + eng.runCancel = nil + eng.emu.Unlock() + s.setTaskStatus(eng, false) }() // Take a git snapshot before the agent run for session diff tracking. - s.takeSessionSnapshot() + s.takeSessionSnapshot(eng) - resp := runner.Run(runCtx, agent, history, s.eventHandler, recorder, s.todoStore, s.env.GoalStore, s.tracer, s.tokenUsage) + resp := runner.Run(runCtx, agent, history, eng.eventHandler, recorder, eng.todoStore, eng.env.GoalStore, s.tracer, eng.tokenUsage) if resp != "" { - s.mu.Lock() - s.history = append(s.history, &schema.Message{Role: schema.Assistant, Content: resp}) - s.mu.Unlock() + eng.emu.Lock() + eng.history = append(eng.history, &schema.Message{Role: schema.Assistant, Content: resp}) + eng.emu.Unlock() } }() @@ -751,10 +730,22 @@ func (s *Server) handleListAllTasks(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } + // Snapshot which task ids are currently running (live engines) so the sidebar + // can show a running indicator even on a fresh page load. + running := make(map[string]bool) + s.tasksMu.RLock() + for id, e := range s.tasks { + if e != nil && e.running.Load() { + running[id] = true + } + } + s.tasksMu.RUnlock() + type taskItem struct { UUID string `json:"uuid"` Project string `json:"project"` CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at,omitempty"` Provider string `json:"provider"` Model string `json:"model"` Title string `json:"title,omitempty"` @@ -762,6 +753,7 @@ func (s *Server) handleListAllTasks(w http.ResponseWriter, r *http.Request) { Archived bool `json:"archived"` Unread bool `json:"unread"` Status string `json:"status,omitempty"` + Running bool `json:"running"` } items := make([]taskItem, 0) for project, metas := range all { @@ -770,6 +762,7 @@ func (s *Server) handleListAllTasks(w http.ResponseWriter, r *http.Request) { UUID: m.UUID, Project: project, CreatedAt: m.StartTime, + UpdatedAt: m.UpdatedAt, Provider: m.Provider, Model: m.Model, Title: m.Title, @@ -777,6 +770,7 @@ func (s *Server) handleListAllTasks(w http.ResponseWriter, r *http.Request) { Archived: m.Archived, Unread: m.Unread, Status: m.Status, + Running: running[m.UUID], }) } } @@ -824,7 +818,7 @@ func (s *Server) handleUpdateTask(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleListSessions(w http.ResponseWriter, r *http.Request) { - metas, err := session.ListSessions(s.pwd) + metas, err := session.ListSessions(s.activePwd()) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -868,8 +862,38 @@ func (s *Server) handleDeleteSession(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "session id is required"}) return } + // Tear down the live engine for this task (if any) so its run is cancelled and + // resources reclaimed. The active foreground engine is left in place — but its + // recorder is reset to a fresh session so post-delete writes don't land in the + // now-unlinked file (silent data loss). + if eng := s.resolveEngine(id); eng != nil { + eng.emu.Lock() + cancel := eng.runCancel + eng.emu.Unlock() + if cancel != nil { + cancel() + } + if eng != s.activeEngine() { + s.deleteEngine(id) + } else { + // Active task: wait for the cancelled run to drain so its final + // RecordAssistant/usage writes land before we close + reset the recorder + // (a post-close write would re-create and truncate the file). + for i := 0; i < 200 && eng.running.Load(); i++ { + time.Sleep(5 * time.Millisecond) + } + eng.emu.Lock() + if eng.recorder != nil && eng.recorder.UUID() == id { + eng.recorder.Close() + eng.recorder = nil + eng.history = nil + } + eng.emu.Unlock() + } + } + // Resolve the owning project across all projects: a task deleted from the - // sidebar tree may not belong to the active project (s.pwd). + // sidebar tree may not belong to the active project. if _, err := session.DeleteSessionByUUID(id); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -878,7 +902,12 @@ func (s *Server) handleDeleteSession(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleTruncateHistory(w http.ResponseWriter, r *http.Request) { - if s.running.Load() { + eng := s.activeEngine() + if eng == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "no active task"}) + return + } + if eng.running.Load() { writeJSON(w, http.StatusConflict, map[string]string{"error": "agent is currently running"}) return } @@ -894,15 +923,15 @@ func (s *Server) handleTruncateHistory(w http.ResponseWriter, r *http.Request) { return } - // Capture the recorder reference under the lock but do file I/O outside - // so we don't block other goroutines. - s.mu.Lock() - rec := s.recorder + // Capture the recorder under eng.emu (same lock submitMessage uses) but do + // file I/O outside the lock. + eng.emu.Lock() + rec := eng.recorder + eng.emu.Unlock() sessionID := "" if rec != nil { sessionID = rec.UUID() } - s.mu.Unlock() // Persist first — if the file rewrite fails we abort without touching // the in-memory history so state never diverges. @@ -914,13 +943,13 @@ func (s *Server) handleTruncateHistory(w http.ResponseWriter, r *http.Request) { } } - // Now truncate in-memory history. - s.mu.Lock() + // Now truncate in-memory history under eng.emu. + eng.emu.Lock() truncAt := 0 if req.BeforeUserMessage > 0 { userCount := 0 - truncAt = len(s.history) // default: keep all - for i, msg := range s.history { + truncAt = len(eng.history) // default: keep all + for i, msg := range eng.history { if msg.Role == schema.User { if userCount == req.BeforeUserMessage { truncAt = i @@ -931,11 +960,11 @@ func (s *Server) handleTruncateHistory(w http.ResponseWriter, r *http.Request) { } } if truncAt == 0 { - s.history = nil + eng.history = nil } else { - s.history = s.history[:truncAt] + eng.history = eng.history[:truncAt] } - s.mu.Unlock() + eng.emu.Unlock() writeJSON(w, http.StatusOK, map[string]any{ "status": "ok", @@ -944,124 +973,94 @@ func (s *Server) handleTruncateHistory(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleNewSession(w http.ResponseWriter, r *http.Request) { - // Parse optional request body for resume session ID. + // Parse optional resume session ID + project. Creating a task no longer + // blocks on "is the agent running" — tasks run concurrently. var req struct { SessionID string `json:"session_id,omitempty"` + Pwd string `json:"pwd,omitempty"` } - _ = json.NewDecoder(io.LimitReader(r.Body, 1<<16)).Decode(&req) - - // Only block creating a brand-new session while the agent is running. - // Resuming (loading) an existing session is always allowed — the web UI - // may refresh at any time and needs to restore its view. - if req.SessionID == "" && s.running.Load() { - writeJSON(w, http.StatusConflict, map[string]string{"error": "agent is currently running"}) + // The body is optional (empty = brand-new task → EOF), but a non-empty + // malformed body should be rejected rather than creating a zero-value task. + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<16)).Decode(&req); err != nil && err != io.EOF { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) return } - // When resuming while running, skip recorder/history replacement — just - // return the requested session_id so the frontend can populate the UI - // from the session entries (which it already fetched via GET). - if req.SessionID != "" && s.running.Load() { - writeJSON(w, http.StatusOK, map[string]any{ - "status": "ok", - "session_id": req.SessionID, - }) - return + // Already-live task: just focus it (do not disturb its run). + if req.SessionID != "" { + if eng := s.resolveEngine(req.SessionID); eng != nil { + s.setActiveEngine(eng) + writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "session_id": eng.taskID}) + return + } } - s.mu.Lock() - // Close the old recorder. - if s.recorder != nil { - s.recorder.Close() - s.recorder = nil + if s.newEngine == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "task creation is not supported"}) + return } - // Only create a recorder when resuming an existing session. - // For brand-new conversations the recorder is created lazily in - // submitMessage on the first actual user message, which avoids - // persisting empty sessions. - if req.SessionID != "" { - rec, _ := session.NewRecorder(s.pwd, s.providerName, s.modelName) - if rec != nil { - rec.SetUUID(req.SessionID) - s.recorder = rec + // Each new/resumed task gets its OWN engine (env, agent, recorder, handler), + // so it runs independently of every other task. + pwd := req.Pwd + if pwd == "" { + if a := s.activeEngine(); a != nil { + pwd = a.pwd } } + eng, err := s.buildLocalEngine(req.SessionID, pwd, s.activeMode()) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } - // Prepare todo items while holding the lock, but apply them after unlocking - // to avoid deadlock: todoStore.Update → OnUpdate → s.mu.RLock. - var updateTodos bool - var todoItems []tools.TodoItem - var resuming bool - var goalSnap *session.GoalSnapshot + // Resume: hydrate the fresh engine with the persisted conversation/todos/goal. if req.SessionID != "" { - resuming = true - // Resuming: load prior conversation into history so the agent has context. - entries, _ := session.LoadSession(req.SessionID) - // Reconstruct full state (message history, todos, goal). + entries, lerr := session.LoadSession(req.SessionID) + if lerr != nil { + // Stale/nonexistent session id: don't silently register a phantom empty + // engine under it — tear the just-built engine down and report not-found. + s.deleteEngine(eng.taskID) + writeJSON(w, http.StatusNotFound, map[string]string{"error": "session not found"}) + return + } st := session.ReconstructState(entries) - s.history = st.History - goalSnap = st.Goal - // Always update (an empty list clears leftovers from the previous session). - if s.todoStore != nil { - updateTodos = true - todoItems = make([]tools.TodoItem, len(st.Todos)) + eng.emu.Lock() + eng.history = st.History + eng.emu.Unlock() + if eng.todoStore != nil { + items := make([]tools.TodoItem, len(st.Todos)) for i, t := range st.Todos { - todoItems[i] = tools.TodoItem{ID: t.ID, Title: t.Title, Status: tools.TodoStatus(t.Status)} + items[i] = tools.TodoItem{ID: t.ID, Title: t.Title, Status: tools.TodoStatus(t.Status)} } + eng.todoStore.Update(items) } - } else { - s.history = nil - // Mark that todos should be reset. - if s.todoStore != nil { - updateTodos = true + if eng.env != nil && eng.env.GoalStore != nil { + eng.env.GoalStore.RestoreFromSnapshot(st.Goal) + if eng.handler != nil { + eng.handler.Emit("goal_update", eng.env.GoalStore.Get()) + } } } - s.mu.Unlock() - // Apply todo updates outside the lock to avoid deadlock with OnUpdate callback. - if updateTodos && s.todoStore != nil { - s.todoStore.Update(todoItems) - } + s.setActiveEngine(eng) - // Apply the session's goal state outside the lock. Restore is silent (no - // OnUpdate, so nothing is re-recorded into the session file); broadcast - // the new state to clients explicitly. A brand-new session always resets - // the store so a goal from the previous session does not leak across. - if s.env != nil && s.env.GoalStore != nil { - if resuming { - s.env.GoalStore.RestoreFromSnapshot(goalSnap) - } else { - s.env.GoalStore.Restore(nil) - } - if s.handler != nil { - s.handler.Emit("goal_update", s.env.GoalStore.Get()) - } - } - - // Notify clients. When resuming an existing session, do NOT broadcast session_reset - // (which would wipe the UI that the frontend is about to repopulate from history). + // Brand-new task: tell its view to start clean. if req.SessionID == "" { - s.wsBroker.Broadcast(WSEvent{Type: "session_reset", Data: map[string]string{}}) + s.wsBroker.Broadcast(WSEvent{TaskID: eng.taskID, Type: "session_reset", Data: map[string]string{}}) } - s.mu.RLock() - newSessionID := "" - if s.recorder != nil { - newSessionID = s.recorder.UUID() - } - s.mu.RUnlock() - - writeJSON(w, http.StatusOK, map[string]any{ - "status": "ok", - "session_id": newSessionID, - }) + writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "session_id": eng.taskID}) } func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { + curProvider, curModel := "", "" + if eng := s.activeEngine(); eng != nil { + curProvider, curModel, _ = eng.modelSnapshot() + } if s.registry == nil || s.cfg == nil { writeJSON(w, http.StatusOK, map[string]any{ - "current": map[string]string{"provider": s.providerName, "model": s.modelName}, + "current": map[string]string{"provider": curProvider, "model": curModel}, "providers": []any{}, }) return @@ -1124,16 +1123,20 @@ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, map[string]any{ - "current": map[string]string{"provider": s.providerName, "model": s.modelName}, + "current": map[string]string{"provider": curProvider, "model": curModel}, "providers": result, }) } func (s *Server) handleSwitchModel(w http.ResponseWriter, r *http.Request) { - if s.running.Load() { - writeJSON(w, http.StatusConflict, map[string]string{"error": "agent is currently running"}) + eng := s.activeEngine() + if eng == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "no active task"}) return } + // No running gate: applyModelSwitch swaps eng.agent under eng.emu (the lock the + // run reads it under), so a mid-run switch is safe and takes effect next turn — + // consistent with mode/approval switching. var req struct { Provider string `json:"provider"` @@ -1148,18 +1151,14 @@ func (s *Server) handleSwitchModel(w http.ResponseWriter, r *http.Request) { return } - ag, err := s.createAgent(req.Provider, req.Model) + // Rebuild THIS task's agent for the new model and swap it in under eng.emu + // (the same lock submitMessage uses to read the agent). Keep history. + ag, err := eng.createAgent(req.Provider, req.Model) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } - - s.mu.Lock() - s.agent = ag - s.providerName = req.Provider - s.modelName = req.Model - // Keep history — allow continuing the conversation with a different model. - s.mu.Unlock() + eng.applyModelSwitch(ag, req.Provider, req.Model) // Track in recent models. if state, err := config.LoadModelState(); err == nil { @@ -1167,8 +1166,7 @@ func (s *Server) handleSwitchModel(w http.ResponseWriter, r *http.Request) { _ = config.SaveModelState(state) } - // Notify clients. - s.wsBroker.Broadcast(WSEvent{Type: "model_changed", Data: map[string]string{ + s.wsBroker.Broadcast(WSEvent{Type: "model_changed", TaskID: eng.taskID, Data: map[string]string{ "provider": req.Provider, "model": req.Model, }}) @@ -1193,24 +1191,34 @@ func (s *Server) handleSwitchMode(w http.ResponseWriter, r *http.Request) { } sm := mode.Parse(req.Mode) - // Lock around the agent rebuild + mode/approval mutation so we don't race an - // in-flight submitMessage that reads s.agent under the same lock. The new - // agent (and tool set) takes effect on the next run, like TUI/ACP. - s.mu.Lock() - s.mode = sm.String() - if s.approvalState != nil { - s.approvalState.SetSessionMode(sm) // approval axis (Full access → auto) + eng := s.activeEngine() + if eng == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "no active task"}) + return } - if s.rebuildForMode != nil { - if ag, err := s.rebuildForMode(sm.IsPlan()); err == nil { - s.agent = ag // tool/prompt axis (Plan → read-only) - } else { + // No running gate: applyModeSwitch writes eng.agent under eng.emu, the same + // lock submitMessage reads it under, so a mid-run switch is safe and simply + // takes effect on the next turn (matching TUI/ACP and the "Allow all" path). + + // Rebuild this task's agent FIRST. If the rebuild fails, abort without + // changing the mode/approval axis — otherwise plan mode could be reported while + // a write-capable agent stays live. + var newAg *adk.ChatModelAgent + if eng.rebuildForMode != nil { + ag, err := eng.rebuildForMode(sm.IsPlan()) + if err != nil { config.Logger().Printf("[web] mode switch agent rebuild error: %v", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to switch mode"}) + return } + newAg = ag } - s.mu.Unlock() + if eng.approvalState != nil { + eng.approvalState.SetSessionMode(sm) // approval axis (Full access → auto) + } + eng.applyModeSwitch(sm.String(), newAg) - s.wsBroker.Broadcast(WSEvent{Type: "mode_changed", Data: map[string]string{ + s.wsBroker.Broadcast(WSEvent{Type: "mode_changed", TaskID: eng.taskID, Data: map[string]string{ "mode": sm.String(), }}) @@ -1233,27 +1241,29 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleGetTodos(w http.ResponseWriter, r *http.Request) { - if s.todoStore == nil { + eng := s.activeEngine() + if eng == nil || eng.todoStore == nil { writeJSON(w, http.StatusOK, []any{}) return } - items := s.todoStore.Items() - writeJSON(w, http.StatusOK, items) + writeJSON(w, http.StatusOK, eng.todoStore.Items()) } // handleGetGoal returns the current session goal (or null when none is set). func (s *Server) handleGetGoal(w http.ResponseWriter, _ *http.Request) { - if s.env == nil || s.env.GoalStore == nil { + eng := s.activeEngine() + if eng == nil || eng.env == nil || eng.env.GoalStore == nil { writeJSON(w, http.StatusOK, nil) return } - writeJSON(w, http.StatusOK, s.env.GoalStore.Get()) + writeJSON(w, http.StatusOK, eng.env.GoalStore.Get()) } // handleSetGoal sets (or replaces) the session goal. Unless start=false, it also // kicks off an agent run so work begins immediately. func (s *Server) handleSetGoal(w http.ResponseWriter, r *http.Request) { - if s.env == nil || s.env.GoalStore == nil { + eng := s.activeEngine() + if eng == nil || eng.env == nil || eng.env.GoalStore == nil { writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "goals not available"}) return } @@ -1270,13 +1280,14 @@ func (s *Server) handleSetGoal(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } - g := s.env.GoalStore.Set(objective) + g := eng.env.GoalStore.Set(objective) if req.Start == nil || *req.Start { // Start working immediately when idle; if busy, the continuation guard - // will pick the goal up after the current run finishes. - if s.running.CompareAndSwap(false, true) { - s.submitMessage(tools.GoalKickoffPrompt(objective), s.mode, "", "", nil) + // will pick the goal up after the current run finishes. Targets the active + // task. + if eng.running.CompareAndSwap(false, true) { + s.submitMessage(eng, tools.GoalKickoffPrompt(objective), eng.curMode(), "", "", nil) } } writeJSON(w, http.StatusOK, g) @@ -1284,8 +1295,8 @@ func (s *Server) handleSetGoal(w http.ResponseWriter, r *http.Request) { // handleClearGoal removes the session goal. func (s *Server) handleClearGoal(w http.ResponseWriter, _ *http.Request) { - if s.env != nil && s.env.GoalStore != nil { - s.env.GoalStore.Clear() + if eng := s.activeEngine(); eng != nil && eng.env != nil && eng.env.GoalStore != nil { + eng.env.GoalStore.Clear() } writeJSON(w, http.StatusOK, map[string]string{"status": "cleared"}) } @@ -1293,6 +1304,7 @@ func (s *Server) handleClearGoal(w http.ResponseWriter, _ *http.Request) { func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { var req struct { ID string `json:"id"` + TaskID string `json:"task_id"` Approved bool `json:"approved"` ApproveAll bool `json:"approve_all"` } @@ -1300,15 +1312,22 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } - if err := s.handler.ResolveApproval(req.ID, req.Approved, req.ApproveAll); err != nil { + // Route the resolve to the requesting task's handler. resolveEngine maps an + // empty task_id to the active task (legacy clients) but a NON-empty unknown id + // to nil — so a stray id can't resolve against the active task's handler-local + // approval ids. + reng := s.resolveEngine(req.TaskID) + if reng == nil || reng.handler == nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "no such task"}) + return + } + if err := reng.handler.ResolveApproval(req.ID, req.Approved, req.ApproveAll); err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()}) return } - // "Allow all" promotes the session to auto-approve (the runner flips its - // ApprovalState on resolve). Mirror that into the user-facing selector: keep - // s.mode in sync (so /api/health reports it) and broadcast mode_changed so the - // chat composer's Ask-for-approval / Full-access pill updates to Full access. - s.syncModeAfterApproval(req.Approved, req.ApproveAll) + // "Allow all" promotes that task to auto-approve (the runner flips its + // ApprovalState on resolve). Mirror it onto that task's mode + selector. + s.syncModeAfterApproval(reng, req.Approved, req.ApproveAll) writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } @@ -1317,15 +1336,13 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { // (or a deny) leaves the mode untouched. The runner's ApprovalState is the // source of truth for the approval axis; this only projects it onto the unified // selector the frontend renders. -func (s *Server) syncModeAfterApproval(approved, approveAll bool) { - if !approved || !approveAll { +func (s *Server) syncModeAfterApproval(eng *Engine, approved, approveAll bool) { + if !approved || !approveAll || eng == nil { return } sm := mode.FullAccess - s.mu.Lock() - s.mode = sm.String() - s.mu.Unlock() - s.wsBroker.Broadcast(WSEvent{Type: "mode_changed", Data: map[string]string{ + eng.applyModeSwitch(sm.String(), nil) + s.wsBroker.Broadcast(WSEvent{Type: "mode_changed", TaskID: eng.taskID, Data: map[string]string{ "mode": sm.String(), }}) } @@ -1334,8 +1351,15 @@ func (s *Server) syncModeAfterApproval(approved, approveAll bool) { // The frontend pulls this after rebuilding the timeline (page reload / session // resume / WS reconnect) so an in-flight approval is re-attached as a card // instead of leaving the agent blocked forever. -func (s *Server) handlePendingApproval(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, s.handler.PendingApprovalRequests()) +func (s *Server) handlePendingApproval(w http.ResponseWriter, r *http.Request) { + // Empty task_id → active task; non-empty unknown → empty (don't leak another + // task's pending requests under a stray id). + eng := s.resolveEngine(r.URL.Query().Get("task_id")) + if eng == nil || eng.handler == nil { + writeJSON(w, http.StatusOK, []handler.WebApprovalRequestData{}) + return + } + writeJSON(w, http.StatusOK, eng.handler.PendingApprovalRequests()) } // handleAskUser resolves a pending ask_user request with the user's answers, @@ -1346,6 +1370,7 @@ func (s *Server) handlePendingApproval(w http.ResponseWriter, _ *http.Request) { func (s *Server) handleAskUser(w http.ResponseWriter, r *http.Request) { var req struct { ID string `json:"id"` + TaskID string `json:"task_id"` Answers []struct { QuestionHeader string `json:"question_header"` Answer string `json:"answer"` @@ -1366,7 +1391,14 @@ func (s *Server) handleAskUser(w http.ResponseWriter, r *http.Request) { }) } - if err := s.handler.ResolveAskUser(req.ID, resp); err != nil { + // Route the answer to the requesting task's handler. Empty task_id → active; + // non-empty unknown → reject (ids are handler-local). + eng := s.resolveEngine(req.TaskID) + if eng == nil || eng.handler == nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "no such task"}) + return + } + if err := eng.handler.ResolveAskUser(req.ID, resp); err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()}) return } @@ -1377,21 +1409,41 @@ func (s *Server) handleAskUser(w http.ResponseWriter, r *http.Request) { // The frontend pulls this after rebuilding the timeline (page reload / session // resume) so an in-flight question is re-attached to its tool card instead of // leaving the agent blocked forever. -func (s *Server) handlePendingAskUser(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, s.handler.PendingAskUserRequests()) +func (s *Server) handlePendingAskUser(w http.ResponseWriter, r *http.Request) { + eng := s.resolveEngine(r.URL.Query().Get("task_id")) + if eng == nil || eng.handler == nil { + writeJSON(w, http.StatusOK, []handler.WebAskUserRequestData{}) + return + } + writeJSON(w, http.StatusOK, eng.handler.PendingAskUserRequests()) +} + +// withinWorkspace reports whether abs is the workspace root or strictly inside +// it. Uses filepath.Rel rather than strings.HasPrefix so a sibling like /repo2 +// can't escape /repo, and an empty root rejects everything. +func withinWorkspace(root, abs string) bool { + if root == "" { + return false + } + rel, err := filepath.Rel(root, abs) + if err != nil { + return false + } + return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator))) } func (s *Server) handleListFiles(w http.ResponseWriter, r *http.Request) { + pwd := s.activePwd() dir := r.URL.Query().Get("path") if dir == "" { - dir = s.pwd + dir = pwd } else if !filepath.IsAbs(dir) { - dir = filepath.Join(s.pwd, dir) + dir = filepath.Join(pwd, dir) } - // Prevent path traversal. + // Prevent path traversal / sibling escape. abs := filepath.Clean(dir) - if !strings.HasPrefix(abs, s.pwd) { + if !withinWorkspace(pwd, abs) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid path"}) return } @@ -1431,14 +1483,15 @@ func (s *Server) handleReadFile(w http.ResponseWriter, r *http.Request) { return } + pwd := s.activePwd() abs := path if !filepath.IsAbs(abs) { - abs = filepath.Join(s.pwd, abs) + abs = filepath.Join(pwd, abs) } - // Prevent path traversal. + // Prevent path traversal / sibling escape. abs = filepath.Clean(abs) - if !strings.HasPrefix(abs, s.pwd) { + if !withinWorkspace(pwd, abs) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "path outside workspace"}) return } @@ -1480,7 +1533,7 @@ func (s *Server) handleExec(w http.ResponseWriter, r *http.Request) { defer cancel() cmd := exec.CommandContext(ctx, "sh", "-c", req.Command) - cmd.Dir = s.pwd + cmd.Dir = s.activePwd() output, err := cmd.CombinedOutput() exitCode := 0 @@ -1528,7 +1581,7 @@ func (s *Server) handleDiff(w http.ResponseWriter, r *http.Request) { } cmd := exec.CommandContext(s.ctx, "git", args...) - cmd.Dir = s.pwd + cmd.Dir = s.activePwd() output, _ := cmd.CombinedOutput() // Parse diff into structured entries @@ -1551,7 +1604,7 @@ func (s *Server) handleDiff(w http.ResponseWriter, r *http.Request) { case "branch": statCmd = exec.CommandContext(s.ctx, "git", "diff", "HEAD~1", "--stat", "--no-color") } - statCmd.Dir = s.pwd + statCmd.Dir = s.activePwd() _, _ = statCmd.CombinedOutput() // Parse unified diff into per-file entries @@ -1635,30 +1688,36 @@ func countDiffLines(patch string) (adds, dels int) { // takeSessionSnapshot records the current git working tree state // so that session-scoped diffs can be computed later. -func (s *Server) takeSessionSnapshot() { +func (s *Server) takeSessionSnapshot(eng *Engine) { + if eng == nil { + return + } // Use "git stash create" to get a tree-ish of the current state without // actually stashing. If there are no changes, use HEAD. cmd := exec.CommandContext(s.ctx, "git", "stash", "create") - cmd.Dir = s.pwd + cmd.Dir = eng.pwd out, err := cmd.Output() snapshot := strings.TrimSpace(string(out)) if err != nil || snapshot == "" { // No local changes — use HEAD as baseline cmd2 := exec.CommandContext(s.ctx, "git", "rev-parse", "HEAD") - cmd2.Dir = s.pwd + cmd2.Dir = eng.pwd out2, _ := cmd2.Output() snapshot = strings.TrimSpace(string(out2)) } - s.mu.Lock() - s.sessionSnapshot = snapshot - s.mu.Unlock() + eng.emu.Lock() + eng.sessionSnapshot = snapshot + eng.emu.Unlock() } // handleSessionDiff computes the diff between the session start snapshot and current state. func (s *Server) handleSessionDiff(w http.ResponseWriter, _ *http.Request) { - s.mu.RLock() - snapshot := s.sessionSnapshot - s.mu.RUnlock() + snapshot := "" + if eng := s.activeEngine(); eng != nil { + eng.emu.Lock() + snapshot = eng.sessionSnapshot + eng.emu.Unlock() + } type diffEntry struct { File string `json:"file"` @@ -1678,7 +1737,7 @@ func (s *Server) handleSessionDiff(w http.ResponseWriter, _ *http.Request) { // Diff from snapshot to current working tree cmd := exec.CommandContext(s.ctx, "git", "diff", snapshot, "--no-color") - cmd.Dir = s.pwd + cmd.Dir = s.activePwd() output, _ := cmd.CombinedOutput() var entries []diffEntry @@ -1805,17 +1864,16 @@ func (s *Server) reloadMCPAndRebuild() error { } s.mu.Unlock() } - if s.createAgent != nil && !s.needsSetup { - s.mu.RLock() - prov, mod := s.providerName, s.modelName - s.mu.RUnlock() - ag, err := s.createAgent(prov, mod) - if err != nil { - return err + if !s.needsSetup { + // Rebuild the foreground task's agent so the new MCP tools take effect. + if eng := s.activeEngine(); eng != nil && eng.createAgent != nil { + prov, mod, _ := eng.modelSnapshot() + ag, err := eng.createAgent(prov, mod) + if err != nil { + return err + } + eng.setAgent(ag) } - s.mu.Lock() - s.agent = ag - s.mu.Unlock() } return nil } @@ -2166,7 +2224,11 @@ func (s *Server) handleBrowse(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleCreatePTY(w http.ResponseWriter, r *http.Request) { - id, err := s.ptyMgr.create(s.pwd) + pwd, owner := "", "" + if eng := s.activeEngine(); eng != nil { + pwd, owner = eng.pwd, eng.taskID + } + id, err := s.ptyMgr.create(pwd, owner) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -2189,14 +2251,39 @@ func (s *Server) handlePTYWebSocket(w http.ResponseWriter, r *http.Request) { s.ptyMgr.serveWS(w, r, id) } -func (s *Server) handleSwitchProject(w http.ResponseWriter, r *http.Request) { - if s.running.Load() { - writeJSON(w, http.StatusConflict, map[string]string{ - "error": "agent is running, cannot switch project", - }) +// handleValidatePaths reports which of the given local paths no longer exist (or +// are not directories). The web UI keeps its workspace list in localStorage and +// can't stat the disk itself, so it calls this to prune dead workspaces from the +// picker instead of letting the user click one and hit "path does not exist". +// Callers send local paths only; ssh:// labels can't be stat'd here and would be +// wrongly reported missing, so they must be filtered out client-side. +func (s *Server) handleValidatePaths(w http.ResponseWriter, r *http.Request) { + var req struct { + Paths []string `json:"paths"` + } + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) return } - if s.switchProject == nil { + + missing := []string{} + for _, p := range req.Paths { + if p == "" { + continue + } + if info, err := os.Stat(p); err != nil || !info.IsDir() { + missing = append(missing, p) + } + } + + writeJSON(w, http.StatusOK, map[string]any{"missing": missing}) +} + +func (s *Server) handleSwitchProject(w http.ResponseWriter, r *http.Request) { + // No running gate: "switch project" builds a NEW independent engine and leaves + // the previous task running in the background — switching to another task while + // one is chatting is the whole point of concurrent tasks. + if s.newEngine == nil { writeJSON(w, http.StatusNotImplemented, map[string]string{ "error": "project switching is not supported", }) @@ -2222,27 +2309,30 @@ func (s *Server) handleSwitchProject(w http.ResponseWriter, r *http.Request) { return } - // Kill all PTY sessions (they were in the old directory). - s.ptyMgr.closeAll() + // Snapshot the outgoing task once, build the new engine BEFORE tearing down its + // PTYs — a failed build must not kill the current task's terminals. + prevTaskID, curMode := "", "" + if cur := s.activeEngine(); cur != nil { + prevTaskID, curMode = cur.taskID, cur.curMode() + } - // Call the switchProject callback to rebuild env, prompt, agent. - ag, rec, err := s.switchProject(req.Path) + // "Switch project" = build a fresh engine rooted at the new path and make it + // active. This replaces in-place env mutation, so no other live task's + // execution context is disturbed. + eng, err := s.buildLocalEngine("", req.Path, curMode) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{ "error": fmt.Sprintf("failed to switch project: %v", err), }) return } + s.ptyMgr.closeForTask(prevTaskID) // outgoing task's PTYs only + s.setActiveEngine(eng) - s.mu.Lock() - s.pwd = req.Path - s.agent = ag - s.recorder = rec - s.history = nil - s.mu.Unlock() - - // Reset todos. - s.todoStore.Update(nil) + // Reset todos for the (now empty) active task view. + if eng.todoStore != nil { + eng.todoStore.Update(nil) + } // Broadcast project change to clients. s.wsBroker.Broadcast(WSEvent{ @@ -2260,13 +2350,20 @@ func (s *Server) handleSwitchProject(w http.ResponseWriter, r *http.Request) { func (s *Server) handleGetApprovalMode(w http.ResponseWriter, r *http.Request) { autoApprove := false - if s.approvalState != nil { - autoApprove = s.approvalState.GetMode() == handler.ModeAuto + if eng := s.activeEngine(); eng != nil && eng.approvalState != nil { + autoApprove = eng.approvalState.GetMode() == handler.ModeAuto } writeJSON(w, http.StatusOK, map[string]any{"auto_approve": autoApprove}) } func (s *Server) handleSetApprovalMode(w http.ResponseWriter, r *http.Request) { + eng := s.activeEngine() + if eng == nil { + writeJSON(w, http.StatusServiceUnavailable, map[string]string{"error": "no active task"}) + return + } + // No running gate: the rebuild is emu-safe and applies next turn, consistent + // with the "Allow all" approval path which also flips full_access mid-run. var req struct { AutoApprove bool `json:"auto_approve"` } @@ -2280,35 +2377,40 @@ func (s *Server) handleSetApprovalMode(w http.ResponseWriter, r *http.Request) { if req.AutoApprove { sm = mode.FullAccess } - s.mu.Lock() - s.mode = sm.String() - if s.approvalState != nil { - s.approvalState.SetSessionMode(sm) - } - if s.rebuildForMode != nil { - if ag, err := s.rebuildForMode(false); err == nil { - s.agent = ag - } else { + // Rebuild first; abort the toggle if the rebuild fails (don't desync the + // reported mode from the live agent). + var newAg *adk.ChatModelAgent + if eng.rebuildForMode != nil { + ag, err := eng.rebuildForMode(false) + if err != nil { config.Logger().Printf("[web] approval mode agent rebuild error: %v", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to set approval mode"}) + return } + newAg = ag + } + if eng.approvalState != nil { + eng.approvalState.SetSessionMode(sm) } + eng.applyModeSwitch(sm.String(), newAg) // Persist as the default startup mode so the preference survives restarts — - // resolveStartupMode reads cfg.DefaultMode. This makes the Settings toggle a - // true "default", not just a one-off runtime flip. + // resolveStartupMode reads cfg.DefaultMode. cfgMu serializes the config RMW. + s.cfgMu.Lock() if s.cfg != nil { s.cfg.DefaultMode = sm.String() if err := config.SaveConfig(s.cfg); err != nil { config.Logger().Printf("[web] approval mode save config failed: %v", err) } } - s.mu.Unlock() + s.cfgMu.Unlock() s.wsBroker.Broadcast(WSEvent{ - Type: "approval_mode_changed", - Data: map[string]any{"auto_approve": req.AutoApprove}, + Type: "approval_mode_changed", + TaskID: eng.taskID, + Data: map[string]any{"auto_approve": req.AutoApprove}, }) // Also emit the unified mode event so updated clients keep their selector synced. - s.wsBroker.Broadcast(WSEvent{Type: "mode_changed", Data: map[string]string{"mode": sm.String()}}) + s.wsBroker.Broadcast(WSEvent{Type: "mode_changed", TaskID: eng.taskID, Data: map[string]string{"mode": sm.String()}}) writeJSON(w, http.StatusOK, map[string]any{"auto_approve": req.AutoApprove}) } @@ -2350,51 +2452,87 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { if err := json.Unmarshal(msg, &incoming); err != nil { continue } - s.handleWSMessage(incoming) + s.handleWSMessage(client, incoming) } } -func (s *Server) handleWSMessage(msg WSIncoming) { +func (s *Server) handleWSMessage(client *WSClient, msg WSIncoming) { switch msg.Type { case "ping": - s.wsBroker.Broadcast(WSEvent{Type: "pong"}) + // Unicast the pong to the pinging client (broadcasting it woke every + // client unnecessarily). + if data, err := json.Marshal(WSEvent{Type: "pong"}); err == nil { + client.send(data) + } + case "subscribe": + var data struct { + TaskIDs []string `json:"task_ids"` + } + if json.Unmarshal(msg.Data, &data) == nil { + client.subscribe(data.TaskIDs) + } + case "unsubscribe": + var data struct { + TaskIDs []string `json:"task_ids"` + } + if json.Unmarshal(msg.Data, &data) == nil { + client.unsubscribe(data.TaskIDs) + } case "approval": var data struct { ID string `json:"id"` + TaskID string `json:"task_id"` Approved bool `json:"approved"` ApproveAll bool `json:"approve_all"` } if err := json.Unmarshal(msg.Data, &data); err != nil { return } - if err := s.handler.ResolveApproval(data.ID, data.Approved, data.ApproveAll); err != nil { + // Empty task_id → active task (legacy); non-empty unknown → drop (ids are + // handler-local and could collide with another task's). + reng := s.resolveEngine(data.TaskID) + if reng == nil || reng.handler == nil { + return + } + if err := reng.handler.ResolveApproval(data.ID, data.Approved, data.ApproveAll); err != nil { config.Logger().Printf("[ws] resolve approval failed for id=%q: %v", data.ID, err) return } // Same mode-sync as the POST path: an "allow all" over WS must also // update the selector pill the user is looking at. - s.syncModeAfterApproval(data.Approved, data.ApproveAll) + s.syncModeAfterApproval(reng, data.Approved, data.ApproveAll) } } // --- Stop handler --- func (s *Server) handleStop(w http.ResponseWriter, r *http.Request) { - if !s.running.Load() { + // Cancel only the targeted task. task_id comes via query or JSON body; absent, + // fall back to the active task (legacy clients). + taskID := r.URL.Query().Get("task_id") + if taskID == "" { + var req struct { + TaskID string `json:"task_id"` + } + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<16)).Decode(&req) + taskID = req.TaskID + } + + eng := s.resolveEngine(taskID) + if eng == nil || !eng.running.Load() { writeJSON(w, http.StatusOK, map[string]string{"status": "not_running"}) return } - s.mu.RLock() - cancel := s.runCancel - s.mu.RUnlock() - + eng.emu.Lock() + cancel := eng.runCancel + eng.emu.Unlock() if cancel != nil { cancel() } - // Notify clients. - s.handler.OnAgentDone(fmt.Errorf("stopped by user")) + // Notify clients on that task's channel. + eng.handler.OnAgentDone(fmt.Errorf("stopped by user")) writeJSON(w, http.StatusOK, map[string]string{"status": "stopped"}) } @@ -2423,7 +2561,7 @@ func (s *Server) handleListSSH(w http.ResponseWriter, r *http.Request) { } current := "local" - if s.env != nil && s.env.IsRemote() { + if eng := s.activeEngine(); eng != nil && eng.env != nil && eng.env.IsRemote() { current = "ssh" } @@ -2488,7 +2626,10 @@ func (s *Server) handleToggleSkill(w http.ResponseWriter, r *http.Request) { return } - s.mu.Lock() + // cfgMu (not s.mu) serializes the cfg read-modify-write+save, so concurrent + // approval-mode / MCP / skill saves can't clobber each other in memory or on + // disk. + s.cfgMu.Lock() // Rebuild the disabled set from config. disabled := make(map[string]bool, len(s.cfg.DisabledSkills)) for _, n := range s.cfg.DisabledSkills { @@ -2506,23 +2647,21 @@ func (s *Server) handleToggleSkill(w http.ResponseWriter, r *http.Request) { sort.Strings(list) s.cfg.DisabledSkills = list if err := config.SaveConfig(s.cfg); err != nil { - s.mu.Unlock() + s.cfgMu.Unlock() writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } - s.mu.Unlock() + s.cfgMu.Unlock() s.skillLoader.SetDisabled(list) - // Rebuild the agent so the system prompt (skill descriptions) and load_skill - // tool reflect the change on the next run. - if s.createAgent != nil && !s.needsSetup { - s.mu.RLock() - prov, mod := s.providerName, s.modelName - s.mu.RUnlock() - if ag, err := s.createAgent(prov, mod); err == nil { - s.mu.Lock() - s.agent = ag - s.mu.Unlock() + // Rebuild the foreground task's agent so the system prompt (skill descriptions) + // and load_skill tool reflect the change on the next run. + if !s.needsSetup { + if eng := s.activeEngine(); eng != nil && eng.createAgent != nil { + prov, mod, _ := eng.modelSnapshot() + if ag, err := eng.createAgent(prov, mod); err == nil { + eng.setAgent(ag) + } } } writeJSON(w, http.StatusOK, map[string]any{"status": "ok", "name": name, "enabled": req.Enabled}) @@ -2754,22 +2893,31 @@ func (s *Server) handleSetupComplete(w http.ResponseWriter, r *http.Request) { return } - // Create the agent with the new config. - ag, err := s.createAgent(req.Provider, req.Model) + // Create the foreground task's agent with the new config. + eng := s.activeEngine() + if eng == nil || eng.createAgent == nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "no active task to configure"}) + return + } + ag, err := eng.createAgent(req.Provider, req.Model) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": "failed to create agent: " + err.Error()}) return } - + eng.applyModelSwitch(ag, req.Provider, req.Model) + // Publish the new config + registry to the live server so endpoints + // (/api/models, context-limit, etc.) reflect the just-configured provider + // without a restart. + s.cfgMu.Lock() + s.cfg = cfg + s.registry = model.NewModelRegistryWithConfig(cfg) + s.cfgMu.Unlock() s.mu.Lock() - s.agent = ag - s.providerName = req.Provider - s.modelName = req.Model s.needsSetup = false s.mu.Unlock() // Notify clients that setup is complete. - s.wsBroker.Broadcast(WSEvent{Type: "model_changed", Data: map[string]string{ + s.wsBroker.Broadcast(WSEvent{Type: "model_changed", TaskID: eng.taskID, Data: map[string]string{ "provider": req.Provider, "model": req.Model, }}) diff --git a/internal/web/tasks_test.go b/internal/web/tasks_test.go index 6d2a19a..58bc18a 100644 --- a/internal/web/tasks_test.go +++ b/internal/web/tasks_test.go @@ -35,7 +35,7 @@ func seedIndex(t *testing.T, sessions map[string][]session.SessionMeta) { // P0-1: GET /api/workspace on a non-git directory returns empty branch + not dirty. func TestWorkspaceNonGit(t *testing.T) { - s := &Server{ctx: context.Background(), pwd: t.TempDir()} + s := &Server{Engine: &Engine{pwd: t.TempDir()}, ctx: context.Background()} rec := httptest.NewRecorder() s.handleWorkspace(rec, httptest.NewRequest(http.MethodGet, "/api/workspace", nil)) if rec.Code != http.StatusOK { @@ -163,7 +163,7 @@ func TestDeleteTaskCrossProject(t *testing.T) { "/work/other": {{UUID: "oth-1", Project: "/work/other"}, {UUID: "oth-2", Project: "/work/other"}}, }) // Active project is /work/active; delete a task in /work/other. - s := &Server{pwd: "/work/active"} + s := &Server{Engine: &Engine{pwd: "/work/active"}} rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodDelete, "/api/sessions/oth-1", nil) req.SetPathValue("id", "oth-1") diff --git a/internal/web/usage.go b/internal/web/usage.go index 9af8b18..6438235 100644 --- a/internal/web/usage.go +++ b/internal/web/usage.go @@ -84,25 +84,27 @@ func (s *Server) handleUsageStats(w http.ResponseWriter, r *http.Request) { // meaningful after the fact). func (s *Server) handleTaskStats(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") + resp := map[string]any{"uuid": id} - s.mu.RLock() - activeUUID := "" - if s.recorder != nil { - activeUUID = s.recorder.UUID() + // Resolve the live engine for this task id: by task id, or (covering a + // recorder swap) by the active engine's recorder UUID. + eng := s.resolveEngine(id) + if eng == nil { + if a := s.activeEngine(); a != nil && a.recUUID() == id { + eng = a + } } - s.mu.RUnlock() - - resp := map[string]any{"uuid": id} - if id != "" && id == activeUUID { - full := s.tokenUsage.GetFull() - last := s.tokenUsage.GetLastDetail() + // Any LIVE task engine (not just the foreground one) reports live stats. + if id != "" && eng != nil { + full := eng.tokenUsage.GetFull() + last := eng.tokenUsage.GetLastDetail() var bd usage.ContextBreakdown - if s.breakdownFn != nil { - bd = s.breakdownFn() + if eng.breakdownFn != nil { + bd = eng.breakdownFn() } - bd.ContextLimit = s.currentModelContextLimit() + bd.ContextLimit = s.currentModelContextLimit(eng) // Messages occupy whatever the last prompt held beyond the static // assembly (system prompt + tools + MCP + skills). if msg := last.PromptTokens - bd.StaticTotal(); msg > 0 { @@ -111,8 +113,8 @@ func (s *Server) handleTaskStats(w http.ResponseWriter, r *http.Request) { resp["is_active"] = true resp["context"] = bd - resp["cache_hit_rate"] = s.tokenUsage.CacheHitRate() - resp["cache_supported"] = s.tokenUsage.CacheObserved() + resp["cache_hit_rate"] = eng.tokenUsage.CacheHitRate() + resp["cache_supported"] = eng.tokenUsage.CacheObserved() resp["tokens"] = map[string]any{ "total_tokens": full.TotalTokens, "prompt_tokens": full.PromptTokens, diff --git a/internal/web/usage_test.go b/internal/web/usage_test.go index 513a81c..ebc2e47 100644 --- a/internal/web/usage_test.go +++ b/internal/web/usage_test.go @@ -129,10 +129,12 @@ func TestTaskStatsActive(t *testing.T) { tu.Add(model.AddParams{Prompt: 1000, Completion: 200, Total: 1200, Cached: 800}) s := &Server{ - recorder: rec, - tokenUsage: tu, - breakdownFn: func() usage.ContextBreakdown { - return usage.ContextBreakdown{SystemPromptTokens: 100, SystemToolsTokens: 200, MCPToolsTokens: 50, SkillsTokens: 30} + Engine: &Engine{ + recorder: rec, + tokenUsage: tu, + breakdownFn: func() usage.ContextBreakdown { + return usage.ContextBreakdown{SystemPromptTokens: 100, SystemToolsTokens: 200, MCPToolsTokens: 50, SkillsTokens: 30} + }, }, } rr := httptest.NewRecorder() @@ -177,7 +179,7 @@ func TestTaskStatsHistorical(t *testing.T) { mustRecord(t, store, usage.Event{Date: today, Session: "sess-B", Model: "m", Prompt: 999, Cached: 0, Completion: 9, Total: 1008, Calls: 1}) // No recorder → every query is treated as historical. - s := &Server{usageStore: store} + s := &Server{Engine: &Engine{}, usageStore: store} rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/tasks/sess-A/stats", nil) req.SetPathValue("id", "sess-A") diff --git a/internal/web/ws.go b/internal/web/ws.go index d8c061c..4b1edec 100644 --- a/internal/web/ws.go +++ b/internal/web/ws.go @@ -15,13 +15,58 @@ type WSClient struct { sendCh chan []byte mu sync.Mutex closeCh sync.Once + + // subMu guards the task subscription set. subAll is true until the client + // sends its first `subscribe` (so legacy clients that never subscribe keep + // receiving every task's events). Once subscribed, only the listed task ids + // (plus global events, TaskID=="") are delivered, preventing a busy task from + // flooding a client that is only viewing a quiet one. + subMu sync.Mutex + subs map[string]bool + subAll bool } func newWSClient(conn *websocket.Conn) *WSClient { return &WSClient{ conn: conn, sendCh: make(chan []byte, 256), + subs: make(map[string]bool), + subAll: true, + } +} + +// subscribe replaces the client's subscription set with the given task ids +// (additive across calls). After the first call the client stops receiving +// every task and only gets its subscribed ids + global events. +func (c *WSClient) subscribe(taskIDs []string) { + c.subMu.Lock() + defer c.subMu.Unlock() + c.subAll = false + for _, id := range taskIDs { + if id != "" { + c.subs[id] = true + } + } +} + +// unsubscribe drops task ids from the client's subscription set. +func (c *WSClient) unsubscribe(taskIDs []string) { + c.subMu.Lock() + defer c.subMu.Unlock() + for _, id := range taskIDs { + delete(c.subs, id) + } +} + +// wants reports whether this client should receive an event for taskID. Global +// events (empty taskID) always pass. +func (c *WSClient) wants(taskID string) bool { + if taskID == "" { + return true } + c.subMu.Lock() + defer c.subMu.Unlock() + return c.subAll || c.subs[taskID] } func (c *WSClient) writePump() { @@ -93,6 +138,9 @@ func (b *WSBroker) Broadcast(event WSEvent) { b.mu.RLock() defer b.mu.RUnlock() for _, client := range b.clients { + if !client.wants(event.TaskID) { + continue + } client.send(data) } } @@ -118,7 +166,12 @@ func (b *WSBroker) ClientCount() int { // WSEvent is a WebSocket message envelope. type WSEvent struct { Type string `json:"type"` - Data any `json:"data,omitempty"` + // TaskID tags the event with the task (engine) it came from so the client can + // route it to the right task view, and so the broker can deliver it only to + // clients subscribed to that task. Empty for global/server-wide events + // (mcp_changed, model_changed, pong, …), which every client receives. + TaskID string `json:"task_id,omitempty"` + Data any `json:"data,omitempty"` } // WSIncoming represents a message from the client over WebSocket. diff --git a/web/package.json b/web/package.json index 0cef742..ce1422d 100644 --- a/web/package.json +++ b/web/package.json @@ -18,6 +18,7 @@ "@headlessui/tailwindcss": "^0.2.2", "@headlessui/vue": "^1.7.23", "@heroicons/vue": "^2.2.0", + "@lobehub/icons-static-svg": "^1.91.0", "@tailwindcss/typography": "^0.5.19", "@tailwindcss/vite": "^4.2.2", "@tauri-apps/api": "^2.9.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 49e4abc..edc3a53 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -17,6 +17,9 @@ importers: '@heroicons/vue': specifier: ^2.2.0 version: 2.2.0(vue@3.5.32(typescript@6.0.2)) + '@lobehub/icons-static-svg': + specifier: ^1.91.0 + version: 1.91.0 '@tailwindcss/typography': specifier: ^0.5.19 version: 0.5.19(tailwindcss@4.2.2) @@ -388,6 +391,9 @@ packages: '@jridgewell/trace-mapping@0.3.31': resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} + '@lobehub/icons-static-svg@1.91.0': + resolution: {integrity: sha512-ZDflEq0uUvAkH4WK4h3qNvvY09ts4OqUb5azD7A0xKfcuYhffGwB1Q/As2RguZYq4Gh4v925CJ8iodiClzc4zw==} + '@napi-rs/wasm-runtime@1.1.3': resolution: {integrity: sha512-xK9sGVbJWYb08+mTJt3/YV24WxvxpXcXtP6B172paPZ+Ts69Re9dAr7lKwJoeIx8OoeuimEiRZ7umkiUVClmmQ==} peerDependencies: @@ -2483,6 +2489,8 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.5 + '@lobehub/icons-static-svg@1.91.0': {} + '@napi-rs/wasm-runtime@1.1.3(@emnapi/core@1.9.2)(@emnapi/runtime@1.9.2)': dependencies: '@emnapi/core': 1.9.2 diff --git a/web/src/App.vue b/web/src/App.vue index db0e009..44640cb 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -52,6 +52,10 @@ provide('openRemoteConnect', openRemoteConnect) // workspace's session, instead of WorkspacePicker landing on a blank welcome // while the projects modal restored the session. provide('onWorkspaceSwitched', () => onProjectSwitched()) +// Sidebar "+" on a workspace row: switch to that workspace (full workspace-scoped +// reload, like onProjectSwitched) but land on a fresh welcome screen so the next +// message starts a brand-new task there — instead of restoring its last session. +provide('onNewTaskInProject', (path: string) => startNewTaskInProject(path)) // When the wizard is launched from Settings it stacks ON TOP of the Settings // overlay. headlessui treats a click inside the wizard as an "outside" click for @@ -103,8 +107,28 @@ function scrollToBottom(smooth = true) { }) } +// Sending a message is an explicit action: snap to the latest even if the user +// had scrolled up into history. Marking atBottom also re-arms the timeline watch +// so the streaming reply keeps following. +function onComposerSent() { + isAtBottom.value = true + nextTick(() => scrollToBottom()) +} + // WebSocket connection const { connected } = useWebSocket({ + activeTaskId: () => store.currentSessionId, + onTaskStatus: (taskId, running) => { + // Live sidebar running indicator for ANY task (incl. backgrounded ones). + projectStore.setTaskRunning(taskId, running) + // Keep the composer's Stop/Send button in sync with the *viewed* task, so a + // run that starts or stops while you're looking at it flips the button live + // (task_status is the one event that reaches us for the active task without + // a fresh agent_start/agent_done round-trip). + if (taskId === store.currentSessionId) store.isRunning = running + // Re-sync persisted status + recency/order from the server. + projectStore.fetchAllTasks() + }, onAgentStart: () => { store.isRunning = true }, onAgentText: (data) => store.appendAgentText(data.text), onToolCall: (data) => store.addToolCall(data.name, data.args, data.tool_call_id, data.display_info), @@ -120,7 +144,7 @@ const { connected } = useWebSocket({ store.addApprovalRequest(data) notify(t('notifications.approvalNeeded'), t('notifications.approvalBody')) }, - onAskUserRequest: (data) => store.attachAskUserRequest(data.id, data.questions), + onAskUserRequest: (data) => store.attachAskUserRequest(data.id, data.questions, data.task_id), onSessionReset: () => store.clearChat(), onModelChanged: (data) => { store.providerName = data.provider @@ -336,6 +360,22 @@ async function onProjectSwitched() { await store.restoreCurrentSession() } +// Switch to a workspace (if not already active) and open a fresh welcome screen +// there, so the next message starts a new task in it. Mirrors onProjectSwitched's +// workspace-scoped reload but deliberately skips restoreCurrentSession — the +// whole point is a blank composer, not the project's last conversation. +async function startNewTaskInProject(path: string): Promise { + const active = projectStore.activeProject?.path || store.pwd + if (path !== active) { + const ok = await projectStore.openProject(path) + if (!ok) return false + await store.fetchHealth() + loadWorkspaceState() + } + await store.newSession() + return true +} + function onSetupComplete() { needsSetup.value = false connectionError.value = false @@ -446,7 +486,7 @@ function startResize(e: MouseEvent) {
- +
@@ -457,7 +497,7 @@ function startResize(e: MouseEvent) {
@@ -531,8 +571,13 @@ function startResize(e: MouseEvent) {
- - + +
+ + +
@@ -658,6 +703,17 @@ function startResize(e: MouseEvent) { cursor: not-allowed; } +/* Feather (羽化) the bottom edge of the timeline so messages softly dissolve + into the surface as they meet the composer, instead of being hard-cut. A mask + is background-agnostic (content goes transparent, revealing the panel behind) + so it adapts to light/dark with no color to maintain. The fade height tracks + the timeline's bottom padding (py-6 ≈ 24px) so the last message stays crisp at + rest and only fades while scrolling past the edge. */ +.messages-feather { + -webkit-mask-image: linear-gradient(to bottom, #000 calc(100% - 28px), transparent 100%); + mask-image: linear-gradient(to bottom, #000 calc(100% - 28px), transparent 100%); +} + /* The conversation + composer live in one inset surface panel so the chat canvas reads as distinct from the sidebar shell, wrapped with breathing room above (below the top bar) and below (above the window edge) — 包裹感. */ diff --git a/web/src/components/BranchPicker.vue b/web/src/components/BranchPicker.vue index 815410a..3e5253a 100644 --- a/web/src/components/BranchPicker.vue +++ b/web/src/components/BranchPicker.vue @@ -10,9 +10,13 @@ withDefaults(defineProps<{ placement?: 'top' | 'bottom' }>(), { placement: 'top' }) -const { current, branches, switching, error, checkout } = useBranch() +const { current, branches, switching, error, pending, checkout, resolvePending, cancelPending } = + useBranch() const { t } = useI18n() +// How many at-risk files to list before collapsing into a "+N more" line. +const MAX_FILES = 8 + const query = ref('') const creating = ref(false) const newName = ref('') @@ -35,7 +39,18 @@ async function pick(branch: string, close: () => void) { reset() close() } - // On failure keep the panel open so the error message stays visible. + // On failure keep the panel open: either an error message or — when git refused + // a dirty-tree switch — the confirmation dialog stays visible. +} + +// Retry a blocked switch with the chosen strategy (stash keeps changes, force +// discards them). Closes the panel on success. +async function applyStrategy(strategy: 'stash' | 'force', close: () => void) { + const ok = await resolvePending(strategy) + if (ok) { + reset() + close() + } } function startCreate() { @@ -59,6 +74,7 @@ function reset() { creating.value = false newName.value = '' error.value = '' + cancelPending() } @@ -67,7 +83,12 @@ function reset() {
- +
{{ error }}
+ +
+ + + +
+
{{ t('branches.confirmHint') }}
-
{{ error }}
+ @@ -311,6 +364,92 @@ function reset() { word-break: break-word; } +.bp-confirm { + padding: 6px 6px 4px; +} +.bp-confirm-title { + font-size: 12.5px; + font-weight: 600; + color: var(--color-foreground); +} +.bp-confirm-intro { + margin-top: 4px; + font-size: 11.5px; + line-height: 1.45; + color: var(--color-muted-foreground); +} +.bp-confirm-files { + margin: 8px 0 2px; + padding: 6px 8px; + max-height: 132px; + overflow-y: auto; + list-style: none; + border: 1px solid var(--color-border); + border-radius: var(--radius-md); + background: var(--color-background); +} +.bp-confirm-files li { + font-family: var(--font-mono); + font-size: 11px; + line-height: 1.6; + color: var(--color-foreground); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} +.bp-confirm-files .bp-confirm-more { + font-family: inherit; + color: var(--color-muted-foreground); +} +.bp-confirm-actions { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-top: 10px; +} +.bp-confirm-btn { + flex: 1 1 auto; + height: 30px; + padding: 0 10px; + border: 1px solid var(--color-border); + border-radius: var(--radius-md); + background: var(--color-surface); + color: var(--color-foreground); + font-size: 12px; + font-weight: 500; + cursor: pointer; + transition: background 0.12s, opacity 0.15s; +} +.bp-confirm-btn:hover:not(:disabled) { + background: var(--color-muted); +} +.bp-confirm-btn:disabled { + opacity: 0.5; + cursor: not-allowed; +} +.bp-confirm-btn.stash { + border-color: transparent; + background: var(--color-accent-neutral); + color: var(--color-surface); +} +.bp-confirm-btn.stash:hover:not(:disabled) { + background: var(--color-accent-neutral); + opacity: 0.9; +} +.bp-confirm-btn.discard { + border-color: var(--color-error-fg); + color: var(--color-error-fg); +} +.bp-confirm-btn.discard:hover:not(:disabled) { + background: var(--color-error-bg); +} +.bp-confirm-hint { + margin-top: 8px; + font-size: 10.5px; + line-height: 1.45; + color: var(--color-muted-foreground); +} + .bp-actions { margin-top: 4px; padding-top: 4px; diff --git a/web/src/components/ChatInput.vue b/web/src/components/ChatInput.vue index 69b8f4f..4bf6165 100644 --- a/web/src/components/ChatInput.vue +++ b/web/src/components/ChatInput.vue @@ -7,6 +7,7 @@ import type { SlashCommandInfo, ChatImage } from '@/types/api' import WorkspacePicker from '@/components/WorkspacePicker.vue' import BranchPicker from '@/components/BranchPicker.vue' import ContextCapacityPopup from '@/components/ContextCapacityPopup.vue' +import ProviderIcon from '@/components/ProviderIcon.vue' import { HandRaisedIcon, ShieldExclamationIcon, ClipboardDocumentListIcon, BoltIcon, PlusIcon, PaperClipIcon, XMarkIcon, ChevronDownIcon, StopIcon, PaperAirplaneIcon, MagnifyingGlassIcon, SquaresPlusIcon, PhotoIcon, WrenchScrewdriverIcon, CheckIcon, StarIcon, SparklesIcon } from '@heroicons/vue/24/outline' import { StarIcon as StarIconSolid, CheckCircleIcon } from '@heroicons/vue/24/solid' @@ -17,6 +18,11 @@ withDefaults(defineProps<{ pickerPlacement?: 'top' | 'bottom' }>(), { pickerPlacement: 'top', }) +// Fired when the user dispatches a message (sent now or queued while a turn is +// in flight). The parent uses it to snap the timeline to the bottom so you see +// your message land even if you'd scrolled up into history. +const emit = defineEmits<{ sent: [] }>() + const store = useChatStore() const { t } = useI18n() const input = ref('') @@ -96,46 +102,6 @@ function getModelDisplayName(providerId: string, modelId: string): string { return modelId } -// Provider identity tile — a tinted squircle with the provider's initial. The -// single "identity primitive" reused across the selector, the Manage dialog, -// and (conceptually) the approval card. Color is keyed off the provider id so -// it's stable without a server-provided brand asset. -const PROVIDER_COLORS: Record = { - anthropic: '#D97757', - openai: '#10A37F', - google: '#4285F4', - deepseek: '#4D6BFE', - moonshot: '#1A1A1A', - zhipu: '#3B5BFE', -} -// Distinct fallbacks for providers without an explicit brand color (e.g. the -// "ZHIPU AI Coding Plan" id, third-party gateways). All are saturated enough -// for white initials and deliberately exclude washed/gray tones so the tile -// never disappears into the surface. Keyed by a stable hash of the id so the -// same provider always gets the same color. -const PROVIDER_FALLBACK_COLORS = [ - '#7C5CFC', // violet - '#E0567A', // rose - '#0EA5A4', // teal - '#E0922F', // amber - '#2E7D5B', // green - '#3D7DE0', // azure - '#C2410C', // burnt orange - '#6366F1', // indigo -] -function providerColor(id: string): string { - const explicit = PROVIDER_COLORS[id] - if (explicit) return explicit - let h = 0 - for (let i = 0; i < id.length; i++) h = (h * 31 + id.charCodeAt(i)) >>> 0 - return PROVIDER_FALLBACK_COLORS[h % PROVIDER_FALLBACK_COLORS.length]! -} -function providerInitial(name: string): string { - const n = (name || '?').trim() - // Latin initial; for CJK names fall back to the first code point. - return /[A-Za-z]/.test(n) ? n[0]!.toLowerCase() : n[0] ?? '?' -} - // Compact context-limit label for the subline: 200000 → "200K", 1000000 → "1M". function formatContext(limit?: number): string | null { if (!limit || limit <= 0) return null @@ -156,12 +122,6 @@ function modelSubline(providerId: string, m: { id: string; context_limit?: numbe return parts.join(' · ') } -// Resolve a provider's display name from its id (the dropdown rows carry the -// name, but recent/favorite refs only have the id). -function providerDisplayName(id: string): string { - return store.providers.find((p) => p.id === id)?.name ?? id -} - // Look up the ModelInfo for a provider+model pair (used by recent/favorite refs // which only carry ids). function modelInfoFor(provider: string, model: string) { @@ -325,6 +285,7 @@ async function send() { } else { store.sendMessage(text || '(see attached images)', images) } + emit('sent') } function selectModel(provider: string, model: string) { @@ -674,11 +635,11 @@ watch(() => store.imageSupport, (supported) => { :aria-expanded="showModelPicker" @click.stop="showModelPicker = !showModelPicker; showModePicker = false; showAddMenu = false" > - {{ providerInitial(providerDisplayName(store.providerName)) }} + :provider="store.providerName" + :size="16" + /> {{ store.modelName ? getModelDisplayName(store.providerName, store.modelName) : 'model' }} @@ -701,10 +662,7 @@ watch(() => store.imageSupport, (supported) => {
- {{ providerInitial(providerDisplayName(store.providerName)) }} + {{ getModelDisplayName(store.providerName, store.modelName) }} {{ modelSubline(store.providerName, currentModelInfo) }} @@ -726,7 +684,7 @@ watch(() => store.imageSupport, (supported) => { class="mm-row" @click="selectModel(r.provider, r.model)" > - {{ providerInitial(providerDisplayName(r.provider)) }} + {{ getModelDisplayName(r.provider, r.model) }} {{ modelSubline(r.provider, modelInfoFor(r.provider, r.model)) }} @@ -754,7 +712,7 @@ watch(() => store.imageSupport, (supported) => { @keydown.enter.prevent="selectModel(p.id, m.id)" @keydown.space.prevent="selectModel(p.id, m.id)" > - {{ providerInitial(p.name) }} + {{ m.name || m.id }} {{ modelSubline(p.id, m) }} @@ -855,7 +813,7 @@ watch(() => store.imageSupport, (supported) => {