From ec8b204e68a97c8d6fa41c71d3be97715a25a1b4 Mon Sep 17 00:00:00 2001 From: Yaroslav Shevchuk Date: Fri, 8 May 2026 11:01:30 +0000 Subject: [PATCH 1/6] mrtr implementation proposal --- design/mrtr.md | 297 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 design/mrtr.md diff --git a/design/mrtr.md b/design/mrtr.md new file mode 100644 index 00000000..44bd810e --- /dev/null +++ b/design/mrtr.md @@ -0,0 +1,297 @@ +## Context + +A proposal for implementing Multi Round-Trip Requests +(MRTR) as defined in [SEP-2322](https://github.com/CaitieM20/modelcontextprotocol/blob/de6d76fba3078eda957dadb3cec51ca8ab851b5c/seps/2322-MRTR.md). + +In the new protocol version servers can't initiate requests to clients, but when a server requires additional input for completing `tools/call`, `prompts/get`, or `resources/read` it can return an incomplete result along with a set of `inputRequests`. The client fulfills them locally and retries the same call with `inputResponses` attached. + +## Goals + +**Must have:** + +* Backward compatibility. +* Correct representation on the wire. + +**Nice to have:** + +* Minimal changes to the exported API surface. +* Hard for server implementers to construct an invalid payload. +* Simple input request handling for clients. +* Protocol-version-independent code. +* Consistency with the rest of the SDK. + +## Proposal + +`ServerSession` methods return an error for new-version protocol connections. + +`InputRequest`/`InputResponse` is introduced as a sealed-interface: +```go +// Implemented by *ElicitParams, *CreateMessageParams, *ListRootsParams +type InputRequest interface{ isInputRequest() } + +type InputRequestMap map[string]InputRequest +// MarshalJSON encodes as map[string]struct{ Method string; Params InputRequest } +func (m InputRequestMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Params InputRequest } +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { ... } + +// Implemented by *ElicitResult, *CreateMessageResult, *ListRootsResult. +type InputResponse interface{ isInputResult() } + +type InputResponseMap map[string]InputResponse +// MarshalJSON encodes as map[string]struct{ Method string; Result InputResponse } +func (m InputResponseMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Result InputResponse } +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { ... } +``` + +All affected methods' `*Params` are extended with `InputResponseMap` and `RequestState` fields: +```go +type CallToolParams struct { + ... + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + RequestState string `json:"requestState,omitempty"` +} +// Same for GetPromptParams, ReadResourceParams +``` + +`InputRequests` and `RequestState` fields are added directly to `CallToolResult`, `GetPromptResult`, and `ReadResourceResult` as exported. +Result type discriminator (completed, input_required) is unexported so that SDK users don't need to set it to the correct constant in addition to setting either `Content` or `InputRequests`. Handler execution result is validated and augmented before marshaling: +```go +type CallToolResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` + resultType string // set by the SDK and used in MarshalJSON() +} +// Same for GetPromptResult, ReadResourceResult. +``` +Alternatively, the field could only exist on `wire struct`, but this would make us return `complete` to older clients or empty string to newer clients, because there's no access to negotiated protocol version in `MarshalJSON`. + +Servers request additional input by constructing a correct struct literal: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return &mcp.CallToolResult{ + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` +An incomplete result should be converted to an error response when communicating with an older client. This is a reasonable default for server implementers to not branch by protocol version in their code. + +An unexported middleware is installed in the client, which similarly to `urlElicitationMiddleware` will automatically invoke handlers for the corresponding methods on incomplete results and retry the original request. `ClientOptions` will be extended with configuration knobs: +```go +type MRTROptions struct { + ... + MaxRetries int + Disabled bool +} +client := mcp.NewClient(impl, &mcp.ClientOptions{ + MRTR: &mcp.MRTROptions{MaxRetries: 3}, +}) +``` + +Alternatively, clients have an option to disable it and write a retry loop manually: +```go +client := mcp.NewClient(impl, &mcp.ClientOptions{MRTR: &mcp.MRTROptions{Disabled: true}}) +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if result.InputRequests != nil { ... } +``` + +This is error-prone because according to the SEP an empty map can be returned for load-shedding purposes. `NeedsInput` function can be provided, but nothing would force users to use it. + +**Pros** + +This is arguably the simplest and the most transparent approach which is also closest to the spec. +What gets explicitly set on the server can be observed on the wire and on the client. +The opt-out middleware follows the principle of the least surprise for app developers. If client method handlers were provided they will continue to be invoked regardless of the protocol version in use. The `Disabled` option lets "power-users" build any custom handling logic. + +**Cons** + +The biggest downside of the proposal is that server developers can construct incorrect responses and this will only be validated at runtime. + +## Alternatives considered + +### Unexported fields + +MRTR fields can be unexported, accessible only through getters, constructible only through constructor functions, and handled explicitly in custom `(Unm|M)arshalJSON`. This will make it impossible for developers to construct incorrect responses and for clients to perform an erroneous `len(result.InputRequests) > 0` check in the load-shedding case. +```go +type CallToolResult struct { + ... + inputRequests InputRequestMap + requestState string + resultType string +} + +func (r *CallToolResult) InputRequests() (InputRequestMap, bool) { ... } + +// InputRequiredResult struct exists for backward-compatibility in case of new fields being needed for input request results. +type InputRequiredResult struct { + InputRequests InputRequestMap + RequestState string +} + +// RequireInput constructs a tool call, prompt or resource result with input requests set. +// mrtrResult provides methods for setting private fields on these types. +func RequireInput[T any, TP interface { *T; mrtrResult }](r InputRequiredResult) TP { ... } +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return mcp.RequireInput[mcp.CallToolResult](mcp.InputRequiredResult{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }), nil, nil + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if requests, ok := result.InputRequests(); ok { ... } +``` + +The biggest downside of this approach is the obscure data model with hidden fields. An incomplete `mcp.CallToolResult` looks like an uninitialized struct until `InputRequests` method result is examined. +In addition to this, the verbose `RequireInput` syntax (no auto type inference from assignment target) does not look idiomatic and fits poorly into the existing SDK APIs. + +--- + +### `InputRequiredError` type + +We could explore a different data channel - `error` return value. This would give us the natural "happy path is when all inputs are provided" flow on the server side, and good result interpretability on the client side (impossible to confuse with a successful response). +The new error could be converted to the correct wire representation at the marshaling stage. +```go +type InputRequiredError struct { + InputRequests InputRequestMap + RequestState string +} + +func (e *InputRequiredError) Error() string { + return fmt.Sprintf("input required: %d request(s)", len(e.InputRequests)) +} +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return nil, zero, &mcp.InputRequiredError{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + } + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +var inputReqErr *mcp.InputRequiredError +if errors.As(err, &inputReqErr) { ... } +``` + +The downsides of this approach are: +* The drift from the protocol, where MRTR is not an error flow. +* Obscure "customError -> non-error protocol type on wirte -> customError" data lifecycle. +* Things get confusing for error-processing middleware. + +--- + +### New functions + +We could introduce new functions with a different handler signature where the return type is a sealed interface. This would give us compiler-enforced correctness for values constructed by tool handlers and clients would be forced to unpack `mcp.RoundTripCallToolResult` and make a concious decision for how to handle it. +```go +type RoundTripToolHandler func(context.Context, *CallToolRequest) (RoundTripCallToolResult, error) +type RoundTripToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (RoundTripCallToolResult, Out, error) + +// RoundTripCallToolResult is implemented by CallToolResult and IncompleteResult +type RoundTripCallToolResult interface { isMRTRResult() } + +type IncompleteResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` +} + +func (s *Server) AddRoundTripTool(t *Tool, h RoundTripToolHandler) +func AddRoundTripTool[In, Out any](s *Server, t *Tool, h RoundTripToolHandlerFor[In, Out]) +``` + +`Server.AddTool` wraps the old `ToolHandler` into a `RoundTripToolHandler` to update its function signature: +```go +mcp.AddRoundTripTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (mcp.RoundTripCallToolResult, MyOut, error) { + if needsInput(in) { + return &mcp.IncompleteResult{ + ResultType: mcp.ResultTypeInputRequired, + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +The downsides of this approach are: +* SEP suggests `ResultType` will potentially be extended with new values, `RoundTrip` in new function names will not allow us to cleanly extend the sealed interface with new types. But an overly generic name for new functions will make the API use-case less clear. +* Different code +* SDK takes the same action (puts it on the wire) regardless of the returned type, it exists only for enforcing correctness of the user code. +* Exported API surface bloat: +7 exported functions. + +--- + +### Exported Middleware + +We could flip "unexported MRTR middleware with opt-out option" to "exported middleware with opt-in requirement". +```go +func AutoMRTR(opts *MRTROptions) Middleware { ... } +type MRTROptions struct { + MaxRetries int +} +client := mcp.NewClient(impl, nil) +client.AddSendingMiddleware(mcp.AutoMRTR(&mcp.MRTROptions{ + MaxRetries: 5, +})) +``` +This would change semantics of `*Handler` fields - depending on the protocol version in use, an extra initialization step will be required for them to "take effect". + +--- + +### Protocol version bridging + +When a handler returns an input-required result to an old client the SDK could bridge by invoking `ServerSession.Elicit`/`CreateMessage`/`ListRoots` on the `ServerSession` and re-invoking the handler with the collected `inputResponses`: +```go +func mrtrBridgeMiddleware(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + res, err := next(ctx, method, req) + if err != nil || !needsInput(res) || isNewProtocol(req) { + return res, err + } + // Fulfill input requests via server-initiated calls, inject inputResponses into the request, + // re-invoke the handler. + } +} +``` + +Or even more radically: instead of making `ServerSession.Elicit`, `CreateMessage`, and `ListRoots` return errors for newer clients, we could convert these invocations into MRTR wire format transparently: suspend the handler, return `input_required`, and resume when the client retries with `inputResponses`: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + result, err := req.Session.Elicit(ctx, &mcp.ElicitParams{Message: "Deploy to production?"}) + if err != nil { + return nil, zero, err + } + if result.Action != "accept" { + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "cancelled"}}}, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "deployed"}}}, myOut, nil +}) +``` + +Such bridging would contradict the design goal of MRTR — servers shouldn't hold resources between round trips, and it should be possible for a retry to arrive on any server instance in a multi-server deployment. + From 3bda0df3bff8130c407d024961417a837f2e97fa Mon Sep 17 00:00:00 2001 From: Yaroslav Shevchuk Date: Tue, 26 May 2026 07:53:18 +0000 Subject: [PATCH 2/6] server and client basic mrtr support --- mcp/client.go | 10 +- mcp/content_nil_test.go | 3 +- mcp/mcp_test.go | 5 +- mcp/mrtr.go | 212 +++++++++++++++++++++++++ mcp/mrtr_test.go | 185 ++++++++++++++++++++++ mcp/protocol.go | 343 +++++++++++++++++++++++++++++++++++++++- mcp/server.go | 33 +++- mcp/streamable_test.go | 12 +- 8 files changed, 783 insertions(+), 20 deletions(-) create mode 100644 mcp/mrtr.go create mode 100644 mcp/mrtr_test.go diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..8bea349a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -58,13 +58,17 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { opts.Logger = ensureLogger(nil) } - return &Client{ + c := &Client{ impl: impl, opts: opts, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } + if opts.MRTR == nil || !opts.MRTR.Disabled { + c.AddSendingMiddleware(mrtrMiddleware(c)) + } + return c } // ClientOptions configures the behavior of the client. @@ -154,6 +158,10 @@ type ClientOptions struct { ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) LoggingMessageHandler func(context.Context, *LoggingMessageRequest) ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // MRTR configures the automatic MRTR (Multi Round-Trip Requests) middleware. + // By default (nil), the middleware is enabled with default settings. + // Set Disabled to true to opt out of automatic MRTR handling. + MRTR *MRTROptions // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index ba263e6f..86bb86ac 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -223,4 +224,4 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(mcp.CallToolResult{}, mcp.GetPromptResult{}, mcp.ReadResourceResult{})} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..bad35086 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -206,7 +207,7 @@ func TestEndToEnd(t *testing.T) { Role: "user", }}, } - if diff := cmp.Diff(wantReview, gotReview); diff != "" { + if diff := cmp.Diff(wantReview, gotReview, ctrCmpOpts...); diff != "" { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } @@ -2371,4 +2372,4 @@ func TestSetErrorPreservesContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} diff --git a/mcp/mrtr.go b/mcp/mrtr.go new file mode 100644 index 00000000..c942237f --- /dev/null +++ b/mcp/mrtr.go @@ -0,0 +1,212 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "log/slog" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const defaultMRTRMaxRetries = 3 + +// MRTROptions configures the client-side MRTR (Multi Round-Trip Requests) +// middleware. The middleware is enabled by default and automatically fulfills input +// requests from the server by invoking the appropriate client handlers and +// retrying the original call. +type MRTROptions struct { + // MaxRetries is the maximum number of MRTR retry attempts after the + // initial call. Defaults to 3 if the provided value is <= 0. + MaxRetries int + // Disabled prevents the automatic MRTR middleware from being installed. + // When true, the client returns input-required results directly and + // callers must handle the retry loop themselves using [CallToolResult.NeedsInput], + // [GetPromptResult.NeedsInput], or [ReadResourceResult.NeedsInput]. + Disabled bool +} + +type mrtrResult interface { + setResultType(ResultType) + inputRequests() map[string]InputRequest + setInputRequest(k string, v InputRequest) + hasContent() bool +} + +func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) error { + hasInputRequests := res.inputRequests() != nil + + protocolVersion := latestProtocolVersion + if iparams := ss.InitializeParams(); iparams != nil { + protocolVersion = iparams.ProtocolVersion + } + supportsMRTR := protocolVersion >= protocolVersion20260630 + + switch { + case hasInputRequests && !supportsMRTR: + return fmt.Errorf("protocol version %q does not support input requests (< %q)", protocolVersion, protocolVersion20260630) + + case hasInputRequests && res.hasContent(): + logger.Warn("handler returned both content and inputRequests; inputRequests takes precedence") + return &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "server bug: result has both content and inputRequests", + } + + case hasInputRequests: + res.setResultType(ResultTypeInputRequired) + + case supportsMRTR: + res.setResultType(ResultTypeComplete) + } + return nil +} + +func mrtrMiddleware(c *Client) Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + + maxRetries := defaultMRTRMaxRetries + if c.opts.MRTR != nil && c.opts.MRTR.MaxRetries > 0 { + maxRetries = c.opts.MRTR.MaxRetries + } + + for retries := 0; ; retries++ { + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + irm := mrtrInputRequests(res) + if len(irm) == 0 { + return res, nil + } + if retries >= maxRetries { + return nil, fmt.Errorf("MRTR: exceeded maximum retries (%d)", maxRetries) + } + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, nil + } + responses, err := fulfillInputRequests(ctx, cs, irm) + if err != nil { + return nil, err + } + setMRTRRetryParams(req, responses, mrtrRequestState(res)) + } + } + } +} + +func mrtrInputRequests(res Result) InputRequestMap { + switch r := res.(type) { + case *CallToolResult: + if r.NeedsInput() { + return r.InputRequests + } + case *GetPromptResult: + if r.NeedsInput() { + return r.InputRequests + } + case *ReadResourceResult: + if r.NeedsInput() { + return r.InputRequests + } + } + return nil +} + +func mrtrRequestState(res Result) string { + switch r := res.(type) { + case *CallToolResult: + return r.RequestState + case *GetPromptResult: + return r.RequestState + case *ReadResourceResult: + return r.RequestState + } + return "" +} + +func setMRTRRetryParams(req Request, responses InputResponseMap, state string) { + switch p := req.GetParams().(type) { + case *CallToolParams: + p.InputResponses = responses + p.RequestState = state + case *GetPromptParams: + p.InputResponses = responses + p.RequestState = state + case *ReadResourceParams: + p.InputResponses = responses + p.RequestState = state + } +} + +func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) { + responses := make(InputResponseMap) + for id, ir := range requests { + resp, err := fulfillInputRequest(ctx, cs, ir) + if err != nil { + return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", id, err) + } + responses[id] = resp + } + return responses, nil +} + +func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return cs.client.elicit(ctx, newClientRequest(cs, p)) + case *CreateMessageParams: + return fulfillCreateMessage(ctx, cs, p) + case *ListRootsParams: + return cs.client.listRoots(ctx, newClientRequest(cs, p)) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} + +func fulfillCreateMessage(ctx context.Context, cs *ClientSession, p *CreateMessageParams) (*CreateMessageResult, error) { + if cs.client.opts.CreateMessageHandler != nil { + return cs.client.opts.CreateMessageHandler(ctx, &CreateMessageRequest{Session: cs, Params: p}) + } + if cs.client.opts.CreateMessageWithToolsHandler != nil { + var msgs []*SamplingMessageV2 + for _, m := range p.Messages { + msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) + } + wtp := &CreateMessageWithToolsParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } + result, err := cs.client.opts.CreateMessageWithToolsHandler(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: wtp}) + if err != nil { + return nil, err + } + var content Content + if len(result.Content) > 0 { + content = result.Content[0] + } + return &CreateMessageResult{ + Meta: result.Meta, + Content: content, + Model: result.Model, + Role: result.Role, + StopReason: result.StopReason, + }, nil + } + return nil, fmt.Errorf("client does not support CreateMessage") +} diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go new file mode 100644 index 00000000..64d10f3a --- /dev/null +++ b/mcp/mrtr_test.go @@ -0,0 +1,185 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +type deployResult struct { + Deployed bool `json:"deployed"` + Reason string `json:"reason,omitempty"` +} + +func TestMRTR_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "deploy"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, *deployResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }, nil, nil + } + + resp, ok := req.Params.InputResponses["confirm"] + if !ok { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Please confirm (retry)"}}, + }, nil, nil + } + + if req.Params.RequestState == "" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "no_state"}, nil + } + if elicitResult := resp.(*ElicitResult); elicitResult != nil && elicitResult.Action != "accept" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "cancelled"}, nil + } + + return &CallToolResult{}, &deployResult{Deployed: true}, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + }) + + // Round 1: initiate deployment + res, err := conn.CallTool(ctx, &CallToolParams{Name: "deploy"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if got := len(res.InputRequests); got != 1 { + t.Fatalf("len(res.InputRequests) = %d, want 1", got) + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("res.InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + // Round 2: retry with confirmation + res, err = conn.CallTool(ctx, &CallToolParams{ + Name: "deploy", + InputResponses: InputResponseMap{ + "confirm": &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, + }, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("CallTool() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + + if diff := cmp.Diff(map[string]any{"deployed": true}, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } +} + +func TestMRTR_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "deploy"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, *deployResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }, nil, nil + } + resp, ok := req.Params.InputResponses["confirm"] + if !ok { + return nil, nil, nil + } + elicitResult := resp.(*ElicitResult) + if elicitResult == nil || elicitResult.Action != "accept" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "cancelled"}, nil + } + return &CallToolResult{}, &deployResult{Deployed: true}, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, nil + }, + }) + + res, err := conn.CallTool(ctx, &CallToolParams{Name: "deploy"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + + if diff := cmp.Diff(map[string]any{"deployed": true}, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } +} + +func TestMRTR_MaxRetries(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "loop"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Again?"}}, + RequestState: "loop-state", + }, nil, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + MRTR: &MRTROptions{MaxRetries: 2}, + }) + + _, err := conn.CallTool(ctx, &CallToolParams{Name: "loop"}) + if err == nil { + t.Fatal("CallTool() err = nil, want error for exceeded max retries") + } +} + +func mustConnectMRTR(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { + t.Helper() + st, ct := NewInMemoryTransports() + ss, err := s.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = ss.Close() + }) + + c := NewClient(testImpl, clientOpts) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = cs.Close() + }) + return cs +} diff --git a/mcp/protocol.go b/mcp/protocol.go index 1646788a..04033b4d 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -13,6 +13,172 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" ) +// ResultType indicates whether a result is complete or requires further input +// from the client via the MRTR (Multi Round-Trip Requests) protocol. +type ResultType string + +const ( + // ResultTypeComplete indicates the result is final. + // This is the default when ResultType is empty. + ResultTypeComplete ResultType = "complete" + + // ResultTypeInputRequired indicates the server needs additional client + // input before it can complete the request. The client should fulfill the + // InputRequests and retry the call with the responses. + ResultTypeInputRequired ResultType = "input_required" +) + +// InputRequest is a sealed interface for parameters that a server can include +// in an MRTR input-required result. Implementations are [*ElicitParams], +// [*CreateMessageParams], and [*ListRootsParams]. +type InputRequest interface{ isInputRequest() } + +// InputRequestMap maps server-assigned request IDs to [InputRequest] values. +// It is used in result types to tell the client what input the server needs. +type InputRequestMap map[string]InputRequest + +func (m InputRequestMap) MarshalJSON() ([]byte, error) { + type wire struct { + Method string `json:"method"` + Params InputRequest `json:"params,omitempty"` + } + typeToMethod := func(v InputRequest) (string, error) { + switch v.(type) { + case *ElicitParams: + return methodElicit, nil + case *CreateMessageParams: + return methodCreateMessage, nil + case *ListRootsParams: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Params: v} + } + return json.Marshal(converted) +} + +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + result := make(InputRequestMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputRequest method: %q", raw.Method) + } + } + *m = result + return nil +} + +// InputResponse is a sealed interface for results that a client sends back +// when fulfilling an MRTR input request. Implementations are [*ElicitResult], +// [*CreateMessageResult], and [*ListRootsResult]. +type InputResponse interface{ isInputResult() } + +// InputResponseMap maps request IDs (from [InputRequestMap]) to [InputResponse] +// values. It is used in params types when retrying a call after an +// input-required result. +type InputResponseMap map[string]InputResponse + +func (m InputResponseMap) MarshalJSON() ([]byte, error) { + type wire struct { + Method string `json:"method"` + Result InputResponse `json:"result,omitempty"` + } + typeToMethod := func(v InputResponse) (string, error) { + switch v.(type) { + case *ElicitResult: + return methodElicit, nil + case *CreateMessageResult: + return methodCreateMessage, nil + case *ListRootsResult: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Result: v} + } + return json.Marshal(converted) +} + +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Result json.RawMessage `json:"result"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + result := make(InputResponseMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputRequest method: %q", raw.Method) + } + } + *m = result + return nil +} + // Optional annotations for the client. The client can use annotations to inform // how objects are used or displayed. type Annotations struct { @@ -46,6 +212,14 @@ type CallToolParams struct { // Arguments holds the tool arguments. It can hold any value that can be // marshaled to JSON. Arguments any `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -61,6 +235,14 @@ type CallToolParamsRaw struct { // is the responsibility of the tool handler to unmarshal and validate the // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // A CallToolResult is the server's response to a tool call. @@ -107,6 +289,24 @@ type CallToolResult struct { // the Content field. IsError bool `json:"isError,omitempty"` + // InputRequests is a map of server-assigned IDs to input requests. + // Populated only when ResultType is ResultTypeInputRequired. + // The client must fulfill these and echo the IDs back in InputResponses + // when retrying the call. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + + // RequestState is an opaque string the client must echo back when + // retrying after an input-required result. Servers use this to carry + // context between independent requests. + // + // Unauthenticated servers must encrypt, sign and verify this value. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. Empty or ResultTypeComplete means the call succeeded + // normally. ResultTypeInputRequired means the client should fulfill the + // InputRequests and retry the call. + resultType ResultType // The error passed to setError, if any. // It is not marshaled, and therefore it is only visible on the server. // Its only use is in server sending middleware, where it can be accessed @@ -145,13 +345,35 @@ func (r *CallToolResult) GetError() error { func (*CallToolResult) isResult() {} -// UnmarshalJSON handles the unmarshalling of content into the Content -// interface. +func (r *CallToolResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *CallToolResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *CallToolResult) setInputRequest(k string, v InputRequest) { r.InputRequests[k] = v } +func (r *CallToolResult) hasContent() bool { + return len(r.Content) > 0 || r.StructuredContent != nil +} + +// NeedsInput reports whether this result requires further client input. +// This is true when the server returned ResultType "input_required". +// When NeedsInput returns true, check InputRequests for the set of +// requests the server needs fulfilled before retrying the call. +// An empty InputRequests with NeedsInput true indicates load-shedding. +func (r *CallToolResult) NeedsInput() bool { return r.resultType == ResultTypeInputRequired } + +func (x *CallToolResult) MarshalJSON() ([]byte, error) { + type res CallToolResult // avoid recursion + type wire struct { + res + ResultType ResultType `json:"resultType,omitempty"` + } + return json.Marshal(wire{res: res(*x), ResultType: x.resultType}) +} + func (x *CallToolResult) UnmarshalJSON(data []byte) error { type res CallToolResult // avoid recursion var wire struct { res - Content []*wireContent `json:"content"` + Content []*wireContent `json:"content"` + ResultType ResultType `json:"resultType"` } if err := internaljson.Unmarshal(data, &wire); err != nil { return err @@ -160,6 +382,7 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } + wire.res.resultType = wire.ResultType *x = CallToolResult(wire.res) return nil } @@ -422,6 +645,7 @@ type CreateMessageParams struct { } func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) isInputRequest() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -547,7 +771,8 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } -func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isInputResult() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -651,6 +876,13 @@ type GetPromptParams struct { Arguments map[string]string `json:"arguments,omitempty"` // The name of the prompt or prompt template. Name string `json:"name"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *GetPromptParams) isParams() {} @@ -665,10 +897,53 @@ type GetPromptResult struct { // An optional description for the prompt. Description string `json:"description,omitempty"` Messages []*PromptMessage `json:"messages"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for MRTR retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType ResultType } func (*GetPromptResult) isResult() {} +func (r *GetPromptResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *GetPromptResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *GetPromptResult) setInputRequest(k string, v InputRequest) { r.inputRequests()[k] = v } +func (r *GetPromptResult) hasContent() bool { return len(r.Messages) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *GetPromptResult) NeedsInput() bool { return r.resultType == ResultTypeInputRequired } + +func (x *GetPromptResult) MarshalJSON() ([]byte, error) { + type res GetPromptResult + type wire struct { + res + ResultType ResultType `json:"resultType,omitempty"` + } + return json.Marshal(wire{res: res(*x), ResultType: x.resultType}) +} + +func (x *GetPromptResult) UnmarshalJSON(data []byte) error { + type res GetPromptResult + var wire struct { + res + ResultType ResultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = GetPromptResult(wire.res) + return nil +} + // InitializeParams is sent by the client to initialize the session. type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -833,6 +1108,7 @@ type ListRootsParams struct { } func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) isInputRequest() {} func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -846,7 +1122,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } -func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isInputResult() {} type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -1086,6 +1363,13 @@ type ReadResourceParams struct { // The URI of the resource to read. The URI can use any protocol; it is up to // the server how to interpret it. URI string `json:"uri"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *ReadResourceParams) isParams() {} @@ -1098,10 +1382,53 @@ type ReadResourceResult struct { // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Contents []*ResourceContents `json:"contents"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for MRTR retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType ResultType } func (*ReadResourceResult) isResult() {} +func (r *ReadResourceResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *ReadResourceResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *ReadResourceResult) setInputRequest(k string, v InputRequest) { r.inputRequests()[k] = v } +func (r *ReadResourceResult) hasContent() bool { return len(r.Contents) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *ReadResourceResult) NeedsInput() bool { return r.resultType == ResultTypeInputRequired } + +func (x *ReadResourceResult) MarshalJSON() ([]byte, error) { + type res ReadResourceResult + type wire struct { + res + ResultType ResultType `json:"resultType,omitempty"` + } + return json.Marshal(wire{res: res(*x), ResultType: x.resultType}) +} + +func (x *ReadResourceResult) UnmarshalJSON(data []byte) error { + type res ReadResourceResult + var wire struct { + res + ResultType ResultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = ReadResourceResult(wire.res) + return nil +} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -1468,7 +1795,8 @@ type ElicitParams struct { ElicitationID string `json:"elicitationId,omitempty"` } -func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isInputRequest() {} func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1488,7 +1816,8 @@ type ElicitResult struct { Content map[string]any `json:"content,omitempty"` } -func (*ElicitResult) isResult() {} +func (*ElicitResult) isResult() {} +func (*ElicitResult) isInputResult() {} // ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. type ElicitationCompleteParams struct { diff --git a/mcp/server.go b/mcp/server.go index 183226d1..ce056393 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -370,8 +370,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa } // Marshal the output and put the RawMessage in the StructuredContent field. + // Skip when the handler returned input requests (MRTR): content and + // inputRequests are mutually exclusive on the wire. var outval any = out - if elemZero != nil { + if res.InputRequests != nil { + outval = nil + } else if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. // Instead, use the zero value of the unpointered type. var z Out @@ -742,7 +746,13 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } - return prompt.handler(ctx, req) + res, err := prompt.handler(ctx, req) + if err == nil && res != nil { + if err := handleMRTRResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + } + return res, err } func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { @@ -775,10 +785,15 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR } } res, err := st.handler(ctx, req) - if err == nil && res != nil && res.Content == nil { - res2 := *res - res2.Content = []Content{} // avoid "null" - res = &res2 + if err == nil && res != nil { + if err := handleMRTRResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.Content == nil && res.resultType != ResultTypeInputRequired { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } } return res, err } @@ -826,6 +841,12 @@ func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*R if err != nil { return nil, err } + if err := handleMRTRResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.resultType == ResultTypeInputRequired { + return res, nil + } if res == nil || res.Contents == nil { return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..7ca109a6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -769,6 +769,12 @@ func req(id int64, method string, params any) *jsonrpc.Request { return r } +func completeCallToolResult() *CallToolResult { + r := &CallToolResult{Content: []Content{}} + r.resultType = ResultTypeComplete + return r +} + func resp(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), @@ -1968,7 +1974,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(2, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2012,7 +2018,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(6, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2027,7 +2033,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, })}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(7, completeCallToolResult(), nil)}, }, { method: "POST", From 2f6ee7fa643be7e66ce622a9525f47be2221e7a5 Mon Sep 17 00:00:00 2001 From: Yaroslav Shevchuk Date: Tue, 26 May 2026 10:54:33 +0000 Subject: [PATCH 3/6] implement bridging with older protocol clients --- design/mrtr.md | 50 +++-------- mcp/client.go | 2 +- mcp/mrtr.go | 162 +++++++++++++++++++++++++++++---- mcp/mrtr_test.go | 229 +++++++++++++++++++++++++++++++++++++++-------- mcp/server.go | 4 +- 5 files changed, 350 insertions(+), 97 deletions(-) diff --git a/design/mrtr.md b/design/mrtr.md index 44bd810e..5974e3f3 100644 --- a/design/mrtr.md +++ b/design/mrtr.md @@ -80,12 +80,13 @@ mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil }) ``` -An incomplete result should be converted to an error response when communicating with an older client. This is a reasonable default for server implementers to not branch by protocol version in their code. +The SDK validates at runtime that a handler does not return both content and `InputRequests` — doing so logs a warning and returns a `CodeInternalError` JSON-RPC error. -An unexported middleware is installed in the client, which similarly to `urlElicitationMiddleware` will automatically invoke handlers for the corresponding methods on incomplete results and retry the original request. `ClientOptions` will be extended with configuration knobs: +An unexported receiving middleware is installed on the server for backward compatibility with older clients. When a handler returns `InputRequests` and the connected client uses a protocol version that does not support MRTR, the middleware fulfills the requests by calling `ServerSession.Elicit`/`CreateMessage`/`ListRoots` on the client directly and reinvokes the handler once with the collected `InputResponses`. If any of these calls fail, the entire request fails. Input requests are fulfilled concurrently. This lets server developers write protocol-version-independent code. + +An unexported sending middleware is installed on the client, which similarly to `urlElicitationMiddleware` will automatically invoke handlers for the corresponding methods on incomplete results and retry the original request. `ClientOptions` is extended with configuration knobs: ```go type MRTROptions struct { - ... MaxRetries int Disabled bool } @@ -94,24 +95,25 @@ client := mcp.NewClient(impl, &mcp.ClientOptions{ }) ``` -Alternatively, clients have an option to disable it and write a retry loop manually: +Alternatively, clients have an option to disable it and write a retry loop manually using `NeedsInput()`: ```go client := mcp.NewClient(impl, &mcp.ClientOptions{MRTR: &mcp.MRTROptions{Disabled: true}}) result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) -if result.InputRequests != nil { ... } +if result.NeedsInput() { ... } ``` -This is error-prone because according to the SEP an empty map can be returned for load-shedding purposes. `NeedsInput` function can be provided, but nothing would force users to use it. +`NeedsInput()` checks the unexported `resultType` field rather than `InputRequests`, correctly handling the load-shedding case where the server returns `input_required` with an empty map. **Pros** This is arguably the simplest and the most transparent approach which is also closest to the spec. What gets explicitly set on the server can be observed on the wire and on the client. -The opt-out middleware follows the principle of the least surprise for app developers. If client method handlers were provided they will continue to be invoked regardless of the protocol version in use. The `Disabled` option lets "power-users" build any custom handling logic. +The opt-out client middleware follows the principle of the least surprise for app developers. If client method handlers were provided they will continue to be invoked regardless of the protocol version in use. The `Disabled` option lets "power-users" build any custom handling logic. +The server middleware makes handler code protocol-version-independent — the same handler works for both old and new clients. **Cons** -The biggest downside of the proposal is that server developers can construct incorrect responses and this will only be validated at runtime. +The biggest downside of the proposal is that server developers can construct incorrect responses (both content and input requests) and this will only be validated at runtime. ## Alternatives considered @@ -263,35 +265,7 @@ This would change semantics of `*Handler` fields - depending on the protocol ver --- -### Protocol version bridging - -When a handler returns an input-required result to an old client the SDK could bridge by invoking `ServerSession.Elicit`/`CreateMessage`/`ListRoots` on the `ServerSession` and re-invoking the handler with the collected `inputResponses`: -```go -func mrtrBridgeMiddleware(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - res, err := next(ctx, method, req) - if err != nil || !needsInput(res) || isNewProtocol(req) { - return res, err - } - // Fulfill input requests via server-initiated calls, inject inputResponses into the request, - // re-invoke the handler. - } -} -``` - -Or even more radically: instead of making `ServerSession.Elicit`, `CreateMessage`, and `ListRoots` return errors for newer clients, we could convert these invocations into MRTR wire format transparently: suspend the handler, return `input_required`, and resume when the client retries with `inputResponses`: -```go -mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { - result, err := req.Session.Elicit(ctx, &mcp.ElicitParams{Message: "Deploy to production?"}) - if err != nil { - return nil, zero, err - } - if result.Action != "accept" { - return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "cancelled"}}}, zero, nil - } - return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "deployed"}}}, myOut, nil -}) -``` +### Server API protocol version bridging -Such bridging would contradict the design goal of MRTR — servers shouldn't hold resources between round trips, and it should be possible for a retry to arrive on any server instance in a multi-server deployment. +Converting `ServerSession.Elicit`/`CreateMessage`/`ListRoots` calls into MRTR wire format transparently (suspend the handler, return `input_required`, resume on retry). Rejected because of a significant implementation effort and the fact that it contradicts the design goal of MRTR where servers shouldn't hold resources between round trips, and it should be possible for a retry to arrive on any server instance in a multi-server deployment. diff --git a/mcp/client.go b/mcp/client.go index 8bea349a..381be027 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -66,7 +66,7 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } if opts.MRTR == nil || !opts.MRTR.Disabled { - c.AddSendingMiddleware(mrtrMiddleware(c)) + c.AddSendingMiddleware(clientMRTRMiddleware(c)) } return c } diff --git a/mcp/mrtr.go b/mcp/mrtr.go index c942237f..e449790b 100644 --- a/mcp/mrtr.go +++ b/mcp/mrtr.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "log/slog" + "sync" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -39,33 +40,36 @@ type mrtrResult interface { func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) error { hasInputRequests := res.inputRequests() != nil - protocolVersion := latestProtocolVersion - if iparams := ss.InitializeParams(); iparams != nil { - protocolVersion = iparams.ProtocolVersion - } - supportsMRTR := protocolVersion >= protocolVersion20260630 - - switch { - case hasInputRequests && !supportsMRTR: - return fmt.Errorf("protocol version %q does not support input requests (< %q)", protocolVersion, protocolVersion20260630) - - case hasInputRequests && res.hasContent(): + if hasInputRequests && res.hasContent() { logger.Warn("handler returned both content and inputRequests; inputRequests takes precedence") return &jsonrpc.Error{ Code: jsonrpc.CodeInternalError, Message: "server bug: result has both content and inputRequests", } + } - case hasInputRequests: - res.setResultType(ResultTypeInputRequired) + supportsMRTR := sessionSupportsMRTR(ss) + switch { + case hasInputRequests && supportsMRTR: + res.setResultType(ResultTypeInputRequired) case supportsMRTR: res.setResultType(ResultTypeComplete) } + // For older clients the resultType is left unset. The serverMRTRMiddleware fulfills the + // requests by calling the client directly and retries the handler. return nil } -func mrtrMiddleware(c *Client) Middleware { +func sessionSupportsMRTR(ss *ServerSession) bool { + protocolVersion := latestProtocolVersion + if iparams := ss.InitializeParams(); iparams != nil { + protocolVersion = iparams.ProtocolVersion + } + return protocolVersion >= protocolVersion20260630 +} + +func clientMRTRMiddleware(c *Client) Middleware { return func(next MethodHandler) MethodHandler { return func(ctx context.Context, method string, req Request) (Result, error) { if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { @@ -103,6 +107,108 @@ func mrtrMiddleware(c *Client) Middleware { } } +// serverMRTRMiddleware is a receiving middleware for servers that transparently +// handles MRTR for clients on older protocol versions. When a handler returns +// InputRequests and the client does not support MRTR, the middleware fulfills +// the requests by calling the client reinvokes the handler once with the responses. +func serverMRTRMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + ss, ok := req.GetSession().(*ServerSession) + if !ok { + return next(ctx, method, req) + } + if sessionSupportsMRTR(ss) { + return next(ctx, method, req) + } + + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + irm := serverMRTRInputRequests(res) + if len(irm) == 0 { + return res, nil + } + responses, err := fulfillServerInputRequests(ctx, ss, irm) + if err != nil { + return nil, err + } + setMRTRRetryParams(req, responses, mrtrRequestState(res)) + return next(ctx, method, req) + } + } +} + +// serverMRTRInputRequests returns input requests from a result for old clients +// where resultType is not set. It checks InputRequests directly. +func serverMRTRInputRequests(res Result) InputRequestMap { + if res == nil { + return nil + } + switch r := res.(type) { + case *CallToolResult: + if r == nil { + return nil + } + return r.InputRequests + case *GetPromptResult: + if r == nil { + return nil + } + return r.InputRequests + case *ReadResourceResult: + if r == nil { + return nil + } + return r.InputRequests + } + return nil +} + +func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests InputRequestMap) (InputResponseMap, error) { + type result struct { + id string + resp InputResponse + err error + } + results := make(chan result, len(requests)) + var wg sync.WaitGroup + for id, ir := range requests { + wg.Go(func() { + resp, err := fulfillServerInputRequest(ctx, ss, ir) + results <- result{id, resp, err} + }) + } + wg.Wait() + close(results) + + responses := make(InputResponseMap, len(results)) + for r := range results { + if r.err != nil { + return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err) + } + responses[r.id] = r.resp + } + return responses, nil +} + +func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return ss.Elicit(ctx, p) + case *CreateMessageParams: + return ss.CreateMessage(ctx, p) + case *ListRootsParams: + return ss.ListRoots(ctx, p) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} + func mrtrInputRequests(res Result) InputRequestMap { switch r := res.(type) { case *CallToolResult: @@ -138,6 +244,9 @@ func setMRTRRetryParams(req Request, responses InputResponseMap, state string) { case *CallToolParams: p.InputResponses = responses p.RequestState = state + case *CallToolParamsRaw: + p.InputResponses = responses + p.RequestState = state case *GetPromptParams: p.InputResponses = responses p.RequestState = state @@ -148,13 +257,28 @@ func setMRTRRetryParams(req Request, responses InputResponseMap, state string) { } func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) { - responses := make(InputResponseMap) + type result struct { + id string + resp InputResponse + err error + } + results := make(chan result, len(requests)) + var wg sync.WaitGroup for id, ir := range requests { - resp, err := fulfillInputRequest(ctx, cs, ir) - if err != nil { - return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", id, err) + wg.Go(func() { + resp, err := fulfillInputRequest(ctx, cs, ir) + results <- result{id, resp, err} + }) + } + wg.Wait() + close(results) + + responses := make(InputResponseMap, len(results)) + for r := range results { + if r.err != nil { + return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err) } - responses[id] = resp + responses[r.id] = r.resp } return responses, nil } diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go index 64d10f3a..23980261 100644 --- a/mcp/mrtr_test.go +++ b/mcp/mrtr_test.go @@ -6,18 +6,20 @@ package mcp import ( "context" + "fmt" "slices" "testing" "github.com/google/go-cmp/cmp" + "github.com/google/jsonschema-go/jsonschema" ) -type deployResult struct { - Deployed bool `json:"deployed"` - Reason string `json:"reason,omitempty"` -} - func TestMRTR_ManualRetry(t *testing.T) { + type deployResult struct { + Deployed bool `json:"deployed"` + Reason string `json:"reason,omitempty"` + } + orig := supportedProtocolVersions supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) t.Cleanup(func() { supportedProtocolVersions = orig }) @@ -94,43 +96,107 @@ func TestMRTR_AutoRetry(t *testing.T) { supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) t.Cleanup(func() { supportedProtocolVersions = orig }) - ctx := context.Background() + tests := []struct { + name string + inputRequests InputRequestMap + wantResult map[string]any + }{ + { + name: "elicit", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Deploy?"}, + }, + wantResult: map[string]any{"ids": []any{"confirm"}}, + }, + { + name: "createMessage", + inputRequests: InputRequestMap{ + "summarize": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "summarize"}}}, + MaxTokens: 100, + }, + }, + wantResult: map[string]any{"ids": []any{"summarize"}}, + }, + { + name: "listRoots", + inputRequests: InputRequestMap{ + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"roots"}}, + }, + { + name: "all three", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "draft": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "write"}}}, + MaxTokens: 50, + }, + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"confirm", "draft", "roots"}}, + }, + } - srv := NewServer(testImpl, nil) - AddTool(srv, &Tool{Name: "deploy"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, *deployResult, error) { - if len(req.Params.InputResponses) == 0 { - return &CallToolResult{ - InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Deploy to production?"}}, - RequestState: "deployment-123", - }, nil, nil - } - resp, ok := req.Params.InputResponses["confirm"] - if !ok { - return nil, nil, nil - } - elicitResult := resp.(*ElicitResult) - if elicitResult == nil || elicitResult.Action != "accept" { - return &CallToolResult{}, &deployResult{Deployed: false, Reason: "cancelled"}, nil - } - return &CallToolResult{}, &deployResult{Deployed: true}, nil - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() - conn := mustConnectMRTR(t, srv, &ClientOptions{ - ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { - return &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, nil - }, - }) + srv := NewServer(testImpl, nil) + inputRequests := tt.inputRequests + AddTool(srv, &Tool{Name: "act"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil, nil + } + // Collect the IDs of fulfilled responses. + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + return &CallToolResult{}, map[string]any{"ids": ids}, nil + }) - res, err := conn.CallTool(ctx, &CallToolParams{Name: "deploy"}) - if err != nil { - t.Fatalf("CallTool() error = %v", err) - } - if res.NeedsInput() { - t.Fatal("NeedsInput() = true after auto-retry, want false") - } + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "response"}, + }, nil + }, + }) + conn.client.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) - if diff := cmp.Diff(map[string]any{"deployed": true}, res.StructuredContent, ctrCmpOpts...); diff != "" { - t.Errorf("result mismatch (-want +got):\n%s", diff) + res, err := conn.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + + // Sort the expected IDs for stable comparison. + if wantIDs, ok := tt.wantResult["ids"].([]any); ok { + slices.SortFunc(wantIDs, func(a, b any) int { + if a.(string) < b.(string) { + return -1 + } + return 1 + }) + } + + if diff := cmp.Diff(tt.wantResult, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } + }) } } @@ -162,6 +228,93 @@ func TestMRTR_MaxRetries(t *testing.T) { } } +func TestMRTR_ServerMiddleware(t *testing.T) { + // mrtrToolHandler returns a ToolHandler (plain, non-generic) that requests + // the given inputRequests on the first call and returns the fulfilled + // response IDs on the second. + mrtrToolHandler := func(inputRequests InputRequestMap) ToolHandler { + return func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil + } + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + content := &TextContent{Text: fmt.Sprintf("%v", ids)} + return &CallToolResult{Content: []Content{content}}, nil + } + } + + tests := []struct { + name string + inputRequests InputRequestMap + wantText string + }{ + { + name: "elicit via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Sure?"}, + }, + wantText: "[confirm]", + }, + { + name: "elicit and listRoots via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "roots": &ListRootsParams{}, + }, + wantText: "[confirm roots]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddTool( + &Tool{Name: "act", InputSchema: &jsonschema.Schema{Type: "object"}}, + mrtrToolHandler(tt.inputRequests), + ) + + // Connect with an OLD protocol version where MRTR is not supported. + // The server middleware should handle it transparently. + st, ct := NewInMemoryTransports() + ss, err := srv.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + c.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if got := res.Content[0].(*TextContent).Text; got != tt.wantText { + t.Errorf("result text = %q, want %q", got, tt.wantText) + } + }) + } +} + func mustConnectMRTR(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { t.Helper() st, ct := NewInMemoryTransports() diff --git a/mcp/server.go b/mcp/server.go index ce056393..c5c6cf3e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -187,7 +187,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { opts.Logger = ensureLogger(nil) } - return &Server{ + s := &Server{ impl: impl, opts: opts, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), @@ -199,6 +199,8 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), } + s.AddReceivingMiddleware(serverMRTRMiddleware()) + return s } // AddPrompt adds a [Prompt] to the server, or replaces one with the same name. From 4d082090212417bfc6d5a48699ea9d71c334323c Mon Sep 17 00:00:00 2001 From: Yaroslav Shevchuk Date: Tue, 26 May 2026 11:42:43 +0000 Subject: [PATCH 4/6] cleanup --- mcp/mrtr.go | 91 ++++++++++++++++++++++++------------------------- mcp/protocol.go | 28 +++++++-------- 2 files changed, 58 insertions(+), 61 deletions(-) diff --git a/mcp/mrtr.go b/mcp/mrtr.go index e449790b..18913700 100644 --- a/mcp/mrtr.go +++ b/mcp/mrtr.go @@ -33,7 +33,6 @@ type MRTROptions struct { type mrtrResult interface { setResultType(ResultType) inputRequests() map[string]InputRequest - setInputRequest(k string, v InputRequest) hasContent() bool } @@ -41,7 +40,7 @@ func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) er hasInputRequests := res.inputRequests() != nil if hasInputRequests && res.hasContent() { - logger.Warn("handler returned both content and inputRequests; inputRequests takes precedence") + logger.Warn("handler returned both content and inputRequests") return &jsonrpc.Error{ Code: jsonrpc.CodeInternalError, Message: "server bug: result has both content and inputRequests", @@ -56,7 +55,7 @@ func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) er case supportsMRTR: res.setResultType(ResultTypeComplete) } - // For older clients the resultType is left unset. The serverMRTRMiddleware fulfills the + // For older clients the resultType is left unset. The serverMRTRMiddleware fulfills the // requests by calling the client directly and retries the handler. return nil } @@ -110,7 +109,7 @@ func clientMRTRMiddleware(c *Client) Middleware { // serverMRTRMiddleware is a receiving middleware for servers that transparently // handles MRTR for clients on older protocol versions. When a handler returns // InputRequests and the client does not support MRTR, the middleware fulfills -// the requests by calling the client reinvokes the handler once with the responses. +// the requests by calling the client directly and reinvokes the handler once with the responses. func serverMRTRMiddleware() Middleware { return func(next MethodHandler) MethodHandler { return func(ctx context.Context, method string, req Request) (Result, error) { @@ -186,7 +185,7 @@ func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests wg.Wait() close(results) - responses := make(InputResponseMap, len(results)) + responses := make(InputResponseMap, len(requests)) for r := range results { if r.err != nil { return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err) @@ -210,19 +209,25 @@ func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputR } func mrtrInputRequests(res Result) InputRequestMap { + if res == nil { + return nil + } switch r := res.(type) { case *CallToolResult: - if r.NeedsInput() { - return r.InputRequests + if r == nil || !r.NeedsInput() { + return nil } + return r.InputRequests case *GetPromptResult: - if r.NeedsInput() { - return r.InputRequests + if r == nil || !r.NeedsInput() { + return nil } + return r.InputRequests case *ReadResourceResult: - if r.NeedsInput() { - return r.InputRequests + if r == nil || !r.NeedsInput() { + return nil } + return r.InputRequests } return nil } @@ -273,7 +278,7 @@ func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests Input wg.Wait() close(results) - responses := make(InputResponseMap, len(results)) + responses := make(InputResponseMap, len(requests)) for r := range results { if r.err != nil { return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err) @@ -297,40 +302,34 @@ func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest } func fulfillCreateMessage(ctx context.Context, cs *ClientSession, p *CreateMessageParams) (*CreateMessageResult, error) { - if cs.client.opts.CreateMessageHandler != nil { - return cs.client.opts.CreateMessageHandler(ctx, &CreateMessageRequest{Session: cs, Params: p}) + var msgs []*SamplingMessageV2 + for _, m := range p.Messages { + msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) } - if cs.client.opts.CreateMessageWithToolsHandler != nil { - var msgs []*SamplingMessageV2 - for _, m := range p.Messages { - msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) - } - wtp := &CreateMessageWithToolsParams{ - Meta: p.Meta, - IncludeContext: p.IncludeContext, - MaxTokens: p.MaxTokens, - Messages: msgs, - Metadata: p.Metadata, - ModelPreferences: p.ModelPreferences, - StopSequences: p.StopSequences, - SystemPrompt: p.SystemPrompt, - Temperature: p.Temperature, - } - result, err := cs.client.opts.CreateMessageWithToolsHandler(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: wtp}) - if err != nil { - return nil, err - } - var content Content - if len(result.Content) > 0 { - content = result.Content[0] - } - return &CreateMessageResult{ - Meta: result.Meta, - Content: content, - Model: result.Model, - Role: result.Role, - StopReason: result.StopReason, - }, nil + wtp := &CreateMessageWithToolsParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } + result, err := cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: wtp}) + if err != nil { + return nil, err + } + var content Content + if len(result.Content) > 0 { + content = result.Content[0] } - return nil, fmt.Errorf("client does not support CreateMessage") + return &CreateMessageResult{ + Meta: result.Meta, + Content: content, + Model: result.Model, + Role: result.Role, + StopReason: result.StopReason, + }, nil } diff --git a/mcp/protocol.go b/mcp/protocol.go index be7b34f2..aba3fad6 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -106,7 +106,7 @@ func (m *InputRequestMap) UnmarshalJSON(data []byte) error { // InputResponse is a sealed interface for results that a client sends back // when fulfilling an MRTR input request. Implementations are [*ElicitResult], // [*CreateMessageResult], and [*ListRootsResult]. -type InputResponse interface{ isInputResult() } +type InputResponse interface{ isInputResponse() } // InputResponseMap maps request IDs (from [InputRequestMap]) to [InputResponse] // values. It is used in params types when retrying a call after an @@ -771,8 +771,8 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } -func (*CreateMessageResult) isResult() {} -func (*CreateMessageResult) isInputResult() {} +func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isInputResponse() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -912,10 +912,9 @@ type GetPromptResult struct { func (*GetPromptResult) isResult() {} -func (r *GetPromptResult) setResultType(rt ResultType) { r.resultType = rt } -func (r *GetPromptResult) inputRequests() map[string]InputRequest { return r.InputRequests } -func (r *GetPromptResult) setInputRequest(k string, v InputRequest) { r.inputRequests()[k] = v } -func (r *GetPromptResult) hasContent() bool { return len(r.Messages) > 0 } +func (r *GetPromptResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *GetPromptResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *GetPromptResult) hasContent() bool { return len(r.Messages) > 0 } // NeedsInput reports whether this result requires further client input. // See [CallToolResult.NeedsInput] for details. @@ -1122,8 +1121,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } -func (*ListRootsResult) isResult() {} -func (*ListRootsResult) isInputResult() {} +func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isInputResponse() {} type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -1397,10 +1396,9 @@ type ReadResourceResult struct { func (*ReadResourceResult) isResult() {} -func (r *ReadResourceResult) setResultType(rt ResultType) { r.resultType = rt } -func (r *ReadResourceResult) inputRequests() map[string]InputRequest { return r.InputRequests } -func (r *ReadResourceResult) setInputRequest(k string, v InputRequest) { r.inputRequests()[k] = v } -func (r *ReadResourceResult) hasContent() bool { return len(r.Contents) > 0 } +func (r *ReadResourceResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *ReadResourceResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *ReadResourceResult) hasContent() bool { return len(r.Contents) > 0 } // NeedsInput reports whether this result requires further client input. // See [CallToolResult.NeedsInput] for details. @@ -1842,8 +1840,8 @@ type ElicitResult struct { Content map[string]any `json:"content,omitempty"` } -func (*ElicitResult) isResult() {} -func (*ElicitResult) isInputResult() {} +func (*ElicitResult) isResult() {} +func (*ElicitResult) isInputResponse() {} // ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. type ElicitationCompleteParams struct { From 7507222de431ca747bf7276e5f0cd42a272c6503 Mon Sep 17 00:00:00 2001 From: Yaroslav Shevchuk Date: Tue, 26 May 2026 11:42:51 +0000 Subject: [PATCH 5/6] extra tests --- mcp/mrtr_test.go | 186 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go index 23980261..26bf689a 100644 --- a/mcp/mrtr_test.go +++ b/mcp/mrtr_test.go @@ -315,6 +315,192 @@ func TestMRTR_ServerMiddleware(t *testing.T) { } } +func TestMRTR_GetPrompt_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } + if got := res.Messages[0].Content.(*TextContent).Text; got != "review this code" { + t.Errorf("message text = %q, want %q", got, "review this code") + } +} + +func TestMRTR_GetPrompt_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + res, err = conn.GetPrompt(ctx, &GetPromptParams{ + Name: "review", + InputResponses: InputResponseMap{"confirm": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("GetPrompt() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } +} + +func TestMRTR_ReadResource_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } + if got := res.Contents[0].Text; got != "resource data" { + t.Errorf("resource text = %q, want %q", got, "resource data") + } +} + +func TestMRTR_ReadResource_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["auth"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[auth] type = %T, want *ElicitParams", res.InputRequests["auth"]) + } + + res, err = conn.ReadResource(ctx, &ReadResourceParams{ + URI: "test://data", + InputResponses: InputResponseMap{"auth": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("ReadResource() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } +} + func mustConnectMRTR(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { t.Helper() st, ct := NewInMemoryTransports() From 60f1c700a0284bcbcce65825e235eab61492fb5e Mon Sep 17 00:00:00 2001 From: Yaroslav Shevchuk Date: Tue, 26 May 2026 12:56:34 +0000 Subject: [PATCH 6/6] use errgroup, handle createMessageWithTools --- go.mod | 1 + go.sum | 2 + mcp/mrtr.go | 125 +++++++++++++++++++++--------------------------- mcp/protocol.go | 19 ++++---- 4 files changed, 67 insertions(+), 80 deletions(-) diff --git a/go.mod b/go.mod index 860d6fa0..3287a957 100644 --- a/go.mod +++ b/go.mod @@ -15,5 +15,6 @@ require ( require ( github.com/segmentio/asm v1.1.3 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect ) diff --git a/go.sum b/go.sum index 377a7b11..c13454aa 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= diff --git a/mcp/mrtr.go b/mcp/mrtr.go index 18913700..a3368a83 100644 --- a/mcp/mrtr.go +++ b/mcp/mrtr.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/sync/errgroup" ) const defaultMRTRMaxRetries = 3 @@ -37,6 +38,9 @@ type mrtrResult interface { } func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) error { + if res == nil { + return nil + } hasInputRequests := res.inputRequests() != nil if hasInputRequests && res.hasContent() { @@ -169,28 +173,23 @@ func serverMRTRInputRequests(res Result) InputRequestMap { } func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests InputRequestMap) (InputResponseMap, error) { - type result struct { - id string - resp InputResponse - err error - } - results := make(chan result, len(requests)) - var wg sync.WaitGroup + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) for id, ir := range requests { - wg.Go(func() { + g.Go(func() error { resp, err := fulfillServerInputRequest(ctx, ss, ir) - results <- result{id, resp, err} + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil }) } - wg.Wait() - close(results) - - responses := make(InputResponseMap, len(requests)) - for r := range results { - if r.err != nil { - return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err) - } - responses[r.id] = r.resp + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("MRTR: %w", err) } return responses, nil } @@ -200,7 +199,9 @@ func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputR case *ElicitParams: return ss.Elicit(ctx, p) case *CreateMessageParams: - return ss.CreateMessage(ctx, p) + return ss.CreateMessageWithTools(ctx, createMessageParamsToWithTools(p)) + case *CreateMessageWithToolsParams: + return ss.CreateMessageWithTools(ctx, p) case *ListRootsParams: return ss.ListRoots(ctx, p) default: @@ -208,6 +209,24 @@ func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputR } } +func createMessageParamsToWithTools(p *CreateMessageParams) *CreateMessageWithToolsParams { + var msgs []*SamplingMessageV2 + for _, m := range p.Messages { + msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) + } + return &CreateMessageWithToolsParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } +} + func mrtrInputRequests(res Result) InputRequestMap { if res == nil { return nil @@ -262,28 +281,23 @@ func setMRTRRetryParams(req Request, responses InputResponseMap, state string) { } func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) { - type result struct { - id string - resp InputResponse - err error - } - results := make(chan result, len(requests)) - var wg sync.WaitGroup + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) for id, ir := range requests { - wg.Go(func() { + g.Go(func() error { resp, err := fulfillInputRequest(ctx, cs, ir) - results <- result{id, resp, err} + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil }) } - wg.Wait() - close(results) - - responses := make(InputResponseMap, len(requests)) - for r := range results { - if r.err != nil { - return nil, fmt.Errorf("MRTR: fulfilling input request %q: %w", r.id, r.err) - } - responses[r.id] = r.resp + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("MRTR: %w", err) } return responses, nil } @@ -293,43 +307,12 @@ func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest case *ElicitParams: return cs.client.elicit(ctx, newClientRequest(cs, p)) case *CreateMessageParams: - return fulfillCreateMessage(ctx, cs, p) + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: createMessageParamsToWithTools(p)}) + case *CreateMessageWithToolsParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: p}) case *ListRootsParams: return cs.client.listRoots(ctx, newClientRequest(cs, p)) default: return nil, fmt.Errorf("unknown input request type: %T", ir) } } - -func fulfillCreateMessage(ctx context.Context, cs *ClientSession, p *CreateMessageParams) (*CreateMessageResult, error) { - var msgs []*SamplingMessageV2 - for _, m := range p.Messages { - msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) - } - wtp := &CreateMessageWithToolsParams{ - Meta: p.Meta, - IncludeContext: p.IncludeContext, - MaxTokens: p.MaxTokens, - Messages: msgs, - Metadata: p.Metadata, - ModelPreferences: p.ModelPreferences, - StopSequences: p.StopSequences, - SystemPrompt: p.SystemPrompt, - Temperature: p.Temperature, - } - result, err := cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: wtp}) - if err != nil { - return nil, err - } - var content Content - if len(result.Content) > 0 { - content = result.Content[0] - } - return &CreateMessageResult{ - Meta: result.Meta, - Content: content, - Model: result.Model, - Role: result.Role, - StopReason: result.StopReason, - }, nil -} diff --git a/mcp/protocol.go b/mcp/protocol.go index aba3fad6..df5a1a2c 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -46,7 +46,7 @@ func (m InputRequestMap) MarshalJSON() ([]byte, error) { switch v.(type) { case *ElicitParams: return methodElicit, nil - case *CreateMessageParams: + case *CreateMessageParams, *CreateMessageWithToolsParams: return methodCreateMessage, nil case *ListRootsParams: return methodListRoots, nil @@ -84,7 +84,7 @@ func (m *InputRequestMap) UnmarshalJSON(data []byte) error { } result[k] = &p case methodCreateMessage: - var p CreateMessageParams + var p CreateMessageWithToolsParams if err := json.Unmarshal(raw.Params, &p); err != nil { return err } @@ -122,7 +122,7 @@ func (m InputResponseMap) MarshalJSON() ([]byte, error) { switch v.(type) { case *ElicitResult: return methodElicit, nil - case *CreateMessageResult: + case *CreateMessageResult, *CreateMessageWithToolsResult: return methodCreateMessage, nil case *ListRootsResult: return methodListRoots, nil @@ -160,7 +160,7 @@ func (m *InputResponseMap) UnmarshalJSON(data []byte) error { } result[k] = &p case methodCreateMessage: - var p CreateMessageResult + var p CreateMessageWithToolsResult if err := json.Unmarshal(raw.Result, &p); err != nil { return err } @@ -172,7 +172,7 @@ func (m *InputResponseMap) UnmarshalJSON(data []byte) error { } result[k] = &p default: - return fmt.Errorf("unsupported InputRequest method: %q", raw.Method) + return fmt.Errorf("unsupported InputResponse method: %q", raw.Method) } } *m = result @@ -345,9 +345,8 @@ func (r *CallToolResult) GetError() error { func (*CallToolResult) isResult() {} -func (r *CallToolResult) setResultType(rt ResultType) { r.resultType = rt } -func (r *CallToolResult) inputRequests() map[string]InputRequest { return r.InputRequests } -func (r *CallToolResult) setInputRequest(k string, v InputRequest) { r.InputRequests[k] = v } +func (r *CallToolResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *CallToolResult) inputRequests() map[string]InputRequest { return r.InputRequests } func (r *CallToolResult) hasContent() bool { return len(r.Content) > 0 || r.StructuredContent != nil } @@ -672,6 +671,7 @@ type CreateMessageWithToolsParams struct { } func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) isInputRequest() {} func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -817,7 +817,8 @@ var createMessageWithToolsResultAllow = map[string]bool{ "tool_use": true, } -func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isInputResponse() {} // MarshalJSON marshals the result. When Content has a single element, it is // marshaled as a single object for compatibility with pre-2025-11-25