diff --git a/cmd/stack/mcp/root.go b/cmd/stack/mcp/root.go new file mode 100644 index 00000000..d4ea5fd2 --- /dev/null +++ b/cmd/stack/mcp/root.go @@ -0,0 +1,423 @@ +package mcp + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "os" + "strconv" + "strings" + + "github.com/spf13/cobra" + "golang.org/x/oauth2" + + fctl "github.com/formancehq/fctl/v3/pkg" +) + +const ( + transportFlag = "transport" + maxMCPMessageSize = 10 * 1024 * 1024 +) + +func NewCommand() *cobra.Command { + return fctl.NewStackCommand("mcp", + fctl.WithShortDescription("Run stack MCP integrations"), + fctl.WithChildCommands(NewServeCommand()), + ) +} + +func NewServeCommand() *cobra.Command { + return fctl.NewStackCommand("serve", + fctl.WithShortDescription("Start a stack MCP server"), + fctl.WithStringFlag(transportFlag, "stdio", "MCP transport to use (stdio)"), + fctl.WithArgs(cobra.NoArgs), + fctl.WithRunE(runServe), + ) +} + +func runServe(cmd *cobra.Command, _ []string) error { + transport := fctl.GetString(cmd, transportFlag) + if transport != "stdio" { + return fmt.Errorf("unsupported MCP transport %q: only stdio is currently supported", transport) + } + + _, profile, profileName, relyingParty, err := fctl.LoadAndAuthenticateCurrentProfile(cmd) + if err != nil { + return err + } + + organizationID, stackID, err := fctl.ResolveStackID(cmd, *profile) + if err != nil { + return err + } + + stackToken, stackAccess, err := fctl.EnsureStackAccess(cmd, relyingParty, stderrDialog{w: cmd.ErrOrStderr()}, profileName, *profile, organizationID, stackID) + if err != nil { + return err + } + + tokenSource := fctl.NewStackTokenSource( + *stackToken, + stackAccess, + relyingParty, + func(newToken fctl.AccessToken) error { + return fctl.WriteStackToken(cmd, profileName, stackID, newToken) + }, + cmd, + profileName, + organizationID, + stackID, + ) + httpClient := oauth2.NewClient(cmd.Context(), tokenSource) + + server := &stdioServer{ + in: os.Stdin, + out: os.Stdout, + err: cmd.ErrOrStderr(), + httpClient: httpClient, + stackURI: stackAccess.URI, + } + return server.Serve(cmd.Context()) +} + +type stderrDialog struct { + w io.Writer +} + +func (d stderrDialog) Info(msg string, args ...any) { + _, _ = fmt.Fprintf(d.w, msg+"\n", args...) +} + +type stdioServer struct { + in io.Reader + out io.Writer + err io.Writer + httpClient *http.Client + stackURI string + remote *remoteMCPClient +} + +type rpcMessage struct { + JSONRPC string `json:"jsonrpc,omitempty"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` +} + +type rpcResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result any `json:"result,omitempty"` + Error *rpcError `json:"error,omitempty"` +} + +type rpcError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (s *stdioServer) Serve(ctx context.Context) error { + s.remote = newRemoteMCPClient(s.httpClient, s.stackURI) + reader := bufio.NewReader(s.in) + + type readResult struct { + data []byte + err error + } + reads := make(chan readResult, 1) + go func() { + for { + data, err := readMCPMessage(reader) + select { + case reads <- readResult{data: data, err: err}: + case <-ctx.Done(): + return + } + if err != nil { + return + } + } + }() + + for { + var read readResult + select { + case <-ctx.Done(): + return ctx.Err() + case read = <-reads: + } + + data, err := read.data, read.err + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + if len(bytes.TrimSpace(data)) == 0 { + continue + } + + var msg rpcMessage + if err := json.Unmarshal(data, &msg); err != nil { + _ = s.writeResponse(rpcResponse{ + JSONRPC: "2.0", + ID: json.RawMessage("null"), + Error: &rpcError{Code: -32700, Message: "parse error"}, + }) + continue + } + + if len(msg.ID) == 0 { + s.handleNotification(msg) + continue + } + + result, rpcErr := s.handleRequest(ctx, msg) + resp := rpcResponse{ + JSONRPC: "2.0", + ID: msg.ID, + Result: result, + Error: rpcErr, + } + if err := s.writeResponse(resp); err != nil { + return err + } + } +} + +func (s *stdioServer) handleNotification(msg rpcMessage) { + switch msg.Method { + case "notifications/cancelled": + _, _ = fmt.Fprintf(s.err, "MCP request cancelled\n") + if err := s.remote.Notify(context.Background(), msg); err != nil { + _, _ = fmt.Fprintf(s.err, "forwarding MCP notification %q failed: %v\n", msg.Method, err) + } + case "notifications/initialized": + if err := s.remote.Notify(context.Background(), msg); err != nil { + _, _ = fmt.Fprintf(s.err, "forwarding MCP notification %q failed: %v\n", msg.Method, err) + } + } +} + +func (s *stdioServer) handleRequest(ctx context.Context, msg rpcMessage) (any, *rpcError) { + if msg.Method == "ping" { + return map[string]any{}, nil + } + result, err := s.remote.Request(ctx, msg) + if err != nil { + return nil, &rpcError{Code: -32000, Message: err.Error()} + } + return result, nil +} + +type remoteMCPClient struct { + httpClient *http.Client + endpoint string + sessionID string + protocolVersion string +} + +func newRemoteMCPClient(httpClient *http.Client, stackURI string) *remoteMCPClient { + if httpClient == nil { + httpClient = http.DefaultClient + } + base, err := url.Parse(stackURI) + if err != nil { + return &remoteMCPClient{httpClient: httpClient, endpoint: stackURI, protocolVersion: "2024-11-05"} + } + endpoint := base.ResolveReference(&url.URL{Path: "/api/mcp"}).String() + return &remoteMCPClient{httpClient: httpClient, endpoint: endpoint, protocolVersion: "2024-11-05"} +} + +func (c *remoteMCPClient) Request(ctx context.Context, msg rpcMessage) (any, error) { + resp, err := c.send(ctx, msg) + if err != nil { + return nil, err + } + if resp.Error != nil { + return nil, fmt.Errorf("remote MCP error %d: %s", resp.Error.Code, resp.Error.Message) + } + return resp.Result, nil +} + +func (c *remoteMCPClient) Notify(ctx context.Context, msg rpcMessage) error { + _, err := c.send(ctx, msg) + return err +} + +func (c *remoteMCPClient) send(ctx context.Context, msg rpcMessage) (*rpcResponse, error) { + if msg.Method == "initialize" { + c.captureProtocolVersion(msg.Params) + } + data, err := json.Marshal(msg) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + req.Header.Set("MCP-Protocol-Version", c.protocolVersion) + if c.sessionID != "" { + req.Header.Set("Mcp-Session-Id", c.sessionID) + } + + httpResp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + + if sessionID := httpResp.Header.Get("Mcp-Session-Id"); sessionID != "" { + c.sessionID = sessionID + } + + if httpResp.StatusCode == http.StatusAccepted && len(msg.ID) == 0 { + return &rpcResponse{JSONRPC: "2.0"}, nil + } + + payload, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, err + } + if httpResp.StatusCode >= 300 { + return nil, fmt.Errorf("remote MCP HTTP %d: %s", httpResp.StatusCode, strings.TrimSpace(string(payload))) + } + if len(bytes.TrimSpace(payload)) == 0 { + return &rpcResponse{JSONRPC: "2.0"}, nil + } + + return decodeRemoteMCPResponse(httpResp.Header.Get("Content-Type"), payload) +} + +func (c *remoteMCPClient) captureProtocolVersion(params json.RawMessage) { + var initParams struct { + ProtocolVersion string `json:"protocolVersion"` + } + if err := json.Unmarshal(params, &initParams); err == nil && initParams.ProtocolVersion != "" { + c.protocolVersion = initParams.ProtocolVersion + } +} + +func decodeRemoteMCPResponse(contentType string, payload []byte) (*rpcResponse, error) { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + mediaType = contentType + } + + if mediaType == "text/event-stream" { + return decodeSSEResponse(payload) + } + + var resp rpcResponse + if err := json.Unmarshal(payload, &resp); err != nil { + return nil, fmt.Errorf("decoding remote MCP response: %w", err) + } + return &resp, nil +} + +func decodeSSEResponse(payload []byte) (*rpcResponse, error) { + scanner := bufio.NewScanner(bytes.NewReader(payload)) + var events [][]string + var dataLines []string + for scanner.Scan() { + line := strings.TrimRight(scanner.Text(), "\r") + if line == "" { + if len(dataLines) > 0 { + events = append(events, dataLines) + dataLines = nil + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, err + } + if len(dataLines) > 0 { + events = append(events, dataLines) + } + if len(events) == 0 { + return nil, fmt.Errorf("remote MCP SSE response did not contain data") + } + if len(events) > 1 { + // The stdio bridge writes one JSON-RPC response per request; streaming SSE is not supported yet. + return nil, fmt.Errorf("remote MCP SSE response contained multiple events") + } + var resp rpcResponse + if err := json.Unmarshal([]byte(strings.Join(events[0], "\n")), &resp); err != nil { + return nil, fmt.Errorf("decoding remote MCP SSE response: %w", err) + } + return &resp, nil +} + +func readMCPMessage(reader *bufio.Reader) ([]byte, error) { + for { + first, err := reader.Peek(1) + if err != nil { + return nil, err + } + if first[0] != '\n' && first[0] != '\r' && first[0] != ' ' && first[0] != '\t' { + break + } + _, _ = reader.ReadByte() + } + + headerOrJSON, err := reader.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) && strings.TrimSpace(headerOrJSON) != "" { + return []byte(strings.TrimSpace(headerOrJSON)), nil + } + return nil, err + } + if strings.HasPrefix(strings.ToLower(headerOrJSON), "content-length:") { + _, lengthValue, _ := strings.Cut(headerOrJSON, ":") + lengthValue = strings.TrimSpace(lengthValue) + length, err := strconv.Atoi(lengthValue) + if err != nil { + return nil, fmt.Errorf("invalid Content-Length header: %w", err) + } + if length < 0 { + return nil, fmt.Errorf("invalid Content-Length header: must be non-negative") + } + if length > maxMCPMessageSize { + return nil, fmt.Errorf("invalid Content-Length header: exceeds maximum size %d", maxMCPMessageSize) + } + for { + line, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + if strings.TrimSpace(line) == "" { + break + } + } + payload := make([]byte, length) + if _, err := io.ReadFull(reader, payload); err != nil { + return nil, err + } + return payload, nil + } + return []byte(strings.TrimSpace(headerOrJSON)), nil +} + +func (s *stdioServer) writeResponse(resp rpcResponse) error { + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = s.out.Write(append(data, '\n')) + return err +} diff --git a/cmd/stack/mcp/root_test.go b/cmd/stack/mcp/root_test.go new file mode 100644 index 00000000..a4c19f0d --- /dev/null +++ b/cmd/stack/mcp/root_test.go @@ -0,0 +1,245 @@ +package mcp + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" +) + +func TestStdioServerForwardsInitialize(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/mcp" { + t.Fatalf("path = %s, want /api/mcp", r.URL.Path) + } + var request rpcMessage + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + t.Fatalf("decoding remote request: %v", err) + } + if request.Method != "initialize" { + t.Fatalf("method = %s, want initialize", request.Method) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{"tools":{}},"serverInfo":{"name":"remote"}}}`)) + })) + defer upstream.Close() + + input := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05"}}` + "\n" + var output bytes.Buffer + server := &stdioServer{ + in: strings.NewReader(input), + out: &output, + err: &bytes.Buffer{}, + httpClient: upstream.Client(), + stackURI: upstream.URL, + } + + if err := server.Serve(context.Background()); err != nil { + t.Fatalf("Serve() error = %v", err) + } + + var response rpcResponse + if err := json.Unmarshal(bytes.TrimSpace(output.Bytes()), &response); err != nil { + t.Fatalf("invalid response: %v", err) + } + if string(response.ID) != "1" { + t.Fatalf("response id = %s, want 1", response.ID) + } + if response.Error != nil { + t.Fatalf("response error = %#v", response.Error) + } +} + +func TestReadMCPMessageContentLength(t *testing.T) { + payload := []byte(`{"jsonrpc":"2.0","id":"abc","method":"ping"}`) + input := fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(payload), payload) + + got, err := readMCPMessage(bufioReader(input)) + if err != nil { + t.Fatalf("readMCPMessage() error = %v", err) + } + if string(got) != string(payload) { + t.Fatalf("payload = %q, want %q", got, payload) + } +} + +func TestReadMCPMessageRejectsNegativeContentLength(t *testing.T) { + _, err := readMCPMessage(bufioReader("Content-Length: -1\r\n\r\n")) + if err == nil { + t.Fatalf("readMCPMessage() expected error") + } + if !strings.Contains(err.Error(), "must be non-negative") { + t.Fatalf("readMCPMessage() error = %v, want non-negative framing error", err) + } +} + +func TestReadMCPMessageRejectsOversizedContentLength(t *testing.T) { + input := fmt.Sprintf("Content-Length: %d\r\n\r\n", maxMCPMessageSize+1) + + _, err := readMCPMessage(bufioReader(input)) + if err == nil { + t.Fatalf("readMCPMessage() expected error") + } + if !strings.Contains(err.Error(), "exceeds maximum size") { + t.Fatalf("readMCPMessage() error = %v, want maximum size framing error", err) + } +} + +func TestStdioServerStopsWhenContextIsCancelled(t *testing.T) { + reader, writer := io.Pipe() + defer func() { _ = reader.Close() }() + defer func() { _ = writer.Close() }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + server := &stdioServer{ + in: reader, + out: &bytes.Buffer{}, + err: &bytes.Buffer{}, + httpClient: http.DefaultClient, + stackURI: "http://127.0.0.1", + } + + if err := server.Serve(ctx); !errors.Is(err, context.Canceled) { + t.Fatalf("Serve() error = %v, want context.Canceled", err) + } +} + +func TestStdioServerForwardsCancelledNotification(t *testing.T) { + var gotMethod atomic.Value + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var request rpcMessage + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + t.Fatalf("decoding remote request: %v", err) + } + gotMethod.Store(request.Method) + w.WriteHeader(http.StatusAccepted) + })) + defer upstream.Close() + + input := `{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}` + "\n" + server := &stdioServer{ + in: strings.NewReader(input), + out: &bytes.Buffer{}, + err: &bytes.Buffer{}, + httpClient: upstream.Client(), + stackURI: upstream.URL, + } + + if err := server.Serve(context.Background()); err != nil { + t.Fatalf("Serve() error = %v", err) + } + if got := gotMethod.Load(); got != "notifications/cancelled" { + t.Fatalf("forwarded method = %v, want notifications/cancelled", got) + } +} + +func TestStdioServerPreservesForwardedMethodErrors(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "invalid_token", http.StatusUnauthorized) + })) + defer upstream.Close() + + input := `{"jsonrpc":"2.0","id":1,"method":"resources/list"}` + "\n" + var output bytes.Buffer + server := &stdioServer{ + in: strings.NewReader(input), + out: &output, + err: &bytes.Buffer{}, + httpClient: upstream.Client(), + stackURI: upstream.URL, + } + + if err := server.Serve(context.Background()); err != nil { + t.Fatalf("Serve() error = %v", err) + } + + var response rpcResponse + if err := json.Unmarshal(bytes.TrimSpace(output.Bytes()), &response); err != nil { + t.Fatalf("invalid response: %v", err) + } + if response.Error == nil { + t.Fatalf("response error is nil") + } + if response.Error.Code != -32000 { + t.Fatalf("response error code = %d, want -32000", response.Error.Code) + } + if !strings.Contains(response.Error.Message, "remote MCP HTTP 401") { + t.Fatalf("response error message = %q, want remote HTTP error", response.Error.Message) + } +} + +func TestRemoteMCPClientKeepsSessionID(t *testing.T) { + var count atomic.Int32 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestNumber := count.Add(1) + if requestNumber == 1 { + w.Header().Set("Mcp-Session-Id", "session-123") + } else if got := r.Header.Get("Mcp-Session-Id"); got != "session-123" { + t.Fatalf("Mcp-Session-Id = %q, want session-123", got) + } + if got := r.Header.Get("MCP-Protocol-Version"); got != "2025-03-26" { + t.Fatalf("MCP-Protocol-Version = %q, want 2025-03-26", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{}}`)) + })) + defer upstream.Close() + + client := newRemoteMCPClient(upstream.Client(), upstream.URL) + _, err := client.Request(context.Background(), rpcMessage{ + JSONRPC: "2.0", + ID: json.RawMessage("1"), + Method: "initialize", + Params: json.RawMessage(`{"protocolVersion":"2025-03-26"}`), + }) + if err != nil { + t.Fatalf("initialize request error = %v", err) + } + + _, err = client.Request(context.Background(), rpcMessage{ + JSONRPC: "2.0", + ID: json.RawMessage("2"), + Method: "tools/list", + }) + if err != nil { + t.Fatalf("tools/list request error = %v", err) + } +} + +func TestDecodeSSEResponse(t *testing.T) { + resp, err := decodeRemoteMCPResponse("text/event-stream", []byte("event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n")) + if err != nil { + t.Fatalf("decodeRemoteMCPResponse() error = %v", err) + } + result, ok := resp.Result.(map[string]any) + if !ok || result["ok"] != true { + t.Fatalf("result = %#v, want ok=true", resp.Result) + } +} + +func TestDecodeSSEResponseRejectsMultipleEvents(t *testing.T) { + payload := []byte("event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"first\":true}}\n\n" + + "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"second\":true}}\n\n") + + _, err := decodeRemoteMCPResponse("text/event-stream", payload) + if err == nil { + t.Fatalf("decodeRemoteMCPResponse() expected error") + } + if !strings.Contains(err.Error(), "multiple events") { + t.Fatalf("decodeRemoteMCPResponse() error = %v, want multiple events error", err) + } +} + +func bufioReader(s string) *bufio.Reader { + return bufio.NewReader(strings.NewReader(s)) +} diff --git a/cmd/stack/root.go b/cmd/stack/root.go index c16c8430..c7018e20 100644 --- a/cmd/stack/root.go +++ b/cmd/stack/root.go @@ -3,6 +3,7 @@ package stack import ( "github.com/spf13/cobra" + "github.com/formancehq/fctl/v3/cmd/stack/mcp" "github.com/formancehq/fctl/v3/cmd/stack/modules" "github.com/formancehq/fctl/v3/cmd/stack/users" fctl "github.com/formancehq/fctl/v3/pkg" @@ -24,6 +25,7 @@ func NewCommand() *cobra.Command { NewUpgradeCommand(), NewHistoryCommand(), NewProxyCommand(), + mcp.NewCommand(), users.NewCommand(), modules.NewCommand(), ), diff --git a/pkg/authentication.go b/pkg/authentication.go index 32deba71..a9932582 100644 --- a/pkg/authentication.go +++ b/pkg/authentication.go @@ -249,7 +249,6 @@ func Refresh(ctx context.Context, relyingParty client.RelyingParty, token Access } func FetchStackToken(ctx context.Context, httpClient *http.Client, stackURI, token string) (*oauth2.Token, error) { - form := url.Values{ "grant_type": []string{"urn:ietf:params:oauth:grant-type:jwt-bearer"}, "assertion": []string{token},