diff --git a/src/stdio/proxy.ts b/src/stdio/proxy.ts index de3ba05..4dbb368 100644 --- a/src/stdio/proxy.ts +++ b/src/stdio/proxy.ts @@ -147,6 +147,11 @@ const getSecurityErrorMessage = (error: unknown): string => { return error instanceof Error ? error.message : 'Fail-Closed: security validation failed.'; }; +const BACKPRESSURE_BYTES_PER_SECOND = parseIntEnv( + process.env['MCP_STDIO_MAX_THROUGHPUT_BPS'], + { fallback: 50 * 1024 * 1024, min: 1024, max: 1024 * 1024 * 1024 }, +); + export const createStdioFirewallProxy = (options: StdioFirewallOptions) => { const input: Readable = options.input ?? process.stdin; const output: Writable = options.output ?? process.stdout; @@ -186,26 +191,75 @@ export const createStdioFirewallProxy = (options: StdioFirewallOptions) => { let stopped = false; let draining = false; let stdioSessionColor: SessionColor = null; + // Throughput limiter state + let stdinBytesInWindow = 0; + let stdinWindowStartMs = Date.now(); + let stdoutBytesInWindow = 0; + let stdoutWindowStartMs = Date.now(); + + const checkThroughputLimit = ( + bytesWritten: number, + direction: 'stdin' | 'stdout', + ): boolean => { + const now = Date.now(); + if (direction === 'stdin') { + if (now - stdinWindowStartMs >= 1000) { + stdinBytesInWindow = 0; + stdinWindowStartMs = now; + } + stdinBytesInWindow += bytesWritten; + return stdinBytesInWindow > BACKPRESSURE_BYTES_PER_SECOND; + } else { + if (now - stdoutWindowStartMs >= 1000) { + stdoutBytesInWindow = 0; + stdoutWindowStartMs = now; + } + stdoutBytesInWindow += bytesWritten; + return stdoutBytesInWindow > BACKPRESSURE_BYTES_PER_SECOND; + } + }; + + const waitForOutputDrain = (): Promise => { + return new Promise((resolve) => { + output.once('drain', () => { + targetInterface?.resume(); + resolve(); + }); + }); + }; const writeRawJson = (message: unknown): void => { try { - output.write(JSON.stringify(message) + '\n'); + const data = JSON.stringify(message) + '\n'; + const flushed = output.write(data); + if (!flushed) { + targetInterface?.pause(); + waitForOutputDrain().catch(() => {}); + } + const byteLen = Buffer.byteLength(data, 'utf8'); + if (checkThroughputLimit(byteLen, 'stdout')) { + targetInterface?.pause(); + setTimeout(() => { + targetInterface?.resume(); + }, 1000 - (Date.now() - stdoutWindowStartMs)); + } } catch {} }; const clearPendingRequest = (requestId: string): PendingRequest | undefined => { const pending = pendingRequests.get(requestId); + pendingRequests.delete(requestId); if (pending) { clearTimeout(pending.timeout); - pendingRequests.delete(requestId); } return pending; }; const failAllPending = (code: number, message: string, data?: unknown): void => { - for (const [key, pending] of pendingRequests.entries()) { + const entries = [...pendingRequests.entries()]; + pendingRequests.clear(); + for (const [, pending] of entries) { clearTimeout(pending.timeout); - pendingRequests.delete(key); writeRawJson(buildJsonRpcErrorResponse(pending.id, code, message, data)); } }; @@ -366,15 +420,27 @@ export const createStdioFirewallProxy = (options: StdioFirewallOptions) => { try { const serializedMessage = JSON.stringify(message) + '\n'; - if (!targetProcess.stdin.write(serializedMessage)) { + const flushed = targetProcess.stdin.write(serializedMessage); + if (!flushed) { auditLog('STDIO_TARGET_BACKPRESSURE', { code: 'STDIO_TARGET_BACKPRESSURE', reason: 'Target stdin reported backpressure.', toolName: tool?.name, pendingRequests: pendingRequests.size, }); + clientInterface?.pause(); + targetProcess.stdin.once('drain', () => { + clientInterface?.resume(); + }); + } + const byteLen = Buffer.byteLength(serializedMessage, 'utf8'); + if (checkThroughputLimit(byteLen, 'stdin')) { + clientInterface?.pause(); + setTimeout(() => { + clientInterface?.resume(); + }, 1000 - (Date.now() - stdinWindowStartMs)); } - } catch (error) { + } catch { if (requestId !== null) clearPendingRequest(String(requestId)); writeRawJson(buildJsonRpcErrorResponse(requestId, -32004, 'Fail-Closed: target process is unavailable.', { code: 'TARGET_UNAVAILABLE' })); }