Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
.idea/
.otel/
.vscode/
*.swp
*.swo
Expand Down
4 changes: 4 additions & 0 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ func (s authTestStore) AppendTurn(context.Context, int64, llm.Message, llm.Messa
return nil
}

func (s authTestStore) ReplaceTailAndAppendTurn(context.Context, int64, int, llm.Message, llm.Message) error {
return nil
}

func (s authTestStore) Close() error {
return nil
}
Expand Down
157 changes: 132 additions & 25 deletions internal/chat/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
)

var (
ErrEmptyPrompt = errors.New("prompt must not be empty")
ErrTurnInProgress = errors.New("turn already in progress")
ErrEmptyPrompt = errors.New("prompt must not be empty")
ErrInvalidReplaceFrom = errors.New("replace_from must point to a user message or the end of the conversation")
ErrTurnInProgress = errors.New("turn already in progress")
)

type Service struct {
Expand All @@ -32,6 +33,7 @@ func NewPersistentService(client llm.Client, store Store) Service {
type Store interface {
Messages(context.Context, int64) ([]llm.Message, error)
AppendTurn(context.Context, int64, llm.Message, llm.Message) error
ReplaceTailAndAppendTurn(context.Context, int64, int, llm.Message, llm.Message) error
}

func (s Service) NewSession() *Session {
Expand Down Expand Up @@ -68,6 +70,8 @@ type SendOptions struct {
ReasoningEffort string
RenderingInstructions string
TelemetryComponent string
ReplaceFrom *int
Now func() time.Time
}

func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*TurnStream, error) {
Expand All @@ -91,14 +95,25 @@ func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*T
Effort: effort,
}
}
now := opts.Now
if now == nil {
now = time.Now
}

s.mu.Lock()
if s.inFlight {
s.mu.Unlock()
observability.RecordSpanError(span, ErrTurnInProgress)
return nil, ErrTurnInProgress
}
request.Messages = append(llm.CloneMessages(s.messages), userMessage.Clone())
keepMessages, err := replaceFromIndex(s.messages, opts.ReplaceFrom)
if err != nil {
s.mu.Unlock()
observability.RecordSpanError(span, err)
return nil, err
}
replaceTail := opts.ReplaceFrom != nil
request.Messages = append(llm.CloneMessages(s.messages[:keepMessages]), userMessage.Clone())
s.inFlight = true
s.mu.Unlock()

Expand All @@ -112,11 +127,14 @@ func (s *Session) Send(ctx context.Context, prompt string, opts SendOptions) (*T
}

return &TurnStream{
session: s,
stream: stream,
userMessage: userMessage,
ctx: ctx,
startedAt: startedAt,
session: s,
stream: stream,
userMessage: userMessage,
keepMessages: keepMessages,
replaceTail: replaceTail,
ctx: ctx,
startedAt: startedAt,
now: now,
}, nil
}

Expand All @@ -126,14 +144,46 @@ func (s *Session) Messages() []llm.Message {
return llm.CloneMessages(s.messages)
}

func (s *Session) ValidateReplaceFrom(replaceFrom *int) error {
s.mu.Lock()
defer s.mu.Unlock()
_, err := replaceFromIndex(s.messages, replaceFrom)
return err
}

func (s *Session) CommitStopped(ctx context.Context, prompt string, opts SendOptions) error {
prompt = strings.TrimSpace(prompt)
if prompt == "" {
return ErrEmptyPrompt
}
userMessage := llm.NewTextMessage(llm.RoleUser, prompt)

s.mu.Lock()
keepMessages, err := replaceFromIndex(s.messages, opts.ReplaceFrom)
if err != nil {
s.mu.Unlock()
return err
}
if err := s.appendOrReplaceTailLocked(ctx, opts.ReplaceFrom != nil, keepMessages, userMessage, llm.Message{Role: llm.RoleAssistant}); err != nil {
s.mu.Unlock()
return err
}
s.mu.Unlock()
return nil
}

type TurnStream struct {
session *Session
stream llm.Stream
userMessage llm.Message
ctx context.Context
startedAt time.Time
session *Session
stream llm.Stream
userMessage llm.Message
keepMessages int
replaceTail bool
ctx context.Context
startedAt time.Time
now func() time.Time

assistantParts []llm.Part
completedAt time.Time
completed bool
finalized bool
recorded bool
Expand All @@ -143,7 +193,9 @@ func (s *TurnStream) Next() (llm.Event, error) {
event, err := s.stream.Next()
if err != nil {
s.recordError(err)
s.abort()
if !errors.Is(err, context.Canceled) {
s.abort()
}
return llm.Event{}, err
}

Expand All @@ -156,7 +208,7 @@ func (s *TurnStream) Next() (llm.Event, error) {
s.mergeCompletedPart(event.Part)
case llm.EventCompleted:
s.completed = true
if err := s.finalize(); err != nil {
if err := s.finalize(false); err != nil {
s.recordError(err)
s.abort()
return event, err
Expand All @@ -181,6 +233,14 @@ func (s *TurnStream) Close() error {
return s.stream.Close()
}

func (s *TurnStream) CommitPartial() error {
return s.finalize(true)
}

func (s *TurnStream) CompletedAt() time.Time {
return s.completedAt
}

func (s *TurnStream) appendDelta(partType llm.PartType, delta string) {
if delta == "" {
return
Expand Down Expand Up @@ -230,29 +290,40 @@ func (s *TurnStream) mergeCompletedPart(part llm.Part) {
s.assistantParts = append(s.assistantParts, part.Clone())
}

func (s *TurnStream) finalize() error {
if !s.completed || s.finalized {
func (s *TurnStream) finalize(allowIncomplete bool) error {
if (!s.completed && !allowIncomplete) || s.finalized {
return nil
}

assistant := llm.Message{
Role: llm.RoleAssistant,
Parts: cloneParts(s.assistantParts),
Role: llm.RoleAssistant,
Parts: cloneParts(s.assistantParts),
CompletedAt: s.completionTime(),
}

s.session.mu.Lock()
defer s.session.mu.Unlock()
if s.session.store != nil {
if err := s.session.store.AppendTurn(context.Background(), s.session.conversationID, s.userMessage.Clone(), assistant.Clone()); err != nil {
return err
}
if err := s.session.appendOrReplaceTailLocked(context.Background(), s.replaceTail, s.keepMessages, s.userMessage, assistant); err != nil {
return err
}
s.finalized = true
s.session.messages = append(s.session.messages, s.userMessage.Clone(), assistant)
s.session.inFlight = false
return nil
}

func (s *TurnStream) completionTime() time.Time {
if !s.completed {
return time.Time{}
}
if s.completedAt.IsZero() {
now := s.now
if now == nil {
now = time.Now
}
s.completedAt = now().UTC()
}
return s.completedAt
}

func (s *TurnStream) abort() {
if s.finalized {
return
Expand Down Expand Up @@ -299,6 +370,42 @@ func (s *Session) releaseTurn() {
s.inFlight = false
}

func replaceFromIndex(messages []llm.Message, replaceFrom *int) (int, error) {
if replaceFrom == nil {
return len(messages), nil
}
keepMessages := *replaceFrom
if keepMessages < 0 || keepMessages > len(messages) {
return 0, ErrInvalidReplaceFrom
}
if keepMessages < len(messages) && messages[keepMessages].Role != llm.RoleUser {
return 0, ErrInvalidReplaceFrom
}
return keepMessages, nil
}

func (s *Session) appendOrReplaceTailLocked(ctx context.Context, replaceTail bool, keepMessages int, userMessage, assistant llm.Message) error {
if s.store != nil {
var err error
if replaceTail {
err = s.store.ReplaceTailAndAppendTurn(ctx, s.conversationID, keepMessages, userMessage.Clone(), assistant.Clone())
} else {
err = s.store.AppendTurn(ctx, s.conversationID, userMessage.Clone(), assistant.Clone())
}
if err != nil {
return err
}
}
nextMessages := llm.CloneMessages(s.messages)
if replaceTail {
nextMessages = llm.CloneMessages(s.messages[:keepMessages])
}
nextMessages = append(nextMessages, userMessage.Clone(), assistant.Clone())
s.messages = nextMessages
s.inFlight = false
return nil
}

func cloneParts(parts []llm.Part) []llm.Part {
if parts == nil {
return nil
Expand Down
Loading
Loading