diff --git a/design/mrtr.md b/design/mrtr.md new file mode 100644 index 00000000..5974e3f3 --- /dev/null +++ b/design/mrtr.md @@ -0,0 +1,271 @@ +## Context + +A proposal for implementing Multi Round-Trip Requests +(MRTR) as defined in [SEP-2322](https://github.com/CaitieM20/modelcontextprotocol/blob/de6d76fba3078eda957dadb3cec51ca8ab851b5c/seps/2322-MRTR.md). + +In the new protocol version servers can't initiate requests to clients, but when a server requires additional input for completing `tools/call`, `prompts/get`, or `resources/read` it can return an incomplete result along with a set of `inputRequests`. The client fulfills them locally and retries the same call with `inputResponses` attached. + +## Goals + +**Must have:** + +* Backward compatibility. +* Correct representation on the wire. + +**Nice to have:** + +* Minimal changes to the exported API surface. +* Hard for server implementers to construct an invalid payload. +* Simple input request handling for clients. +* Protocol-version-independent code. +* Consistency with the rest of the SDK. + +## Proposal + +`ServerSession` methods return an error for new-version protocol connections. + +`InputRequest`/`InputResponse` is introduced as a sealed-interface: +```go +// Implemented by *ElicitParams, *CreateMessageParams, *ListRootsParams +type InputRequest interface{ isInputRequest() } + +type InputRequestMap map[string]InputRequest +// MarshalJSON encodes as map[string]struct{ Method string; Params InputRequest } +func (m InputRequestMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Params InputRequest } +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { ... } + +// Implemented by *ElicitResult, *CreateMessageResult, *ListRootsResult. +type InputResponse interface{ isInputResult() } + +type InputResponseMap map[string]InputResponse +// MarshalJSON encodes as map[string]struct{ Method string; Result InputResponse } +func (m InputResponseMap) MarshalJSON() ([]byte, error) { ... } +// UnmarshalJSON decodes from map[string]struct{ Method string; Result InputResponse } +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { ... } +``` + +All affected methods' `*Params` are extended with `InputResponseMap` and `RequestState` fields: +```go +type CallToolParams struct { + ... + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + RequestState string `json:"requestState,omitempty"` +} +// Same for GetPromptParams, ReadResourceParams +``` + +`InputRequests` and `RequestState` fields are added directly to `CallToolResult`, `GetPromptResult`, and `ReadResourceResult` as exported. +Result type discriminator (completed, input_required) is unexported so that SDK users don't need to set it to the correct constant in addition to setting either `Content` or `InputRequests`. Handler execution result is validated and augmented before marshaling: +```go +type CallToolResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` + resultType string // set by the SDK and used in MarshalJSON() +} +// Same for GetPromptResult, ReadResourceResult. +``` +Alternatively, the field could only exist on `wire struct`, but this would make us return `complete` to older clients or empty string to newer clients, because there's no access to negotiated protocol version in `MarshalJSON`. + +Servers request additional input by constructing a correct struct literal: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return &mcp.CallToolResult{ + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` +The SDK validates at runtime that a handler does not return both content and `InputRequests` — doing so logs a warning and returns a `CodeInternalError` JSON-RPC error. + +An unexported receiving middleware is installed on the server for backward compatibility with older clients. When a handler returns `InputRequests` and the connected client uses a protocol version that does not support MRTR, the middleware fulfills the requests by calling `ServerSession.Elicit`/`CreateMessage`/`ListRoots` on the client directly and reinvokes the handler once with the collected `InputResponses`. If any of these calls fail, the entire request fails. Input requests are fulfilled concurrently. This lets server developers write protocol-version-independent code. + +An unexported sending middleware is installed on the client, which similarly to `urlElicitationMiddleware` will automatically invoke handlers for the corresponding methods on incomplete results and retry the original request. `ClientOptions` is extended with configuration knobs: +```go +type MRTROptions struct { + MaxRetries int + Disabled bool +} +client := mcp.NewClient(impl, &mcp.ClientOptions{ + MRTR: &mcp.MRTROptions{MaxRetries: 3}, +}) +``` + +Alternatively, clients have an option to disable it and write a retry loop manually using `NeedsInput()`: +```go +client := mcp.NewClient(impl, &mcp.ClientOptions{MRTR: &mcp.MRTROptions{Disabled: true}}) +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if result.NeedsInput() { ... } +``` + +`NeedsInput()` checks the unexported `resultType` field rather than `InputRequests`, correctly handling the load-shedding case where the server returns `input_required` with an empty map. + +**Pros** + +This is arguably the simplest and the most transparent approach which is also closest to the spec. +What gets explicitly set on the server can be observed on the wire and on the client. +The opt-out client middleware follows the principle of the least surprise for app developers. If client method handlers were provided they will continue to be invoked regardless of the protocol version in use. The `Disabled` option lets "power-users" build any custom handling logic. +The server middleware makes handler code protocol-version-independent — the same handler works for both old and new clients. + +**Cons** + +The biggest downside of the proposal is that server developers can construct incorrect responses (both content and input requests) and this will only be validated at runtime. + +## Alternatives considered + +### Unexported fields + +MRTR fields can be unexported, accessible only through getters, constructible only through constructor functions, and handled explicitly in custom `(Unm|M)arshalJSON`. This will make it impossible for developers to construct incorrect responses and for clients to perform an erroneous `len(result.InputRequests) > 0` check in the load-shedding case. +```go +type CallToolResult struct { + ... + inputRequests InputRequestMap + requestState string + resultType string +} + +func (r *CallToolResult) InputRequests() (InputRequestMap, bool) { ... } + +// InputRequiredResult struct exists for backward-compatibility in case of new fields being needed for input request results. +type InputRequiredResult struct { + InputRequests InputRequestMap + RequestState string +} + +// RequireInput constructs a tool call, prompt or resource result with input requests set. +// mrtrResult provides methods for setting private fields on these types. +func RequireInput[T any, TP interface { *T; mrtrResult }](r InputRequiredResult) TP { ... } +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return mcp.RequireInput[mcp.CallToolResult](mcp.InputRequiredResult{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }), nil, nil + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +if requests, ok := result.InputRequests(); ok { ... } +``` + +The biggest downside of this approach is the obscure data model with hidden fields. An incomplete `mcp.CallToolResult` looks like an uninitialized struct until `InputRequests` method result is examined. +In addition to this, the verbose `RequireInput` syntax (no auto type inference from assignment target) does not look idiomatic and fits poorly into the existing SDK APIs. + +--- + +### `InputRequiredError` type + +We could explore a different data channel - `error` return value. This would give us the natural "happy path is when all inputs are provided" flow on the server side, and good result interpretability on the client side (impossible to confuse with a successful response). +The new error could be converted to the correct wire representation at the marshaling stage. +```go +type InputRequiredError struct { + InputRequests InputRequestMap + RequestState string +} + +func (e *InputRequiredError) Error() string { + return fmt.Sprintf("input required: %d request(s)", len(e.InputRequests)) +} +``` + +On the server: +```go +mcp.AddTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (*mcp.CallToolResult, MyOut, error) { + if !hasConfirmation(in) { + return nil, zero, &mcp.InputRequiredError{ + InputRequests: mcp.InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + RequestState: "state-token", + } + } + return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +On the client: +```go +result, err := client.CallTool(ctx, &mcp.CallToolParams{Name: "my-tool"}) +var inputReqErr *mcp.InputRequiredError +if errors.As(err, &inputReqErr) { ... } +``` + +The downsides of this approach are: +* The drift from the protocol, where MRTR is not an error flow. +* Obscure "customError -> non-error protocol type on wirte -> customError" data lifecycle. +* Things get confusing for error-processing middleware. + +--- + +### New functions + +We could introduce new functions with a different handler signature where the return type is a sealed interface. This would give us compiler-enforced correctness for values constructed by tool handlers and clients would be forced to unpack `mcp.RoundTripCallToolResult` and make a concious decision for how to handle it. +```go +type RoundTripToolHandler func(context.Context, *CallToolRequest) (RoundTripCallToolResult, error) +type RoundTripToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (RoundTripCallToolResult, Out, error) + +// RoundTripCallToolResult is implemented by CallToolResult and IncompleteResult +type RoundTripCallToolResult interface { isMRTRResult() } + +type IncompleteResult struct { + ... + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + RequestState string `json:"requestState,omitempty"` +} + +func (s *Server) AddRoundTripTool(t *Tool, h RoundTripToolHandler) +func AddRoundTripTool[In, Out any](s *Server, t *Tool, h RoundTripToolHandlerFor[In, Out]) +``` + +`Server.AddTool` wraps the old `ToolHandler` into a `RoundTripToolHandler` to update its function signature: +```go +mcp.AddRoundTripTool(s, tool, func(ctx context.Context, req *mcp.CallToolRequest, in MyIn) (mcp.RoundTripCallToolResult, MyOut, error) { + if needsInput(in) { + return &mcp.IncompleteResult{ + ResultType: mcp.ResultTypeInputRequired, + InputRequests: InputRequestMap{"confirm": &mcp.ElicitParams{Message: "Sure?"}}, + }, zero, nil + } + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "done"}}}, myOut, nil +}) +``` + +The downsides of this approach are: +* SEP suggests `ResultType` will potentially be extended with new values, `RoundTrip` in new function names will not allow us to cleanly extend the sealed interface with new types. But an overly generic name for new functions will make the API use-case less clear. +* Different code +* SDK takes the same action (puts it on the wire) regardless of the returned type, it exists only for enforcing correctness of the user code. +* Exported API surface bloat: +7 exported functions. + +--- + +### Exported Middleware + +We could flip "unexported MRTR middleware with opt-out option" to "exported middleware with opt-in requirement". +```go +func AutoMRTR(opts *MRTROptions) Middleware { ... } +type MRTROptions struct { + MaxRetries int +} +client := mcp.NewClient(impl, nil) +client.AddSendingMiddleware(mcp.AutoMRTR(&mcp.MRTROptions{ + MaxRetries: 5, +})) +``` +This would change semantics of `*Handler` fields - depending on the protocol version in use, an extra initialization step will be required for them to "take effect". + +--- + +### Server API protocol version bridging + +Converting `ServerSession.Elicit`/`CreateMessage`/`ListRoots` calls into MRTR wire format transparently (suspend the handler, return `input_required`, resume on retry). Rejected because of a significant implementation effort and the fact that it contradicts the design goal of MRTR where servers shouldn't hold resources between round trips, and it should be possible for a retry to arrive on any server instance in a multi-server deployment. + diff --git a/go.mod b/go.mod index 860d6fa0..3287a957 100644 --- a/go.mod +++ b/go.mod @@ -15,5 +15,6 @@ require ( require ( github.com/segmentio/asm v1.1.3 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.41.0 // indirect ) diff --git a/go.sum b/go.sum index 377a7b11..c13454aa 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..381be027 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -58,13 +58,17 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { opts.Logger = ensureLogger(nil) } - return &Client{ + c := &Client{ impl: impl, opts: opts, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], } + if opts.MRTR == nil || !opts.MRTR.Disabled { + c.AddSendingMiddleware(clientMRTRMiddleware(c)) + } + return c } // ClientOptions configures the behavior of the client. @@ -154,6 +158,10 @@ type ClientOptions struct { ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) LoggingMessageHandler func(context.Context, *LoggingMessageRequest) ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) + // MRTR configures the automatic MRTR (Multi Round-Trip Requests) middleware. + // By default (nil), the middleware is enabled with default settings. + // Set Disabled to true to opt out of automatic MRTR handling. + MRTR *MRTROptions // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index ba263e6f..86bb86ac 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -223,4 +224,4 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(mcp.CallToolResult{}, mcp.GetPromptResult{}, mcp.ReadResourceResult{})} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..bad35086 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -24,6 +24,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -206,7 +207,7 @@ func TestEndToEnd(t *testing.T) { Role: "user", }}, } - if diff := cmp.Diff(wantReview, gotReview); diff != "" { + if diff := cmp.Diff(wantReview, gotReview, ctrCmpOpts...); diff != "" { t.Errorf("prompts/get 'code_review' mismatch (-want +got):\n%s", diff) } @@ -2371,4 +2372,4 @@ func TestSetErrorPreservesContent(t *testing.T) { } } -var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} +var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} diff --git a/mcp/mrtr.go b/mcp/mrtr.go new file mode 100644 index 00000000..a3368a83 --- /dev/null +++ b/mcp/mrtr.go @@ -0,0 +1,318 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/sync/errgroup" +) + +const defaultMRTRMaxRetries = 3 + +// MRTROptions configures the client-side MRTR (Multi Round-Trip Requests) +// middleware. The middleware is enabled by default and automatically fulfills input +// requests from the server by invoking the appropriate client handlers and +// retrying the original call. +type MRTROptions struct { + // MaxRetries is the maximum number of MRTR retry attempts after the + // initial call. Defaults to 3 if the provided value is <= 0. + MaxRetries int + // Disabled prevents the automatic MRTR middleware from being installed. + // When true, the client returns input-required results directly and + // callers must handle the retry loop themselves using [CallToolResult.NeedsInput], + // [GetPromptResult.NeedsInput], or [ReadResourceResult.NeedsInput]. + Disabled bool +} + +type mrtrResult interface { + setResultType(ResultType) + inputRequests() map[string]InputRequest + hasContent() bool +} + +func handleMRTRResult(ss *ServerSession, logger *slog.Logger, res mrtrResult) error { + if res == nil { + return nil + } + hasInputRequests := res.inputRequests() != nil + + if hasInputRequests && res.hasContent() { + logger.Warn("handler returned both content and inputRequests") + return &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "server bug: result has both content and inputRequests", + } + } + + supportsMRTR := sessionSupportsMRTR(ss) + + switch { + case hasInputRequests && supportsMRTR: + res.setResultType(ResultTypeInputRequired) + case supportsMRTR: + res.setResultType(ResultTypeComplete) + } + // For older clients the resultType is left unset. The serverMRTRMiddleware fulfills the + // requests by calling the client directly and retries the handler. + return nil +} + +func sessionSupportsMRTR(ss *ServerSession) bool { + protocolVersion := latestProtocolVersion + if iparams := ss.InitializeParams(); iparams != nil { + protocolVersion = iparams.ProtocolVersion + } + return protocolVersion >= protocolVersion20260630 +} + +func clientMRTRMiddleware(c *Client) Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + + maxRetries := defaultMRTRMaxRetries + if c.opts.MRTR != nil && c.opts.MRTR.MaxRetries > 0 { + maxRetries = c.opts.MRTR.MaxRetries + } + + for retries := 0; ; retries++ { + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + irm := mrtrInputRequests(res) + if len(irm) == 0 { + return res, nil + } + if retries >= maxRetries { + return nil, fmt.Errorf("MRTR: exceeded maximum retries (%d)", maxRetries) + } + cs, ok := req.GetSession().(*ClientSession) + if !ok { + return res, nil + } + responses, err := fulfillInputRequests(ctx, cs, irm) + if err != nil { + return nil, err + } + setMRTRRetryParams(req, responses, mrtrRequestState(res)) + } + } + } +} + +// serverMRTRMiddleware is a receiving middleware for servers that transparently +// handles MRTR for clients on older protocol versions. When a handler returns +// InputRequests and the client does not support MRTR, the middleware fulfills +// the requests by calling the client directly and reinvokes the handler once with the responses. +func serverMRTRMiddleware() Middleware { + return func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method != methodCallTool && method != methodGetPrompt && method != methodReadResource { + return next(ctx, method, req) + } + ss, ok := req.GetSession().(*ServerSession) + if !ok { + return next(ctx, method, req) + } + if sessionSupportsMRTR(ss) { + return next(ctx, method, req) + } + + res, err := next(ctx, method, req) + if err != nil { + return res, err + } + irm := serverMRTRInputRequests(res) + if len(irm) == 0 { + return res, nil + } + responses, err := fulfillServerInputRequests(ctx, ss, irm) + if err != nil { + return nil, err + } + setMRTRRetryParams(req, responses, mrtrRequestState(res)) + return next(ctx, method, req) + } + } +} + +// serverMRTRInputRequests returns input requests from a result for old clients +// where resultType is not set. It checks InputRequests directly. +func serverMRTRInputRequests(res Result) InputRequestMap { + if res == nil { + return nil + } + switch r := res.(type) { + case *CallToolResult: + if r == nil { + return nil + } + return r.InputRequests + case *GetPromptResult: + if r == nil { + return nil + } + return r.InputRequests + case *ReadResourceResult: + if r == nil { + return nil + } + return r.InputRequests + } + return nil +} + +func fulfillServerInputRequests(ctx context.Context, ss *ServerSession, requests InputRequestMap) (InputResponseMap, error) { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) + for id, ir := range requests { + g.Go(func() error { + resp, err := fulfillServerInputRequest(ctx, ss, ir) + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("MRTR: %w", err) + } + return responses, nil +} + +func fulfillServerInputRequest(ctx context.Context, ss *ServerSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return ss.Elicit(ctx, p) + case *CreateMessageParams: + return ss.CreateMessageWithTools(ctx, createMessageParamsToWithTools(p)) + case *CreateMessageWithToolsParams: + return ss.CreateMessageWithTools(ctx, p) + case *ListRootsParams: + return ss.ListRoots(ctx, p) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} + +func createMessageParamsToWithTools(p *CreateMessageParams) *CreateMessageWithToolsParams { + var msgs []*SamplingMessageV2 + for _, m := range p.Messages { + msgs = append(msgs, &SamplingMessageV2{Content: []Content{m.Content}, Role: m.Role}) + } + return &CreateMessageWithToolsParams{ + Meta: p.Meta, + IncludeContext: p.IncludeContext, + MaxTokens: p.MaxTokens, + Messages: msgs, + Metadata: p.Metadata, + ModelPreferences: p.ModelPreferences, + StopSequences: p.StopSequences, + SystemPrompt: p.SystemPrompt, + Temperature: p.Temperature, + } +} + +func mrtrInputRequests(res Result) InputRequestMap { + if res == nil { + return nil + } + switch r := res.(type) { + case *CallToolResult: + if r == nil || !r.NeedsInput() { + return nil + } + return r.InputRequests + case *GetPromptResult: + if r == nil || !r.NeedsInput() { + return nil + } + return r.InputRequests + case *ReadResourceResult: + if r == nil || !r.NeedsInput() { + return nil + } + return r.InputRequests + } + return nil +} + +func mrtrRequestState(res Result) string { + switch r := res.(type) { + case *CallToolResult: + return r.RequestState + case *GetPromptResult: + return r.RequestState + case *ReadResourceResult: + return r.RequestState + } + return "" +} + +func setMRTRRetryParams(req Request, responses InputResponseMap, state string) { + switch p := req.GetParams().(type) { + case *CallToolParams: + p.InputResponses = responses + p.RequestState = state + case *CallToolParamsRaw: + p.InputResponses = responses + p.RequestState = state + case *GetPromptParams: + p.InputResponses = responses + p.RequestState = state + case *ReadResourceParams: + p.InputResponses = responses + p.RequestState = state + } +} + +func fulfillInputRequests(ctx context.Context, cs *ClientSession, requests InputRequestMap) (InputResponseMap, error) { + g, ctx := errgroup.WithContext(ctx) + var mu sync.Mutex + responses := make(InputResponseMap, len(requests)) + for id, ir := range requests { + g.Go(func() error { + resp, err := fulfillInputRequest(ctx, cs, ir) + if err != nil { + return fmt.Errorf("fulfilling input request %q: %w", id, err) + } + mu.Lock() + responses[id] = resp + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("MRTR: %w", err) + } + return responses, nil +} + +func fulfillInputRequest(ctx context.Context, cs *ClientSession, ir InputRequest) (InputResponse, error) { + switch p := ir.(type) { + case *ElicitParams: + return cs.client.elicit(ctx, newClientRequest(cs, p)) + case *CreateMessageParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: createMessageParamsToWithTools(p)}) + case *CreateMessageWithToolsParams: + return cs.client.createMessage(ctx, &CreateMessageWithToolsRequest{Session: cs, Params: p}) + case *ListRootsParams: + return cs.client.listRoots(ctx, newClientRequest(cs, p)) + default: + return nil, fmt.Errorf("unknown input request type: %T", ir) + } +} diff --git a/mcp/mrtr_test.go b/mcp/mrtr_test.go new file mode 100644 index 00000000..26bf689a --- /dev/null +++ b/mcp/mrtr_test.go @@ -0,0 +1,524 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/jsonschema-go/jsonschema" +) + +func TestMRTR_ManualRetry(t *testing.T) { + type deployResult struct { + Deployed bool `json:"deployed"` + Reason string `json:"reason,omitempty"` + } + + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "deploy"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, *deployResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Deploy to production?"}}, + RequestState: "deployment-123", + }, nil, nil + } + + resp, ok := req.Params.InputResponses["confirm"] + if !ok { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Please confirm (retry)"}}, + }, nil, nil + } + + if req.Params.RequestState == "" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "no_state"}, nil + } + if elicitResult := resp.(*ElicitResult); elicitResult != nil && elicitResult.Action != "accept" { + return &CallToolResult{}, &deployResult{Deployed: false, Reason: "cancelled"}, nil + } + + return &CallToolResult{}, &deployResult{Deployed: true}, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + }) + + // Round 1: initiate deployment + res, err := conn.CallTool(ctx, &CallToolParams{Name: "deploy"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if got := len(res.InputRequests); got != 1 { + t.Fatalf("len(res.InputRequests) = %d, want 1", got) + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("res.InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + // Round 2: retry with confirmation + res, err = conn.CallTool(ctx, &CallToolParams{ + Name: "deploy", + InputResponses: InputResponseMap{ + "confirm": &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, + }, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("CallTool() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + + if diff := cmp.Diff(map[string]any{"deployed": true}, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } +} + +func TestMRTR_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + tests := []struct { + name string + inputRequests InputRequestMap + wantResult map[string]any + }{ + { + name: "elicit", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Deploy?"}, + }, + wantResult: map[string]any{"ids": []any{"confirm"}}, + }, + { + name: "createMessage", + inputRequests: InputRequestMap{ + "summarize": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "summarize"}}}, + MaxTokens: 100, + }, + }, + wantResult: map[string]any{"ids": []any{"summarize"}}, + }, + { + name: "listRoots", + inputRequests: InputRequestMap{ + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"roots"}}, + }, + { + name: "all three", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "draft": &CreateMessageParams{ + Messages: []*SamplingMessage{{Role: "user", Content: &TextContent{Text: "write"}}}, + MaxTokens: 50, + }, + "roots": &ListRootsParams{}, + }, + wantResult: map[string]any{"ids": []any{"confirm", "draft", "roots"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + inputRequests := tt.inputRequests + AddTool(srv, &Tool{Name: "act"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil, nil + } + // Collect the IDs of fulfilled responses. + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + return &CallToolResult{}, map[string]any{"ids": ids}, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "response"}, + }, nil + }, + }) + conn.client.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + + res, err := conn.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + + // Sort the expected IDs for stable comparison. + if wantIDs, ok := tt.wantResult["ids"].([]any); ok { + slices.SortFunc(wantIDs, func(a, b any) int { + if a.(string) < b.(string) { + return -1 + } + return 1 + }) + } + + if diff := cmp.Diff(tt.wantResult, res.StructuredContent, ctrCmpOpts...); diff != "" { + t.Errorf("result mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestMRTR_MaxRetries(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + AddTool(srv, &Tool{Name: "loop"}, func(ctx context.Context, req *CallToolRequest, input struct{}) (*CallToolResult, any, error) { + return &CallToolResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Again?"}}, + RequestState: "loop-state", + }, nil, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + MRTR: &MRTROptions{MaxRetries: 2}, + }) + + _, err := conn.CallTool(ctx, &CallToolParams{Name: "loop"}) + if err == nil { + t.Fatal("CallTool() err = nil, want error for exceeded max retries") + } +} + +func TestMRTR_ServerMiddleware(t *testing.T) { + // mrtrToolHandler returns a ToolHandler (plain, non-generic) that requests + // the given inputRequests on the first call and returns the fulfilled + // response IDs on the second. + mrtrToolHandler := func(inputRequests InputRequestMap) ToolHandler { + return func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + if len(req.Params.InputResponses) == 0 { + return &CallToolResult{ + InputRequests: inputRequests, + RequestState: "state-1", + }, nil + } + var ids []string + for id := range req.Params.InputResponses { + ids = append(ids, id) + } + slices.Sort(ids) + content := &TextContent{Text: fmt.Sprintf("%v", ids)} + return &CallToolResult{Content: []Content{content}}, nil + } + } + + tests := []struct { + name string + inputRequests InputRequestMap + wantText string + }{ + { + name: "elicit via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "Sure?"}, + }, + wantText: "[confirm]", + }, + { + name: "elicit and listRoots via ToolHandler", + inputRequests: InputRequestMap{ + "confirm": &ElicitParams{Message: "OK?"}, + "roots": &ListRootsParams{}, + }, + wantText: "[confirm roots]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddTool( + &Tool{Name: "act", InputSchema: &jsonschema.Schema{Type: "object"}}, + mrtrToolHandler(tt.inputRequests), + ) + + // Connect with an OLD protocol version where MRTR is not supported. + // The server middleware should handle it transparently. + st, ct := NewInMemoryTransports() + ss, err := srv.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + ElicitationHandler: func(_ context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + c.AddRoots(&Root{URI: "file:///workspace", Name: "workspace"}) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(ctx, &CallToolParams{Name: "act"}) + if err != nil { + t.Fatalf("CallTool() error = %v", err) + } + if got := res.Content[0].(*TextContent).Text; got != tt.wantText { + t.Errorf("result text = %q, want %q", got, tt.wantText) + } + }) + } +} + +func TestMRTR_GetPrompt_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } + if got := res.Messages[0].Content.(*TextContent).Text; got != "review this code" { + t.Errorf("message text = %q, want %q", got, "review this code") + } +} + +func TestMRTR_GetPrompt_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddPrompt(&Prompt{Name: "review"}, func(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) { + if len(req.Params.InputResponses) == 0 { + return &GetPromptResult{ + InputRequests: InputRequestMap{"confirm": &ElicitParams{Message: "Include sensitive data?"}}, + RequestState: "prompt-state", + }, nil + } + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{{Role: "user", Content: &TextContent{Text: "review this code"}}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + }) + + res, err := conn.GetPrompt(ctx, &GetPromptParams{Name: "review"}) + if err != nil { + t.Fatalf("GetPrompt() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["confirm"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[confirm] type = %T, want *ElicitParams", res.InputRequests["confirm"]) + } + + res, err = conn.GetPrompt(ctx, &GetPromptParams{ + Name: "review", + InputResponses: InputResponseMap{"confirm": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("GetPrompt() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Messages) != 1 { + t.Fatalf("len(res.Messages) = %d, want 1", len(res.Messages)) + } +} + +func TestMRTR_ReadResource_AutoRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + ElicitationHandler: func(_ context.Context, _ *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after auto-retry, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } + if got := res.Contents[0].Text; got != "resource data" { + t.Errorf("resource text = %q, want %q", got, "resource data") + } +} + +func TestMRTR_ReadResource_ManualRetry(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), protocolVersion20260630) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx := context.Background() + + srv := NewServer(testImpl, nil) + srv.AddResource(&Resource{URI: "test://data", Name: "data"}, func(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { + if len(req.Params.InputResponses) == 0 { + return &ReadResourceResult{ + InputRequests: InputRequestMap{"auth": &ElicitParams{Message: "Authenticate?"}}, + RequestState: "resource-state", + }, nil + } + return &ReadResourceResult{ + Contents: []*ResourceContents{{URI: "test://data", Text: "resource data"}}, + }, nil + }) + + conn := mustConnectMRTR(t, srv, &ClientOptions{ + MRTR: &MRTROptions{Disabled: true}, + }) + + res, err := conn.ReadResource(ctx, &ReadResourceParams{URI: "test://data"}) + if err != nil { + t.Fatalf("ReadResource() error = %v", err) + } + if !res.NeedsInput() { + t.Fatal("NeedsInput() = false, want true") + } + if _, ok := res.InputRequests["auth"].(*ElicitParams); !ok { + t.Fatalf("InputRequests[auth] type = %T, want *ElicitParams", res.InputRequests["auth"]) + } + + res, err = conn.ReadResource(ctx, &ReadResourceParams{ + URI: "test://data", + InputResponses: InputResponseMap{"auth": &ElicitResult{Action: "accept"}}, + RequestState: res.RequestState, + }) + if err != nil { + t.Fatalf("ReadResource() follow-up error = %v", err) + } + if res.NeedsInput() { + t.Fatal("NeedsInput() = true after follow-up, want false") + } + if len(res.Contents) != 1 { + t.Fatalf("len(res.Contents) = %d, want 1", len(res.Contents)) + } +} + +func mustConnectMRTR(t *testing.T, s *Server, clientOpts *ClientOptions) *ClientSession { + t.Helper() + st, ct := NewInMemoryTransports() + ss, err := s.Connect(t.Context(), st, nil) + if err != nil { + t.Fatalf("server.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = ss.Close() + }) + + c := NewClient(testImpl, clientOpts) + cs, err := c.Connect(t.Context(), ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client.Connect() error = %v", err) + } + t.Cleanup(func() { + _ = cs.Close() + }) + return cs +} diff --git a/mcp/protocol.go b/mcp/protocol.go index 0af1ec57..df5a1a2c 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -13,6 +13,172 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/mcpgodebug" ) +// ResultType indicates whether a result is complete or requires further input +// from the client via the MRTR (Multi Round-Trip Requests) protocol. +type ResultType string + +const ( + // ResultTypeComplete indicates the result is final. + // This is the default when ResultType is empty. + ResultTypeComplete ResultType = "complete" + + // ResultTypeInputRequired indicates the server needs additional client + // input before it can complete the request. The client should fulfill the + // InputRequests and retry the call with the responses. + ResultTypeInputRequired ResultType = "input_required" +) + +// InputRequest is a sealed interface for parameters that a server can include +// in an MRTR input-required result. Implementations are [*ElicitParams], +// [*CreateMessageParams], and [*ListRootsParams]. +type InputRequest interface{ isInputRequest() } + +// InputRequestMap maps server-assigned request IDs to [InputRequest] values. +// It is used in result types to tell the client what input the server needs. +type InputRequestMap map[string]InputRequest + +func (m InputRequestMap) MarshalJSON() ([]byte, error) { + type wire struct { + Method string `json:"method"` + Params InputRequest `json:"params,omitempty"` + } + typeToMethod := func(v InputRequest) (string, error) { + switch v.(type) { + case *ElicitParams: + return methodElicit, nil + case *CreateMessageParams, *CreateMessageWithToolsParams: + return methodCreateMessage, nil + case *ListRootsParams: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Params: v} + } + return json.Marshal(converted) +} + +func (m *InputRequestMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + result := make(InputRequestMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageWithToolsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsParams + if err := json.Unmarshal(raw.Params, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputRequest method: %q", raw.Method) + } + } + *m = result + return nil +} + +// InputResponse is a sealed interface for results that a client sends back +// when fulfilling an MRTR input request. Implementations are [*ElicitResult], +// [*CreateMessageResult], and [*ListRootsResult]. +type InputResponse interface{ isInputResponse() } + +// InputResponseMap maps request IDs (from [InputRequestMap]) to [InputResponse] +// values. It is used in params types when retrying a call after an +// input-required result. +type InputResponseMap map[string]InputResponse + +func (m InputResponseMap) MarshalJSON() ([]byte, error) { + type wire struct { + Method string `json:"method"` + Result InputResponse `json:"result,omitempty"` + } + typeToMethod := func(v InputResponse) (string, error) { + switch v.(type) { + case *ElicitResult: + return methodElicit, nil + case *CreateMessageResult, *CreateMessageWithToolsResult: + return methodCreateMessage, nil + case *ListRootsResult: + return methodListRoots, nil + default: + return "", fmt.Errorf("unsupported type: %T", v) + } + } + converted := map[string]*wire{} + for k, v := range m { + method, err := typeToMethod(v) + if err != nil { + return nil, err + } + converted[k] = &wire{Method: method, Result: v} + } + return json.Marshal(converted) +} + +func (m *InputResponseMap) UnmarshalJSON(data []byte) error { + type raw struct { + Method string `json:"method"` + Result json.RawMessage `json:"result"` + } + var rawMap map[string]*raw + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + result := make(InputResponseMap, len(rawMap)) + for k, raw := range rawMap { + switch raw.Method { + case methodElicit: + var p ElicitResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodCreateMessage: + var p CreateMessageWithToolsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + case methodListRoots: + var p ListRootsResult + if err := json.Unmarshal(raw.Result, &p); err != nil { + return err + } + result[k] = &p + default: + return fmt.Errorf("unsupported InputResponse method: %q", raw.Method) + } + } + *m = result + return nil +} + // Optional annotations for the client. The client can use annotations to inform // how objects are used or displayed. type Annotations struct { @@ -46,6 +212,14 @@ type CallToolParams struct { // Arguments holds the tool arguments. It can hold any value that can be // marshaled to JSON. Arguments any `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // CallToolParamsRaw is passed to tool handlers on the server. Its arguments @@ -61,6 +235,14 @@ type CallToolParamsRaw struct { // is the responsibility of the tool handler to unmarshal and validate the // Arguments (see [AddTool]). Arguments json.RawMessage `json:"arguments,omitempty"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + // The client must echo this back when retrying. + RequestState string `json:"requestState,omitempty"` } // A CallToolResult is the server's response to a tool call. @@ -107,6 +289,24 @@ type CallToolResult struct { // the Content field. IsError bool `json:"isError,omitempty"` + // InputRequests is a map of server-assigned IDs to input requests. + // Populated only when ResultType is ResultTypeInputRequired. + // The client must fulfill these and echo the IDs back in InputResponses + // when retrying the call. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + + // RequestState is an opaque string the client must echo back when + // retrying after an input-required result. Servers use this to carry + // context between independent requests. + // + // Unauthenticated servers must encrypt, sign and verify this value. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. Empty or ResultTypeComplete means the call succeeded + // normally. ResultTypeInputRequired means the client should fulfill the + // InputRequests and retry the call. + resultType ResultType // The error passed to setError, if any. // It is not marshaled, and therefore it is only visible on the server. // Its only use is in server sending middleware, where it can be accessed @@ -145,13 +345,34 @@ func (r *CallToolResult) GetError() error { func (*CallToolResult) isResult() {} -// UnmarshalJSON handles the unmarshalling of content into the Content -// interface. +func (r *CallToolResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *CallToolResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *CallToolResult) hasContent() bool { + return len(r.Content) > 0 || r.StructuredContent != nil +} + +// NeedsInput reports whether this result requires further client input. +// This is true when the server returned ResultType "input_required". +// When NeedsInput returns true, check InputRequests for the set of +// requests the server needs fulfilled before retrying the call. +// An empty InputRequests with NeedsInput true indicates load-shedding. +func (r *CallToolResult) NeedsInput() bool { return r.resultType == ResultTypeInputRequired } + +func (x *CallToolResult) MarshalJSON() ([]byte, error) { + type res CallToolResult // avoid recursion + type wire struct { + res + ResultType ResultType `json:"resultType,omitempty"` + } + return json.Marshal(wire{res: res(*x), ResultType: x.resultType}) +} + func (x *CallToolResult) UnmarshalJSON(data []byte) error { type res CallToolResult // avoid recursion var wire struct { res - Content []*wireContent `json:"content"` + Content []*wireContent `json:"content"` + ResultType ResultType `json:"resultType"` } if err := internaljson.Unmarshal(data, &wire); err != nil { return err @@ -160,6 +381,7 @@ func (x *CallToolResult) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } + wire.res.resultType = wire.ResultType *x = CallToolResult(wire.res) return nil } @@ -422,6 +644,7 @@ type CreateMessageParams struct { } func (x *CreateMessageParams) isParams() {} +func (x *CreateMessageParams) isInputRequest() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -448,6 +671,7 @@ type CreateMessageWithToolsParams struct { } func (x *CreateMessageWithToolsParams) isParams() {} +func (x *CreateMessageWithToolsParams) isInputRequest() {} func (x *CreateMessageWithToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageWithToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -547,7 +771,8 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } -func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isResult() {} +func (*CreateMessageResult) isInputResponse() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -592,7 +817,8 @@ var createMessageWithToolsResultAllow = map[string]bool{ "tool_use": true, } -func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isResult() {} +func (*CreateMessageWithToolsResult) isInputResponse() {} // MarshalJSON marshals the result. When Content has a single element, it is // marshaled as a single object for compatibility with pre-2025-11-25 @@ -651,6 +877,13 @@ type GetPromptParams struct { Arguments map[string]string `json:"arguments,omitempty"` // The name of the prompt or prompt template. Name string `json:"name"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *GetPromptParams) isParams() {} @@ -665,10 +898,52 @@ type GetPromptResult struct { // An optional description for the prompt. Description string `json:"description,omitempty"` Messages []*PromptMessage `json:"messages"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for MRTR retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType ResultType } func (*GetPromptResult) isResult() {} +func (r *GetPromptResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *GetPromptResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *GetPromptResult) hasContent() bool { return len(r.Messages) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *GetPromptResult) NeedsInput() bool { return r.resultType == ResultTypeInputRequired } + +func (x *GetPromptResult) MarshalJSON() ([]byte, error) { + type res GetPromptResult + type wire struct { + res + ResultType ResultType `json:"resultType,omitempty"` + } + return json.Marshal(wire{res: res(*x), ResultType: x.resultType}) +} + +func (x *GetPromptResult) UnmarshalJSON(data []byte) error { + type res GetPromptResult + var wire struct { + res + ResultType ResultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = GetPromptResult(wire.res) + return nil +} + // InitializeParams is sent by the client to initialize the session. type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -833,6 +1108,7 @@ type ListRootsParams struct { } func (x *ListRootsParams) isParams() {} +func (x *ListRootsParams) isInputRequest() {} func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -846,7 +1122,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } -func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isResult() {} +func (*ListRootsResult) isInputResponse() {} type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -1086,6 +1363,13 @@ type ReadResourceParams struct { // The URI of the resource to read. The URI can use any protocol; it is up to // the server how to interpret it. URI string `json:"uri"` + + // InputResponses maps input request IDs to responses, provided when + // retrying a call after receiving a result with ResultType + // ResultTypeInputRequired. + InputResponses InputResponseMap `json:"inputResponses,omitempty"` + // RequestState is the opaque state from the previous input-required result. + RequestState string `json:"requestState,omitempty"` } func (x *ReadResourceParams) isParams() {} @@ -1098,10 +1382,52 @@ type ReadResourceResult struct { // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Contents []*ResourceContents `json:"contents"` + + // InputRequests is populated when ResultType is ResultTypeInputRequired. + // See [CallToolResult.InputRequests]. + InputRequests InputRequestMap `json:"inputRequests,omitempty"` + // RequestState is the opaque state for MRTR retries. + // See [CallToolResult.RequestState]. + RequestState string `json:"requestState,omitempty"` + + // ResultType indicates whether this result is complete or requires further + // client input. See [CallToolResult.ResultType] for details. + resultType ResultType } func (*ReadResourceResult) isResult() {} +func (r *ReadResourceResult) setResultType(rt ResultType) { r.resultType = rt } +func (r *ReadResourceResult) inputRequests() map[string]InputRequest { return r.InputRequests } +func (r *ReadResourceResult) hasContent() bool { return len(r.Contents) > 0 } + +// NeedsInput reports whether this result requires further client input. +// See [CallToolResult.NeedsInput] for details. +func (r *ReadResourceResult) NeedsInput() bool { return r.resultType == ResultTypeInputRequired } + +func (x *ReadResourceResult) MarshalJSON() ([]byte, error) { + type res ReadResourceResult + type wire struct { + res + ResultType ResultType `json:"resultType,omitempty"` + } + return json.Marshal(wire{res: res(*x), ResultType: x.resultType}) +} + +func (x *ReadResourceResult) UnmarshalJSON(data []byte) error { + type res ReadResourceResult + var wire struct { + res + ResultType ResultType `json:"resultType"` + } + if err := internaljson.Unmarshal(data, &wire); err != nil { + return err + } + wire.res.resultType = wire.ResultType + *x = ReadResourceResult(wire.res) + return nil +} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -1494,7 +1820,8 @@ type ElicitParams struct { ElicitationID string `json:"elicitationId,omitempty"` } -func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isParams() {} +func (x *ElicitParams) isInputRequest() {} func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -1514,7 +1841,8 @@ type ElicitResult struct { Content map[string]any `json:"content,omitempty"` } -func (*ElicitResult) isResult() {} +func (*ElicitResult) isResult() {} +func (*ElicitResult) isInputResponse() {} // ElicitationCompleteParams is sent from the server to the client, informing it that an out-of-band elicitation interaction has completed. type ElicitationCompleteParams struct { diff --git a/mcp/server.go b/mcp/server.go index 7526ea7b..c322f23d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -187,7 +187,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { opts.Logger = ensureLogger(nil) } - return &Server{ + s := &Server{ impl: impl, opts: opts, prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), @@ -199,6 +199,8 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), } + s.AddReceivingMiddleware(serverMRTRMiddleware()) + return s } // AddPrompt adds a [Prompt] to the server, or replaces one with the same name. @@ -370,8 +372,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *SchemaCa } // Marshal the output and put the RawMessage in the StructuredContent field. + // Skip when the handler returned input requests (MRTR): content and + // inputRequests are mutually exclusive on the wire. var outval any = out - if elemZero != nil { + if res.InputRequests != nil { + outval = nil + } else if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. // Instead, use the zero value of the unpointered type. var z Out @@ -742,7 +748,13 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm Message: fmt.Sprintf("unknown prompt %q", req.Params.Name), } } - return prompt.handler(ctx, req) + res, err := prompt.handler(ctx, req) + if err == nil && res != nil { + if err := handleMRTRResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + } + return res, err } func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { @@ -775,10 +787,15 @@ func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolR } } res, err := st.handler(ctx, req) - if err == nil && res != nil && res.Content == nil { - res2 := *res - res2.Content = []Content{} // avoid "null" - res = &res2 + if err == nil && res != nil { + if err := handleMRTRResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.Content == nil && res.resultType != ResultTypeInputRequired { + res2 := *res + res2.Content = []Content{} // avoid "null" + res = &res2 + } } return res, err } @@ -826,6 +843,12 @@ func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*R if err != nil { return nil, err } + if err := handleMRTRResult(req.Session, s.opts.Logger, res); err != nil { + return nil, err + } + if res.resultType == ResultTypeInputRequired { + return res, nil + } if res == nil || res.Contents == nil { return nil, fmt.Errorf("reading resource %s: read handler returned nil information", uri) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d2e54224..7ca109a6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -769,6 +769,12 @@ func req(id int64, method string, params any) *jsonrpc.Request { return r } +func completeCallToolResult() *CallToolResult { + r := &CallToolResult{Content: []Content{}} + r.resultType = ResultTypeComplete + return r +} + func resp(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ ID: jsonrpc2.Int64ID(id), @@ -1968,7 +1974,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(2, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2012,7 +2018,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { }, messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(6, completeCallToolResult(), nil)}, }, { method: "POST", @@ -2027,7 +2033,7 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, })}, wantStatusCode: http.StatusOK, - wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + wantMessages: []jsonrpc.Message{resp(7, completeCallToolResult(), nil)}, }, { method: "POST",