Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions apps/backend/servers/echo/handlers/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,27 @@ func Error(c echo.Context, logger *slog.Logger, err error) error {
return echo.NewHTTPError(status, ErrorResponse{Error: message})
}

// StreamErrorMessage maps an error to a user-visible string for SSE error events,
// using the same sentinel-based rules as Error() but without writing an HTTP response.
func StreamErrorMessage(err error) string {
if err == nil {
return ""
}
switch {
case errors.Is(err, chaterrors.ErrProviderNotAvailable),
errors.Is(err, chaterrors.ErrInvalidProviderConfig),
errors.Is(err, chaterrors.ErrModelNotAvailable),
errors.Is(err, chaterrors.ErrEmbeddedSnapshotNotReady),
errors.Is(err, chaterrors.ErrInvalidSQL),
errors.Is(err, chaterrors.ErrPromptTooLarge),
errors.Is(err, chaterrors.ErrUnsupportedDatabaseKind),
errors.Is(err, chaterrors.ErrMessageExecutionFailed):
return err.Error()
default:
return "internal server error"
}
}

func logInternalError(c echo.Context, logger *slog.Logger, err error) {
if logger == nil {
logger = slog.Default()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ func (h *Handler) Register(group *echo.Group) {
group.DELETE("/conversations/:"+conversationIDParam, h.DeleteConversation)
group.GET("/conversations/:"+conversationIDParam+"/messages", h.ListMessages)
group.POST("/conversations/:"+conversationIDParam+"/messages", h.SendMessage)
group.POST("/conversations/:"+conversationIDParam+"/messages/stream", h.SendMessageStream)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package conversations

import (
"encoding/json"
"fmt"
"net/http"
"strings"

echohandlers "github.com/Uncensored-Developer/datalk/apps/backend/servers/echo/handlers"
chatapi "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/api"
chattype "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/chat"
"github.com/labstack/echo/v4"
)

// SendMessageStream handles POST /conversations/:id/messages/stream.
// It runs the full send-message pipeline and writes SSE events to the response
// as each stage completes, including token-by-token explanation streaming.
func (h *Handler) SendMessageStream(c echo.Context) error {
userID, err := echohandlers.UserID(c)
if err != nil {
return err
}

conversationID, err := echohandlers.Int64Param(c, conversationIDParam)
if err != nil {
return err
}

var req sendMessageRequest
if err := c.Bind(&req); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, echohandlers.ErrorResponse{Error: "invalid request body"})
}
if strings.TrimSpace(req.Content) == "" {
return echo.NewHTTPError(http.StatusBadRequest, echohandlers.ErrorResponse{Error: "message content is required"})
}

// Set SSE response headers before writing anything.
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Accel-Buffering", "no") // disable nginx buffering
c.Response().WriteHeader(http.StatusOK)

flusher, ok := c.Response().Writer.(http.Flusher)
if !ok {
// Streaming not supported by the underlying writer — write a single error event.
_, _ = fmt.Fprintf(c.Response(), "data: {\"type\":\"error\",\"error\":\"streaming not supported\"}\n\n")
return nil
}

writeEvent := func(event chatapi.StreamEvent) {
data, err := json.Marshal(event)
if err != nil {
return
}
_, _ = fmt.Fprintf(c.Response(), "data: %s\n\n", data)
flusher.Flush()
}

serviceErr := h.service.SendMessageStream(
c.Request().Context(),
chattype.SendMessageParams{
UserID: userID,
ConversationID: conversationID,
Content: req.Content,
Provider: req.Provider,
Model: req.Model,
},
writeEvent,
)

if serviceErr != nil {
// Map to a user-visible error message (same logic as the non-streaming handler).
msg := echohandlers.StreamErrorMessage(serviceErr)
writeEvent(chatapi.StreamEvent{Type: chatapi.StreamEventError, Error: msg})
}

return nil
}
5 changes: 5 additions & 0 deletions apps/backend/services/chat/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Service interface {
DeleteConversation(ctx context.Context, userID int32, conversationID int64) error
ListMessages(ctx context.Context, userID int32, filter chattype.ListMessagesFilter) ([]*chattype.MessageDetails, error)
SendMessage(ctx context.Context, params chattype.SendMessageParams) (*chattype.AssistantTurn, error)
SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error
ListProviderConfigs(ctx context.Context) ([]*llmtypes.ProviderConfig, error)
SaveProviderConfig(ctx context.Context, params chat.SaveProviderConfigParams) (*llmtypes.ProviderConfig, error)
}
Expand Down Expand Up @@ -66,6 +67,10 @@ func (a *Api) SendMessage(ctx context.Context, params SendMessageParams) (*chatt
return a.service.SendMessage(ctx, params)
}

func (a *Api) SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error {
return a.service.SendMessageStream(ctx, params, onEvent)
}

func (a *Api) ListProviderConfigs(ctx context.Context) ([]*llmtypes.ProviderConfig, error) {
return a.service.ListProviderConfigs(ctx)
}
Expand Down
1 change: 1 addition & 0 deletions apps/backend/services/chat/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Client interface {
DeleteConversation(ctx context.Context, userID int32, conversationID int64) error
ListMessages(ctx context.Context, userID int32, filter ListMessagesFilter) ([]*chattype.MessageDetails, error)
SendMessage(ctx context.Context, params SendMessageParams) (*chattype.AssistantTurn, error)
SendMessageStream(ctx context.Context, params chattype.SendMessageParams, onEvent func(chattype.StreamEvent)) error
ListProviderConfigs(ctx context.Context) ([]*llmtypes.ProviderConfig, error)
SaveProviderConfig(ctx context.Context, params SaveProviderConfigParams) (*llmtypes.ProviderConfig, error)
ListAvailableModels(ctx context.Context) ([]llmtypes.Model, error)
Expand Down
16 changes: 16 additions & 0 deletions apps/backend/services/chat/api/stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package api

import chattype "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/chat"

// Re-export stream event types from pkg/chat so callers can use chatapi.StreamEvent etc.
type StreamEventType = chattype.StreamEventType
type StreamEvent = chattype.StreamEvent

const (
StreamEventUserMessage = chattype.StreamEventUserMessage
StreamEventToken = chattype.StreamEventToken
StreamEventAssistantMessage = chattype.StreamEventAssistantMessage
StreamEventExecution = chattype.StreamEventExecution
StreamEventComplete = chattype.StreamEventComplete
StreamEventError = chattype.StreamEventError
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"

chaterrors "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/errors"
llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm"
schematypes "github.com/Uncensored-Developer/datalk/apps/backend/services/schemas/pkg/schemas"
"github.com/mdobak/go-xerrors"
Expand Down Expand Up @@ -114,7 +115,11 @@ func ParseGenerateSQLResponse(rawRequest, rawResponse []byte, payloadText string
payload.SQL = strings.TrimSpace(payload.SQL)
payload.Explanation = strings.TrimSpace(payload.Explanation)
if payload.SQL == "" {
return nil, xerrors.New("structured SQL payload did not include sql")
explanation := strings.TrimSpace(payload.Explanation)
if explanation != "" {
return nil, xerrors.Newf("%s: %w", explanation, chaterrors.ErrInvalidSQL)
}
return nil, xerrors.Newf("the model could not generate SQL for this question: %w", chaterrors.ErrInvalidSQL)
}

return &llmtypes.GenerateSQLResponse{
Expand Down
133 changes: 133 additions & 0 deletions apps/backend/services/chat/internal/chat/llm/ollama/streaming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package ollama

import (
"bufio"
"bytes"
"encoding/json"
"io"
"net/http"
"strings"

llmcore "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/internal/chat/llm"
llmtypes "github.com/Uncensored-Developer/datalk/apps/backend/services/chat/pkg/llm"
"github.com/mdobak/go-xerrors"
"golang.org/x/net/context"
)

// streamingChatRequest mirrors chatRequest but always sets Stream: true.
type streamingChatRequest struct {
Model string `json:"model"`
Messages []message `json:"messages"`
Format map[string]any `json:"format"`
Stream bool `json:"stream"`
}

// streamingChunk is one NDJSON line from the Ollama streaming response.
type streamingChunk struct {
Message message `json:"message"`
Done bool `json:"done"`
DoneReason string `json:"done_reason,omitempty"`
// Usage fields (only present when done=true)
PromptEvalCount *int `json:"prompt_eval_count,omitempty"`
EvalCount *int `json:"eval_count,omitempty"`
}

// GenerateSQLStream calls the Ollama chat API with streaming enabled.
// As JSON tokens arrive they are fed into an explanationStreamer; each
// explanation character is forwarded to onToken. When the stream ends the
// accumulated JSON is parsed and the full GenerateSQLResponse is returned.
func (c *Client) GenerateSQLStream(ctx context.Context, req llmtypes.GenerateSQLRequest, onToken func(string)) (*llmtypes.GenerateSQLResponse, error) {
if strings.TrimSpace(req.Model) == "" {
return nil, xerrors.New("model is required")
}

requestBody := streamingChatRequest{
Model: req.Model,
Messages: ollamaMessages(req),
Format: llmcore.GenerateSQLSchema(),
Stream: true,
}

rawRequest, err := json.Marshal(requestBody)
if err != nil {
return nil, xerrors.Newf("failed to marshal ollama streaming request: %w", err)
}

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/chat", bytes.NewReader(rawRequest))
if err != nil {
return nil, xerrors.Newf("failed to create ollama streaming request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")

resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, xerrors.Newf("failed to call ollama streaming api: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
if err := decodeHTTPError("ollama streaming api", resp.StatusCode, body); err != nil {
return nil, err
}
return nil, xerrors.Newf("ollama streaming api failed with status %d", resp.StatusCode)
}

streamer := llmcore.NewExplanationStreamer()

var (
fullText strings.Builder
promptEvalCount *int
evalCount *int
finishReason *string
)

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}

var chunk streamingChunk
if err := json.Unmarshal(line, &chunk); err != nil {
continue // skip malformed lines
}

token := chunk.Message.Content
if token != "" {
fullText.WriteString(token)
streamer.Feed(token, onToken)
}

if chunk.Done {
promptEvalCount = chunk.PromptEvalCount
evalCount = chunk.EvalCount
if strings.TrimSpace(chunk.DoneReason) != "" && chunk.DoneReason != "stop" {
reason := chunk.DoneReason
finishReason = &reason
}
break
}
}

if err := scanner.Err(); err != nil {
return nil, xerrors.Newf("error reading ollama stream: %w", err)
}

responseText := strings.TrimSpace(fullText.String())
if responseText == "" {
return nil, xerrors.New("ollama streaming response was empty")
}

usage := &llmtypes.Usage{
InputTokens: promptEvalCount,
OutputTokens: evalCount,
}
if promptEvalCount != nil && evalCount != nil {
total := *promptEvalCount + *evalCount
usage.TotalTokens = &total
}

return llmcore.ParseGenerateSQLResponse(rawRequest, []byte(responseText), responseText, usage, finishReason)
}
11 changes: 9 additions & 2 deletions apps/backend/services/chat/internal/chat/llm/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,18 @@ func openAIInputMessages(req llmtypes.GenerateSQLRequest) []inputMessage {
})

for _, message := range promptMessages {
role := normalizeOpenAIRole(message.Role)
// OpenAI Responses API requires "output_text" for assistant turns
// and "input_text" for user/developer turns.
contentType := "input_text"
if role == "assistant" {
contentType = "output_text"
}
messages = append(messages, inputMessage{
Role: normalizeOpenAIRole(message.Role),
Role: role,
Content: []contentPart{
{
Type: "input_text",
Type: contentType,
Text: message.Content,
},
},
Expand Down
Loading