From 5c092f07e2350ec8380a629899925e4b27eb7602 Mon Sep 17 00:00:00 2001 From: Israel Akintunde Date: Wed, 3 Jun 2026 01:18:49 +0100 Subject: [PATCH] feat: redesign chat UI and add SSE token streaming Chat UI: - Remove page title header from chat view; replace empty state with conversational welcome centered in the message area - Redesign message bubbles: user messages use a neutral elevated pill, assistant messages render as plain text with no border box - Move model selector out of the compose bar into a chip below-right; compose area is now just a text field + send icon - Add animated typing indicator (three-dot pulse) while waiting for the first token; replace with streaming text + blinking cursor as tokens arrive - Optimistic UI: user bubble and cleared input appear instantly on send before the API round-trip completes - Permanent timestamps below every bubble (smart-formatted: time / Yesterday / date) - Shake animation + red border on the input when validation fails; no inline error text - Hidden scrollbar; reliable scroll-to-bottom using direct scrollTop assignment with setTimeout deferral after state updates - Center chat column (max 800px, mx auto) within the right content area; fix AppShell main box mx:auto so it centers on wide screens Backend: - Fix Ollama model ID parsing: colons in model tags (e.g. llama3.2:1b) were mistaken for provider separators in normalizeProviderModels - Fix OpenAI Responses API: assistant history messages must use content type output_text, not input_text - Map empty-SQL LLM responses to ErrInvalidSQL (400) with the model's explanation instead of a generic 500 internal server error - Add SSE streaming endpoint POST /conversations/:id/messages/stream that streams explanation tokens via ExplanationStreamer as they arrive from Ollama (stream:true NDJSON) and OpenAI (stream:true SSE deltas), then runs SQL validation and execution before sending the final events - Add StreamingClient interface; both Ollama and OpenAI clients implement GenerateSQLStream; non-streaming clients fall back automatically - Add StreamErrorMessage helper to map service errors to SSE error events --- apps/backend/servers/echo/handlers/http.go | 21 + .../resources/conversations/handler.go | 1 + .../conversations/send_message_stream.go | 79 ++ apps/backend/services/chat/api/api.go | 5 + apps/backend/services/chat/api/client.go | 1 + apps/backend/services/chat/api/stream.go | 16 + .../internal/chat/llm/generate_sql_helpers.go | 7 +- .../internal/chat/llm/ollama/streaming.go | 133 +++ .../chat/internal/chat/llm/openai/client.go | 11 +- .../internal/chat/llm/openai/streaming.go | 198 ++++ .../chat/internal/chat/llm/registry.go | 23 +- .../chat/internal/chat/llm/streaming.go | 83 ++ .../chat/internal/chat/send_message_stream.go | 222 ++++ apps/backend/services/chat/pkg/chat/stream.go | 22 + apps/web/src/api/client.ts | 18 + apps/web/src/components/layout/AppShell.tsx | 2 +- apps/web/src/pages/chat/ChatPage.tsx | 1007 ++++++++++------- 17 files changed, 1458 insertions(+), 391 deletions(-) create mode 100644 apps/backend/servers/echo/handlers/resources/conversations/send_message_stream.go create mode 100644 apps/backend/services/chat/api/stream.go create mode 100644 apps/backend/services/chat/internal/chat/llm/ollama/streaming.go create mode 100644 apps/backend/services/chat/internal/chat/llm/openai/streaming.go create mode 100644 apps/backend/services/chat/internal/chat/llm/streaming.go create mode 100644 apps/backend/services/chat/internal/chat/send_message_stream.go create mode 100644 apps/backend/services/chat/pkg/chat/stream.go diff --git a/apps/backend/servers/echo/handlers/http.go b/apps/backend/servers/echo/handlers/http.go index 3f6986a..8446433 100644 --- a/apps/backend/servers/echo/handlers/http.go +++ b/apps/backend/servers/echo/handlers/http.go @@ -117,6 +117,27 @@ func Error(c echo.Context, logger *slog.Logger, err error) error { return echo.NewHTTPError(status, ErrorResponse{Error: message}) } +// StreamErrorMessage maps an error to a user-visible string for SSE error events, +// using the same sentinel-based rules as Error() but without writing an HTTP response. +func StreamErrorMessage(err error) string { + if err == nil { + return "" + } + switch { + case errors.Is(err, chaterrors.ErrProviderNotAvailable), + errors.Is(err, chaterrors.ErrInvalidProviderConfig), + errors.Is(err, chaterrors.ErrModelNotAvailable), + errors.Is(err, chaterrors.ErrEmbeddedSnapshotNotReady), + errors.Is(err, chaterrors.ErrInvalidSQL), + errors.Is(err, chaterrors.ErrPromptTooLarge), + errors.Is(err, chaterrors.ErrUnsupportedDatabaseKind), + errors.Is(err, chaterrors.ErrMessageExecutionFailed): + return err.Error() + default: + return "internal server error" + } +} + func logInternalError(c echo.Context, logger *slog.Logger, err error) { if logger == nil { logger = slog.Default() diff --git a/apps/backend/servers/echo/handlers/resources/conversations/handler.go b/apps/backend/servers/echo/handlers/resources/conversations/handler.go index aa6c005..d125108 100644 --- a/apps/backend/servers/echo/handlers/resources/conversations/handler.go +++ b/apps/backend/servers/echo/handlers/resources/conversations/handler.go @@ -32,4 +32,5 @@ func (h *Handler) Register(group *echo.Group) { group.DELETE("/conversations/:"+conversationIDParam, h.DeleteConversation) group.GET("/conversations/:"+conversationIDParam+"/messages", h.ListMessages) group.POST("/conversations/:"+conversationIDParam+"/messages", h.SendMessage) + group.POST("/conversations/:"+conversationIDParam+"/messages/stream", h.SendMessageStream) } diff --git a/apps/backend/servers/echo/handlers/resources/conversations/send_message_stream.go b/apps/backend/servers/echo/handlers/resources/conversations/send_message_stream.go new file mode 100644 index 0000000..49c6aa9 --- /dev/null +++ b/apps/backend/servers/echo/handlers/resources/conversations/send_message_stream.go @@ -0,0 +1,79 @@ +package conversations + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + echohandlers "github.com/Uncensored-Developer/datalk/apps/backend/servers/echo/handlers" + chatapi "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/api" + chattype "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/chat" + "github.com/labstack/echo/v4" +) + +// SendMessageStream handles POST /conversations/:id/messages/stream. +// It runs the full send-message pipeline and writes SSE events to the response +// as each stage completes, including token-by-token explanation streaming. +func (h *Handler) SendMessageStream(c echo.Context) error { + userID, err := echohandlers.UserID(c) + if err != nil { + return err + } + + conversationID, err := echohandlers.Int64Param(c, conversationIDParam) + if err != nil { + return err + } + + var req sendMessageRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, echohandlers.ErrorResponse{Error: "invalid request body"}) + } + if strings.TrimSpace(req.Content) == "" { + return echo.NewHTTPError(http.StatusBadRequest, echohandlers.ErrorResponse{Error: "message content is required"}) + } + + // Set SSE response headers before writing anything. + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().Header().Set("X-Accel-Buffering", "no") // disable nginx buffering + c.Response().WriteHeader(http.StatusOK) + + flusher, ok := c.Response().Writer.(http.Flusher) + if !ok { + // Streaming not supported by the underlying writer — write a single error event. + _, _ = fmt.Fprintf(c.Response(), "data: {\"type\":\"error\",\"error\":\"streaming not supported\"}\n\n") + return nil + } + + writeEvent := func(event chatapi.StreamEvent) { + data, err := json.Marshal(event) + if err != nil { + return + } + _, _ = fmt.Fprintf(c.Response(), "data: %s\n\n", data) + flusher.Flush() + } + + serviceErr := h.service.SendMessageStream( + c.Request().Context(), + chattype.SendMessageParams{ + UserID: userID, + ConversationID: conversationID, + Content: req.Content, + Provider: req.Provider, + Model: req.Model, + }, + writeEvent, + ) + + if serviceErr != nil { + // Map to a user-visible error message (same logic as the non-streaming handler). + msg := echohandlers.StreamErrorMessage(serviceErr) + writeEvent(chatapi.StreamEvent{Type: chatapi.StreamEventError, Error: msg}) + } + + return nil +} diff --git a/apps/backend/services/chat/api/api.go b/apps/backend/services/chat/api/api.go index b2a013b..ea5b5bf 100644 --- a/apps/backend/services/chat/api/api.go +++ b/apps/backend/services/chat/api/api.go @@ -19,6 +19,7 @@ type Service interface { DeleteConversation(ctx context.Context, userID int32, conversationID int64) error ListMessages(ctx context.Context, userID int32, filter chattype.ListMessagesFilter) ([]*chattype.MessageDetails, error) SendMessage(ctx context.Context, params chattype.SendMessageParams) (*chattype.AssistantTurn, error) + SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error ListProviderConfigs(ctx context.Context) ([]*llmtypes.ProviderConfig, error) SaveProviderConfig(ctx context.Context, params chat.SaveProviderConfigParams) (*llmtypes.ProviderConfig, error) } @@ -66,6 +67,10 @@ func (a *Api) SendMessage(ctx context.Context, params SendMessageParams) (*chatt return a.service.SendMessage(ctx, params) } +func (a *Api) SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error { + return a.service.SendMessageStream(ctx, params, onEvent) +} + func (a *Api) ListProviderConfigs(ctx context.Context) ([]*llmtypes.ProviderConfig, error) { return a.service.ListProviderConfigs(ctx) } diff --git a/apps/backend/services/chat/api/client.go b/apps/backend/services/chat/api/client.go index 47b900c..d0335ed 100644 --- a/apps/backend/services/chat/api/client.go +++ b/apps/backend/services/chat/api/client.go @@ -17,6 +17,7 @@ type Client interface { DeleteConversation(ctx context.Context, userID int32, conversationID int64) error ListMessages(ctx context.Context, userID int32, filter ListMessagesFilter) ([]*chattype.MessageDetails, error) SendMessage(ctx context.Context, params SendMessageParams) (*chattype.AssistantTurn, error) + SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error ListProviderConfigs(ctx context.Context) ([]*llmtypes.ProviderConfig, error) SaveProviderConfig(ctx context.Context, params SaveProviderConfigParams) (*llmtypes.ProviderConfig, error) ListAvailableModels(ctx context.Context) ([]llmtypes.Model, error) diff --git a/apps/backend/services/chat/api/stream.go b/apps/backend/services/chat/api/stream.go new file mode 100644 index 0000000..23ce5c6 --- /dev/null +++ b/apps/backend/services/chat/api/stream.go @@ -0,0 +1,16 @@ +package api + +import chattype "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/chat" + +// Re-export stream event types from pkg/chat so callers can use chatapi.StreamEvent etc. +type StreamEventType = chattype.StreamEventType +type StreamEvent = chattype.StreamEvent + +const ( + StreamEventUserMessage = chattype.StreamEventUserMessage + StreamEventToken = chattype.StreamEventToken + StreamEventAssistantMessage = chattype.StreamEventAssistantMessage + StreamEventExecution = chattype.StreamEventExecution + StreamEventComplete = chattype.StreamEventComplete + StreamEventError = chattype.StreamEventError +) diff --git a/apps/backend/services/chat/internal/chat/llm/generate_sql_helpers.go b/apps/backend/services/chat/internal/chat/llm/generate_sql_helpers.go index a53b4dd..7d92edd 100644 --- a/apps/backend/services/chat/internal/chat/llm/generate_sql_helpers.go +++ b/apps/backend/services/chat/internal/chat/llm/generate_sql_helpers.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + chaterrors "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/errors" llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm" schematypes "github.com/Uncensored-Developer/datalk/apps/backend/services/schemas/pkg/schemas" "github.com/mdobak/go-xerrors" @@ -114,7 +115,11 @@ func ParseGenerateSQLResponse(rawRequest, rawResponse []byte, payloadText string payload.SQL = strings.TrimSpace(payload.SQL) payload.Explanation = strings.TrimSpace(payload.Explanation) if payload.SQL == "" { - return nil, xerrors.New("structured SQL payload did not include sql") + explanation := strings.TrimSpace(payload.Explanation) + if explanation != "" { + return nil, xerrors.Newf("%s: %w", explanation, chaterrors.ErrInvalidSQL) + } + return nil, xerrors.Newf("the model could not generate SQL for this question: %w", chaterrors.ErrInvalidSQL) } return &llmtypes.GenerateSQLResponse{ diff --git a/apps/backend/services/chat/internal/chat/llm/ollama/streaming.go b/apps/backend/services/chat/internal/chat/llm/ollama/streaming.go new file mode 100644 index 0000000..fa9df0d --- /dev/null +++ b/apps/backend/services/chat/internal/chat/llm/ollama/streaming.go @@ -0,0 +1,133 @@ +package ollama + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + + llmcore "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/internal/chat/llm" + llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm" + "github.com/mdobak/go-xerrors" + "golang.org/x/net/context" +) + +// streamingChatRequest mirrors chatRequest but always sets Stream: true. +type streamingChatRequest struct { + Model string `json:"model"` + Messages []message `json:"messages"` + Format map[string]any `json:"format"` + Stream bool `json:"stream"` +} + +// streamingChunk is one NDJSON line from the Ollama streaming response. +type streamingChunk struct { + Message message `json:"message"` + Done bool `json:"done"` + DoneReason string `json:"done_reason,omitempty"` + // Usage fields (only present when done=true) + PromptEvalCount *int `json:"prompt_eval_count,omitempty"` + EvalCount *int `json:"eval_count,omitempty"` +} + +// GenerateSQLStream calls the Ollama chat API with streaming enabled. +// As JSON tokens arrive they are fed into an explanationStreamer; each +// explanation character is forwarded to onToken. When the stream ends the +// accumulated JSON is parsed and the full GenerateSQLResponse is returned. +func (c *Client) GenerateSQLStream(ctx context.Context, req llmtypes.GenerateSQLRequest, onToken func(string)) (*llmtypes.GenerateSQLResponse, error) { + if strings.TrimSpace(req.Model) == "" { + return nil, xerrors.New("model is required") + } + + requestBody := streamingChatRequest{ + Model: req.Model, + Messages: ollamaMessages(req), + Format: llmcore.GenerateSQLSchema(), + Stream: true, + } + + rawRequest, err := json.Marshal(requestBody) + if err != nil { + return nil, xerrors.Newf("failed to marshal ollama streaming request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/chat", bytes.NewReader(rawRequest)) + if err != nil { + return nil, xerrors.Newf("failed to create ollama streaming request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, xerrors.Newf("failed to call ollama streaming api: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= http.StatusBadRequest { + body, _ := io.ReadAll(resp.Body) + if err := decodeHTTPError("ollama streaming api", resp.StatusCode, body); err != nil { + return nil, err + } + return nil, xerrors.Newf("ollama streaming api failed with status %d", resp.StatusCode) + } + + streamer := llmcore.NewExplanationStreamer() + + var ( + fullText strings.Builder + promptEvalCount *int + evalCount *int + finishReason *string + ) + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + var chunk streamingChunk + if err := json.Unmarshal(line, &chunk); err != nil { + continue // skip malformed lines + } + + token := chunk.Message.Content + if token != "" { + fullText.WriteString(token) + streamer.Feed(token, onToken) + } + + if chunk.Done { + promptEvalCount = chunk.PromptEvalCount + evalCount = chunk.EvalCount + if strings.TrimSpace(chunk.DoneReason) != "" && chunk.DoneReason != "stop" { + reason := chunk.DoneReason + finishReason = &reason + } + break + } + } + + if err := scanner.Err(); err != nil { + return nil, xerrors.Newf("error reading ollama stream: %w", err) + } + + responseText := strings.TrimSpace(fullText.String()) + if responseText == "" { + return nil, xerrors.New("ollama streaming response was empty") + } + + usage := &llmtypes.Usage{ + InputTokens: promptEvalCount, + OutputTokens: evalCount, + } + if promptEvalCount != nil && evalCount != nil { + total := *promptEvalCount + *evalCount + usage.TotalTokens = &total + } + + return llmcore.ParseGenerateSQLResponse(rawRequest, []byte(responseText), responseText, usage, finishReason) +} diff --git a/apps/backend/services/chat/internal/chat/llm/openai/client.go b/apps/backend/services/chat/internal/chat/llm/openai/client.go index a0e3a92..1cbb914 100644 --- a/apps/backend/services/chat/internal/chat/llm/openai/client.go +++ b/apps/backend/services/chat/internal/chat/llm/openai/client.go @@ -338,11 +338,18 @@ func openAIInputMessages(req llmtypes.GenerateSQLRequest) []inputMessage { }) for _, message := range promptMessages { + role := normalizeOpenAIRole(message.Role) + // OpenAI Responses API requires "output_text" for assistant turns + // and "input_text" for user/developer turns. + contentType := "input_text" + if role == "assistant" { + contentType = "output_text" + } messages = append(messages, inputMessage{ - Role: normalizeOpenAIRole(message.Role), + Role: role, Content: []contentPart{ { - Type: "input_text", + Type: contentType, Text: message.Content, }, }, diff --git a/apps/backend/services/chat/internal/chat/llm/openai/streaming.go b/apps/backend/services/chat/internal/chat/llm/openai/streaming.go new file mode 100644 index 0000000..523f2c7 --- /dev/null +++ b/apps/backend/services/chat/internal/chat/llm/openai/streaming.go @@ -0,0 +1,198 @@ +package openai + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "strings" + + llmcore "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/internal/chat/llm" + llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm" + "github.com/mdobak/go-xerrors" +) + +// streamingGenerateRequest adds stream:true to the Responses API payload. +type streamingGenerateRequest struct { + Model string `json:"model"` + Input []inputMessage `json:"input"` + Text textConfig `json:"text"` + Stream bool `json:"stream"` +} + +// sseEvent is a parsed Server-Sent Events line pair. +type sseEvent struct { + Event string + Data string +} + +// outputTextDelta is the payload of response.output_text.delta SSE events. +type outputTextDelta struct { + Type string `json:"type"` + Delta string `json:"delta"` +} + +// outputTextDone is the payload of response.output_text.done SSE events. +type outputTextDone struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// streamingUsage holds usage data from the response.completed event. +type streamingUsage struct { + InputTokens *int `json:"input_tokens"` + OutputTokens *int `json:"output_tokens"` + TotalTokens *int `json:"total_tokens"` +} + +type streamingCompletedResponse struct { + Response struct { + Status string `json:"status"` + Usage streamingUsage `json:"usage"` + OutputText string `json:"output_text"` + } `json:"response"` +} + +// GenerateSQLStream calls the OpenAI Responses API with stream:true. +// response.output_text.delta events feed the explanation streamer, and the +// complete JSON text is assembled from response.output_text.done. +func (c *Client) GenerateSQLStream(ctx context.Context, req llmtypes.GenerateSQLRequest, onToken func(string)) (*llmtypes.GenerateSQLResponse, error) { + if strings.TrimSpace(req.Model) == "" { + return nil, xerrors.New("model is required") + } + + requestBody := streamingGenerateRequest{ + Model: req.Model, + Input: openAIInputMessages(req), + Text: textConfig{Format: map[string]any{"type": "json_schema", "name": "sql_generation", "strict": true, "schema": llmcore.GenerateSQLSchema()}}, + Stream: true, + } + + rawRequest, err := json.Marshal(requestBody) + if err != nil { + return nil, xerrors.Newf("failed to marshal openai streaming request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/v1/responses", bytes.NewReader(rawRequest)) + if err != nil { + return nil, xerrors.Newf("failed to create openai streaming request: %w", err) + } + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, xerrors.Newf("failed to call openai streaming api: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= http.StatusBadRequest { + body, _ := io.ReadAll(resp.Body) + if err := decodeHTTPError("openai streaming api", resp.StatusCode, body); err != nil { + return nil, err + } + return nil, xerrors.Newf("openai streaming api failed with status %d", resp.StatusCode) + } + + streamer := llmcore.NewExplanationStreamer() + + var ( + fullText string + usageData streamingUsage + finishReason *string + ) + + scanner := bufio.NewScanner(resp.Body) + var currentEvent string + + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "event: ") { + currentEvent = strings.TrimPrefix(line, "event: ") + continue + } + + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + switch currentEvent { + case "response.output_text.delta": + var delta outputTextDelta + if json.Unmarshal([]byte(data), &delta) == nil && delta.Delta != "" { + streamer.Feed(delta.Delta, onToken) + } + + case "response.output_text.done": + var done outputTextDone + if json.Unmarshal([]byte(data), &done) == nil { + fullText = done.Text + } + + case "response.completed": + var completed streamingCompletedResponse + if json.Unmarshal([]byte(data), &completed) == nil { + usageData = completed.Response.Usage + if completed.Response.Status != "" && completed.Response.Status != "completed" { + status := completed.Response.Status + finishReason = &status + } + if fullText == "" { + fullText = completed.Response.OutputText + } + } + } + + currentEvent = "" + } + } + + if err := scanner.Err(); err != nil { + return nil, xerrors.Newf("error reading openai stream: %w", err) + } + + responseText := strings.TrimSpace(fullText) + if responseText == "" { + return nil, xerrors.New("openai streaming response was empty") + } + + return llmcore.ParseGenerateSQLResponse( + rawRequest, + []byte(responseText), + responseText, + &llmtypes.Usage{ + InputTokens: usageData.InputTokens, + OutputTokens: usageData.OutputTokens, + TotalTokens: usageData.TotalTokens, + }, + finishReason, + ) +} + +// sseEvents is kept for reference but parsing is done inline above. +var _ = sseEvent{} // suppress unused warning + +// parseSSELine is a helper used in tests. +func parseSSELine(line string) (event, data string) { + if after, ok := strings.CutPrefix(line, "event: "); ok { + return after, "" + } + if after, ok := strings.CutPrefix(line, "data: "); ok { + return "", after + } + return "", "" +} + +// NewExplanationStreamerForTesting exposes the streamer for package tests. +func NewExplanationStreamerForTesting() *llmcore.ExplanationStreamer { + return llmcore.NewExplanationStreamer() +} + +// Ensure bufio and bytes are used. +var _ = bufio.NewScanner +var _ = bytes.NewReader diff --git a/apps/backend/services/chat/internal/chat/llm/registry.go b/apps/backend/services/chat/internal/chat/llm/registry.go index 078cde3..fc3c4f7 100644 --- a/apps/backend/services/chat/internal/chat/llm/registry.go +++ b/apps/backend/services/chat/internal/chat/llm/registry.go @@ -246,17 +246,22 @@ func normalizeProviderModels(provider llmtypes.Provider, models []llmtypes.Model qualifiedModelID := QualifiedModelID(provider, rawModelID) if strings.Contains(rawModelID, ":") { + // Only treat the colon as a provider:model separator if the part + // before it is actually a known provider. Ollama model IDs use ":" + // as a tag separator (e.g. "llama3.2:1b"), which must not be + // mistaken for a qualified provider prefix. parsedProvider, parsedModelID, err := ParseQualifiedModelID(rawModelID) - if err != nil { - return nil, err + if err == nil { + // Providers are allowed to return already-qualified ids, but the + // embedded provider prefix still has to match the provider being queried. + if parsedProvider != provider { + return nil, xerrors.Newf("provider %s returned model id for different provider %s", provider, parsedProvider) + } + qualifiedModelID = QualifiedModelID(parsedProvider, parsedModelID) + rawModelID = parsedModelID } - // Providers are allowed to return already-qualified ids, but the - // embedded provider prefix still has to match the provider being queried. - if parsedProvider != provider { - return nil, xerrors.Newf("provider %s returned model id for different provider %s", provider, parsedProvider) - } - qualifiedModelID = QualifiedModelID(parsedProvider, parsedModelID) - rawModelID = parsedModelID + // If parsing failed the colon is part of the model's own id (e.g. an + // Ollama tag). Fall through and use the already-computed qualifiedModelID. } // Deduplicate after normalization so "gpt-5.2" and "openai:gpt-5.2" diff --git a/apps/backend/services/chat/internal/chat/llm/streaming.go b/apps/backend/services/chat/internal/chat/llm/streaming.go new file mode 100644 index 0000000..ed11531 --- /dev/null +++ b/apps/backend/services/chat/internal/chat/llm/streaming.go @@ -0,0 +1,83 @@ +package llm + +import ( + "context" + "strings" + + llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm" +) + +// StreamingClient extends Client with token-by-token streaming support. +// Providers that implement this interface will have their explanation text +// streamed token by token via the onToken callback. +type StreamingClient interface { + Client + GenerateSQLStream(ctx context.Context, req llmtypes.GenerateSQLRequest, onToken func(string)) (*llmtypes.GenerateSQLResponse, error) +} + +// ExplanationStreamer extracts the "explanation" field value from a streaming +// JSON response and emits its characters via an emit callback as they arrive. +// +// The LLM outputs structured JSON (e.g. {"sql":"...","explanation":"..."}). As +// tokens stream in the full text accumulates. The streamer detects the start of +// the explanation string value and emits each subsequent character until the +// closing (unescaped) double-quote is found. +type ExplanationStreamer struct { + buf strings.Builder + state int // 0 = before explanation, 1 = in explanation, 2 = done + start int // byte position in buf where explanation value begins + sent int // number of explanation bytes already emitted +} + +// NewExplanationStreamer creates a new ExplanationStreamer ready for use. +func NewExplanationStreamer() *ExplanationStreamer { + return &ExplanationStreamer{} +} + +// Feed appends token to the internal buffer and emits any newly-available +// explanation characters via emit. +func (s *ExplanationStreamer) Feed(token string, emit func(string)) { + s.buf.WriteString(token) + + if s.state == 2 { + return + } + + full := s.buf.String() + + if s.state == 0 { + // Look for the opening of the explanation string value. Both + // "explanation":"value" and "explanation": "value" are handled. + for _, prefix := range []string{`"explanation":"`, `"explanation": "`} { + if idx := strings.Index(full, prefix); idx >= 0 { + s.start = idx + len(prefix) + s.state = 1 + break + } + } + if s.state == 0 { + return + } + } + + // state == 1: emit newly-buffered explanation characters. + for pos := s.start + s.sent; pos < len(full); pos++ { + ch := full[pos] + + if ch == '"' { + // Count preceding backslashes to determine if this quote is escaped. + bs := 0 + for i := pos - 1; i >= s.start && full[i] == '\\'; i-- { + bs++ + } + if bs%2 == 0 { + // Unescaped closing quote — explanation value is complete. + s.state = 2 + return + } + } + + emit(string(ch)) + s.sent++ + } +} diff --git a/apps/backend/services/chat/internal/chat/send_message_stream.go b/apps/backend/services/chat/internal/chat/send_message_stream.go new file mode 100644 index 0000000..de696c1 --- /dev/null +++ b/apps/backend/services/chat/internal/chat/send_message_stream.go @@ -0,0 +1,222 @@ +package chat + +import ( + "context" + "log/slog" + "strings" + "time" + + chatllm "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/internal/chat/llm" + "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/internal/chat/sqlrunner" + chattype "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/chat" + chaterrors "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/errors" + llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm" + schematypes "github.com/Uncensored-Developer/datalk/apps/backend/services/schemas/pkg/schemas" + "github.com/mdobak/go-xerrors" +) + +// SendMessageStream runs the full send-message pipeline, emitting SSE-style +// events via onEvent as each stage completes. The explanation text from the LLM +// is streamed token-by-token when the underlying provider supports it. +func (s *Service) SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error { + userContent := strings.TrimSpace(params.Content) + if userContent == "" { + return xerrors.New("message content is required") + } + + // ── 1. Validate conversation access ────────────────────────────────────── + conversation, connection, err := s.getOwnedConversation(ctx, params.UserID, params.ConversationID) + if err != nil { + s.logSendMessageFailure("stream: failed to validate conversation access", err, params, nil) + return err + } + + if !isSupportedChatDatabase(connection.Database) { + err := xerrors.Newf("%s: %w", connection.Database, chaterrors.ErrUnsupportedDatabaseKind) + s.logSendMessageFailure("stream: unsupported chat database", err, params, connection) + return err + } + + // ── 2. Resolve LLM client ───────────────────────────────────────────────── + resolved, err := s.clientResolver.ResolveClient(ctx, params.Provider, params.Model) + if err != nil { + s.logSendMessageFailure("stream: failed to resolve llm client", err, params, connection) + return err + } + if resolved == nil || resolved.ProviderConfig == nil || resolved.Client == nil { + err := xerrors.Newf("model resolver returned incomplete client: %w", chaterrors.ErrModelNotAvailable) + s.logSendMessageFailure("stream: llm resolver returned incomplete client", err, params, connection) + return err + } + + // ── 3. Load history & schema context ───────────────────────────────────── + history, err := s.loadRecentHistory(ctx, conversation.ID, 0) + if err != nil { + s.logSendMessageFailure("stream: failed to load recent chat history", err, params, connection) + return err + } + + lastAssistantSQL, err := s.latestAssistantSQL(ctx, history) + if err != nil { + s.logSendMessageFailure("stream: failed to load previous assistant sql", err, params, connection) + return err + } + + retrievalQuery := buildRetrievalQuery(userContent, toConversationMessages(history), lastAssistantSQL) + schemaContext, err := s.schemaRetriever.RetrieveRelevantSchemaContext(ctx, schematypes.RetrieveRelevantSchemaContextParams{ + ConnectionID: conversation.ConnectionID, + QueryText: retrievalQuery, + Limit: defaultSchemaChunkLimit, + }) + if err != nil { + s.logSendMessageFailure("stream: failed to retrieve schema context", err, params, connection) + return err + } + if schemaContext == nil { + err := xerrors.Newf("schema retrieval returned no context: %w", chaterrors.ErrEmbeddedSnapshotNotReady) + s.logSendMessageFailure("stream: schema retrieval returned no context", err, params, connection) + return err + } + + retrieval := &chattype.MessageRetrieval{ + SnapshotID: schemaContext.SnapshotID, + QueryText: retrievalQuery, + Chunks: schemaContext.Chunks, + RetrievedAt: schemaContext.RetrievedAt, + } + + generateReq := buildGenerateSQLRequest(conversation, connection.Database, userContent, history, schemaContext, resolved.ProviderModelID) + if err := enforceGenerateSQLRequestLimits(&generateReq); err != nil { + s.logSendMessageFailure("stream: generated sql request exceeded limits", err, params, connection, + slog.Int("schema_chunks", len(generateReq.Schema.Chunks)), + slog.Int("max_prompt_bytes", generateReq.Options.MaxPromptBytes), + ) + return err + } + + // ── 4. Save user message immediately so the client can display it ───────── + userMessage := &chattype.Message{ + ConversationID: conversation.ID, + Role: chattype.MessageRoleUser, + Content: userContent, + Status: chattype.MessageStatusCompleted, + } + if err := s.storage.InsertMessage(ctx, userMessage); err != nil { + s.logSendMessageFailure("stream: failed to persist user message", err, params, connection) + return xerrors.Newf("failed to persist user message: %w", err) + } + if err := s.storage.InsertRetrieval(ctx, &chattype.MessageRetrieval{ + MessageID: userMessage.ID, + SnapshotID: retrieval.SnapshotID, + QueryText: retrieval.QueryText, + Chunks: retrieval.Chunks, + RetrievedAt: retrieval.RetrievedAt, + }); err != nil { + // Non-fatal — log but continue. + s.Logger().Warn("stream: failed to persist retrieval", slog.Any("err", err)) + } + + onEvent(chattype.StreamEvent{Type: chattype.StreamEventUserMessage, Message: userMessage}) + + // ── 5. Generate SQL — stream explanation tokens when provider supports it ─ + llmStarted := time.Now() + var generateResp *llmtypes.GenerateSQLResponse + + if sc, ok := resolved.Client.(chatllm.StreamingClient); ok { + generateResp, err = sc.GenerateSQLStream(ctx, generateReq, func(token string) { + onEvent(chattype.StreamEvent{Type: chattype.StreamEventToken, Content: token}) + }) + } else { + generateResp, err = resolved.Client.GenerateSQL(ctx, generateReq) + } + + llmLatencyMS := elapsedMilliseconds(llmStarted) + if err != nil { + s.logSendMessageFailure("stream: failed to generate sql", err, params, connection, slog.Int("llm_latency_ms", int(llmLatencyMS))) + return xerrors.Newf("failed to generate sql: %w", err) + } + if generateResp == nil { + err := xerrors.New("provider returned empty sql response") + s.logSendMessageFailure("stream: provider returned empty sql response", err, params, connection, slog.Int("llm_latency_ms", int(llmLatencyMS))) + return err + } + + // ── 6. Validate SQL ─────────────────────────────────────────────────────── + if err := sqlrunner.NewValidator().Validate(connection.Database, generateResp.SQL); err != nil { + s.logSendMessageFailure("stream: failed to validate generated sql", err, params, connection, slog.Int("llm_latency_ms", int(llmLatencyMS))) + return xerrors.Newf("failed to validate sql: %w", err) + } + + // ── 7. Execute SQL ──────────────────────────────────────────────────────── + execStarted := time.Now() + result, err := s.sqlRunner.Run(ctx, *connection, generateResp.SQL, sqlrunner.RunOptions{ + Timeout: defaultQueryTimeout, + RowLimit: defaultResultRowLimit, + }) + executionLatencyMS := elapsedMilliseconds(execStarted) + if err != nil { + s.logSendMessageFailure("stream: failed to execute generated sql", err, params, connection, slog.Int("execution_latency_ms", int(executionLatencyMS))) + return err + } + + // ── 8. Persist assistant message + execution ────────────────────────────── + explanation := strings.TrimSpace(generateResp.Explanation) + if explanation == "" { + explanation = "Query executed successfully." + } + + assistantMessage := &chattype.Message{ + ConversationID: conversation.ID, + Role: chattype.MessageRoleAssistant, + Content: explanation, + Provider: ¶ms.Provider, + Model: ¶ms.Model, + Status: chattype.MessageStatusCompleted, + } + + execution := &chattype.MessageExecution{ + ConnectionID: connection.ID, + DatabaseKind: connection.Database, + GeneratedSQL: generateResp.SQL, + NormalizedSQL: generateResp.SQL, + Result: *result, + ExecutionLatencyMS: executionLatencyMS, + ExecutedAt: time.Now().UTC(), + } + + if err := s.storage.InTransaction(ctx, func(txCtx context.Context) error { + if err := s.storage.InsertMessage(txCtx, assistantMessage); err != nil { + return xerrors.Newf("failed to persist assistant message: %w", err) + } + llmCall := buildLLMCall(assistantMessage.ID, resolved, generateReq, generateResp, llmLatencyMS) + if err := s.storage.InsertLLMCall(txCtx, llmCall); err != nil { + return xerrors.Newf("failed to persist llm call: %w", err) + } + execution.MessageID = assistantMessage.ID + if err := s.storage.InsertExecution(txCtx, execution); err != nil { + return xerrors.Newf("failed to persist execution: %w", err) + } + return nil + }); err != nil { + s.logSendMessageFailure("stream: failed to persist assistant turn", err, params, connection) + return err + } + + // ── 9. Emit final events ────────────────────────────────────────────────── + onEvent(chattype.StreamEvent{Type: chattype.StreamEventAssistantMessage, Message: assistantMessage}) + onEvent(chattype.StreamEvent{Type: chattype.StreamEventExecution, Execution: execution}) + onEvent(chattype.StreamEvent{Type: chattype.StreamEventComplete}) + + s.Logger().Info( + "stream chat message completed", + slog.Int64("conversation_id", conversation.ID), + slog.Int("user_id", int(params.UserID)), + slog.String("provider", string(params.Provider)), + slog.String("model", params.Model), + slog.Int("llm_latency_ms", int(llmLatencyMS)), + slog.Int("execution_latency_ms", int(executionLatencyMS)), + slog.Int("result_rows", int(result.RowCount)), + ) + + return nil +} diff --git a/apps/backend/services/chat/pkg/chat/stream.go b/apps/backend/services/chat/pkg/chat/stream.go new file mode 100644 index 0000000..32fba26 --- /dev/null +++ b/apps/backend/services/chat/pkg/chat/stream.go @@ -0,0 +1,22 @@ +package chat + +// StreamEventType identifies the kind of event sent over the SSE stream. +type StreamEventType string + +const ( + StreamEventUserMessage StreamEventType = "user_message" + StreamEventToken StreamEventType = "token" + StreamEventAssistantMessage StreamEventType = "assistant_message" + StreamEventExecution StreamEventType = "execution" + StreamEventComplete StreamEventType = "complete" + StreamEventError StreamEventType = "error" +) + +// StreamEvent is a single SSE payload, JSON-encoded as the SSE data field. +type StreamEvent struct { + Type StreamEventType `json:"type"` + Content string `json:"content,omitempty"` + Message *Message `json:"message,omitempty"` + Execution *MessageExecution `json:"execution,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/apps/web/src/api/client.ts b/apps/web/src/api/client.ts index dbf3658..3b8e8ef 100644 --- a/apps/web/src/api/client.ts +++ b/apps/web/src/api/client.ts @@ -51,6 +51,24 @@ export class ApiClient { return this.request(path, { ...options, method: "DELETE" }); } + /** + * Opens a streaming POST connection. Returns the raw Response so the caller + * can read the body as a ReadableStream (for SSE parsing). + */ + async stream(path: string, body?: unknown): Promise { + const response = await fetch(this.urlFor(path), { + method: "POST", + headers: this.headersFor({ body }), + body: serializeBody(body), + }); + + if (!response.ok) { + throw await parseApiError(response); + } + + return response; + } + async request(path: string, options: RequestOptions = {}): Promise { return this.fetchJson(path, options, true); } diff --git a/apps/web/src/components/layout/AppShell.tsx b/apps/web/src/components/layout/AppShell.tsx index fd97873..d18380c 100644 --- a/apps/web/src/components/layout/AppShell.tsx +++ b/apps/web/src/components/layout/AppShell.tsx @@ -242,7 +242,7 @@ export function AppShell({ title, children }: AppShellProps) { - + {children} diff --git a/apps/web/src/pages/chat/ChatPage.tsx b/apps/web/src/pages/chat/ChatPage.tsx index 9cb0163..66a5bfc 100644 --- a/apps/web/src/pages/chat/ChatPage.tsx +++ b/apps/web/src/pages/chat/ChatPage.tsx @@ -1,7 +1,6 @@ import CloseFullscreenOutlinedIcon from "@mui/icons-material/CloseFullscreenOutlined"; import CodeOutlinedIcon from "@mui/icons-material/CodeOutlined"; import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; -import KeyboardArrowRightOutlinedIcon from "@mui/icons-material/KeyboardArrowRightOutlined"; import OpenInFullOutlinedIcon from "@mui/icons-material/OpenInFullOutlined"; import PsychologyAltOutlinedIcon from "@mui/icons-material/PsychologyAltOutlined"; import SendOutlinedIcon from "@mui/icons-material/SendOutlined"; @@ -14,13 +13,10 @@ import Collapse from "@mui/material/Collapse"; import Dialog from "@mui/material/Dialog"; import DialogContent from "@mui/material/DialogContent"; import DialogTitle from "@mui/material/DialogTitle"; -import Divider from "@mui/material/Divider"; -import FormControl from "@mui/material/FormControl"; import IconButton from "@mui/material/IconButton"; -import InputLabel from "@mui/material/InputLabel"; +import Menu from "@mui/material/Menu"; import MenuItem from "@mui/material/MenuItem"; import Paper from "@mui/material/Paper"; -import Select from "@mui/material/Select"; import Stack from "@mui/material/Stack"; import Table from "@mui/material/Table"; import TableBody from "@mui/material/TableBody"; @@ -31,8 +27,8 @@ import TableRow from "@mui/material/TableRow"; import TextField from "@mui/material/TextField"; import Tooltip from "@mui/material/Tooltip"; import Typography from "@mui/material/Typography"; -import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import { useEffect, useMemo, useRef, useState } from "react"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Controller, useForm } from "react-hook-form"; import { useParams } from "react-router-dom"; import { useAuth } from "../../auth/AuthProvider"; @@ -153,48 +149,29 @@ function MessagePanel({ onRetryModels: () => void; }) { const queryClient = useQueryClient(); - const messagesEndRef = useRef(null); + const scrollContainerRef = useRef(null); const streamTimersRef = useRef([]); const [streamedNaturalResponses, setStreamedNaturalResponses] = useState>({}); const lastMessageID = messages.at(-1)?.message.id; + const [isAIResponding, setIsAIResponding] = useState(false); + const [optimisticContent, setOptimisticContent] = useState(null); + const [streamingText, setStreamingText] = useState(""); - useEffect(() => { - if (lastMessageID && typeof messagesEndRef.current?.scrollIntoView === "function") { - messagesEndRef.current.scrollIntoView({ behavior: "smooth", block: "end" }); - } - }, [lastMessageID]); - - useEffect(() => { - const timers = streamTimersRef.current; - return () => { - for (const timer of timers) { - window.clearInterval(timer); - } - }; + /** Always-instant, direct DOM scroll — most reliable across browsers. */ + const scrollToBottom = useCallback(() => { + const el = scrollContainerRef.current; + if (!el) return; + el.scrollTop = el.scrollHeight; }, []); - if (!conversation) { - return ( - - ); - } - - const streamNaturalResponse = (messageID: number, fullText: string) => { + const streamNaturalResponse = useCallback((messageID: number, fullText: string) => { const chunks = fullText.match(/\S+\s*/g) ?? [fullText]; let index = 0; setStreamedNaturalResponses((current) => ({ ...current, [messageID]: "" })); - const timer = window.setInterval(() => { index += 1; - const visibleText = chunks.slice(0, index).join(""); - setStreamedNaturalResponses((current) => ({ - ...current, - [messageID]: visibleText, - })); - + const visible = chunks.slice(0, index).join(""); + setStreamedNaturalResponses((current) => ({ ...current, [messageID]: visible })); if (index >= chunks.length) { window.clearInterval(timer); window.setTimeout(() => { @@ -207,34 +184,83 @@ function MessagePanel({ } }, 45); streamTimersRef.current.push(timer); - }; + }, []); - const handleSendSuccess = (response: SendMessageResponse) => { - const naturalResponse = response.assistant_message.natural_response?.trim(); - if (!naturalResponse) { - return false; + // When real messages arrive from the server: + // 1. Clear the optimistic / streaming state (triggers a re-render) + // 2. Defer the scroll with setTimeout(0) so it runs after React has + // committed the updated DOM with the new real messages. + useEffect(() => { + if (!lastMessageID) return; + setOptimisticContent(null); + setStreamingText(""); + const id = setTimeout(scrollToBottom, 0); + + // Kick off typewriter for any new assistant message that has a natural response + const lastMsg = messages.at(-1); + if (lastMsg?.message.role === "assistant" && lastMsg.message.natural_response) { + streamNaturalResponse(lastMsg.message.id, lastMsg.message.natural_response); } - queryClient.setQueryData( - ["chat-messages", conversation.id], - (current = []) => { - const nextItems: MessageListItem[] = [ - { message: response.user_message, retrieval: response.retrieval }, - { - message: response.assistant_message, - execution: response.execution, - }, - ]; - const nextIDs = new Set(nextItems.map((item) => item.message.id)); - return [ - ...current.filter((item) => !nextIDs.has(item.message.id)), - ...nextItems, - ]; + return () => clearTimeout(id); + }, [lastMessageID, scrollToBottom, streamNaturalResponse]); // eslint-disable-line react-hooks/exhaustive-deps + + // Scroll instantly the moment the user's bubble appears. + useEffect(() => { + if (optimisticContent) scrollToBottom(); + }, [optimisticContent, scrollToBottom]); + + // Follow the streaming cursor on every token — direct scrollTop set is + // cheap enough to run on every render without throttling. + useEffect(() => { + if (streamingText) scrollToBottom(); + }, [streamingText, scrollToBottom]); + + // Cleanup timers on unmount + useEffect(() => { + const timers = streamTimersRef.current; + return () => { for (const t of timers) window.clearInterval(t); }; + }, []); + + const handleOptimisticMessage = useCallback((content: string) => { + setOptimisticContent(content); + setStreamingText(""); + }, []); + + const handleStreamToken = useCallback((token: string) => { + setStreamingText((prev) => prev + token); + }, []); + + const handleStreamComplete = useCallback(() => { + // streaming text stays until real messages reload via lastMessageID effect + }, []); + + // Merge real messages with the optimistic user message + const allMessages: MessageListItem[] = useMemo(() => { + if (!optimisticContent || !conversation) return messages; + return [ + ...messages, + { + message: { + id: -1, + conversation_id: conversation.id, + role: "user" as const, + content: optimisticContent, + status: "pending" as const, + created_at: new Date().toISOString(), + }, }, + ]; + }, [messages, optimisticContent, conversation]); + + if (!conversation) { + return ( + ); - streamNaturalResponse(response.assistant_message.id, naturalResponse); - return true; - }; + } return ( - - {conversation.title} - - Connection {conversation.connection_id} - - - + {/* Models error banner */} {modelsError ? ( } - sx={{ mb: 2 }} + sx={{ mb: 1.5 }} > {modelsError} ) : null} + {/* Messages scroll area */} {isLoading ? : null} + {messagesError ? ( ) : null} + + {/* Conversational welcome — no card, no border */} {!isLoading && !messagesError && messages.length === 0 ? ( - + + + {conversation.title} + + + Ask anything about your data. I'll write the SQL and show you the results. + + ) : null} - - {messages.map((item) => ( - - ))} - - + + {/* Message list (real + optimistic) */} + {allMessages.length > 0 ? ( + + {allMessages.map((item) => ( + + ))} + + ) : null} + + {/* Streaming AI bubble — shows tokens as they arrive */} + {isAIResponding && streamingText ? ( + + + {streamingText} + {/* Blinking cursor */} + + + + ) : null} + + {/* Typing indicator dots — shown while waiting for first token */} + {isAIResponding && !streamingText ? ( + + + {[0, 1, 2].map((i) => ( + + ))} + + + ) : null} + - + {/* Compose bar — no hard border separator */} + ); } +function formatTimestamp(iso: string): string { + const date = new Date(iso); + const now = new Date(); + const isToday = + date.getFullYear() === now.getFullYear() && + date.getMonth() === now.getMonth() && + date.getDate() === now.getDate(); + + const time = date.toLocaleTimeString([], { hour: "numeric", minute: "2-digit" }); + if (isToday) return time; + + const yesterday = new Date(now); + yesterday.setDate(now.getDate() - 1); + const isYesterday = + date.getFullYear() === yesterday.getFullYear() && + date.getMonth() === yesterday.getMonth() && + date.getDate() === yesterday.getDate(); + + if (isYesterday) return `Yesterday ${time}`; + return `${date.toLocaleDateString([], { month: "short", day: "numeric" })} ${time}`; +} + function MessageItem({ item, streamedNaturalResponse, @@ -341,111 +485,66 @@ function MessageItem({ item: MessageListItem; streamedNaturalResponse?: string; }) { - const [detailsOpen, setDetailsOpen] = useState(false); const isAssistant = item.message.role === "assistant"; - const hasModelInfo = Boolean(item.message.provider || item.message.model); - const hasNaturalResponse = isAssistant && Boolean(item.message.natural_response); - const hasHiddenDetails = hasNaturalResponse && Boolean(item.message.content || item.execution); - const messageText = hasNaturalResponse - ? streamedNaturalResponse ?? item.message.natural_response + const timestamp = item.message.created_at ? formatTimestamp(item.message.created_at) : null; + + // Prefer natural_response (with optional typewriter effect) over raw content + const displayText = isAssistant && item.message.natural_response + ? (streamedNaturalResponse ?? item.message.natural_response) : item.message.content; return ( - - - theme.transitions.create("opacity", { - duration: theme.transitions.duration.shortest, - }), - }, - "&:hover .assistant-message-controls, &:focus-within .assistant-message-controls": { - opacity: 1, - }, - }} + {/* Bubble */} + + theme.palette.mode === "dark" ? "#374151" : "#dde4f0", + color: (theme) => + theme.palette.mode === "dark" ? "#f9fafb" : "#111827", + borderRadius: "12px 12px 3px 12px", + px: 2, + py: 1.25, + } + } > - - {isAssistant && (hasModelInfo || hasHiddenDetails) ? ( - - {hasModelInfo ? ( - - - - - - - - ) : null} - {hasHiddenDetails ? ( - - setDetailsOpen((open) => !open)} - size="small" - > - - theme.transitions.create("transform", { - duration: theme.transitions.duration.shortest, - }), - }} - /> - - - ) : null} - - ) : null} - {messageText} - {item.message.error_message ? ( - {item.message.error_message} - ) : null} - {hasNaturalResponse ? ( - - - - {item.message.content ? ( - - {item.message.content} - - ) : null} - {item.execution ? : null} - - - ) : item.execution ? ( - - ) : null} - - - + + {displayText} + + {item.message.error_message ? ( + {item.message.error_message} + ) : null} + {item.execution ? : null} + + + {/* Permanent timestamp */} + {timestamp ? ( + + {timestamp} + + ) : null} + ); } @@ -462,67 +561,72 @@ function ExecutionPanel({ execution }: { execution: MessageExecution }) { ].join(" | "); return ( - + - - - - - - - - - - setSqlOpen((open) => !open)} - size="small" - > - - - - - setFullscreenOpen(true)} - > - - - - - - - {execution.result.truncated ? ( - - ) : null} + {/* Toolbar */} + + + + + + + + setSqlOpen((open) => !open)} + size="small" + > + + + + + setFullscreenOpen(true)} + > + + + + + {execution.result.truncated ? ( + + ) : null} + {execution.generated_sql} + {isScalarResult ? ( ) : ( )} + setFullscreenOpen(false)}> @@ -555,12 +659,18 @@ function ScalarResult({ execution }: { execution: MessageExecution }) { py: 1.75, bgcolor: "action.hover", borderStyle: "dashed", + borderRadius: 2, }} > {column.name} - + {formatCellValue(value)} @@ -577,18 +687,20 @@ function ResultTable({ execution }: { execution: MessageExecution }) { } return ( - - + +
{execution.result.columns.map((column) => ( - {column.name} + + {column.name} + ))} {execution.result.rows.map((row, index) => ( - + {execution.result.columns.map((column) => ( {formatCellValue(row[column.name])} @@ -602,52 +714,101 @@ function ResultTable({ execution }: { execution: MessageExecution }) { ); } +// ── SSE event types (mirrors backend pkg/chat/stream.go) ───────────────────── +type SSEEventType = + | "user_message" + | "token" + | "assistant_message" + | "execution" + | "complete" + | "error"; + +type SSEEvent = { + type: SSEEventType; + content?: string; + message?: MessageListItem["message"]; + execution?: MessageExecution; + error?: string; +}; + +async function* readSSE(response: Response): AsyncGenerator { + const reader = response.body!.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + try { + yield JSON.parse(line.slice(6)) as SSEEvent; + } catch { + // skip malformed lines + } + } + } + } +} + function SendMessageForm({ conversationID, models, - onSendSuccess, + onPendingChange, + onOptimisticMessage, + onStreamToken, + onStreamComplete, }: { conversationID: number; models: ChatModel[]; - onSendSuccess: (response: SendMessageResponse) => boolean; + onPendingChange: (pending: boolean) => void; + onOptimisticMessage: (content: string) => void; + onStreamToken: (token: string) => void; + onStreamComplete: () => void; }) { const { apiClient } = useAuth(); const queryClient = useQueryClient(); + const [modelMenuAnchor, setModelMenuAnchor] = useState(null); + const [isPending, setIsPending] = useState(false); + const [sendError, setSendError] = useState(null); + const [shaking, setShaking] = useState(false); + const [requireNaturalResponse, setRequireNaturalResponse] = useState(() => { + if (typeof window === "undefined") return true; + const stored = window.localStorage.getItem(requireNaturalResponseKey); + return stored !== "false"; + }); + + const shake = useCallback(() => { + setShaking(true); + setTimeout(() => setShaking(false), 400); + }, []); + const defaultModel = useMemo(() => { const storedModel = typeof window === "undefined" ? null : window.localStorage.getItem(lastChatModelKey); - if (storedModel && models.some((model) => model.id === storedModel)) { return storedModel; } - return models[0]?.id ?? ""; }, [models]); - const [requireNaturalResponse, setRequireNaturalResponse] = useState(() => { - if (typeof window === "undefined") { - return true; - } - const stored = window.localStorage.getItem(requireNaturalResponseKey); - if (stored === "false") { - return false; - } - if (stored === "true") { - return true; - } - return true; - }); + const { control, formState: { errors }, handleSubmit, register, reset, - setError, } = useForm({ values: { content: "", model: defaultModel }, }); + const contentField = register("content", { validate: (value) => value.trim() ? true : "Message is required", }); @@ -657,150 +818,240 @@ function SendMessageForm({ [models], ); - const mutation = useMutation({ - mutationFn: (values: SendForm) => { - const selectedModel = selectedModelByID.get(values.model); - return apiClient.post( - `/chat/conversations/${conversationID}/messages`, + useEffect(() => { + onPendingChange(isPending); + }, [isPending, onPendingChange]); + + const doStream = useCallback(async (values: SendForm) => { + const content = values.content.trim(); + if (!content || isPending) return; + + const selectedModel = selectedModelByID.get(values.model); + setSendError(null); + setIsPending(true); + onOptimisticMessage(content); + reset({ content: "", model: values.model }); + window.localStorage.setItem(lastChatModelKey, values.model); + + try { + const response = await apiClient.stream( + `/chat/conversations/${conversationID}/messages/stream`, { - content: values.content.trim(), + content, provider: selectedModel?.provider as Provider, model: values.model, require_natural_response: requireNaturalResponse, }, ); - }, - onSuccess(response, values) { - window.localStorage.setItem(lastChatModelKey, values.model); - window.localStorage.setItem(requireNaturalResponseKey, String(requireNaturalResponse)); - reset({ content: "", model: values.model }); - const responseHandled = onSendSuccess(response); - if (!responseHandled) { - void queryClient.invalidateQueries({ queryKey: ["chat-messages", conversationID] }); + + for await (const event of readSSE(response)) { + switch (event.type) { + case "token": + onStreamToken(event.content ?? ""); + break; + case "complete": + window.localStorage.setItem(requireNaturalResponseKey, String(requireNaturalResponse)); + void queryClient.invalidateQueries({ queryKey: ["chat-messages", conversationID] }); + void queryClient.invalidateQueries({ queryKey: ["chat-conversations"] }); + void queryClient.invalidateQueries({ queryKey: ["chat-conversation", conversationID] }); + onStreamComplete(); + break; + case "error": + setSendError(event.error ?? "Something went wrong"); + shake(); + onStreamComplete(); + break; + } } - void queryClient.invalidateQueries({ queryKey: ["chat-conversations"] }); - void queryClient.invalidateQueries({ queryKey: ["chat-conversation", conversationID] }); - }, - onError(error) { - setError("content", { message: errorMessage(error) }); - }, - }); + } catch (err) { + setSendError(errorMessage(err)); + shake(); + onStreamComplete(); + } finally { + setIsPending(false); + } + }, [isPending, selectedModelByID, apiClient, conversationID, queryClient, onOptimisticMessage, onStreamToken, onStreamComplete, reset]); return ( - mutation.mutate(values))} - variant="outlined" - sx={{ - p: 1, - borderRadius: 3, - bgcolor: "background.paper", - boxShadow: (theme) => theme.shadows[1], - }} - > - + {/* Input row: text field + send button only */} + { void doStream(values); }, + () => shake(), // validation failed + )} + elevation={2} + sx={{ + borderRadius: 1.5, + bgcolor: "background.paper", + border: "1px solid", + borderColor: (errors.content || sendError) ? "error.main" : "divider", + overflow: "hidden", + display: "flex", + alignItems: "flex-end", + gap: 0, + transition: "border-color 0.15s, box-shadow 0.15s", + "&:focus-within": { + borderColor: (errors.content || sendError) ? "error.main" : "primary.main", + boxShadow: (theme) => + `0 0 0 2px ${(errors.content || sendError) ? theme.palette.error.main : theme.palette.primary.main}22`, }, + ...(shaking + ? { + animation: "shake 0.35s cubic-bezier(.36,.07,.19,.97) both", + "@keyframes shake": { + "0%, 100%": { transform: "translateX(0)" }, + "15%": { transform: "translateX(-6px)" }, + "30%": { transform: "translateX(5px)" }, + "45%": { transform: "translateX(-4px)" }, + "60%": { transform: "translateX(3px)" }, + "75%": { transform: "translateX(-2px)" }, + "90%": { transform: "translateX(1px)" }, + }, + } + : {}), }} - {...contentField} - onKeyDown={(event) => { - if (event.key === "Enter" && !event.shiftKey && !mutation.isPending) { - event.preventDefault(); - void handleSubmit((values) => mutation.mutate(values))(); - } - }} - /> - - ( - - Model - - - )} + > + { + if (event.key === "Enter" && !event.shiftKey && !isPending) { + event.preventDefault(); + void handleSubmit((values) => { void doStream(values); }, () => shake())(); + } + }} /> + + + + + + {isPending ? ( + + ) : ( + + )} + + + + + + + {/* Below-input row: keyboard hint (left) + model selector (right) */} + + + Enter to send · Shift+Enter for new line + + + + + {/* Natural response toggle */} { - const nextValue = !requireNaturalResponse; - setRequireNaturalResponse(nextValue); - window.localStorage.setItem(requireNaturalResponseKey, String(nextValue)); + const next = !requireNaturalResponse; + setRequireNaturalResponse(next); + window.localStorage.setItem(requireNaturalResponseKey, String(next)); }} - size="small" > - - - - - {mutation.isPending ? ( - - ) : ( - - )} - - - + + {/* Model selector — bottom right, outside the input */} + { + const selected = selectedModelByID.get(field.value); + return ( + <> + setModelMenuAnchor(e.currentTarget)} + disabled={models.length === 0} + sx={{ + borderRadius: 999, + fontSize: "0.72rem", + cursor: "pointer", + maxWidth: 200, + height: 24, + }} + /> + setModelMenuAnchor(null)} + anchorOrigin={{ vertical: "top", horizontal: "right" }} + transformOrigin={{ vertical: "bottom", horizontal: "right" }} + slotProps={{ + paper: { sx: { minWidth: 220, borderRadius: 2, mb: 0.5 } }, + }} + > + {models.map((model) => ( + { + field.onChange(model.id); + setModelMenuAnchor(null); + }} + sx={{ borderRadius: 1, mx: 0.5, my: 0.25 }} + > + + + {model.display_name} + + {model.description ? ( + + {model.description} + + ) : null} + + + ))} + + + ); + }} + /> - {errors.content?.message || errors.model?.message ? ( - - {errors.content?.message ?? errors.model?.message} - - ) : null} - + ); }