Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/services/mcp/McpHub.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ const BaseConfigSchema = z.object({
alwaysAllow: z.array(z.string()).default([]),
watchPaths: z.array(z.string()).optional(), // paths to watch for changes and restart server
disabledTools: z.array(z.string()).default([]),
oauth: z
.object({
clientId: z.string().optional(),
})
.optional(),
})

// Custom error messages for better user feedback
Expand Down Expand Up @@ -807,6 +812,7 @@ export class McpHub {

const authProvider = await McpOAuthClientProvider.create(configInjected.url, this.secretStorage, name, {
skipDiscovery,
clientId: configInjected.oauth?.clientId,
})

// Pre-register the OAuth client so the SDK can skip its own
Expand Down
15 changes: 14 additions & 1 deletion src/services/mcp/McpOAuthClientProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
private readonly _authServerMeta: Record<string, any> | null,
private readonly _resourceIndicator: string | null,
private readonly _clientName: string,
private readonly _staticClientId?: string,
) {}

/**
Expand All @@ -102,7 +103,7 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
serverUrl: string,
secretStorage: SecretStorageService,
serverName?: string,
options?: { skipDiscovery?: boolean },
options?: { skipDiscovery?: boolean; clientId?: string },
): Promise<McpOAuthClientProvider> {
let authServerMeta: Record<string, any> | null = null
let resourceIndicator: string | null = null
Expand Down Expand Up @@ -151,6 +152,7 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
authServerMeta,
resourceIndicator,
serverName || "Roo Code",
options?.clientId,
)
}

Expand Down Expand Up @@ -228,6 +230,17 @@ export class McpOAuthClientProvider implements OAuthClientProvider {
async registerClientIfNeeded(): Promise<void> {
if (this._clientInfo) return // already registered

// If a static clientId was provided (e.g. from mcp.json oauth.clientId),
// use it directly instead of performing Dynamic Client Registration.
// This enables connections to OAuth servers that don't support DCR.
if (this._staticClientId) {
this._clientInfo = {
client_id: this._staticClientId,
redirect_uris: [this.redirectUrl],
}
return
}

// Check if we have a cached client_id from previous registration
const cachedData = await this._secretStorage.getOAuthData(this._serverUrl)
if (cachedData?.client_info) {
Expand Down
45 changes: 45 additions & 0 deletions src/services/mcp/__tests__/McpOAuthClientProvider.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -849,4 +849,49 @@ describe("McpOAuthClientProvider", () => {
await provider.close()
})
})

describe("static clientId support", () => {
it("should use static clientId instead of performing DCR", async () => {
const secretStorage = createMockSecretStorage()
const provider = await McpOAuthClientProvider.create(
"https://example.com/mcp",
secretStorage,
"test-server",
{ clientId: "my-static-client-id" },
)

await provider.registerClientIfNeeded()

const info = await provider.clientInformation()
expect(info?.client_id).toBe("my-static-client-id")
await provider.close()
})

it("should use static clientId even when cached data exists", async () => {
setupCallbackServerMock()
const secretStorage = createMockSecretStorage()

await secretStorage.saveOAuthData("https://example.com/mcp", {
tokens: { access_token: "cached-token", token_type: "Bearer" },
expires_at: Date.now() + 3600_000,
client_info: {
client_id: "cached-client-id",
redirect_uris: ["http://localhost:0/callback"],
},
})

const provider = await McpOAuthClientProvider.create(
"https://example.com/mcp",
secretStorage,
"test-server",
{ clientId: "my-static-client-id" },
)

await provider.registerClientIfNeeded()

const info = await provider.clientInformation()
expect(info?.client_id).toBe("my-static-client-id")
await provider.close()
})
})
})
Loading