From bf4952d4a1ce7431004a75c8243a2af3ad0ac1d7 Mon Sep 17 00:00:00 2001 From: Alejandro Amaral Date: Fri, 1 May 2026 23:31:12 -0300 Subject: [PATCH 1/2] feat: Enhance profile and comparison report formatting in sparkle-cli (#28) --- cmd/sparkle-cli/main.go | 108 ++++++- examples/config/config.example.yaml | 1 + internal/config/config.go | 1 + internal/config/types.go | 1 + internal/ollama/client.go | 17 +- internal/ollama/stream.go | 26 +- internal/ollama/types.go | 11 +- internal/profiler/tokenizer.go | 30 ++ internal/profiler/tracker.go | 472 ++++++++++++++++++++++++++++ internal/profiler/types.go | 55 ++++ internal/profiler/writer.go | 61 ++++ internal/search/search.go | 33 ++ internal/search/service_options.go | 12 +- internal/tui/model.go | 100 +++++- 14 files changed, 895 insertions(+), 33 deletions(-) create mode 100644 internal/profiler/tokenizer.go create mode 100644 internal/profiler/tracker.go create mode 100644 internal/profiler/types.go create mode 100644 internal/profiler/writer.go diff --git a/cmd/sparkle-cli/main.go b/cmd/sparkle-cli/main.go index ee2d0df..b9c4117 100644 --- a/cmd/sparkle-cli/main.go +++ b/cmd/sparkle-cli/main.go @@ -4,20 +4,30 @@ import ( "flag" "fmt" "os" + "sort" "strings" + "text/tabwriter" "github.com/logico/sparkle-cli/internal/config" + "github.com/logico/sparkle-cli/internal/profiler" "github.com/logico/sparkle-cli/internal/tui" ) func main() { + if len(os.Args) > 1 && strings.EqualFold(strings.TrimSpace(os.Args[1]), "stats") { + runStats(os.Args[2:]) + return + } + var configPath string var initialContext string var resultFile string + var profileEnabled bool flag.StringVar(&configPath, "config", "", "override config file path") flag.StringVar(&initialContext, "context", "", "seed the input with shell buffer content") flag.StringVar(&resultFile, "result-file", "", "write accepted output to this file instead of stdout") + flag.BoolVar(&profileEnabled, "profile", false, "enable runtime profiling and metrics persistence") flag.Parse() cfg, loadedConfigPath, err := config.Load(configPath) @@ -25,12 +35,25 @@ func main() { fmt.Fprintln(os.Stderr, err) os.Exit(4) } + if profileEnabled { + cfg.Profiler = true + } + + tracker, err := profiler.New(loadedConfigPath, cfg.Profiler) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(3) + } + defer func() { _ = tracker.Close() }() - output, exitCode, err := tui.Run(cfg, loadedConfigPath, initialContext) + output, exitCode, err := tui.Run(cfg, loadedConfigPath, initialContext, tracker) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(3) } + if tracker.Enabled() { + printCurrentRunSummary(os.Stderr, tracker, "search") + } if exitCode == 0 && output != "" { if err := emitOutput(output, resultFile); err != nil { @@ -42,6 +65,89 @@ func main() { os.Exit(exitCode) } +func runStats(args []string) { + statsFlags := flag.NewFlagSet("stats", flag.ContinueOnError) + statsFlags.SetOutput(os.Stderr) + var configPath string + var command string + var last int + statsFlags.StringVar(&configPath, "config", "", "override config file path") + statsFlags.StringVar(&command, "command", "search", "command to inspect") + statsFlags.IntVar(&last, "last", 10, "number of historical runs to compare") + if err := statsFlags.Parse(args); err != nil { + os.Exit(2) + } + + _, loadedConfigPath, err := config.Load(configPath) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(4) + } + + tracker, err := profiler.New(loadedConfigPath, true) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(3) + } + defer func() { _ = tracker.Close() }() + + report, err := tracker.Comparison(strings.TrimSpace(command), last) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(3) + } + printComparisonReport(os.Stdout, report, last) +} + +func printCurrentRunSummary(output *os.File, tracker profiler.Tracker, command string) { + rows := tracker.CurrentRun(command) + if len(rows) == 0 { + return + } + fmt.Fprintln(output, "\n=== Profiling Summary ===") + fmt.Fprintf(output, "run_id: %s\n", tracker.RunID()) + fmt.Fprintf(output, "command: %s\n", command) + totalDuration := int64(0) + for _, row := range rows { + totalDuration += row.DurationMS + fmt.Fprintf(output, "- %s: %dms", row.StepName, row.DurationMS) + if row.TokensOut > 0 { + fmt.Fprintf(output, " | in=%d out=%d tps=%.2f", row.TokensIn, row.TokensOut, row.TPS) + } + fmt.Fprintln(output) + } + fmt.Fprintf(output, "total: %dms\n", totalDuration) +} + +func printComparisonReport(output *os.File, report profiler.ComparisonReport, last int) { + if len(report.Steps) == 0 { + fmt.Fprintln(output, "No profiling data found.") + return + } + sort.Slice(report.Steps, func(i, j int) bool { + return report.Steps[i].StepName < report.Steps[j].StepName + }) + fmt.Fprintf(output, "Profiling stats for command=%s\n", report.Command) + fmt.Fprintf(output, "Current run: %s\n", report.CurrentRun) + fmt.Fprintf(output, "Compared against: last %d runs\n", last) + fmt.Fprintln(output, "") + + tw := tabwriter.NewWriter(output, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "step\tcurrent_ms\tavg_ms(last)\tcurrent_out\tavg_out(last)\tcurrent_tps\tavg_tps(last)") + for _, step := range report.Steps { + fmt.Fprintf(tw, "%s\t%d\t%.1f\t%d\t%.1f\t%.2f\t%.2f\n", + step.StepName, + step.CurrentDuration, + step.DurationMS, + step.CurrentTokensOut, + step.TokensOut, + step.CurrentTPS, + step.TPS, + ) + } + _ = tw.Flush() +} + func emitOutput(output, resultFile string) error { if strings.TrimSpace(resultFile) == "" { fmt.Print(output) diff --git a/examples/config/config.example.yaml b/examples/config/config.example.yaml index d80d069..ed00d18 100644 --- a/examples/config/config.example.yaml +++ b/examples/config/config.example.yaml @@ -19,6 +19,7 @@ qdrant_score_threshold: 0.90 qdrant_ttl_hours: 48 qdrant_pool_size: 3 logs: false +profiler: false editor: neovim slash_commands_file: ./slash-commands.yaml slash_commands_dir: ./slash-commands diff --git a/internal/config/config.go b/internal/config/config.go index d31bb2b..5d48cd6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -191,6 +191,7 @@ func setDefaults(v *viper.Viper) { v.SetDefault("qdrant_pool_size", defaultQdrantPoolSize) v.SetDefault("theme", defaultTheme) v.SetDefault("logs", false) + v.SetDefault("profiler", false) v.SetDefault("editor", defaultEditor) commands := make(map[string]map[string]string, len(defaultCommands)) diff --git a/internal/config/types.go b/internal/config/types.go index b24ae64..1f8cead 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -33,6 +33,7 @@ type Config struct { QdrantPoolSize int `mapstructure:"qdrant_pool_size"` Theme string `mapstructure:"theme"` Logs bool `mapstructure:"logs"` + Profiler bool `mapstructure:"profiler"` Editor string `mapstructure:"editor"` SlashCommandsFile string `mapstructure:"slash_commands_file"` SlashCommandsDir string `mapstructure:"slash_commands_dir"` diff --git a/internal/ollama/client.go b/internal/ollama/client.go index a6ceed3..426b7ee 100644 --- a/internal/ollama/client.go +++ b/internal/ollama/client.go @@ -72,6 +72,11 @@ func (c *Client) StreamChatWithModel(ctx context.Context, model string, messages } func (c *Client) StreamChatWithModelWithThinking(ctx context.Context, model string, messages []ChatMessage, thinking bool, onChunk func(string) error) error { + _, err := c.StreamChatWithModelWithThinkingStats(ctx, model, messages, thinking, onChunk) + return err +} + +func (c *Client) StreamChatWithModelWithThinkingStats(ctx context.Context, model string, messages []ChatMessage, thinking bool, onChunk func(string) error) (StreamStats, error) { if strings.TrimSpace(model) == "" { model = c.model } @@ -84,34 +89,34 @@ func (c *Client) StreamChatWithModelWithThinking(ctx context.Context, model stri Stream: true, }) if err != nil { - return err + return StreamStats{}, err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/chat", bytes.NewReader(body)) if err != nil { - return fmt.Errorf("create ollama request: %w", err) + return StreamStats{}, fmt.Errorf("create ollama request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := c.http.Do(req) if err != nil { - return fmt.Errorf("request ollama: %w", err) + return StreamStats{}, fmt.Errorf("request ollama: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { payload, readErr := io.ReadAll(io.LimitReader(resp.Body, 4096)) if readErr != nil { - return fmt.Errorf("ollama status %d", resp.StatusCode) + return StreamStats{}, fmt.Errorf("ollama status %d", resp.StatusCode) } message := strings.TrimSpace(string(payload)) if message == "" { message = http.StatusText(resp.StatusCode) } - return fmt.Errorf("ollama status %d: %s", resp.StatusCode, message) + return StreamStats{}, fmt.Errorf("ollama status %d: %s", resp.StatusCode, message) } - return ParseStream(resp.Body, onChunk) + return ParseStreamWithStats(resp.Body, onChunk) } func marshalRequest(request chatRequest) ([]byte, error) { diff --git a/internal/ollama/stream.go b/internal/ollama/stream.go index c791bfa..0958267 100644 --- a/internal/ollama/stream.go +++ b/internal/ollama/stream.go @@ -9,52 +9,60 @@ import ( ) func ParseStream(reader io.Reader, onChunk func(string) error) error { + _, err := ParseStreamWithStats(reader, onChunk) + return err +} + +func ParseStreamWithStats(reader io.Reader, onChunk func(string) error) (StreamStats, error) { decoder := json.NewDecoder(reader) thinkingOpen := false + stats := StreamStats{} for { var chunk chatChunk if err := decoder.Decode(&chunk); err != nil { if errors.Is(err, io.EOF) { - return nil + return stats, nil } - return fmt.Errorf("decode ollama stream: %w", err) + return stats, fmt.Errorf("decode ollama stream: %w", err) } if strings.TrimSpace(chunk.Error) != "" { - return errors.New(chunk.Error) + return stats, errors.New(chunk.Error) } if chunk.Message.Thinking != "" { if !thinkingOpen { if err := onChunk("<|channel|>thought\n"); err != nil { - return err + return stats, err } thinkingOpen = true } if err := onChunk(chunk.Message.Thinking); err != nil { - return err + return stats, err } } if chunk.Message.Content != "" { if thinkingOpen { if err := onChunk(""); err != nil { - return err + return stats, err } thinkingOpen = false } if err := onChunk(chunk.Message.Content); err != nil { - return err + return stats, err } } if chunk.Done { + stats.PromptEvalCount = chunk.PromptEvalCount + stats.EvalCount = chunk.EvalCount if thinkingOpen { if err := onChunk(""); err != nil { - return err + return stats, err } } - return nil + return stats, nil } } } diff --git a/internal/ollama/types.go b/internal/ollama/types.go index 9bc0e15..3cc6a0a 100644 --- a/internal/ollama/types.go +++ b/internal/ollama/types.go @@ -25,8 +25,15 @@ type chatChunk struct { Content string `json:"content"` Thinking string `json:"thinking"` } `json:"message"` - Done bool `json:"done"` - Error string `json:"error"` + Done bool `json:"done"` + Error string `json:"error"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` +} + +type StreamStats struct { + PromptEvalCount int + EvalCount int } type embedRequest struct { diff --git a/internal/profiler/tokenizer.go b/internal/profiler/tokenizer.go new file mode 100644 index 0000000..8cbb46a --- /dev/null +++ b/internal/profiler/tokenizer.go @@ -0,0 +1,30 @@ +package profiler + +import ( + "strings" + "sync" + + "github.com/pkoukk/tiktoken-go" +) + +const tokenizerEncoding = tiktoken.MODEL_CL100K_BASE + +var ( + tokenizerOnce sync.Once + tokenizerInst *tiktoken.Tiktoken + tokenizerErr error +) + +func EstimateTokens(input string) int { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return 0 + } + tokenizerOnce.Do(func() { + tokenizerInst, tokenizerErr = tiktoken.GetEncoding(tokenizerEncoding) + }) + if tokenizerErr != nil || tokenizerInst == nil { + return 0 + } + return len(tokenizerInst.Encode(trimmed, nil, nil)) +} diff --git a/internal/profiler/tracker.go b/internal/profiler/tracker.go new file mode 100644 index 0000000..c9c8842 --- /dev/null +++ b/internal/profiler/tracker.go @@ -0,0 +1,472 @@ +package profiler + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime/debug" + "sort" + "strings" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +const sqliteFileName = "performance_logs.sqlite" + +type sqliteTracker struct { + db *sql.DB + runID string + commitHash string + + mu sync.Mutex + records []LogRecord +} + +type sqliteSpan struct { + tracker *sqliteTracker + command string + stepName string + metadata map[string]any + modelName string + tokensIn int + tokensOut int + tps float64 + start time.Time + ended bool +} + +type noopTracker struct{} + +type noopSpan struct{} + +func Disabled() Tracker { return noopTracker{} } + +func New(configPath string, enabled bool) (Tracker, error) { + if !enabled { + return noopTracker{}, nil + } + + dbPath, err := resolveDBPath(configPath) + if err != nil { + return nil, err + } + if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { + return nil, fmt.Errorf("create profiler dir: %w", err) + } + + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("open sqlite profiler db: %w", err) + } + if err := initializeSchema(db); err != nil { + _ = db.Close() + return nil, err + } + + tracker := &sqliteTracker{ + db: db, + runID: newRunID(), + commitHash: resolveCommitHash(), + records: make([]LogRecord, 0, 8), + } + return tracker, nil +} + +func resolveDBPath(configPath string) (string, error) { + trimmed := strings.TrimSpace(configPath) + if trimmed != "" { + return filepath.Join(filepath.Dir(trimmed), sqliteFileName), nil + } + configRoot, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("resolve user config dir: %w", err) + } + return filepath.Join(configRoot, "sparkle-cli", sqliteFileName), nil +} + +func initializeSchema(db *sql.DB) error { + if db == nil { + return fmt.Errorf("sqlite db is nil") + } + const schema = ` +CREATE TABLE IF NOT EXISTS performance_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, + command TEXT NOT NULL, + step_name TEXT NOT NULL, + duration_ms INTEGER NOT NULL, + tokens_in INTEGER NOT NULL DEFAULT 0, + tokens_out INTEGER NOT NULL DEFAULT 0, + tps REAL NOT NULL DEFAULT 0, + model_name TEXT NOT NULL DEFAULT '', + commit_hash TEXT NOT NULL DEFAULT '', + metadata_json TEXT NOT NULL DEFAULT '{}', + timestamp DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); +CREATE INDEX IF NOT EXISTS idx_performance_logs_command_time ON performance_logs(command, timestamp DESC); +CREATE INDEX IF NOT EXISTS idx_performance_logs_run ON performance_logs(run_id); +` + if _, err := db.Exec(schema); err != nil { + return fmt.Errorf("initialize sqlite schema: %w", err) + } + return nil +} + +func (t *sqliteTracker) Enabled() bool { return t != nil && t.db != nil } + +func (t *sqliteTracker) RunID() string { + if t == nil { + return "" + } + return t.runID +} + +func (t noopTracker) Enabled() bool { return false } +func (t noopTracker) RunID() string { return "" } + +func (t *sqliteTracker) StartSpan(command string, stepName string, metadata map[string]any) Span { + if t == nil || t.db == nil { + return noopSpan{} + } + return &sqliteSpan{ + tracker: t, + command: strings.TrimSpace(command), + stepName: strings.TrimSpace(stepName), + metadata: cloneMetadata(metadata), + start: time.Now().UTC(), + } +} + +func (t noopTracker) StartSpan(command string, stepName string, metadata map[string]any) Span { + return noopSpan{} +} + +func (s *sqliteSpan) SetModel(name string) { + s.modelName = strings.TrimSpace(name) +} + +func (s *sqliteSpan) SetTokens(tokensIn int, tokensOut int) { + s.tokensIn = max(tokensIn, 0) + s.tokensOut = max(tokensOut, 0) +} + +func (s *sqliteSpan) SetTPS(tps float64) { + if tps < 0 { + tps = 0 + } + s.tps = tps +} + +func (s *sqliteSpan) AddMetadata(metadata map[string]any) { + for key, value := range metadata { + s.metadata[key] = value + } +} + +func (s *sqliteSpan) End() { + if s == nil || s.tracker == nil || s.ended { + return + } + s.ended = true + + duration := time.Since(s.start) + if duration < 0 { + duration = 0 + } + now := time.Now().UTC() + + if s.command == "" { + s.command = "unknown" + } + if s.stepName == "" { + s.stepName = "unknown" + } + if s.metadata == nil { + s.metadata = map[string]any{} + } + s.metadata["run_id"] = s.tracker.runID + + payload, _ := json.Marshal(s.metadata) + record := LogRecord{ + RunID: s.tracker.runID, + Command: s.command, + StepName: s.stepName, + DurationMS: duration.Milliseconds(), + TokensIn: s.tokensIn, + TokensOut: s.tokensOut, + TPS: s.tps, + ModelName: s.modelName, + CommitHash: s.tracker.commitHash, + Metadata: cloneMetadata(s.metadata), + Timestamp: now, + } + + _, err := s.tracker.db.Exec( + `INSERT INTO performance_logs(run_id, command, step_name, duration_ms, tokens_in, tokens_out, tps, model_name, commit_hash, metadata_json, timestamp) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + record.RunID, + record.Command, + record.StepName, + record.DurationMS, + record.TokensIn, + record.TokensOut, + record.TPS, + record.ModelName, + record.CommitHash, + string(payload), + record.Timestamp, + ) + if err != nil { + return + } + + s.tracker.mu.Lock() + s.tracker.records = append(s.tracker.records, record) + s.tracker.mu.Unlock() +} + +func (s noopSpan) SetModel(name string) {} +func (s noopSpan) SetTokens(tokensIn int, tokensOut int) {} +func (s noopSpan) SetTPS(tps float64) {} +func (s noopSpan) AddMetadata(metadata map[string]any) {} +func (s noopSpan) End() {} + +func (t *sqliteTracker) CurrentRun(command string) []LogRecord { + if t == nil { + return nil + } + trimmed := strings.TrimSpace(command) + t.mu.Lock() + defer t.mu.Unlock() + rows := make([]LogRecord, 0, len(t.records)) + for _, record := range t.records { + if trimmed != "" && !strings.EqualFold(record.Command, trimmed) { + continue + } + rows = append(rows, record) + } + sort.Slice(rows, func(i, j int) bool { return rows[i].Timestamp.Before(rows[j].Timestamp) }) + return rows +} + +func (t noopTracker) CurrentRun(command string) []LogRecord { return nil } + +func (t *sqliteTracker) Comparison(command string, previousRuns int) (ComparisonReport, error) { + if t == nil || t.db == nil { + return ComparisonReport{}, nil + } + trimmed := strings.TrimSpace(command) + if trimmed == "" { + trimmed = "search" + } + if previousRuns <= 0 { + previousRuns = 10 + } + + runIDs, err := t.listRecentRuns(trimmed, previousRuns+1) + if err != nil { + return ComparisonReport{}, err + } + if len(runIDs) == 0 { + return ComparisonReport{Command: trimmed}, nil + } + currentRun := runIDs[0] + current, err := t.loadRunSteps(currentRun, trimmed) + if err != nil { + return ComparisonReport{}, err + } + currentByStep := map[string]LogRecord{} + for _, row := range current { + currentByStep[row.StepName] = row + } + + previous := runIDs[1:] + avgByStep := map[string]StepStats{} + if len(previous) > 0 { + for _, runID := range previous { + rows, loadErr := t.loadRunSteps(runID, trimmed) + if loadErr != nil { + return ComparisonReport{}, loadErr + } + for _, row := range rows { + stat := avgByStep[row.StepName] + stat.StepName = row.StepName + stat.DurationMS += float64(row.DurationMS) + stat.TokensIn += float64(row.TokensIn) + stat.TokensOut += float64(row.TokensOut) + stat.TPS += row.TPS + stat.Samples++ + avgByStep[row.StepName] = stat + } + } + } + + stepNames := make([]string, 0, len(currentByStep)+len(avgByStep)) + seen := map[string]struct{}{} + for name := range currentByStep { + seen[name] = struct{}{} + stepNames = append(stepNames, name) + } + for name := range avgByStep { + if _, ok := seen[name]; ok { + continue + } + stepNames = append(stepNames, name) + } + sort.Strings(stepNames) + + report := ComparisonReport{ + Command: trimmed, + CurrentRun: currentRun, + Steps: make([]StepStats, 0, len(stepNames)), + } + if len(current) > 0 { + report.CurrentTime = current[0].Timestamp + } + for _, name := range stepNames { + avg := avgByStep[name] + if avg.Samples > 0 { + div := float64(avg.Samples) + avg.DurationMS = avg.DurationMS / div + avg.TokensIn = avg.TokensIn / div + avg.TokensOut = avg.TokensOut / div + avg.TPS = avg.TPS / div + } + if currentRecord, ok := currentByStep[name]; ok { + avg.CurrentDuration = currentRecord.DurationMS + avg.CurrentTokensIn = currentRecord.TokensIn + avg.CurrentTokensOut = currentRecord.TokensOut + avg.CurrentTPS = currentRecord.TPS + } + report.Steps = append(report.Steps, avg) + } + return report, nil +} + +func (t noopTracker) Comparison(command string, previousRuns int) (ComparisonReport, error) { + return ComparisonReport{}, nil +} + +func (t *sqliteTracker) listRecentRuns(command string, limit int) ([]string, error) { + rows, err := t.db.Query( + `SELECT run_id, MAX(timestamp) AS ts + FROM performance_logs + WHERE command = ? + GROUP BY run_id + ORDER BY ts DESC + LIMIT ?`, + command, + limit, + ) + if err != nil { + return nil, fmt.Errorf("list recent profiler runs: %w", err) + } + defer rows.Close() + + runIDs := make([]string, 0, limit) + for rows.Next() { + var runID string + var ignored any + if scanErr := rows.Scan(&runID, &ignored); scanErr != nil { + return nil, fmt.Errorf("scan profiler run ids: %w", scanErr) + } + runIDs = append(runIDs, runID) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate profiler run ids: %w", err) + } + return runIDs, nil +} + +func (t *sqliteTracker) loadRunSteps(runID string, command string) ([]LogRecord, error) { + rows, err := t.db.Query( + `SELECT id, run_id, command, step_name, duration_ms, tokens_in, tokens_out, tps, model_name, commit_hash, metadata_json, timestamp + FROM performance_logs + WHERE run_id = ? AND command = ? + ORDER BY timestamp ASC, id ASC`, + runID, + command, + ) + if err != nil { + return nil, fmt.Errorf("load profiler run steps: %w", err) + } + defer rows.Close() + + result := make([]LogRecord, 0, 8) + for rows.Next() { + var record LogRecord + var metadataJSON string + if scanErr := rows.Scan( + &record.ID, + &record.RunID, + &record.Command, + &record.StepName, + &record.DurationMS, + &record.TokensIn, + &record.TokensOut, + &record.TPS, + &record.ModelName, + &record.CommitHash, + &metadataJSON, + &record.Timestamp, + ); scanErr != nil { + return nil, fmt.Errorf("scan profiler row: %w", scanErr) + } + record.Metadata = map[string]any{} + if strings.TrimSpace(metadataJSON) != "" { + _ = json.Unmarshal([]byte(metadataJSON), &record.Metadata) + } + result = append(result, record) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate profiler rows: %w", err) + } + return result, nil +} + +func (t *sqliteTracker) Close() error { + if t == nil || t.db == nil { + return nil + } + return t.db.Close() +} + +func (t noopTracker) Close() error { return nil } + +func cloneMetadata(input map[string]any) map[string]any { + if len(input) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + return out +} + +func newRunID() string { + now := time.Now().UTC().Format("20060102T150405.000000000Z") + return fmt.Sprintf("%s-%d", now, os.Getpid()) +} + +func resolveCommitHash() string { + if value := strings.TrimSpace(os.Getenv("SPARKLE_COMMIT_HASH")); value != "" { + return value + } + info, ok := debug.ReadBuildInfo() + if !ok || info == nil { + return "unknown" + } + for _, setting := range info.Settings { + if setting.Key == "vcs.revision" { + if value := strings.TrimSpace(setting.Value); value != "" { + return value + } + } + } + return "unknown" +} diff --git a/internal/profiler/types.go b/internal/profiler/types.go new file mode 100644 index 0000000..190ac2b --- /dev/null +++ b/internal/profiler/types.go @@ -0,0 +1,55 @@ +package profiler + +import "time" + +type LogRecord struct { + ID int64 + RunID string + Command string + StepName string + DurationMS int64 + TokensIn int + TokensOut int + TPS float64 + ModelName string + CommitHash string + Metadata map[string]any + Timestamp time.Time +} + +type Span interface { + SetModel(name string) + SetTokens(tokensIn int, tokensOut int) + SetTPS(tps float64) + AddMetadata(metadata map[string]any) + End() +} + +type Tracker interface { + Enabled() bool + RunID() string + StartSpan(command string, stepName string, metadata map[string]any) Span + CurrentRun(command string) []LogRecord + Comparison(command string, previousRuns int) (ComparisonReport, error) + Close() error +} + +type StepStats struct { + StepName string + DurationMS float64 + TokensIn float64 + TokensOut float64 + TPS float64 + Samples int + CurrentDuration int64 + CurrentTokensIn int + CurrentTokensOut int + CurrentTPS float64 +} + +type ComparisonReport struct { + Command string + CurrentRun string + CurrentTime time.Time + Steps []StepStats +} diff --git a/internal/profiler/writer.go b/internal/profiler/writer.go new file mode 100644 index 0000000..8260846 --- /dev/null +++ b/internal/profiler/writer.go @@ -0,0 +1,61 @@ +package profiler + +import ( + "io" + "sync" + "time" +) + +type GenerationWriter struct { + inner io.Writer + + mu sync.Mutex + started bool + first time.Time + last time.Time + total int64 +} + +func NewGenerationWriter(inner io.Writer) *GenerationWriter { + if inner == nil { + inner = io.Discard + } + return &GenerationWriter{inner: inner} +} + +func (w *GenerationWriter) Write(p []byte) (int, error) { + n, err := w.inner.Write(p) + if n > 0 { + now := time.Now().UTC() + w.mu.Lock() + if !w.started { + w.started = true + w.first = now + } + w.last = now + w.total += int64(n) + w.mu.Unlock() + } + return n, err +} + +func (w *GenerationWriter) Window() (time.Time, time.Time) { + w.mu.Lock() + defer w.mu.Unlock() + return w.first, w.last +} + +func (w *GenerationWriter) Duration() time.Duration { + w.mu.Lock() + defer w.mu.Unlock() + if !w.started || w.last.Before(w.first) { + return 0 + } + return w.last.Sub(w.first) +} + +func (w *GenerationWriter) Bytes() int64 { + w.mu.Lock() + defer w.mu.Unlock() + return w.total +} diff --git a/internal/search/search.go b/internal/search/search.go index 31426dc..fd981b4 100644 --- a/internal/search/search.go +++ b/internal/search/search.go @@ -20,6 +20,7 @@ import ( "unicode/utf8" "github.com/cixtor/readability" + "github.com/logico/sparkle-cli/internal/profiler" "github.com/pkoukk/tiktoken-go" htmlnode "golang.org/x/net/html" "golang.org/x/sync/errgroup" @@ -213,6 +214,7 @@ type Service struct { parse articleParser embedder EmbeddingProvider embeddingModel string + tracker profiler.Tracker cache semanticCacheStore domainReputation DomainReputationProvider background sync.WaitGroup @@ -246,6 +248,7 @@ func NewService(searchURL string, options ...Option) *Service { service := &Service{ searchURL: strings.TrimSpace(searchURL), http: &http.Client{}, + tracker: profiler.Disabled(), parse: func(input io.Reader, pageURL string) (readability.Article, error) { return readability.New().Parse(input, pageURL) }, @@ -292,7 +295,20 @@ func (s *Service) Prepare(ctx context.Context, query string, searchQuery string, if strings.TrimSpace(s.searchURL) == "" { return PreparedPrompt{}, fmt.Errorf("search url is not configured") } + + retrievalSpan := profiler.Disabled().StartSpan("", "", nil) + profilingEnabled := s.tracker != nil && s.tracker.Enabled() + if profilingEnabled { + retrievalSpan = s.tracker.StartSpan("search", "embedding_retrieval", map[string]any{ + "query_length": len(trimmedQuery), + "cached": false, + }) + retrievalSpan.SetModel(strings.TrimSpace(s.embeddingModel)) + } + defer retrievalSpan.End() + if cached, ok := s.lookupCache(ctx, trimmedQuery, cacheSearchQuery, safeOnActivity, safeOnProgress); ok { + retrievalSpan.AddMetadata(map[string]any{"cached": true}) return cached, nil } searchPlan, err := resolvePreparedSearchPlan(ctx, trimmedQuery, trimmedSearchQuery, resolveSearchQuery) @@ -332,6 +348,19 @@ func (s *Service) Prepare(ctx context.Context, query string, searchQuery string, return PreparedPrompt{}, fmt.Errorf("could not extract readable content from search results") } rawDocuments := append([]Document(nil), documents...) + if profilingEnabled { + retrievalSpan.AddMetadata(map[string]any{ + "documents_retrieved": len(rawDocuments), + "search_queries": strings.Join(searchPlan.Queries(), " | "), + }) + } + + rerankSpan := profiler.Disabled().StartSpan("", "", nil) + if profilingEnabled { + rerankSpan = s.tracker.StartSpan("search", "rerank", map[string]any{ + "documents_in": len(documents), + }) + } notifyProgress(safeOnProgress, ProgressUpdate{ Key: "chunk-selection", Kind: ProgressKindStep, @@ -345,6 +374,10 @@ func (s *Service) Prepare(ctx context.Context, query string, searchQuery string, Text: fmt.Sprintf("Fragmentos relevantes seleccionados: %d", selectedChunks), State: ProgressDone, }) + if profilingEnabled { + rerankSpan.AddMetadata(map[string]any{"chunks_selected": selectedChunks}) + } + rerankSpan.End() notifyProgress(safeOnProgress, ProgressUpdate{ Key: "downloads", Kind: ProgressKindStep, diff --git a/internal/search/service_options.go b/internal/search/service_options.go index 3c93dd4..3d3e78d 100644 --- a/internal/search/service_options.go +++ b/internal/search/service_options.go @@ -1,6 +1,10 @@ package search -import "context" +import ( + "context" + + "github.com/logico/sparkle-cli/internal/profiler" +) type EmbeddingProvider interface { EmbedWithModel(ctx context.Context, model string, input []string) ([][]float32, error) @@ -24,6 +28,12 @@ type QdrantConfig struct { type Option func(*Service) +func WithProfiler(tracker profiler.Tracker) Option { + return func(s *Service) { + s.tracker = tracker + } +} + func WithEmbedder(provider EmbeddingProvider, model string) Option { return func(s *Service) { s.embedder = provider diff --git a/internal/tui/model.go b/internal/tui/model.go index 6652178..5696623 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -28,6 +28,7 @@ import ( "github.com/logico/sparkle-cli/internal/feedback" "github.com/logico/sparkle-cli/internal/i18n" "github.com/logico/sparkle-cli/internal/ollama" + "github.com/logico/sparkle-cli/internal/profiler" "github.com/logico/sparkle-cli/internal/search" "github.com/logico/sparkle-cli/internal/slash" ) @@ -314,6 +315,7 @@ type model struct { styles styles searchBuilder searchPromptBuilder feedbackStore *feedback.Store + profiler profiler.Tracker requesting bool userCanceled bool llmTimerActive bool @@ -396,8 +398,8 @@ type styles struct { modeIndicator lipgloss.Style } -func Run(cfg config.Config, configPath string, initialContext string) (string, int, error) { - tuiModel := newModel(cfg, initialContext) +func Run(cfg config.Config, configPath string, initialContext string, tracker profiler.Tracker) (string, int, error) { + tuiModel := newModelWithTracker(cfg, initialContext, tracker) tuiModel.configPath = configPath tuiModel.rebuildSearchRuntime() if cfg.Logs { @@ -447,6 +449,13 @@ func Run(cfg config.Config, configPath string, initialContext string) (string, i } func newModel(cfg config.Config, initialContext string) model { + return newModelWithTracker(cfg, initialContext, profiler.Disabled()) +} + +func newModelWithTracker(cfg config.Config, initialContext string, tracker profiler.Tracker) model { + if tracker == nil { + tracker = profiler.Disabled() + } if normalizedEditor, err := config.NormalizeEditor(cfg.Editor); err == nil { cfg.Editor = normalizedEditor } @@ -509,7 +518,7 @@ func newModel(cfg config.Config, initialContext string) model { modeIndicator: lipgloss.NewStyle().Foreground(lipgloss.Color(colors.accent)).Background(lipgloss.Color(colors.bgRaised)), } client := ollama.NewClient(cfg.OllamaURL, cfg.Model) - searchBuilder := newSearchBuilder(cfg, client) + searchBuilder := newSearchBuilder(cfg, client, tracker) model := model{ cfg: cfg, @@ -530,6 +539,7 @@ func newModel(cfg config.Config, initialContext string) model { colors: colors, styles: sty, searchBuilder: searchBuilder, + profiler: tracker, feedbackRating: feedback.VoteNeutral, status: localizer.Get("status.ready"), mode: modeNormal, @@ -546,13 +556,14 @@ func newModel(cfg config.Config, initialContext string) model { return model } -func newSearchBuilder(cfg config.Config, client *ollama.Client) searchPromptBuilder { - return newSearchBuilderWithReputation(cfg, client, nil) +func newSearchBuilder(cfg config.Config, client *ollama.Client, tracker profiler.Tracker) searchPromptBuilder { + return newSearchBuilderWithReputation(cfg, client, tracker, nil) } -func newSearchBuilderWithReputation(cfg config.Config, client *ollama.Client, domainReputation search.DomainReputationProvider) searchPromptBuilder { +func newSearchBuilderWithReputation(cfg config.Config, client *ollama.Client, tracker profiler.Tracker, domainReputation search.DomainReputationProvider) searchPromptBuilder { options := []search.Option{ search.WithEmbedder(client, cfg.SearchEmbeddingModel), + search.WithProfiler(tracker), search.WithQdrantCache(search.QdrantConfig{ Enabled: cfg.QdrantEnabled, Host: cfg.QdrantHost, @@ -598,7 +609,7 @@ func (m *model) rebuildSearchRuntime() { if err == nil { m.feedbackStore = feedbackStore } - m.searchBuilder = newSearchBuilderWithReputation(m.cfg, m.client, m.feedbackStore) + m.searchBuilder = newSearchBuilderWithReputation(m.cfg, m.client, m.profiler, m.feedbackStore) } func (m *model) applyRuntimeConfig(cfg config.Config, configPath string) { @@ -2458,6 +2469,7 @@ func (m *model) runRequestStream(ctx context.Context, cancel context.CancelFunc, stopSearchTimeout() requestMessages := m.buildRequestMessages(promptForModel, expansion.SystemPrompt, requestModel) + requestTokenUsage := countTokenUsage(requestMessages) m.logSessionEntry("model_used", requestModel) m.logSessionEntry("prompt_sent_to_model", promptForModel) systemPrompt := "" @@ -2468,7 +2480,28 @@ func (m *model) runRequestStream(ctx context.Context, cancel context.CancelFunc, } } m.logSessionEntry("system_prompt_sent_to_model", systemPrompt) - llmTimedOut, err = m.streamLLMWithAdaptiveTimeout(ctx, cancel, requestModel, requestMessages, m.mode == modeReasoning, func(chunk string) error { + profilingActive := m.profiler != nil && m.profiler.Enabled() + var generationSpan profiler.Span + var generationWriter *profiler.GenerationWriter + var generated strings.Builder + if profilingActive { + commandName := "chat" + if expansion.Kind == slash.KindSearch { + commandName = "search" + } + generationSpan = m.profiler.StartSpan(commandName, "generation", map[string]any{ + "slash_kind": expansion.Kind, + "mode": string(m.mode), + }) + generationSpan.SetModel(requestModel) + generationSpan.SetTokens(requestTokenUsage.total(), 0) + generationWriter = profiler.NewGenerationWriter(nil) + } + llmTimedOut, streamStats, err := m.streamLLMWithAdaptiveTimeoutWithStats(ctx, cancel, requestModel, requestMessages, m.mode == modeReasoning, func(chunk string) error { + if profilingActive { + _, _ = generationWriter.Write([]byte(chunk)) + generated.WriteString(chunk) + } select { case <-ctx.Done(): return ctx.Err() @@ -2476,6 +2509,20 @@ func (m *model) runRequestStream(ctx context.Context, cancel context.CancelFunc, return nil } }) + if profilingActive && expansion.Kind == slash.KindSearch { + tokensOut := streamStats.EvalCount + if tokensOut <= 0 { + tokensOut = profiler.EstimateTokens(generated.String()) + } + generationSpan.SetTokens(requestTokenUsage.total(), tokensOut) + duration := generationWriter.Duration() + if duration > 0 && tokensOut > 0 { + generationSpan.SetTPS(float64(tokensOut) / duration.Seconds()) + } + } + if profilingActive { + generationSpan.End() + } if err != nil { streamCh <- streamEvent{err: stageRequestErr(requestStageLLM, normalizeRequestErr(err, llmTimedOut))} return @@ -2943,10 +2990,25 @@ func (m *model) rewriteSearchPlan(ctx context.Context, originalQuery string) (se } } messages := []ollama.ChatMessage{{Role: "system", Content: search.BuildSearchRewritePromptWithExamples(originalQuery, examples)}} - response, timedOut, err := m.collectLLMResponse(ctx, m.searchQueryModel(), messages) + + rewriteModel := m.searchQueryModel() + rewriteSpan := m.profiler.StartSpan("search", "rewrite_query", map[string]any{ + "query_length": len(originalQuery), + }) + rewriteSpan.SetModel(rewriteModel) + + response, timedOut, streamStats, err := m.collectLLMResponseWithStats(ctx, rewriteModel, messages) if err != nil { + rewriteSpan.End() return search.SearchPlan{}, timedOut, err } + + if streamStats.EvalCount > 0 { + rewriteSpan.SetTokens(streamStats.PromptEvalCount, streamStats.EvalCount) + } else { + rewriteSpan.SetTokens(profiler.EstimateTokens(search.BuildSearchRewritePromptWithExamples(originalQuery, examples)), profiler.EstimateTokens(response)) + } + rewriteSpan.End() queries := search.ExtractSearchQueries(response) if len(queries) == 0 { queries = []string{strings.TrimSpace(originalQuery)} @@ -2999,18 +3061,28 @@ func (m *model) reduceSearchPrompt(ctx context.Context, requestModel string, pre } func (m *model) collectLLMResponse(ctx context.Context, requestModel string, messages []ollama.ChatMessage) (string, func() bool, error) { + response, timedOut, _, err := m.collectLLMResponseWithStats(ctx, requestModel, messages) + return response, timedOut, err +} + +func (m *model) collectLLMResponseWithStats(ctx context.Context, requestModel string, messages []ollama.ChatMessage) (string, func() bool, ollama.StreamStats, error) { var builder strings.Builder - timedOut, err := m.streamLLMWithAdaptiveTimeout(ctx, nil, requestModel, messages, false, func(chunk string) error { + timedOut, stats, err := m.streamLLMWithAdaptiveTimeoutWithStats(ctx, nil, requestModel, messages, false, func(chunk string) error { builder.WriteString(chunk) return nil }) if err != nil { - return "", timedOut, err + return "", timedOut, stats, err } - return strings.TrimSpace(builder.String()), timedOut, nil + return strings.TrimSpace(builder.String()), timedOut, stats, nil } func (m *model) streamLLMWithAdaptiveTimeout(ctx context.Context, cancel context.CancelFunc, requestModel string, messages []ollama.ChatMessage, thinking bool, onChunk func(string) error) (func() bool, error) { + timedOut, _, err := m.streamLLMWithAdaptiveTimeoutWithStats(ctx, cancel, requestModel, messages, thinking, onChunk) + return timedOut, err +} + +func (m *model) streamLLMWithAdaptiveTimeoutWithStats(ctx context.Context, cancel context.CancelFunc, requestModel string, messages []ollama.ChatMessage, thinking bool, onChunk func(string) error) (func() bool, ollama.StreamStats, error) { if cancel == nil { var innerCancel context.CancelFunc ctx, innerCancel = context.WithCancel(ctx) @@ -3026,7 +3098,7 @@ func (m *model) streamLLMWithAdaptiveTimeout(ctx context.Context, cancel context }() firstChunk := true - err := m.client.StreamChatWithModelWithThinking(ctx, requestModel, messages, thinking, func(chunk string) error { + stats, err := m.client.StreamChatWithModelWithThinkingStats(ctx, requestModel, messages, thinking, func(chunk string) error { if firstChunk { if stopCurrent != nil { stopCurrent() @@ -3038,7 +3110,7 @@ func (m *model) streamLLMWithAdaptiveTimeout(ctx context.Context, cancel context return onChunk(chunk) }) - return currentTimedOut, err + return currentTimedOut, stats, err } func normalizeRequestErr(err error, timedOut func() bool) error { From 7941338f57fc8d4b38cfcd561d94b60476a47db2 Mon Sep 17 00:00:00 2001 From: Alejandro Amaral Date: Wed, 13 May 2026 16:12:25 -0300 Subject: [PATCH 2/2] feat: added non interactive mode --- README.md | 13 +++ USER_DOCS/06-interaction-modes.md | 2 + cmd/sparkle-cli/main.go | 60 +++++++++++ internal/tui/direct.go | 160 ++++++++++++++++++++++++++++++ internal/tui/model.go | 22 +--- internal/tui/slash_input_test.go | 81 +++++++++++++++ 6 files changed, 318 insertions(+), 20 deletions(-) create mode 100644 internal/tui/direct.go diff --git a/README.md b/README.md index 4d69225..d8e69ce 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,17 @@ https://github.com/user-attachments/assets/53324380-0ca7-4dfd-90d5-4f72a49cadc1 go run ./cmd/sparkle-cli --context "git log --oneline" ``` +### Non-interactive mode + +Use `direct` for scripts, pipes, and integrations that need a single final response on stdout. + +```bash +sparkle-cli direct -m normal "Por que el cielo es azul" | less +sparkle-cli direct -m reasoning "/search de que color es el cielo" +``` + +`direct` supports `-m normal` and `-m reasoning`. It also accepts `thinking` as an alias for `reasoning`. In this mode, sparkle-cli resolves slash commands, waits for the full LLM response, strips visible reasoning, and prints only the final answer to stdout. + ## End-User Documentation Detailed feature documentation for end users is available in [USER_DOCS/](USER_DOCS/): @@ -124,6 +135,8 @@ Key bindings inside the TUI: `Chat` mode sends the previous user and assistant messages as conversation context on each request. `Reasoning` mode keeps the existing thinking prompt behavior without adding prior turns. +For automation, use `sparkle-cli direct -m normal|reasoning "..."`. Direct mode is non-interactive and does not expose `Chat` mode. + Supported editors for `editor` are `neovim` (default), `vim`, `vscode`/`visual studio code`, and `emacs`. ## Zsh Bridge diff --git a/USER_DOCS/06-interaction-modes.md b/USER_DOCS/06-interaction-modes.md index f55fb36..029d577 100644 --- a/USER_DOCS/06-interaction-modes.md +++ b/USER_DOCS/06-interaction-modes.md @@ -10,6 +10,8 @@ sparkle-cli supports 3 interaction modes: Switch mode with `Ctrl+T`. +For non-interactive automation, `sparkle-cli direct -m normal "..."` and `sparkle-cli direct -m reasoning "..."` reuse the same two single-turn modes. Direct mode does not support `Chat` and always prints only the final answer to stdout. + ## Normal mode - Standard request/response behavior. diff --git a/cmd/sparkle-cli/main.go b/cmd/sparkle-cli/main.go index b9c4117..da689f8 100644 --- a/cmd/sparkle-cli/main.go +++ b/cmd/sparkle-cli/main.go @@ -18,6 +18,13 @@ func main() { runStats(os.Args[2:]) return } + if len(os.Args) > 1 { + switch strings.ToLower(strings.TrimSpace(os.Args[1])) { + case "direct", "run": + runDirect(os.Args[2:]) + return + } + } var configPath string var initialContext string @@ -65,6 +72,59 @@ func main() { os.Exit(exitCode) } +func runDirect(args []string) { + directFlags := flag.NewFlagSet("direct", flag.ContinueOnError) + directFlags.SetOutput(os.Stderr) + + var configPath string + var mode string + var profileEnabled bool + + directFlags.StringVar(&configPath, "config", "", "override config file path") + directFlags.StringVar(&mode, "m", "normal", "direct mode: normal or reasoning") + directFlags.BoolVar(&profileEnabled, "profile", false, "enable runtime profiling and metrics persistence") + if err := directFlags.Parse(args); err != nil { + os.Exit(2) + } + + prompt := strings.TrimSpace(strings.Join(directFlags.Args(), " ")) + if prompt == "" { + fmt.Fprintln(os.Stderr, "missing prompt for direct mode") + os.Exit(2) + } + + cfg, loadedConfigPath, err := config.Load(configPath) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(4) + } + if profileEnabled { + cfg.Profiler = true + } + + tracker, err := profiler.New(loadedConfigPath, cfg.Profiler) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(3) + } + defer func() { _ = tracker.Close() }() + + output, err := tui.RunDirect(cfg, loadedConfigPath, prompt, mode, tracker) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(3) + } + + if output != "" { + fmt.Print(output) + } + if tracker.Enabled() { + os.Exit(0) + } + + os.Exit(0) +} + func runStats(args []string) { statsFlags := flag.NewFlagSet("stats", flag.ContinueOnError) statsFlags.SetOutput(os.Stderr) diff --git a/internal/tui/direct.go b/internal/tui/direct.go new file mode 100644 index 0000000..13d0934 --- /dev/null +++ b/internal/tui/direct.go @@ -0,0 +1,160 @@ +package tui + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/logico/sparkle-cli/internal/config" + "github.com/logico/sparkle-cli/internal/profiler" + "github.com/logico/sparkle-cli/internal/search" + "github.com/logico/sparkle-cli/internal/slash" +) + +func RunDirect(cfg config.Config, configPath string, prompt string, mode string, tracker profiler.Tracker) (_ string, err error) { + directModel := newModelWithTracker(cfg, "", tracker) + directModel.configPath = configPath + directModel.rebuildSearchRuntime() + if cfg.Logs { + logger, loggerErr := newSessionLogger(configPath) + if loggerErr != nil { + return "", loggerErr + } + directModel.sessionLogger = logger + } + defer func() { + if closeErr := closeModelRuntime(&directModel); err == nil && closeErr != nil { + err = closeErr + } + }() + + return runDirectWithModel(&directModel, prompt, mode) +} + +func runDirectWithModel(directModel *model, prompt string, mode string) (string, error) { + if directModel == nil { + return "", fmt.Errorf("direct mode is not initialized") + } + + interactionMode, err := normalizeDirectMode(mode) + if err != nil { + return "", err + } + directModel.mode = interactionMode + + trimmedPrompt := strings.TrimSpace(prompt) + if trimmedPrompt == "" { + return "", fmt.Errorf("missing prompt for direct mode") + } + if err := validateDirectPrompt(trimmedPrompt); err != nil { + return "", err + } + + expansion, err := slash.Resolve(trimmedPrompt, directModel.cfg) + if err != nil { + return "", err + } + if expansion.Kind == slash.KindConfig { + return "", fmt.Errorf("slash command %s is not supported in direct mode", slashCommandConfig) + } + + directModel.logSessionEntry("user_input", trimmedPrompt) + + resolvedPrompt := expansion.Prompt + requestModel := strings.TrimSpace(directModel.cfg.Model) + if strings.TrimSpace(expansion.Model) != "" { + requestModel = strings.TrimSpace(expansion.Model) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + streamCh := make(chan streamEvent) + go directModel.runRequestStream(ctx, cancel, resolvedPrompt, requestModel, expansion, streamCh) + + var rawResponse strings.Builder + var preparedDocs []search.Document + var cacheQuery string + var cacheDocs []search.Document + + for event := range streamCh { + if event.err != nil { + return "", event.err + } + if event.preparedPrompt != "" { + preparedDocs = append([]search.Document(nil), event.preparedDocs...) + cacheQuery = strings.TrimSpace(event.cacheQuery) + cacheDocs = append([]search.Document(nil), event.cacheDocs...) + } + if event.chunk != "" { + rawResponse.WriteString(event.chunk) + } + } + + finalOutput := strings.TrimSpace(rawResponse.String()) + directModel.logSessionEntry("llm_full_response", finalOutput) + if finalOutput == "" { + return "", nil + } + + if len(preparedDocs) > 0 { + finalOutput = directModel.appendSyntheticSourcesIfMissing(finalOutput, preparedDocs) + } + if _, answer, hasReasoning := splitThinkingOutput(finalOutput); hasReasoning { + finalOutput = answer + } + finalOutput = strings.TrimSpace(finalOutput) + + if finalOutput != "" && cacheQuery != "" && len(cacheDocs) > 0 { + if done := directModel.searchBuilder.PersistSemanticCache(cacheQuery, cacheDocs, nil); done != nil { + <-done + } + } + + return finalOutput, nil +} + +func normalizeDirectMode(value string) (interactionMode, error) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "", string(modeNormal): + return modeNormal, nil + case string(modeReasoning), "thinking": + return modeReasoning, nil + default: + return "", fmt.Errorf("unsupported direct mode %q: use normal or reasoning", value) + } +} + +func validateDirectPrompt(prompt string) error { + parts := strings.Fields(strings.TrimSpace(prompt)) + if len(parts) == 0 { + return fmt.Errorf("missing prompt for direct mode") + } + if strings.EqualFold(parts[0], slashCommandHelp) { + return fmt.Errorf("slash command %s is not supported in direct mode", slashCommandHelp) + } + return nil +} + +func closeModelRuntime(m *model) error { + if m == nil { + return nil + } + if closer, ok := m.searchBuilder.(io.Closer); ok { + if err := closer.Close(); err != nil { + return err + } + } + if m.feedbackStore != nil { + if err := m.feedbackStore.Close(); err != nil { + return err + } + } + if m.sessionLogger != nil { + if err := m.sessionLogger.close(); err != nil { + return err + } + } + return nil +} diff --git a/internal/tui/model.go b/internal/tui/model.go index 5696623..ae294cf 100644 --- a/internal/tui/model.go +++ b/internal/tui/model.go @@ -423,26 +423,8 @@ func Run(cfg config.Config, configPath string, initialContext string, tracker pr if !ok { return "", 3, fmt.Errorf("unexpected final model type %T", finalModel) } - if closer, ok := result.searchBuilder.(io.Closer); ok { - if closeErr := closer.Close(); closeErr != nil { - if result.sessionLogger != nil { - _ = result.sessionLogger.close() - } - return "", 3, closeErr - } - } - if result.feedbackStore != nil { - if closeErr := result.feedbackStore.Close(); closeErr != nil { - if result.sessionLogger != nil { - _ = result.sessionLogger.close() - } - return "", 3, closeErr - } - } - if result.sessionLogger != nil { - if closeErr := result.sessionLogger.close(); closeErr != nil { - return "", 3, closeErr - } + if closeErr := closeModelRuntime(&result); closeErr != nil { + return "", 3, closeErr } return result.acceptedOutput, result.exitCode, nil diff --git a/internal/tui/slash_input_test.go b/internal/tui/slash_input_test.go index f914e23..e1ac6ef 100644 --- a/internal/tui/slash_input_test.go +++ b/internal/tui/slash_input_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "io" "net/http" "net/http/httptest" "os" @@ -1949,6 +1950,86 @@ func TestHandleKeyMsgTogglesThinkingMode(t *testing.T) { } } +func TestRunDirectReturnsOnlyFinalAnswerInReasoningMode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(jsonContentTypeHeader, jsonContentTypeValue) + _, _ = io.WriteString(w, "{\"message\":{\"content\":\"analizandorespuesta final\"},\"done\":false}\n") + _, _ = io.WriteString(w, doneChunkPayload) + })) + defer server.Close() + + got, err := RunDirect(config.Config{OllamaURL: server.URL, Model: "gemma4"}, "", "por que el cielo es azul", "thinking", nil) + if err != nil { + t.Fatalf("RunDirect() error = %v", err) + } + if got != assistantResponse { + t.Fatalf("RunDirect() = %q, want %q", got, assistantResponse) + } + if strings.Contains(got, "analizando") { + t.Fatalf("RunDirect() = %q, want no reasoning content", got) + } +} + +func TestRunDirectResolvesSearchSlashAndPersistsCache(t *testing.T) { + m := newModel(config.Config{ + Model: "gemma4", + Commands: map[string]config.SlashCommand{"search": {Template: "{{.Input}}", Kind: slash.KindSearch}}, + }, "") + persistCalled := false + m.searchBuilder = &stubSearchBuilder{ + prepared: search.PreparedPrompt{ + Prompt: finalPromptText, + Query: ollamaInstallQuestion, + Documents: []search.Document{{URL: testSourceURLA}}, + CacheQuery: ollamaInstallQuestion, + CacheDocs: []search.Document{{URL: testSourceURLA, Content: "contenido"}}, + }, + persist: func(query string, documents []search.Document, onProgress func(search.ProgressUpdate)) <-chan struct{} { + persistCalled = query == ollamaInstallQuestion && len(documents) == 1 + done := make(chan struct{}) + close(done) + return done + }, + } + m.client = ollama.NewClient("http://127.0.0.1:1", "gemma4") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(jsonContentTypeHeader, jsonContentTypeValue) + _, _ = io.WriteString(w, "{\"message\":{\"content\":\"respuesta sin cita\"},\"done\":false}\n") + _, _ = io.WriteString(w, doneChunkPayload) + })) + defer server.Close() + m.client = ollama.NewClient(server.URL, "gemma4") + + got, err := runDirectWithModel(&m, "/search "+ollamaInstallQuestion, "normal") + if err != nil { + t.Fatalf("runDirectWithModel() error = %v", err) + } + if !strings.Contains(got, "respuesta sin cita") { + t.Fatalf("runDirectWithModel() = %q, want answer body", got) + } + if !strings.Contains(got, sourcesFooterHeading) { + t.Fatalf("runDirectWithModel() = %q, want synthetic sources footer", got) + } + if !persistCalled { + t.Fatal("runDirectWithModel() did not persist semantic cache") + } +} + +func TestNormalizeDirectModeSupportsThinkingAlias(t *testing.T) { + got, err := normalizeDirectMode("thinking") + if err != nil { + t.Fatalf("normalizeDirectMode() error = %v", err) + } + if got != modeReasoning { + t.Fatalf("normalizeDirectMode() = %q, want %q", got, modeReasoning) + } + + if _, err := normalizeDirectMode("chat"); err == nil { + t.Fatal("normalizeDirectMode() error = nil, want error for chat") + } +} + func TestHandleKeyMsgTogglesReasoningVisibility(t *testing.T) { m := newModel(config.Config{}, "") m.viewport.Width = 80