diff --git a/.gitignore b/.gitignore index d6aa6d2..0afd986 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .DS_Store .idea/ +.otel/ .vscode/ *.swp *.swo diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 2d022d5..60e1dd3 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -380,6 +380,10 @@ func (s authTestStore) AppendTurn(context.Context, int64, llm.Message, llm.Messa return nil } +func (s authTestStore) ReplaceTailAndAppendTurn(context.Context, int64, int, llm.Message, llm.Message) error { + return nil +} + func (s authTestStore) Close() error { return nil } diff --git a/internal/chat/service.go b/internal/chat/service.go index 62c182a..30e80e8 100644 --- a/internal/chat/service.go +++ b/internal/chat/service.go @@ -12,8 +12,9 @@ import ( ) var ( - ErrEmptyPrompt = errors.New("prompt must not be empty") - ErrTurnInProgress = errors.New("turn already in progress") + ErrEmptyPrompt = errors.New("prompt must not be empty") + ErrInvalidReplaceFrom = errors.New("replace_from must point to a user message or the end of the conversation") + ErrTurnInProgress = errors.New("turn already in progress") ) type Service struct { @@ -32,6 +33,7 @@ func NewPersistentService(client llm.Client, store Store) Service { type Store interface { Messages(context.Context, int64) ([]llm.Message, error) AppendTurn(context.Context, int64, llm.Message, llm.Message) error + ReplaceTailAndAppendTurn(context.Context, int64, int, llm.Message, llm.Message) error } func (s Service) NewSession() *Session { @@ -68,6 +70,8 @@ type SendOptions struct { ReasoningEffort string RenderingInstructions string TelemetryComponent string + ReplaceFrom *int + Now func() time.Time } func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*TurnStream, error) { @@ -91,6 +95,10 @@ func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*T Effort: effort, } } + now := opts.Now + if now == nil { + now = time.Now + } s.mu.Lock() if s.inFlight { @@ -98,7 +106,14 @@ func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*T observability.RecordSpanError(span, ErrTurnInProgress) return nil, ErrTurnInProgress } - request.Messages = append(llm.CloneMessages(s.messages), userMessage.Clone()) + keepMessages, err := replaceFromIndex(s.messages, opts.ReplaceFrom) + if err != nil { + s.mu.Unlock() + observability.RecordSpanError(span, err) + return nil, err + } + replaceTail := opts.ReplaceFrom != nil + request.Messages = append(llm.CloneMessages(s.messages[:keepMessages]), userMessage.Clone()) s.inFlight = true s.mu.Unlock() @@ -112,11 +127,14 @@ func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*T } return &TurnStream{ - session: s, - stream: stream, - userMessage: userMessage, - ctx: ctx, - startedAt: startedAt, + session: s, + stream: stream, + userMessage: userMessage, + keepMessages: keepMessages, + replaceTail: replaceTail, + ctx: ctx, + startedAt: startedAt, + now: now, }, nil } @@ -126,14 +144,46 @@ func (s *Session) Messages() []llm.Message { return llm.CloneMessages(s.messages) } +func (s *Session) ValidateReplaceFrom(replaceFrom *int) error { + s.mu.Lock() + defer s.mu.Unlock() + _, err := replaceFromIndex(s.messages, replaceFrom) + return err +} + +func (s *Session) CommitStopped(ctx context.Context, prompt string, opts SendOptions) error { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return ErrEmptyPrompt + } + userMessage := llm.NewTextMessage(llm.RoleUser, prompt) + + s.mu.Lock() + keepMessages, err := replaceFromIndex(s.messages, opts.ReplaceFrom) + if err != nil { + s.mu.Unlock() + return err + } + if err := s.appendOrReplaceTailLocked(ctx, opts.ReplaceFrom != nil, keepMessages, userMessage, llm.Message{Role: llm.RoleAssistant}); err != nil { + s.mu.Unlock() + return err + } + s.mu.Unlock() + return nil +} + type TurnStream struct { - session *Session - stream llm.Stream - userMessage llm.Message - ctx context.Context - startedAt time.Time + session *Session + stream llm.Stream + userMessage llm.Message + keepMessages int + replaceTail bool + ctx context.Context + startedAt time.Time + now func() time.Time assistantParts []llm.Part + completedAt time.Time completed bool finalized bool recorded bool @@ -143,7 +193,9 @@ func (s *TurnStream) Next() (llm.Event, error) { event, err := s.stream.Next() if err != nil { s.recordError(err) - s.abort() + if !errors.Is(err, context.Canceled) { + s.abort() + } return llm.Event{}, err } @@ -156,7 +208,7 @@ func (s *TurnStream) Next() (llm.Event, error) { s.mergeCompletedPart(event.Part) case llm.EventCompleted: s.completed = true - if err := s.finalize(); err != nil { + if err := s.finalize(false); err != nil { s.recordError(err) s.abort() return event, err @@ -181,6 +233,14 @@ func (s *TurnStream) Close() error { return s.stream.Close() } +func (s *TurnStream) CommitPartial() error { + return s.finalize(true) +} + +func (s *TurnStream) CompletedAt() time.Time { + return s.completedAt +} + func (s *TurnStream) appendDelta(partType llm.PartType, delta string) { if delta == "" { return @@ -230,29 +290,40 @@ func (s *TurnStream) mergeCompletedPart(part llm.Part) { s.assistantParts = append(s.assistantParts, part.Clone()) } -func (s *TurnStream) finalize() error { - if !s.completed || s.finalized { +func (s *TurnStream) finalize(allowIncomplete bool) error { + if (!s.completed && !allowIncomplete) || s.finalized { return nil } assistant := llm.Message{ - Role: llm.RoleAssistant, - Parts: cloneParts(s.assistantParts), + Role: llm.RoleAssistant, + Parts: cloneParts(s.assistantParts), + CompletedAt: s.completionTime(), } s.session.mu.Lock() defer s.session.mu.Unlock() - if s.session.store != nil { - if err := s.session.store.AppendTurn(context.Background(), s.session.conversationID, s.userMessage.Clone(), assistant.Clone()); err != nil { - return err - } + if err := s.session.appendOrReplaceTailLocked(context.Background(), s.replaceTail, s.keepMessages, s.userMessage, assistant); err != nil { + return err } s.finalized = true - s.session.messages = append(s.session.messages, s.userMessage.Clone(), assistant) - s.session.inFlight = false return nil } +func (s *TurnStream) completionTime() time.Time { + if !s.completed { + return time.Time{} + } + if s.completedAt.IsZero() { + now := s.now + if now == nil { + now = time.Now + } + s.completedAt = now().UTC() + } + return s.completedAt +} + func (s *TurnStream) abort() { if s.finalized { return @@ -299,6 +370,42 @@ func (s *Session) releaseTurn() { s.inFlight = false } +func replaceFromIndex(messages []llm.Message, replaceFrom *int) (int, error) { + if replaceFrom == nil { + return len(messages), nil + } + keepMessages := *replaceFrom + if keepMessages < 0 || keepMessages > len(messages) { + return 0, ErrInvalidReplaceFrom + } + if keepMessages < len(messages) && messages[keepMessages].Role != llm.RoleUser { + return 0, ErrInvalidReplaceFrom + } + return keepMessages, nil +} + +func (s *Session) appendOrReplaceTailLocked(ctx context.Context, replaceTail bool, keepMessages int, userMessage, assistant llm.Message) error { + if s.store != nil { + var err error + if replaceTail { + err = s.store.ReplaceTailAndAppendTurn(ctx, s.conversationID, keepMessages, userMessage.Clone(), assistant.Clone()) + } else { + err = s.store.AppendTurn(ctx, s.conversationID, userMessage.Clone(), assistant.Clone()) + } + if err != nil { + return err + } + } + nextMessages := llm.CloneMessages(s.messages) + if replaceTail { + nextMessages = llm.CloneMessages(s.messages[:keepMessages]) + } + nextMessages = append(nextMessages, userMessage.Clone(), assistant.Clone()) + s.messages = nextMessages + s.inFlight = false + return nil +} + func cloneParts(parts []llm.Part) []llm.Part { if parts == nil { return nil diff --git a/internal/chat/service_test.go b/internal/chat/service_test.go index 8b27a88..13c7b4b 100644 --- a/internal/chat/service_test.go +++ b/internal/chat/service_test.go @@ -6,12 +6,14 @@ import ( "io" "strings" "testing" + "time" "example.com/llm-chat-web/internal/llm" "example.com/llm-chat-web/internal/llm/dummy" ) func TestSessionSendStreamsAndStoresCompletedTurn(t *testing.T) { + completedAt := time.Date(2026, 6, 14, 12, 34, 56, 0, time.UTC) client := dummy.NewClient(dummy.Turn{ ReasoningChunks: []string{"think", "ing"}, TextChunks: []string{"ans", "wer"}, @@ -24,7 +26,10 @@ func TestSessionSendStreamsAndStoresCompletedTurn(t *testing.T) { }) session := NewService(client).NewSession() - stream, err := session.Send(context.Background(), " hello ", SendOptions{Model: "test-model"}) + stream, err := session.Send(context.Background(), " hello ", SendOptions{ + Model: "test-model", + Now: func() time.Time { return completedAt }, + }) if err != nil { t.Fatalf("Send() error = %v, want nil", err) } @@ -50,6 +55,9 @@ func TestSessionSendStreamsAndStoresCompletedTurn(t *testing.T) { if got := messages[1].Text(); got != "answer" { t.Fatalf("assistant text = %q, want answer", got) } + if !messages[1].CompletedAt.Equal(completedAt) { + t.Fatalf("assistant completed_at = %v, want %v", messages[1].CompletedAt, completedAt) + } reasoning := messages[1].Parts[0] if reasoning.Type != llm.PartReasoning { t.Fatalf("first assistant part type = %q, want reasoning", reasoning.Type) @@ -136,6 +144,121 @@ func TestSessionSendIncludesPriorTurnsAndReasoning(t *testing.T) { } } +func TestSessionSendCanReplaceTailFromUserMessage(t *testing.T) { + client := dummy.NewClient( + dummy.Turn{TextChunks: []string{"first answer"}}, + dummy.Turn{TextChunks: []string{"second answer"}}, + dummy.Turn{TextChunks: []string{"replacement answer"}}, + ) + session := NewService(client).NewSession() + + first, err := session.Send(context.Background(), "first", SendOptions{}) + if err != nil { + t.Fatalf("first Send() error = %v, want nil", err) + } + collectEvents(t, first) + second, err := session.Send(context.Background(), "second", SendOptions{}) + if err != nil { + t.Fatalf("second Send() error = %v, want nil", err) + } + collectEvents(t, second) + + replaceFrom := 2 + replacement, err := session.Send(context.Background(), "edited second", SendOptions{ReplaceFrom: &replaceFrom}) + if err != nil { + t.Fatalf("replacement Send() error = %v, want nil", err) + } + collectEvents(t, replacement) + + requests := client.Requests() + if len(requests) != 3 { + t.Fatalf("request count = %d, want 3", len(requests)) + } + if got := requests[2].Messages; len(got) != 3 || got[0].Text() != "first" || got[1].Text() != "first answer" || got[2].Text() != "edited second" { + t.Fatalf("replacement request messages = %#v, want first turn plus edited prompt", requests[2].Messages) + } + messages := session.Messages() + if len(messages) != 4 { + t.Fatalf("stored message count = %d, want 4", len(messages)) + } + if messages[0].Text() != "first" || messages[1].Text() != "first answer" || messages[2].Text() != "edited second" || messages[3].Text() != "replacement answer" { + t.Fatalf("stored messages = %#v, want tail replaced by edited turn", messages) + } +} + +func TestSessionSendRejectsReplaceFromAssistantMessage(t *testing.T) { + session := NewService(dummy.NewClient(dummy.Turn{TextChunks: []string{"answer"}})).NewSession() + stream, err := session.Send(context.Background(), "first", SendOptions{}) + if err != nil { + t.Fatalf("Send() error = %v, want nil", err) + } + collectEvents(t, stream) + + replaceFrom := 1 + _, err = session.Send(context.Background(), "bad edit", SendOptions{ReplaceFrom: &replaceFrom}) + if !errors.Is(err, ErrInvalidReplaceFrom) { + t.Fatalf("Send() error = %v, want ErrInvalidReplaceFrom", err) + } +} + +func TestSessionValidateReplaceFromAndCommitStopped(t *testing.T) { + session := NewService(dummy.NewClient()).NewSession() + if err := session.ValidateReplaceFrom(nil); err != nil { + t.Fatalf("ValidateReplaceFrom(nil) error = %v, want nil", err) + } + if err := session.CommitStopped(context.Background(), " first ", SendOptions{}); err != nil { + t.Fatalf("CommitStopped append error = %v, want nil", err) + } + messages := session.Messages() + if len(messages) != 2 || messages[0].Text() != "first" || messages[1].Role != llm.RoleAssistant { + t.Fatalf("messages after stopped append = %#v, want user plus empty assistant", messages) + } + + replaceFrom := 0 + if err := session.ValidateReplaceFrom(&replaceFrom); err != nil { + t.Fatalf("ValidateReplaceFrom(user index) error = %v, want nil", err) + } + if err := session.CommitStopped(context.Background(), "edited first", SendOptions{ReplaceFrom: &replaceFrom}); err != nil { + t.Fatalf("CommitStopped replace error = %v, want nil", err) + } + messages = session.Messages() + if len(messages) != 2 || messages[0].Text() != "edited first" { + t.Fatalf("messages after stopped replacement = %#v, want edited stopped turn", messages) + } + + assistantIndex := 1 + if err := session.ValidateReplaceFrom(&assistantIndex); !errors.Is(err, ErrInvalidReplaceFrom) { + t.Fatalf("ValidateReplaceFrom(assistant index) error = %v, want ErrInvalidReplaceFrom", err) + } + if err := session.CommitStopped(context.Background(), " ", SendOptions{}); !errors.Is(err, ErrEmptyPrompt) { + t.Fatalf("CommitStopped(empty) error = %v, want ErrEmptyPrompt", err) + } +} + +func TestSessionCommitStoppedRejectsInvalidReplaceFromAndStoreFailure(t *testing.T) { + session := NewService(dummy.NewClient()).NewSession() + invalidReplaceFrom := 1 + if err := session.CommitStopped(context.Background(), "bad edit", SendOptions{ReplaceFrom: &invalidReplaceFrom}); !errors.Is(err, ErrInvalidReplaceFrom) { + t.Fatalf("CommitStopped(invalid replace_from) error = %v, want ErrInvalidReplaceFrom", err) + } + if messages := session.Messages(); len(messages) != 0 { + t.Fatalf("messages after invalid stopped edit = %#v, want none", messages) + } + + store := &chatStore{appendErr: errors.New("stopped append failed")} + persistent, err := NewPersistentService(dummy.NewClient(), store).NewPersistedSession(context.Background(), 42) + if err != nil { + t.Fatalf("NewPersistedSession error = %v, want nil", err) + } + err = persistent.CommitStopped(context.Background(), "hello", SendOptions{}) + if err == nil || !strings.Contains(err.Error(), "stopped append failed") { + t.Fatalf("CommitStopped store error = %v, want stopped append failed", err) + } + if messages := persistent.Messages(); len(messages) != 0 { + t.Fatalf("messages after failed stopped append = %#v, want none", messages) + } +} + func TestPersistentSessionLoadsHistoryAndAppendsCompletedTurn(t *testing.T) { store := &chatStore{ messages: []llm.Message{ @@ -168,11 +291,45 @@ func TestPersistentSessionLoadsHistoryAndAppendsCompletedTurn(t *testing.T) { if len(store.appended) != 1 { t.Fatalf("append count = %d, want 1", len(store.appended)) } + if len(store.replaced) != 0 { + t.Fatalf("replace count = %d, want 0", len(store.replaced)) + } if store.appended[0].conversationID != 42 || store.appended[0].user.Text() != "fresh prompt" || store.appended[0].assistant.Text() != "fresh answer" { t.Fatalf("appended turn = %#v, want completed fresh turn in conversation 42", store.appended[0]) } } +func TestPersistentSessionReplaceTailPersistsEditedTurn(t *testing.T) { + store := &chatStore{ + messages: []llm.Message{ + llm.NewTextMessage(llm.RoleUser, "stored first"), + llm.NewTextMessage(llm.RoleAssistant, "stored first answer"), + llm.NewTextMessage(llm.RoleUser, "stored second"), + llm.NewTextMessage(llm.RoleAssistant, "stored second answer"), + }, + } + client := dummy.NewClient(dummy.Turn{TextChunks: []string{"edited answer"}}) + session, err := NewPersistentService(client, store).NewPersistedSession(context.Background(), 42) + if err != nil { + t.Fatalf("NewPersistedSession error = %v, want nil", err) + } + + replaceFrom := 2 + stream, err := session.Send(context.Background(), "edited second", SendOptions{ReplaceFrom: &replaceFrom}) + if err != nil { + t.Fatalf("Send() error = %v, want nil", err) + } + collectEvents(t, stream) + + if len(store.replaced) != 1 { + t.Fatalf("replace count = %d, want 1", len(store.replaced)) + } + replaced := store.replaced[0] + if replaced.conversationID != 42 || replaced.keepMessages != 2 || replaced.user.Text() != "edited second" || replaced.assistant.Text() != "edited answer" { + t.Fatalf("replaced turn = %#v, want replacement at message index 2", replaced) + } +} + func TestPersistentSessionFallsBackWithoutStore(t *testing.T) { session, err := NewPersistentService(dummy.NewClient(), nil).NewPersistedSession(context.Background(), 42) if err != nil { @@ -191,7 +348,7 @@ func TestPersistentSessionReturnsHistoryLoadFailure(t *testing.T) { } } -func TestPersistentSessionDoesNotAppendFailedOrAbortedTurn(t *testing.T) { +func TestPersistentSessionDoesNotAppendFailedOrClosedTurn(t *testing.T) { t.Run("failed stream", func(t *testing.T) { store := &chatStore{} session, err := NewPersistentService(failingClient{}, store).NewPersistedSession(context.Background(), 42) @@ -207,12 +364,12 @@ func TestPersistentSessionDoesNotAppendFailedOrAbortedTurn(t *testing.T) { if err == nil { t.Fatalf("Next() error = nil, want failure") } - if len(store.appended) != 0 { - t.Fatalf("append count = %d, want 0", len(store.appended)) + if len(store.appended) != 0 || len(store.replaced) != 0 { + t.Fatalf("stored turn count = appended %d replaced %d, want 0", len(store.appended), len(store.replaced)) } }) - t.Run("aborted stream", func(t *testing.T) { + t.Run("closed stream without partial commit", func(t *testing.T) { store := &chatStore{} session, err := NewPersistentService(eventClient{events: []llm.Event{{Type: llm.EventTextDelta, Delta: "partial"}}}, store).NewPersistedSession(context.Background(), 42) if err != nil { @@ -225,12 +382,49 @@ func TestPersistentSessionDoesNotAppendFailedOrAbortedTurn(t *testing.T) { if err := stream.Close(); err != nil { t.Fatalf("Close() error = %v, want nil", err) } - if len(store.appended) != 0 { - t.Fatalf("append count = %d, want 0", len(store.appended)) + if len(store.appended) != 0 || len(store.replaced) != 0 { + t.Fatalf("stored turn count = appended %d replaced %d, want 0", len(store.appended), len(store.replaced)) } }) } +func TestSessionCanCommitPartialTurn(t *testing.T) { + session := NewService(eventClient{ + events: []llm.Event{ + {Type: llm.EventReasoningDelta, Delta: "thinking"}, + {Type: llm.EventTextDelta, Delta: "partial"}, + }, + }).NewSession() + + stream, err := session.Send(context.Background(), "hello", SendOptions{}) + if err != nil { + t.Fatalf("Send() error = %v, want nil", err) + } + if event, nextErr := stream.Next(); nextErr != nil || event.Type != llm.EventReasoningDelta { + t.Fatalf("first Next() = %#v, %v; want reasoning delta", event, nextErr) + } + if event, nextErr := stream.Next(); nextErr != nil || event.Type != llm.EventTextDelta { + t.Fatalf("second Next() = %#v, %v; want text delta", event, nextErr) + } + if err := stream.CommitPartial(); err != nil { + t.Fatalf("CommitPartial() error = %v, want nil", err) + } + if err := stream.Close(); err != nil { + t.Fatalf("Close() error = %v, want nil", err) + } + + messages := session.Messages() + if len(messages) != 2 { + t.Fatalf("message count = %d, want 2", len(messages)) + } + if messages[0].Text() != "hello" || messages[1].Text() != "partial" { + t.Fatalf("messages = %#v, want committed partial turn", messages) + } + if messages[1].Parts[0].Type != llm.PartReasoning || messages[1].Parts[0].Text != "thinking" { + t.Fatalf("assistant parts = %#v, want partial reasoning and text", messages[1].Parts) + } +} + func TestPersistentSessionReturnsAppendFailureAndDoesNotKeepInFlight(t *testing.T) { store := &chatStore{appendErr: errors.New("append failed")} session, err := NewPersistentService(dummy.NewClient(dummy.Turn{TextChunks: []string{"answer"}}), store).NewPersistedSession(context.Background(), 42) @@ -580,7 +774,7 @@ func TestTurnStreamFinalizeAndClonePartsGuards(t *testing.T) { userMessage: llm.NewTextMessage(llm.RoleUser, "hello"), } - if err := turn.finalize(); err != nil { + if err := turn.finalize(false); err != nil { t.Fatalf("finalize before completion error = %v, want nil", err) } if got := len(session.Messages()); got != 0 { @@ -588,10 +782,10 @@ func TestTurnStreamFinalizeAndClonePartsGuards(t *testing.T) { } turn.completed = true - if err := turn.finalize(); err != nil { + if err := turn.finalize(false); err != nil { t.Fatalf("finalize after completion error = %v, want nil", err) } - if err := turn.finalize(); err != nil { + if err := turn.finalize(false); err != nil { t.Fatalf("second finalize error = %v, want nil", err) } if got := len(session.Messages()); got != 2 { @@ -603,6 +797,24 @@ func TestTurnStreamFinalizeAndClonePartsGuards(t *testing.T) { } } +func TestTurnStreamRecordErrorTreatsContextCancellationAsCancelled(t *testing.T) { + cancelled := &TurnStream{} + cancelled.recordError(context.Canceled) + if !cancelled.recorded { + t.Fatalf("cancelled stream was not recorded") + } + + failed := &TurnStream{} + failed.recordError(errors.New("stream failed")) + if !failed.recorded { + t.Fatalf("failed stream was not recorded") + } + failed.recordError(context.Canceled) + if !failed.recorded { + t.Fatalf("second recordError changed recorded state") + } +} + func TestSessionMessagesReturnsDeepCopy(t *testing.T) { session := NewService(eventClient{ events: []llm.Event{ @@ -677,6 +889,7 @@ type chatStore struct { messages []llm.Message messagesErr error appended []appendedTurn + replaced []replacedTurn appendErr error } @@ -686,6 +899,13 @@ type appendedTurn struct { assistant llm.Message } +type replacedTurn struct { + conversationID int64 + keepMessages int + user llm.Message + assistant llm.Message +} + func (s *chatStore) Messages(context.Context, int64) ([]llm.Message, error) { if s.messagesErr != nil { return nil, s.messagesErr @@ -705,6 +925,19 @@ func (s *chatStore) AppendTurn(_ context.Context, conversationID int64, user, as return nil } +func (s *chatStore) ReplaceTailAndAppendTurn(_ context.Context, conversationID int64, keepMessages int, user, assistant llm.Message) error { + if s.appendErr != nil { + return s.appendErr + } + s.replaced = append(s.replaced, replacedTurn{ + conversationID: conversationID, + keepMessages: keepMessages, + user: user.Clone(), + assistant: assistant.Clone(), + }) + return nil +} + func collectEvents(t *testing.T, stream *TurnStream) []llm.Event { t.Helper() defer stream.Close() diff --git a/internal/llm/llm.go b/internal/llm/llm.go index 7c3a5b9..bae19b7 100644 --- a/internal/llm/llm.go +++ b/internal/llm/llm.go @@ -1,6 +1,9 @@ package llm -import "context" +import ( + "context" + "time" +) type Role string @@ -45,8 +48,9 @@ func (p Part) Clone() Part { } type Message struct { - Role Role - Parts []Part + Role Role + Parts []Part + CompletedAt time.Time } func NewTextMessage(role Role, text string) Message { @@ -69,7 +73,7 @@ func (m Message) Text() string { } func (m Message) Clone() Message { - clone := Message{Role: m.Role} + clone := Message{Role: m.Role, CompletedAt: m.CompletedAt} if m.Parts != nil { clone.Parts = make([]Part, len(m.Parts)) for i, part := range m.Parts { diff --git a/internal/llm/openresponses/client_test.go b/internal/llm/openresponses/client_test.go index e9a6347..3d8c0de 100644 --- a/internal/llm/openresponses/client_test.go +++ b/internal/llm/openresponses/client_test.go @@ -562,6 +562,40 @@ func TestClientOmitsReasoningByDefault(t *testing.T) { } } +func TestClientDoesNotSendMessageCompletionMetadata(t *testing.T) { + var requestBody map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + t.Fatalf("Decode request body error = %v", err) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, `data: {"type":"response.completed","sequence_number":1,"response":{"id":"resp_1"}}`+"\n\n") + })) + defer server.Close() + + completed := llm.NewTextMessage(llm.RoleAssistant, "prior answer") + completed.CompletedAt = time.Date(2026, 6, 14, 20, 16, 13, 0, time.UTC) + stream, err := NewClient(server.URL).Stream(context.Background(), llm.Request{ + Messages: []llm.Message{ + llm.NewTextMessage(llm.RoleUser, "prior prompt"), + completed, + llm.NewTextMessage(llm.RoleUser, "next prompt"), + }, + }) + if err != nil { + t.Fatalf("Stream() error = %v, want nil", err) + } + drainEvents(t, stream) + + encoded, err := json.Marshal(requestBody) + if err != nil { + t.Fatalf("Marshal request body error = %v, want nil", err) + } + if strings.Contains(string(encoded), "CompletedAt") || strings.Contains(string(encoded), "completed_at") || strings.Contains(string(encoded), "2026-06-14") { + t.Fatalf("request body = %s, did not expect assistant completion metadata", encoded) + } +} + func TestClientMapsPriorReasoningIntoInput(t *testing.T) { var requestBody map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/storage/sqlite.go b/internal/storage/sqlite.go index 06557ba..fcd76a1 100644 --- a/internal/storage/sqlite.go +++ b/internal/storage/sqlite.go @@ -94,7 +94,7 @@ CREATE TABLE IF NOT EXISTS schema_version ( ); INSERT INTO schema_version (version) -SELECT 1 +SELECT 0 WHERE NOT EXISTS (SELECT 1 FROM schema_version); UPDATE schema_version SET version = 1 WHERE version < 1; @@ -137,12 +137,20 @@ CREATE TABLE IF NOT EXISTS messages ( role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')), parts_json TEXT NOT NULL, created_at TEXT NOT NULL, + completed_at TEXT, UNIQUE (conversation_id, sequence) ); CREATE INDEX IF NOT EXISTS messages_conversation_sequence_idx ON messages(conversation_id, sequence); `) + if err != nil { + return err + } + if migrateErr := s.ensureMessagesCompletedAt(ctx); migrateErr != nil { + return migrateErr + } + _, err = s.db.ExecContext(ctx, `UPDATE schema_version SET version = 2 WHERE version < 2;`) return err } @@ -370,7 +378,7 @@ func (s *SQLite) Messages(ctx context.Context, conversationID int64) ([]llm.Mess return nil, ErrInvalidArgument } rows, err := s.db.QueryContext(ctx, ` -SELECT role, parts_json +SELECT role, parts_json, completed_at FROM messages WHERE conversation_id = ? ORDER BY sequence ASC @@ -384,11 +392,18 @@ ORDER BY sequence ASC for rows.Next() { var message llm.Message var partsJSON string - if err := rows.Scan(&message.Role, &partsJSON); err != nil { - return nil, err + var completedAt sql.NullString + if scanErr := rows.Scan(&message.Role, &partsJSON, &completedAt); scanErr != nil { + return nil, scanErr + } + if unmarshalErr := json.Unmarshal([]byte(partsJSON), &message.Parts); unmarshalErr != nil { + return nil, unmarshalErr } - if err := json.Unmarshal([]byte(partsJSON), &message.Parts); err != nil { - return nil, err + if completedAt.Valid && strings.TrimSpace(completedAt.String) != "" { + message.CompletedAt, err = parseTime(completedAt.String) + if err != nil { + return nil, err + } } messages = append(messages, message.Clone()) } @@ -402,48 +417,125 @@ func (s *SQLite) AppendTurn(ctx context.Context, conversationID int64, userMessa if conversationID <= 0 || userMessage.Role != llm.RoleUser || assistantMessage.Role != llm.RoleAssistant { return ErrInvalidArgument } - userParts, err := json.Marshal(userMessage.Parts) + tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } - assistantParts, err := json.Marshal(assistantMessage.Parts) - if err != nil { + defer rollback(tx) + + var maxSequence int + if err := tx.QueryRowContext(ctx, ` +SELECT COALESCE(MAX(sequence), 0) +FROM messages +WHERE conversation_id = ? +`, conversationID).Scan(&maxSequence); err != nil { return err } + if err := insertTurnAtSequence(ctx, tx, conversationID, maxSequence+1, userMessage, assistantMessage); err != nil { + return err + } + return tx.Commit() +} + +func (s *SQLite) ReplaceTailAndAppendTurn(ctx context.Context, conversationID int64, keepMessages int, userMessage, assistantMessage llm.Message) error { + if conversationID <= 0 || userMessage.Role != llm.RoleUser || assistantMessage.Role != llm.RoleAssistant { + return ErrInvalidArgument + } + if keepMessages < 0 { + return ErrInvalidArgument + } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } defer rollback(tx) - var maxSequence int64 + var messageCount int if err := tx.QueryRowContext(ctx, ` -SELECT COALESCE(MAX(sequence), 0) +SELECT COUNT(*) FROM messages WHERE conversation_id = ? -`, conversationID).Scan(&maxSequence); err != nil { +`, conversationID).Scan(&messageCount); err != nil { + return err + } + if keepMessages > messageCount { + return ErrInvalidArgument + } + if _, err := tx.ExecContext(ctx, ` +DELETE FROM messages +WHERE conversation_id = ? AND sequence > ? +`, conversationID, keepMessages); err != nil { + return err + } + if err := insertTurnAtSequence(ctx, tx, conversationID, keepMessages+1, userMessage, assistantMessage); err != nil { + return err + } + return tx.Commit() +} + +func insertTurnAtSequence(ctx context.Context, tx *sql.Tx, conversationID int64, firstSequence int, userMessage, assistantMessage llm.Message) error { + userParts, err := json.Marshal(userMessage.Parts) + if err != nil { + return err + } + assistantParts, err := json.Marshal(assistantMessage.Parts) + if err != nil { return err } createdAt := formatTime(time.Now().UTC()) if _, err := tx.ExecContext(ctx, ` -INSERT INTO messages (conversation_id, sequence, role, parts_json, created_at) -VALUES (?, ?, ?, ?, ?) -`, conversationID, maxSequence+1, userMessage.Role, string(userParts), createdAt); err != nil { +INSERT INTO messages (conversation_id, sequence, role, parts_json, created_at, completed_at) +VALUES (?, ?, ?, ?, ?, ?) +`, conversationID, firstSequence, userMessage.Role, string(userParts), createdAt, nil); err != nil { if sqliteIsConstraint(err) { return ErrInvalidArgument } return err } if _, err := tx.ExecContext(ctx, ` -INSERT INTO messages (conversation_id, sequence, role, parts_json, created_at) -VALUES (?, ?, ?, ?, ?) -`, conversationID, maxSequence+2, assistantMessage.Role, string(assistantParts), createdAt); err != nil { +INSERT INTO messages (conversation_id, sequence, role, parts_json, created_at, completed_at) +VALUES (?, ?, ?, ?, ?, ?) +`, conversationID, firstSequence+1, assistantMessage.Role, string(assistantParts), createdAt, completedAtValue(assistantMessage.CompletedAt)); err != nil { if sqliteIsConstraint(err) { return ErrInvalidArgument } return err } - return tx.Commit() + return nil +} + +func (s *SQLite) ensureMessagesCompletedAt(ctx context.Context) error { + rows, err := s.db.QueryContext(ctx, `PRAGMA table_info(messages)`) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var cid int + var name, typ string + var notNull int + var defaultValue any + var pk int + if scanErr := rows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); scanErr != nil { + return scanErr + } + if name == "completed_at" { + return rows.Err() + } + } + if rowsErr := rows.Err(); rowsErr != nil { + return rowsErr + } + _, err = s.db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN completed_at TEXT`) + return err +} + +func completedAtValue(t time.Time) any { + if t.IsZero() { + return nil + } + return formatTime(t.UTC()) } type sqlExecer interface { diff --git a/internal/storage/sqlite_test.go b/internal/storage/sqlite_test.go index ccc7328..2e10ff0 100644 --- a/internal/storage/sqlite_test.go +++ b/internal/storage/sqlite_test.go @@ -26,8 +26,82 @@ func TestSQLiteMigrateIsIdempotent(t *testing.T) { if err := store.db.QueryRowContext(context.Background(), `SELECT version FROM schema_version`).Scan(&version); err != nil { t.Fatalf("schema version query error = %v", err) } - if version != 1 { - t.Fatalf("schema version = %d, want 1", version) + if version != 2 { + t.Fatalf("schema version = %d, want 2", version) + } +} + +func TestSQLiteMigrateAddsCompletedAtToVersionOneMessages(t *testing.T) { + db, openErr := sql.Open("sqlite", ":memory:") + if openErr != nil { + t.Fatalf("sql.Open memory error = %v, want nil", openErr) + } + store := NewSQLiteForDB(db) + ctx := context.Background() + if _, err := store.db.ExecContext(ctx, ` +CREATE TABLE schema_version (version INTEGER NOT NULL); +INSERT INTO schema_version (version) VALUES (1); +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password_hash BLOB NOT NULL, + created_at TEXT NOT NULL +); +CREATE TABLE conversations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title TEXT NOT NULL, + is_default INTEGER NOT NULL DEFAULT 0 CHECK (is_default IN (0, 1)), + created_at TEXT NOT NULL +); +CREATE TABLE messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id INTEGER NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + sequence INTEGER NOT NULL, + role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')), + parts_json TEXT NOT NULL, + created_at TEXT NOT NULL, + UNIQUE (conversation_id, sequence) +); +`); err != nil { + t.Fatalf("create version one schema error = %v, want nil", err) + } + + if err := store.Migrate(ctx); err != nil { + t.Fatalf("Migrate version one schema error = %v, want nil", err) + } + + var version int + if err := store.db.QueryRowContext(ctx, `SELECT version FROM schema_version`).Scan(&version); err != nil { + t.Fatalf("schema version query error = %v", err) + } + if version != 2 { + t.Fatalf("schema version = %d, want 2", version) + } + rows, err := store.db.QueryContext(ctx, `PRAGMA table_info(messages)`) + if err != nil { + t.Fatalf("messages table_info error = %v, want nil", err) + } + defer rows.Close() + hasCompletedAt := false + for rows.Next() { + var cid int + var name, typ string + var notNull int + var defaultValue any + var pk int + if err := rows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil { + t.Fatalf("scan table_info error = %v, want nil", err) + } + if name == "completed_at" { + hasCompletedAt = true + } + } + if err := rows.Err(); err != nil { + t.Fatalf("table_info rows error = %v, want nil", err) + } + if !hasCompletedAt { + t.Fatalf("messages table missing completed_at column after migration") } } @@ -343,7 +417,8 @@ func TestSQLiteMessagesAreOrderedAndPartsJSONRoundTrips(t *testing.T) { } firstAssistant := llm.Message{ - Role: llm.RoleAssistant, + Role: llm.RoleAssistant, + CompletedAt: time.Date(2026, 6, 14, 10, 11, 12, 0, time.UTC), Parts: []llm.Part{ {Type: llm.PartReasoning, ID: "rs_1", Summary: []string{"thinking"}, EncryptedContent: "encrypted"}, {Type: llm.PartText, Text: "answer one"}, @@ -367,6 +442,12 @@ func TestSQLiteMessagesAreOrderedAndPartsJSONRoundTrips(t *testing.T) { if messages[0].Text() != "first" || messages[1].Text() != "answer one" || messages[2].Text() != "second" || messages[3].Text() != "answer two" { t.Fatalf("messages = %#v, want insertion order", messages) } + if !messages[1].CompletedAt.Equal(firstAssistant.CompletedAt) { + t.Fatalf("first assistant completed_at = %v, want %v", messages[1].CompletedAt, firstAssistant.CompletedAt) + } + if !messages[3].CompletedAt.IsZero() { + t.Fatalf("second assistant completed_at = %v, want zero value", messages[3].CompletedAt) + } reasoning := messages[1].Parts[0] if reasoning.ID != "rs_1" || reasoning.EncryptedContent != "encrypted" || len(reasoning.Summary) != 1 { t.Fatalf("round-tripped reasoning part = %#v, want metadata intact", reasoning) @@ -377,6 +458,39 @@ func TestSQLiteMessagesAreOrderedAndPartsJSONRoundTrips(t *testing.T) { } } +func TestSQLiteReplaceTailAndAppendTurnTruncatesAndAppends(t *testing.T) { + store := newMigratedTestSQLite(t) + ctx := context.Background() + user := createTestUser(t, store, "dina") + conversation, err := store.DefaultConversationForUser(ctx, user.ID) + if err != nil { + t.Fatalf("DefaultConversationForUser error = %v, want nil", err) + } + + if appendErr := store.AppendTurn(ctx, conversation.ID, llm.NewTextMessage(llm.RoleUser, "first"), llm.NewTextMessage(llm.RoleAssistant, "answer one")); appendErr != nil { + t.Fatalf("AppendTurn first error = %v, want nil", appendErr) + } + if appendErr := store.AppendTurn(ctx, conversation.ID, llm.NewTextMessage(llm.RoleUser, "second"), llm.NewTextMessage(llm.RoleAssistant, "answer two")); appendErr != nil { + t.Fatalf("AppendTurn second error = %v, want nil", appendErr) + } + + err = store.ReplaceTailAndAppendTurn(ctx, conversation.ID, 2, llm.NewTextMessage(llm.RoleUser, "edited second"), llm.NewTextMessage(llm.RoleAssistant, "replacement answer")) + if err != nil { + t.Fatalf("ReplaceTailAndAppendTurn error = %v, want nil", err) + } + + messages, err := store.Messages(ctx, conversation.ID) + if err != nil { + t.Fatalf("Messages error = %v, want nil", err) + } + if len(messages) != 4 { + t.Fatalf("message count = %d, want 4", len(messages)) + } + if messages[0].Text() != "first" || messages[1].Text() != "answer one" || messages[2].Text() != "edited second" || messages[3].Text() != "replacement answer" { + t.Fatalf("messages = %#v, want first turn plus edited replacement turn", messages) + } +} + func TestSQLiteAppendTurnRejectsInvalidInputWithoutPersisting(t *testing.T) { store := newMigratedTestSQLite(t) ctx := context.Background() @@ -404,6 +518,34 @@ func TestSQLiteAppendTurnRejectsInvalidInputWithoutPersisting(t *testing.T) { } } +func TestSQLiteReplaceTailAndAppendTurnRejectsInvalidKeepCount(t *testing.T) { + store := newMigratedTestSQLite(t) + ctx := context.Background() + user := createTestUser(t, store, "edie") + conversation, err := store.DefaultConversationForUser(ctx, user.ID) + if err != nil { + t.Fatalf("DefaultConversationForUser error = %v, want nil", err) + } + if appendErr := store.AppendTurn(ctx, conversation.ID, llm.NewTextMessage(llm.RoleUser, "first"), llm.NewTextMessage(llm.RoleAssistant, "answer")); appendErr != nil { + t.Fatalf("AppendTurn error = %v, want nil", appendErr) + } + + for _, keepMessages := range []int{-1, 3} { + replaceErr := store.ReplaceTailAndAppendTurn(ctx, conversation.ID, keepMessages, llm.NewTextMessage(llm.RoleUser, "bad"), llm.NewTextMessage(llm.RoleAssistant, "bad answer")) + if !errors.Is(replaceErr, ErrInvalidArgument) { + t.Fatalf("ReplaceTailAndAppendTurn keep=%d error = %v, want ErrInvalidArgument", keepMessages, replaceErr) + } + } + + messages, err := store.Messages(ctx, conversation.ID) + if err != nil { + t.Fatalf("Messages error = %v, want nil", err) + } + if len(messages) != 2 || messages[0].Text() != "first" || messages[1].Text() != "answer" { + t.Fatalf("messages after rejected replacements = %#v, want original turn intact", messages) + } +} + func TestSQLiteReadPathDataCorruptionReturnsErrors(t *testing.T) { store := newMigratedTestSQLite(t) ctx := context.Background() diff --git a/internal/storage/storage.go b/internal/storage/storage.go index bd7f073..3755252 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -25,6 +25,7 @@ type Store interface { DefaultConversationForUser(context.Context, int64) (Conversation, error) Messages(context.Context, int64) ([]llm.Message, error) AppendTurn(context.Context, int64, llm.Message, llm.Message) error + ReplaceTailAndAppendTurn(context.Context, int64, int, llm.Message, llm.Message) error Close() error } diff --git a/internal/web/assets/app.css b/internal/web/assets/app.css index 8cf1994..73a4356 100644 --- a/internal/web/assets/app.css +++ b/internal/web/assets/app.css @@ -32,6 +32,12 @@ button { --status-sweep-edge: rgb(136 136 136); --status-sweep-mid: rgb(192 192 192); --status-sweep-low: rgb(248 248 248); + --icon-button-color: var(--pico-muted-color); + --icon-button-hover-color: var(--pico-color); + --icon-button-inverse-color: color-mix(in srgb, var(--pico-background-color) 78%, transparent); + --icon-button-inverse-hover-color: var(--pico-background-color); + --icon-button-inverse-border-color: color-mix(in srgb, var(--pico-background-color) 30%, transparent); + --icon-button-inverse-hover-background: color-mix(in srgb, var(--pico-background-color) 18%, transparent); } @media (prefers-color-scheme: dark) { @@ -81,14 +87,18 @@ button { } .shell { + --shell-inline-padding: clamp(0.75rem, 2vw, 1.5rem); + --shell-max-width: 1080px; + + position: relative; display: grid; grid-template-rows: auto minmax(0, 1fr) auto; gap: 0; - width: min(1080px, 100%); + width: min(var(--shell-max-width), 100%); height: 100vh; height: 100svh; margin: 0 auto; - padding: 0 clamp(0.75rem, 2vw, 1.5rem); + padding: 0 var(--shell-inline-padding); } .topbar { @@ -197,16 +207,24 @@ button { .chat-panel { position: relative; min-height: 0; + width: 100vw; + margin: 0 calc(50% - 50vw); } .messages { + --message-role-offset: clamp(0.375rem, 2vw, 1.5rem); + --messages-inline-padding: max( + var(--shell-inline-padding), + calc((100vw - var(--shell-max-width)) / 2 + var(--shell-inline-padding)) + ); + display: flex; flex-direction: column; gap: 1rem; height: 100%; min-height: 0; overflow-y: auto; - padding: 1.25rem 0 1.5rem; + padding: 0 var(--messages-inline-padding); scroll-behavior: smooth; scrollbar-gutter: stable; } @@ -229,6 +247,19 @@ button { overflow-wrap: anywhere; } +.message[data-after-active-prompt="true"] { + opacity: 0.56; + filter: grayscale(0.35); + transition: + opacity 120ms ease, + filter 120ms ease; +} + +.messages[data-dirty-prompt="true"] .message[data-after-active-prompt="true"] { + pointer-events: none; + user-select: none; +} + .auth-shell { display: grid; min-height: 100vh; @@ -259,15 +290,50 @@ button { } .message-user { - align-self: stretch; - padding: 0.875rem 1rem; + --message-user-block-padding: 0.875rem; + --message-user-inline-padding: 1rem; + + align-self: flex-end; + width: calc(100% - var(--message-role-offset)); + padding: var(--message-user-block-padding) var(--message-user-inline-padding); border-radius: 0.5rem; background: var(--pico-color); color: var(--pico-background-color); } +.message-user[data-editable-prompt="true"] { + --prompt-control-reserve: 2.65rem; + + cursor: text; + padding-bottom: calc(var(--message-user-block-padding) + var(--prompt-control-reserve)); +} + +.message-user[data-active-prompt="true"] { + position: sticky; + top: 0; + bottom: 0; + z-index: 3; + box-shadow: + 0 0 0 var(--pico-outline-width) var(--pico-primary-focus), + 0 0 18px rgba(0, 0, 0, 0.08), + 0 8px 24px rgba(0, 0, 0, 0.14), + 0 2px 8px rgba(0, 0, 0, 0.1); +} + +.message-end-target:focus-visible { + outline: var(--pico-outline-width) solid var(--pico-primary-focus); + outline-offset: 0.125rem; +} + +.message-end-target > .message-text { + min-height: 1.5em; + min-height: 1lh; +} + .message-assistant { - align-self: stretch; + align-self: flex-start; + width: calc(100% - var(--message-role-offset)); + padding-right: 2.75rem; } .message-empty { @@ -723,18 +789,20 @@ button { } .message-actions { + position: absolute; + top: 0; + right: 0; display: flex; - justify-content: flex-end; gap: 0.25rem; - min-height: 1.75rem; - margin-top: 0.35rem; - opacity: 0.28; + opacity: 0; + pointer-events: none; transition: opacity 120ms ease; } -.message:hover .message-actions, -.message:focus-within .message-actions { +.message-assistant:hover .message-actions, +.message-assistant:focus-within .message-actions { opacity: 1; + pointer-events: auto; } .message-action, @@ -767,20 +835,6 @@ button { color: var(--pico-color); } -.message-user .message-actions { - opacity: 0.58; -} - -.message-user .message-action { - color: var(--pico-background-color); -} - -.message-user .message-action:hover, -.message-user .message-action:focus-visible { - background: color-mix(in srgb, var(--pico-background-color) 18%, transparent); - color: var(--pico-background-color); -} - .message-action svg, .icon-button svg, .scroll-bottom svg, @@ -831,7 +885,6 @@ button { .composer { margin: 0; - padding: 0 0 1rem; } .composer-box { @@ -844,6 +897,7 @@ button { border-radius: 0.9rem; background: var(--pico-card-background-color); box-shadow: 0 -0.35rem 1.5rem rgba(0, 0, 0, 0.04); + will-change: transform; } .composer-box:focus-within { @@ -887,25 +941,128 @@ button { white-space: nowrap; } +.composer-status:empty { + display: none; +} + .icon-button { border-radius: 999px; } -.stop-button { - border-color: var(--pico-del-color); - color: var(--pico-del-color); +.history-button, +.action-button { + border: var(--pico-border-width) solid var(--pico-muted-border-color); + border-color: var(--pico-muted-border-color); + background: transparent; + color: var(--icon-button-color); + box-shadow: none; } -.send-button { - border-color: var(--pico-primary); - background: var(--pico-primary); - color: var(--pico-primary-inverse); +.history-button:hover, +.history-button:focus-visible, +.action-button:hover, +.action-button:focus-visible { + background: var(--pico-muted-border-color); + color: var(--icon-button-hover-color); +} + +.action-button[data-action-state="stop"], +.action-button[data-action-state="stopping"] { + color: var(--icon-button-color); } .icon-button:disabled { opacity: 0.45; } +.message-edit-slot { + display: none; +} + +.message-user[data-editing="true"] { + cursor: default; +} + +.message-user[data-editing="true"] > .message-text, +.message-user[data-editing="true"] > .message-actions { + display: none; +} + +.message-user[data-editing="true"] > .message-edit-slot { + display: block; +} + +.message-user .composer { + padding: 0; +} + +.message-user .composer-box { + display: contents; + padding: 0; + border: 0; + border-radius: 0; + background: transparent; + box-shadow: none; +} + +.message-user .composer-box:focus-within { + box-shadow: none; +} + +.message-user .composer textarea { + display: block; + width: 100%; + max-height: none; + min-height: 0; + padding: 0; + color: var(--pico-background-color); + font: inherit; + line-height: inherit; + overflow-y: hidden; +} + +.message-user .composer textarea::placeholder { + color: color-mix(in srgb, var(--pico-background-color) 65%, transparent); +} + +.message-user .composer-controls { + position: absolute; + right: var(--message-user-inline-padding); + bottom: var(--message-user-block-padding); + left: var(--message-user-inline-padding); + justify-content: flex-end; +} + +.message-user .composer-status { + margin-right: auto; + text-align: left; +} + +.message-user .composer-status, +.message-user .history-button, +.message-user .action-button { + color: var(--icon-button-inverse-color); +} + +.message-user .history-button, +.message-user .action-button { + border-color: var(--icon-button-inverse-border-color); +} + +.message-user .history-button:hover, +.message-user .history-button:focus-visible, +.message-user .action-button:hover, +.message-user .action-button:focus-visible { + background: var(--icon-button-inverse-hover-background); + color: var(--icon-button-inverse-hover-color); +} + +.message-stopped .message-stopped-note { + margin-top: 0.45rem; + color: var(--pico-muted-color); + font-size: 0.75rem; +} + .sr-only { position: absolute; width: 1px; @@ -917,6 +1074,8 @@ button { @media (max-width: 640px) { .shell { + --shell-inline-padding: 0.75rem; + padding-inline: 0.75rem; } @@ -929,13 +1088,13 @@ button { display: none; } - .message-user, - .message-error { - padding: 0.8rem 0.875rem; + .message-user { + --message-user-block-padding: 0.8rem; + --message-user-inline-padding: 0.875rem; } - .composer { - padding-bottom: 0.75rem; + .message-error { + padding: 0.8rem 0.875rem; } .composer-box { diff --git a/internal/web/assets/app.js b/internal/web/assets/app.js index bee3b73..ade2746 100644 --- a/internal/web/assets/app.js +++ b/internal/web/assets/app.js @@ -5,9 +5,14 @@ const scrollButton = document.getElementById('scroll-bottom'); const form = document.getElementById('chat-form'); const prompt = document.getElementById('prompt'); - const sendButton = document.getElementById('send-button'); - const stopButton = document.getElementById('stop-button'); + const actionButton = document.getElementById('composer-action'); + const actionIcons = actionButton ? actionButton.querySelectorAll('[data-action-icon]') : []; + const revertButton = document.getElementById('revert-button'); + const previousButton = document.getElementById('previous-button'); + const nextButton = document.getElementById('next-button'); + const ffwdButton = document.getElementById('ffwd-button'); const composerStatus = document.getElementById('composer-status'); + const composerEndTarget = document.getElementById('composer-end-target'); const themeToggle = document.querySelector('[data-theme-toggle]'); const themeIcons = themeToggle ? themeToggle.querySelectorAll('[data-theme-icon]') : []; @@ -24,6 +29,10 @@ let abortRequested = false; let creatingTurn = false; let statusIDCounter = 0; + let nextMessageIndex = initialNextMessageIndex(); + let currentEditIndex = null; + let originalPromptValue = ''; + let endPromptValue = ''; let mermaidInitialized = false; let mermaidCurrentTheme = ''; let mermaidIDCounter = 0; @@ -102,10 +111,120 @@ function setStatus(text) { if (composerStatus) { - composerStatus.textContent = text; + composerStatus.textContent = text || ''; } } + function messageIndex(article) { + const value = Number.parseInt(article?.dataset.messageIndex || '', 10); + return Number.isFinite(value) ? value : -1; + } + + function initialNextMessageIndex() { + let maxIndex = -1; + messages.querySelectorAll('.message[data-message-index]').forEach(function (article) { + maxIndex = Math.max(maxIndex, messageIndex(article)); + }); + const serverIndex = Number.parseInt(messages.dataset.nextMessageIndex || '', 10); + if (Number.isFinite(serverIndex) && serverIndex >= 0) { + return Math.max(serverIndex, maxIndex + 1); + } + return maxIndex + 1; + } + + function transcriptPosition() { + return currentEditIndex === null ? nextMessageIndex : currentEditIndex; + } + + function userMessages() { + return Array.from(messages.querySelectorAll('.message-user[data-editable-prompt="true"][data-message-index]')) + .sort(function (left, right) { + return messageIndex(left) - messageIndex(right); + }); + } + + function userMessageBefore(index) { + let target = null; + userMessages().forEach(function (article) { + if (messageIndex(article) < index) { + target = article; + } + }); + return target; + } + + function userMessageAfter(index) { + return userMessages().find(function (article) { + return messageIndex(article) > index; + }) || null; + } + + function promptTextForArticle(article) { + const body = article?.querySelector('.message-text'); + return body ? body.textContent || '' : ''; + } + + function updatePromptHistoryState() { + const activeIndex = currentEditIndex; + const dirtySelectedPrompt = hasDirtySelectedPrompt(); + const messageData = messages.dataset; + if (composerEndTarget) { + if (activeIndex === null) { + composerEndTarget.dataset.activePrompt = 'true'; + composerEndTarget.dataset.editing = 'true'; + delete composerEndTarget.dataset.afterActivePrompt; + } else { + delete composerEndTarget.dataset.activePrompt; + delete composerEndTarget.dataset.editing; + composerEndTarget.dataset.afterActivePrompt = 'true'; + } + } + if (dirtySelectedPrompt) { + messageData.dirtyPrompt = 'true'; + } else { + delete messageData.dirtyPrompt; + } + messages.querySelectorAll('.message[data-message-index]').forEach(function (article) { + const index = messageIndex(article); + const data = article.dataset; + if (activeIndex !== null && index === activeIndex && article.matches('.message-user[data-editable-prompt="true"]')) { + data.activePrompt = 'true'; + } else { + delete data.activePrompt; + } + + if (activeIndex !== null && index > activeIndex) { + data.afterActivePrompt = 'true'; + } else { + delete data.afterActivePrompt; + } + }); + } + + function hasDirtyPrompt() { + return prompt.value !== originalPromptValue; + } + + function hasDirtySelectedPrompt() { + return currentEditIndex !== null && hasDirtyPrompt(); + } + + function canRevertPromptChanges() { + return !currentTurn && !creatingTurn && hasDirtySelectedPrompt(); + } + + function revertPromptChanges() { + if (!canRevertPromptChanges()) { + return false; + } + prompt.value = originalPromptValue; + syncPromptHeight(); + updatePromptHistoryState(); + updateComposerState(); + focusPrompt('end'); + return true; + } + function isNearBottom() { return messages.scrollHeight - messages.scrollTop - messages.clientHeight < nearBottomThreshold; } @@ -116,15 +235,34 @@ } } + function bottomScrollTop() { + return Math.max(0, messages.scrollHeight - messages.clientHeight); + } + + function forceScrollToBottom() { + const wasEndActive = composerEndTarget?.dataset.activePrompt; + if (wasEndActive) { + delete composerEndTarget.dataset.activePrompt; + void composerEndTarget.offsetHeight; + } + messages.scrollTo({ + top: bottomScrollTop(), + behavior: 'auto', + }); + if (wasEndActive) { + composerEndTarget.dataset.activePrompt = wasEndActive; + } + } + function scrollToBottom(force, wasNearBottom) { if (force || wasNearBottom) { - messages.scrollTop = messages.scrollHeight; + forceScrollToBottom(); } updateScrollButton(); } function insertMessage(article) { - messages.insertBefore(article, messagesEnd || null); + messages.insertBefore(article, composerEndTarget || messagesEnd || null); } function clearEmptyState() { @@ -135,7 +273,7 @@ } function ensureEmptyState() { - if (messages.querySelector('.message')) { + if (messages.querySelector('.message[data-message-index]')) { return; } const article = document.createElement('article'); @@ -154,6 +292,7 @@ if (message && message.article && message.article.parentNode === messages) { message.article.remove(); ensureEmptyState(); + updatePromptHistoryState(); updateScrollButton(); } } @@ -163,7 +302,232 @@ removeMessage(user); } - function createMessageActions() { + function articleForEditIndex(index) { + return messages.querySelector(`.message-user[data-editable-prompt="true"][data-message-index="${index}"]`); + } + + function editSlotFor(article) { + let slot = article.querySelector(':scope > .message-edit-slot'); + if (!slot) { + slot = createPromptEditSlot(); + const actions = article.querySelector(':scope > .message-actions'); + article.insertBefore(slot, actions || null); + } + return slot; + } + + function createPromptEditSlot() { + const slot = document.createElement('div'); + slot.className = 'message-edit-slot'; + return slot; + } + + function clearEditingArticle() { + messages.querySelectorAll('.message-user[data-editing="true"]').forEach(function (editing) { + delete editing.dataset.editing; + }); + } + + function animateComposerFrom(firstRect) { + if (!firstRect || window.matchMedia?.('(prefers-reduced-motion: reduce)').matches) { + return; + } + const box = form.querySelector('.composer-box'); + if (!box) { + return; + } + const lastRect = form.getBoundingClientRect(); + const dx = firstRect.left - lastRect.left; + const dy = firstRect.top - lastRect.top; + if (Math.abs(dx) < 1 && Math.abs(dy) < 1) { + return; + } + box.style.transition = 'none'; + box.style.transform = `translate(${dx}px, ${dy}px)`; + window.requestAnimationFrame(function () { + box.style.transition = 'transform 180ms ease'; + box.style.transform = 'translate(0, 0)'; + }); + box.addEventListener('transitionend', function cleanup(event) { + if (event.propertyName !== 'transform') { + return; + } + box.style.transition = ''; + box.style.transform = ''; + box.removeEventListener('transitionend', cleanup); + }); + } + + function focusPromptNow(selection, preventScroll) { + prompt.focus({ preventScroll: preventScroll !== false }); + const position = selection === 'start' ? 0 : prompt.value.length; + prompt.setSelectionRange(position, position); + } + + function focusPrompt(selection) { + window.requestAnimationFrame(function () { + focusPromptNow(selection); + }); + } + + function preferredScrollBehavior(force) { + if (force || window.matchMedia?.('(prefers-reduced-motion: reduce)').matches) { + return 'auto'; + } + return 'smooth'; + } + + function cssPixels(value) { + const pixels = Number.parseFloat(value); + return Number.isFinite(pixels) ? pixels : 0; + } + + function clampedScrollTop(value) { + return Math.min(bottomScrollTop(), Math.max(0, value)); + } + + function messageGap() { + const gap = Number.parseFloat(window.getComputedStyle(messages).rowGap); + return Number.isFinite(gap) ? gap : 0; + } + + function targetScrollBounds(target) { + const messagesRect = messages.getBoundingClientRect(); + const targetRect = target.getBoundingClientRect(); + let top = targetRect.top - messagesRect.top + messages.scrollTop; + if (target.dataset.activePrompt === 'true') { + const gap = messageGap(); + const previous = target.previousElementSibling; + const next = target.nextElementSibling; + if (previous) { + top = previous.offsetTop + previous.offsetHeight + gap; + } else if (next) { + top = next.offsetTop - target.offsetHeight - gap; + } + } + return { + top, + bottom: top + target.offsetHeight, + }; + } + + function targetVisualBounds(target) { + const messagesRect = messages.getBoundingClientRect(); + const targetRect = target.getBoundingClientRect(); + const top = targetRect.top - messagesRect.top + messages.scrollTop; + return { + top, + bottom: top + targetRect.height, + }; + } + + function scrollComposerIntoView(targetIndex, direction, force) { + const target = targetIndex === null ? composerEndTarget : articleForEditIndex(targetIndex); + if (!target) { + return; + } + if (force && targetIndex === null && direction === 'down') { + forceScrollToBottom(); + updateScrollButton(); + return; + } + const targetStyle = window.getComputedStyle(target); + const topInset = cssPixels(targetStyle.top); + const bottomInset = cssPixels(targetStyle.bottom); + const visualBounds = targetVisualBounds(target); + const visibleTop = messages.scrollTop + topInset; + const visibleBottom = messages.scrollTop + messages.clientHeight - bottomInset; + + if (!force && visualBounds.top >= visibleTop && visualBounds.bottom <= visibleBottom) { + updateScrollButton(); + return; + } + + const targetBounds = targetScrollBounds(target); + let nextScrollTop = messages.scrollTop; + if (direction === 'up') { + nextScrollTop = targetBounds.top - topInset; + } else if (direction === 'down') { + nextScrollTop = targetBounds.bottom - messages.clientHeight + bottomInset; + } else if (visualBounds.top < visibleTop) { + nextScrollTop = targetBounds.top - topInset; + } else if (visualBounds.bottom > visibleBottom) { + nextScrollTop = targetBounds.bottom - messages.clientHeight + bottomInset; + } + + nextScrollTop = clampedScrollTop(nextScrollTop); + if (Math.abs(nextScrollTop - messages.scrollTop) >= 1) { + messages.scrollTo({ + top: nextScrollTop, + behavior: preferredScrollBehavior(force), + }); + } + updateScrollButton(); + } + + function focusEndPromptOnLoad() { + const alignEndPrompt = function () { + focusPromptNow('end', false); + scrollComposerIntoView(null, 'down', true); + updateScrollButton(); + }; + window.requestAnimationFrame(alignEndPrompt); + window.setTimeout(alignEndPrompt, 0); + window.setTimeout(alignEndPrompt, 100); + window.addEventListener('load', alignEndPrompt, { once: true }); + } + + function moveComposerTo(targetIndex, selection, direction) { + const firstRect = form.getBoundingClientRect(); + if (document.activeElement instanceof HTMLElement && form.contains(document.activeElement)) { + document.activeElement.blur(); + } + if (currentEditIndex === null) { + endPromptValue = prompt.value; + } + clearEditingArticle(); + + if (targetIndex === null) { + editSlotFor(composerEndTarget).append(form); + composerEndTarget.dataset.editing = 'true'; + currentEditIndex = null; + prompt.value = endPromptValue; + originalPromptValue = endPromptValue; + } else { + const article = articleForEditIndex(targetIndex); + if (!article) { + return; + } + article.dataset.editing = 'true'; + editSlotFor(article).append(form); + currentEditIndex = targetIndex; + originalPromptValue = promptTextForArticle(article); + prompt.value = originalPromptValue; + } + + syncPromptHeight(); + updateComposerState(); + animateComposerFrom(firstRect); + scrollComposerIntoView(targetIndex, direction || 'nearest', false); + updatePromptHistoryState(); + focusPrompt(selection); + } + + function moveComposerToEndForSubmit() { + const firstRect = form.getBoundingClientRect(); + clearEditingArticle(); + editSlotFor(composerEndTarget).append(form); + composerEndTarget.dataset.editing = 'true'; + currentEditIndex = null; + endPromptValue = ''; + updatePromptHistoryState(); + animateComposerFrom(firstRect); + } + + function createMessageActions(role) { + if (role !== 'assistant') { + return null; + } const actions = document.createElement('div'); actions.className = 'message-actions'; actions.setAttribute('aria-label', 'Message actions'); @@ -173,6 +537,7 @@ copy.type = 'button'; copy.dataset.copyMessage = ''; copy.setAttribute('aria-label', 'Copy message'); + copy.title = 'Copy message'; copy.innerHTML = `${copyIcon}Copy message`; actions.append(copy); @@ -301,6 +666,12 @@ clearEmptyState(); const article = document.createElement('article'); article.className = `message message-${role}`; + const index = options && Number.isInteger(options.messageIndex) ? options.messageIndex : nextMessageIndex; + article.dataset.messageIndex = String(index); + nextMessageIndex = Math.max(nextMessageIndex, index + 1); + if (role === 'user') { + article.dataset.editablePrompt = 'true'; + } if (options && options.streaming) { article.classList.add('message-streaming'); } @@ -309,15 +680,53 @@ messageText.className = role === 'assistant' ? 'message-text markdown-body' : 'message-text message-plain'; messageText.textContent = text || ''; - if (role === 'assistant' && options && options.streaming) { - article.append(createThinkingStatus()); + article.append(messageText); + if (role === 'user') { + article.append(createPromptEditSlot()); + } + const actions = createMessageActions(role); + if (actions) { + article.append(actions); } - article.append(messageText, createMessageActions()); insertMessage(article); scrollToBottom(true, true); return { article, text: messageText }; } + function truncateMessagesFrom(index) { + const snapshot = { + articles: [], + nextMessageIndex, + }; + messages.querySelectorAll('.message[data-message-index]').forEach(function (article) { + if (messageIndex(article) >= index) { + snapshot.articles.push(article); + article.remove(); + } + }); + nextMessageIndex = index; + ensureEmptyState(); + updatePromptHistoryState(); + updateScrollButton(); + return snapshot; + } + + function restoreTruncatedMessages(snapshot) { + if (!snapshot) { + return; + } + const empty = messages.querySelector('.message-empty'); + if (empty) { + empty.remove(); + } + snapshot.articles.forEach(function (article) { + insertMessage(article); + }); + nextMessageIndex = snapshot.nextMessageIndex; + updatePromptHistoryState(); + updateScrollButton(); + } + function assignMessageIDs(user, assistant, turn) { if (turn.user_message_id) { user.article.id = `message-${turn.user_message_id}`; @@ -341,6 +750,10 @@ return Boolean(status && status.textContent); } + function assistantOutputStarted(assistant) { + return Boolean(assistant && (assistant.text.textContent || assistant.text.innerHTML)); + } + function setCompletedAt(assistant, completedAt) { if (!completedAt) { return; @@ -352,18 +765,57 @@ const actions = assistant.article.querySelector('.message-actions'); assistant.article.insertBefore(timestamp, actions || null); } - const date = new Date(completedAt); timestamp.dateTime = completedAt; - timestamp.textContent = Number.isNaN(date.getTime()) ? completedAt : `Completed ${date.toLocaleString()}`; + timestamp.textContent = completedAtText(completedAt); + } + + function completedAtText(completedAt) { + const date = new Date(completedAt); + return Number.isNaN(date.getTime()) ? completedAt : `Completed ${date.toLocaleString()}`; + } + + function localizeCompletedTimes(root) { + root.querySelectorAll('.message-completed-at[datetime]').forEach(function (timestamp) { + timestamp.textContent = completedAtText(timestamp.getAttribute('datetime') || ''); + }); } function updateComposerState() { const submitting = Boolean(currentTurn) || creatingTurn; - sendButton.disabled = submitting || prompt.value.trim() === ''; - stopButton.disabled = !currentTurn || abortRequested; + if (actionButton) { + const state = currentTurn ? (abortRequested ? 'stopping' : 'stop') : 'send'; + const label = currentTurn ? (abortRequested ? 'Stopping response' : 'Stop response') : 'Send message'; + const iconName = currentTurn ? 'stop' : 'play'; + actionButton.dataset.actionState = state; + actionButton.disabled = currentTurn ? abortRequested : creatingTurn || prompt.value.trim() === ''; + actionButton.setAttribute('aria-label', label); + actionButton.title = label; + actionIcons.forEach(function (icon) { + icon.toggleAttribute('hidden', icon.dataset.actionIcon !== iconName); + }); + } + updateHistoryButtons(submitting); prompt.disabled = submitting; } + function updateHistoryButtons(submitting) { + const busy = submitting; + const dirtySelectedPrompt = hasDirtySelectedPrompt(); + const previous = userMessageBefore(transcriptPosition()); + if (revertButton) { + revertButton.disabled = busy || !dirtySelectedPrompt; + } + if (previousButton) { + previousButton.disabled = busy || dirtySelectedPrompt || !previous; + } + if (nextButton) { + nextButton.disabled = busy || dirtySelectedPrompt || currentEditIndex === null; + } + if (ffwdButton) { + ffwdButton.disabled = busy || dirtySelectedPrompt || currentEditIndex === null; + } + } + function closeSource() { if (currentSource) { currentSource.close(); @@ -378,7 +830,7 @@ } } - function finishTurn(status) { + function finishTurn(status, options) { clearStreamErrorTimer(); closeSource(); currentTurn = null; @@ -387,7 +839,10 @@ abortRequested = false; creatingTurn = false; updateComposerState(); - setStatus(status || 'Ready'); + setStatus(status || ''); + if (!options || options.focus !== false) { + focusPrompt('end'); + } } function markTurnError(assistant, message) { @@ -403,6 +858,29 @@ assistant.text.replaceChildren(error); } + function markTurnStopped(assistant) { + if (!assistant) { + return; + } + assistant.article.classList.remove('message-streaming'); + assistant.article.classList.add('message-stopped'); + completeThinkingStatus(assistant.article); + if (!assistantHasContent(assistant)) { + const status = assistant.article.querySelector('.message-status'); + if (status) { + status.remove(); + } + } + let note = assistant.article.querySelector('.message-stopped-note'); + if (!note) { + note = document.createElement('div'); + note.className = 'message-stopped-note'; + note.textContent = 'Stopped'; + const actions = assistant.article.querySelector('.message-actions'); + assistant.article.insertBefore(note, actions || null); + } + } + function languageFromCode(code) { if (!code) { return 'code'; @@ -578,6 +1056,7 @@ } function enhanceMessage(article) { + localizeCompletedTimes(article); const body = article.querySelector('.markdown-body'); if (body) { enhanceCodeBlocks(body); @@ -636,14 +1115,18 @@ }, 1400); } - async function submitPrompt(text) { + async function submitPrompt(text, options) { + const body = { prompt: text }; + if (options && Number.isInteger(options.replaceFrom)) { + body.replace_from = options.replaceFrom; + } const response = await fetch('/chat/turns', { method: 'POST', headers: { 'Content-Type': 'application/json', [csrfHeaderName()]: csrfToken, }, - body: JSON.stringify({ prompt: text }), + body: JSON.stringify(body), }); if (!response.ok) { throw new Error(await response.text()); @@ -666,19 +1149,132 @@ async function requestAbort(turn) { abortRequested = true; updateComposerState(); - setStatus('Stopping response'); + setStatus(''); try { await abortTurn(turn); } catch (error) { if (currentTurn === turn) { abortRequested = false; updateComposerState(); - setStatus('Generating response'); + setStatus(''); } throw error; } } + function navigationDirection(targetIndex) { + const current = transcriptPosition(); + const target = targetIndex === null ? nextMessageIndex : targetIndex; + if (target < current) { + return 'up'; + } + if (target > current) { + return 'down'; + } + return 'nearest'; + } + + function navigateTo(targetIndex, selection, direction) { + moveComposerTo(targetIndex, selection || (targetIndex === null ? 'end' : 'end'), direction); + } + + function requestNavigation(targetIndex, selection) { + if (currentTurn || creatingTurn) { + return; + } + if (hasDirtySelectedPrompt()) { + return; + } + navigateTo(targetIndex, selection, navigationDirection(targetIndex)); + } + + function navigateBackward() { + const previous = userMessageBefore(transcriptPosition()); + if (previous) { + requestNavigation(messageIndex(previous), 'end'); + } + } + + function navigateForward() { + if (currentEditIndex === null) { + return; + } + const next = userMessageAfter(currentEditIndex); + requestNavigation(next ? messageIndex(next) : null, 'start'); + } + + function historyNavigationBusy() { + return Boolean(currentTurn) || creatingTurn || hasDirtySelectedPrompt(); + } + + function canNavigateBackward() { + return !historyNavigationBusy() && Boolean(userMessageBefore(transcriptPosition())); + } + + function canNavigateForward() { + return !historyNavigationBusy() && currentEditIndex !== null; + } + + function isEditablePromptInteractiveTarget(element) { + return Boolean(element.closest('a, button, input, label, select, textarea, [contenteditable="true"]')); + } + + function editablePromptArticle(element) { + const article = element.closest('.message-user[data-editable-prompt="true"][data-message-index]'); + return article && messages.contains(article) ? article : null; + } + + function canSelectPrompt(index) { + if (!hasDirtySelectedPrompt()) { + return true; + } + return index === currentEditIndex; + } + + function handleEditablePromptMouseDown(event) { + if (!(event.target instanceof Element) || isEditablePromptInteractiveTarget(event.target)) { + return; + } + const article = editablePromptArticle(event.target); + if (article) { + event.preventDefault(); + } + } + + function handleEditablePromptClick(event) { + if (!(event.target instanceof Element) || isEditablePromptInteractiveTarget(event.target)) { + return false; + } + const article = editablePromptArticle(event.target); + if (!article) { + return false; + } + const index = messageIndex(article); + if (index < 0) { + return false; + } + + event.preventDefault(); + if (!canSelectPrompt(index)) { + return true; + } + if (currentEditIndex === index) { + scrollComposerIntoView(index, 'nearest', false); + focusPrompt('end'); + return true; + } + requestNavigation(index, 'end'); + return true; + } + + function caretAtStart() { + return prompt.selectionStart === 0 && prompt.selectionEnd === 0; + } + + function caretAtEnd() { + return prompt.selectionStart === prompt.value.length && prompt.selectionEnd === prompt.value.length; + } + async function abortDisconnectedTurn(turn, user, assistant) { if (currentTurn !== turn) { return; @@ -693,7 +1289,7 @@ return; } if (currentTurn === turn) { - discardTurn(user, assistant); + markTurnStopped(assistant); finishTurn('Stream disconnected'); } } @@ -708,7 +1304,7 @@ currentSource.onopen = function () { clearStreamErrorTimer(); - setStatus('Generating response'); + setStatus(''); }; currentSource.addEventListener('preview', function (event) { @@ -723,6 +1319,7 @@ assistant.article.dataset.messageId = data.assistant_message_id; assistant.text.id = `message-body-${data.assistant_message_id}`; } + completeThinkingStatus(assistant.article); assistant.text.innerHTML = data.html || ''; enhanceMessage(assistant.article); scrollToBottom(false, wasNearBottom); @@ -736,6 +1333,9 @@ const wasNearBottom = isNearBottom(); const data = JSON.parse(event.data); ensureThinkingStatus(assistant.article).textContent += data.delta || ''; + if (assistantOutputStarted(assistant)) { + completeThinkingStatus(assistant.article); + } scrollToBottom(false, wasNearBottom); }); @@ -762,7 +1362,7 @@ removeMessage(assistant); } scrollToBottom(false, wasNearBottom); - finishTurn('Response complete'); + finishTurn(); }); currentSource.addEventListener('aborted', function () { @@ -770,8 +1370,8 @@ return; } clearStreamErrorTimer(); - discardTurn(user, assistant); - finishTurn('Response stopped'); + markTurnStopped(assistant); + finishTurn(''); }); currentSource.addEventListener('stream-error', function (event) { @@ -804,6 +1404,11 @@ function syncPromptHeight() { prompt.style.height = 'auto'; + if (prompt.closest('.message-user')) { + prompt.style.height = `${prompt.scrollHeight}px`; + prompt.style.overflowY = 'hidden'; + return; + } const maxHeight = parseFloat(window.getComputedStyle(prompt).maxHeight); const nextHeight = Number.isFinite(maxHeight) ? Math.min(prompt.scrollHeight, maxHeight) : prompt.scrollHeight; prompt.style.height = `${nextHeight}px`; @@ -818,18 +1423,27 @@ return; } - const user = addMessage('user', text); + const replaceFrom = currentEditIndex; + moveComposerToEndForSubmit(); + let truncatedSnapshot = null; + if (replaceFrom !== null) { + truncatedSnapshot = truncateMessagesFrom(replaceFrom); + } + + const userIndex = replaceFrom === null ? nextMessageIndex : replaceFrom; + const user = addMessage('user', text, { messageIndex: userIndex }); const assistant = addMessage('assistant', '', { streaming: true }); currentUser = user; currentAssistant = assistant; prompt.value = ''; + originalPromptValue = ''; syncPromptHeight(); creatingTurn = true; updateComposerState(); setStatus('Starting response'); try { - const turn = await submitPrompt(text); + const turn = await submitPrompt(text, replaceFrom === null ? null : { replaceFrom }); assignMessageIDs(user, assistant, turn); creatingTurn = false; subscribe(turn, user, assistant); @@ -837,12 +1451,26 @@ } catch (error) { creatingTurn = false; discardTurn(user, assistant); + restoreTruncatedMessages(truncatedSnapshot); finishTurn('Message not sent'); } }); prompt.addEventListener('keydown', function (event) { - if (event.key !== 'Enter' || event.shiftKey || event.ctrlKey || event.metaKey || event.altKey || event.isComposing || event.keyCode === 229) { + if (event.isComposing || event.keyCode === 229) { + return; + } + if (event.key === 'ArrowUp' && !event.shiftKey && !event.ctrlKey && !event.metaKey && !event.altKey && caretAtStart()) { + event.preventDefault(); + navigateBackward(); + return; + } + if (event.key === 'ArrowDown' && !event.shiftKey && !event.ctrlKey && !event.metaKey && !event.altKey && caretAtEnd()) { + event.preventDefault(); + navigateForward(); + return; + } + if (event.key !== 'Enter' || event.shiftKey || event.ctrlKey || event.metaKey || event.altKey) { return; } event.preventDefault(); @@ -852,29 +1480,77 @@ }); prompt.addEventListener('input', function () { + if (currentEditIndex === null) { + endPromptValue = prompt.value; + } syncPromptHeight(); + updatePromptHistoryState(); updateComposerState(); }); - stopButton.addEventListener('click', async function () { - if (!currentTurn) { - return; - } - const turn = currentTurn; - try { - await requestAbort(turn); - if (currentTurn === turn) { - discardTurn(currentUser, currentAssistant); - finishTurn('Response stopped'); + if (actionButton) { + actionButton.addEventListener('click', async function () { + if (!currentTurn) { + form.requestSubmit(); + return; } - } catch (error) { - if (currentTurn === turn && currentAssistant) { - markTurnError(currentAssistant, 'The turn could not be stopped.'); - finishTurn('Stop failed'); + const turn = currentTurn; + const assistant = currentAssistant; + try { + await requestAbort(turn); + if (currentTurn === turn) { + markTurnStopped(assistant); + finishTurn(''); + } + } catch (error) { + if (currentTurn === turn && assistant) { + markTurnError(assistant, 'The turn could not be stopped.'); + finishTurn('Stop failed'); + } } - } - }); + }); + } + if (revertButton) { + revertButton.addEventListener('click', function () { + revertPromptChanges(); + }); + } + + if (previousButton) { + previousButton.addEventListener('click', navigateBackward); + } + + if (nextButton) { + nextButton.addEventListener('click', navigateForward); + } + + if (ffwdButton) { + ffwdButton.addEventListener('click', function () { + requestNavigation(null, 'end'); + }); + } + + if (composerEndTarget) { + composerEndTarget.addEventListener('click', function (event) { + if (event.target instanceof Element && isEditablePromptInteractiveTarget(event.target)) { + return; + } + requestNavigation(null, 'end'); + }); + composerEndTarget.addEventListener('keydown', function (event) { + if (event.target instanceof Element && isEditablePromptInteractiveTarget(event.target)) { + return; + } + if (event.key !== 'Enter' && event.key !== ' ') { + return; + } + event.preventDefault(); + requestNavigation(null, 'end'); + }); + } + + messages.addEventListener('mousedown', handleEditablePromptMouseDown); messages.addEventListener('scroll', updateScrollButton, { passive: true }); if (scrollButton) { @@ -907,6 +1583,10 @@ return; } + if (handleEditablePromptClick(event)) { + return; + } + const copyMessage = event.target.closest('[data-copy-message]'); if (copyMessage) { const article = copyMessage.closest('.message'); @@ -963,6 +1643,8 @@ applyThemePreference(); enhanceAllMessages(); syncPromptHeight(); + updatePromptHistoryState(); updateComposerState(); + focusEndPromptOnLoad(); updateScrollButton(); })(); diff --git a/internal/web/server.go b/internal/web/server.go index 86961f8..7eef237 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -136,11 +136,14 @@ func (s *Server) handleIndex(w http.ResponseWriter, r *http.Request) { return } + messages := session.chat.Messages() data := pageData{ - CSRFToken: session.csrf, - ModelLabel: modelDisplayLabel(s.model), - Messages: viewMessages(session.chat.Messages(), s.markdown), - Username: session.username, + CSRFToken: session.csrf, + ModelLabel: modelDisplayLabel(s.model), + Messages: viewMessages(messages, s.markdown), + EndPrompt: endPromptViewMessage(), + NextMessageIndex: len(messages), + Username: session.username, } w.Header().Set("Content-Type", "text/html; charset=utf-8") if err := s.template.ExecuteTemplate(w, "index.html", data); err != nil { @@ -182,6 +185,10 @@ func (s *Server) handleCreateTurn(w http.ResponseWriter, r *http.Request) { writeJSONError(w, http.StatusBadRequest, "empty_prompt", "prompt must not be empty") return } + if validateErr := session.chat.ValidateReplaceFrom(request.ReplaceFrom); validateErr != nil { + writeJSONError(w, http.StatusBadRequest, "invalid_replace_from", "replace_from must point to a user message or the end of the conversation") + return + } turn, err := newTurnJobWithContext(context.WithoutCancel(r.Context()), prompt) if err != nil { @@ -205,6 +212,8 @@ func (s *Server) handleCreateTurn(w http.ResponseWriter, r *http.Request) { ReasoningEffort: s.reasoningEffort, RenderingInstructions: chat.WebRenderingInstructions(), TelemetryComponent: observability.ComponentWeb, + ReplaceFrom: request.ReplaceFrom, + Now: timeNow, }) writeJSON(w, http.StatusCreated, createTurnResponse{ @@ -687,7 +696,8 @@ func validCSRF(r *http.Request, token string) bool { } type createTurnRequest struct { - Prompt string `json:"prompt"` + Prompt string `json:"prompt"` + ReplaceFrom *int `json:"replace_from,omitempty"` } type createTurnResponse struct { @@ -698,10 +708,12 @@ type createTurnResponse struct { } type pageData struct { - CSRFToken string - ModelLabel string - Username string - Messages []viewMessage + CSRFToken string + ModelLabel string + Username string + Messages []viewMessage + EndPrompt viewMessage + NextMessageIndex int } type authPageData struct { @@ -714,11 +726,15 @@ type authPageData struct { } type viewMessage struct { - Role string - Label string - Text string - HTML template.HTML - Statuses []viewStatus + Index int + Role string + Label string + ID string + Text string + HTML template.HTML + Statuses []viewStatus + CompletedAt string + EndPrompt bool } type viewStatus struct { @@ -733,6 +749,7 @@ func viewMessages(messages []llm.Message, renderer assistantRenderer) []viewMess for messageIndex, message := range messages { text := message.Text() view := viewMessage{ + Index: messageIndex, Role: string(message.Role), Label: messageRoleLabel(message.Role), Text: text, @@ -746,6 +763,9 @@ func viewMessages(messages []llm.Message, renderer assistantRenderer) []viewMess } view.HTML = html view.Statuses = statuses + if !message.CompletedAt.IsZero() { + view.CompletedAt = message.CompletedAt.UTC().Format(time.RFC3339) + } } if view.Text == "" && view.HTML == "" && len(view.Statuses) == 0 { continue @@ -755,6 +775,15 @@ func viewMessages(messages []llm.Message, renderer assistantRenderer) []viewMess return out } +func endPromptViewMessage() viewMessage { + return viewMessage{ + Role: string(llm.RoleUser), + ID: "composer-end-target", + Text: "Latest prompt", + EndPrompt: true, + } +} + func assistantStatuses(parts []llm.Part, messageIndex int) []viewStatus { var statuses []viewStatus for partIndex, part := range parts { @@ -1076,6 +1105,10 @@ func (j *turnJob) run(session *chat.Session, opts chat.SendOptions) { stream, err := session.Send(j.ctx, j.prompt, opts) if err != nil { + if j.shouldAbort(err) { + j.emitAbortedAfterCommit(session, nil, opts) + return + } j.emitError(err) return } @@ -1089,11 +1122,19 @@ func (j *turnJob) run(session *chat.Session, opts chat.SendOptions) { event, err := stream.Next() if errors.Is(err, io.EOF) { if !completed { + if j.wasAbortRequested() { + j.emitAbortedAfterCommit(session, stream, opts) + return + } j.emitError(io.ErrUnexpectedEOF) } return } if err != nil { + if j.shouldAbort(err) { + j.emitAbortedAfterCommit(session, stream, opts) + return + } j.emitError(err) return } @@ -1130,6 +1171,10 @@ func (j *turnJob) run(session *chat.Session, opts chat.SendOptions) { log.Printf("markdown final render failed for turn %s: %v", j.id, err) html = escapedPlainTextHTML(fullMarkdown.String()) } + completedAt := stream.CompletedAt() + if completedAt.IsZero() { + completedAt = timeNow().UTC() + } j.emitTerminal("done", doneEvent{ TurnID: j.id, AssistantMessageID: j.assistantMessageID, @@ -1137,13 +1182,17 @@ func (j *turnJob) run(session *chat.Session, opts chat.SendOptions) { Usage: event.Usage, HTML: html, Statuses: assistantStatuses(assistantParts, -1), - CompletedAt: timeNow().UTC().Format(time.RFC3339), + CompletedAt: completedAt.UTC().Format(time.RFC3339), }) return } } } +func (j *turnJob) shouldAbort(err error) bool { + return j.wasAbortRequested() || errors.Is(err, context.Canceled) +} + func appendOutputDelta(parts *[]llm.Part, partType llm.PartType, delta string) { if delta == "" { return @@ -1200,13 +1249,35 @@ func escapedPlainTextHTML(text string) template.HTML { } func (j *turnJob) emitError(err error) { - if j.wasAbortRequested() || errors.Is(err, context.Canceled) { - j.emitTerminal("aborted", abortedEvent{ - TurnID: j.id, - AssistantMessageID: j.assistantMessageID, - }) + if j.shouldAbort(err) { + j.emitAborted() return } + j.emitStreamError() +} + +func (j *turnJob) emitAbortedAfterCommit(session *chat.Session, stream *chat.TurnStream, opts chat.SendOptions) { + var err error + if stream != nil { + err = stream.CommitPartial() + } else { + err = session.CommitStopped(context.Background(), j.prompt, opts) + } + if err != nil { + j.emitStreamError() + return + } + j.emitAborted() +} + +func (j *turnJob) emitAborted() { + j.emitTerminal("aborted", abortedEvent{ + TurnID: j.id, + AssistantMessageID: j.assistantMessageID, + }) +} + +func (j *turnJob) emitStreamError() { j.emitTerminal("stream-error", errorEvent{ TurnID: j.id, AssistantMessageID: j.assistantMessageID, diff --git a/internal/web/server_test.go b/internal/web/server_test.go index 88629c5..c5683df 100644 --- a/internal/web/server_test.go +++ b/internal/web/server_test.go @@ -84,12 +84,40 @@ func TestRootRendersChatPageAndSetsSessionCookie(t *testing.T) { `gpt-example`, `class="chat-panel"`, `id="scroll-bottom"`, + `title="Scroll to latest message"`, `id="composer-status"`, + `
`, + `id="revert-button"`, + `title="Revert prompt changes"`, + `id="previous-button"`, + `title="Previous prompt"`, + `id="next-button"`, + `title="Next prompt"`, + `id="ffwd-button"`, + `title="Latest prompt"`, + `id="composer-action"`, + `title="Send message"`, + `data-action-state="send"`, + `data-action-icon="play"`, + `data-action-icon="stop"`, + `id="composer-end-target"`, + `class="message message-user message-end-target"`, + `data-composer-end-target`, + `data-end-prompt="true"`, + `data-editable-prompt="true"`, + `data-editing="true"`, + `data-active-prompt="true"`, + ``, + ` -