From 24ff02557cb0626ab7030e14a92bf8a1fb2aa9b8 Mon Sep 17 00:00:00 2001 From: Bri <34875062+Monkatraz@users.noreply.github.com> Date: Sun, 31 May 2026 15:56:12 -0700 Subject: [PATCH] Check for malformed ctx.from ids in clients --- __tests__/e2e.test.ts | 56 +++++++++++++++++++ package-lock.json | 4 +- package.json | 2 +- protobuf/handshake.ts | 10 +++- router/handshake.ts | 21 ++++--- transport/server.ts | 1 + .../sessionStateMachine/SessionConnected.ts | 9 +++ .../sessionStateMachine/stateMachine.test.ts | 45 ++++++++++++--- transport/transport.test.ts | 2 + 9 files changed, 128 insertions(+), 22 deletions(-) diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index d65cefa3..63454e32 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -1314,5 +1314,61 @@ describe.each(testMatrix())( server, }); }); + + test('validate receives the connecting client id', async () => { + const requestSchema = Type.Object({}); + + interface ParsedMetadata { + seenFrom: string; + } + + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, () => ({})), + ); + const serverTransport = getServerTransport( + 'SERVER', + createServerHandshakeOptions( + requestSchema, + (_metadata, _prev, from) => ({ + seenFrom: from ?? '', + }), + ), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const ServiceSchema = createServiceSchema< + MaybeDisposable, + ParsedMetadata + >(); + const services = { + test: ServiceSchema.define({ + whoami: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ seenFrom: Type.String() }), + handler: async ({ ctx }) => Ok({ seenFrom: ctx.metadata.seenFrom }), + }), + }), + }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const result = await client.test.whoami.rpc({}); + expect(result).toStrictEqual({ + ok: true, + payload: { seenFrom: 'client' }, + }); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); }, ); diff --git a/package-lock.json b/package-lock.json index f3ae81ab..0274cf1f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.217.1", + "version": "0.217.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.217.1", + "version": "0.217.2", "license": "MIT", "dependencies": { "@bufbuild/protobuf": "^2.11.0", diff --git a/package.json b/package.json index 3c73edf1..78e9e0c6 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.217.1", + "version": "0.217.2", "type": "module", "exports": { ".": "./dist/router/index.js", diff --git a/protobuf/handshake.ts b/protobuf/handshake.ts index c474d550..686d5eb8 100644 --- a/protobuf/handshake.ts +++ b/protobuf/handshake.ts @@ -10,7 +10,10 @@ import { type ClientHandshakeOptions, type ServerHandshakeOptions, } from '../router/handshake'; -import { HandshakeErrorCustomHandlerFatalResponseCodes } from '../transport/message'; +import { + HandshakeErrorCustomHandlerFatalResponseCodes, + type TransportClientId, +} from '../transport/message'; import { decodeMessageBytes, encodeMessageBytes } from './shared'; import { Uint8ArrayType } from '../customSchemas'; @@ -27,6 +30,7 @@ type ConstructHandshake = () => type ValidateHandshake = ( metadata: MessageShape, previousParsedMetadata?: ParsedMetadata, + from?: TransportClientId, ) => | ParsedMetadata | ProtobufHandshakeFailureCode @@ -61,7 +65,7 @@ export function createServerHandshakeOptions< ): ServerHandshakeOptions { return createTransportServerHandshakeOptions( HandshakeBytesSchema, - async (metadata, previousParsedMetadata) => { + async (metadata, previousParsedMetadata, from) => { let decoded; try { decoded = decodeMessageBytes(schema, metadata); @@ -69,7 +73,7 @@ export function createServerHandshakeOptions< return 'REJECTED_BY_CUSTOM_HANDLER' as ProtobufHandshakeFailureCode; } - return await validate(decoded, previousParsedMetadata); + return await validate(decoded, previousParsedMetadata, from); }, ); } diff --git a/router/handshake.ts b/router/handshake.ts index 12de3071..821f2af4 100644 --- a/router/handshake.ts +++ b/router/handshake.ts @@ -1,5 +1,8 @@ import type { Static, TSchema } from 'typebox'; -import { HandshakeErrorCustomHandlerFatalResponseCodes } from '../transport/message'; +import { + HandshakeErrorCustomHandlerFatalResponseCodes, + type TransportClientId, +} from '../transport/message'; type ConstructHandshake = () => | Static @@ -8,6 +11,7 @@ type ConstructHandshake = () => type ValidateHandshake = ( metadata: Static, previousParsedMetadata?: ParsedMetadata, + from?: TransportClientId, ) => | Static | ParsedMetadata @@ -42,15 +46,16 @@ export interface ServerHandshakeOptions< schema: MetadataSchema; /** - * Parses the {@link HandshakeRequestMetadata} sent by the client, transforming - * it into {@link ParsedHandshakeMetadata}. - * - * May return `false` if the client should be rejected. + * Parses the metadata sent by the client during the handshake into the + * server-side {@link ParsedMetadata}, or returns a handshake failure code to + * reject the connection. * * @param metadata - The metadata sent by the client. - * @param session - The session that the client would be associated with. - * @param isReconnect - Whether the client is reconnecting to the session, - * or if this is a new session. + * @param previousParsedMetadata - The parsed metadata from the previous + * connection on this session, if any (e.g. on reconnect). + * @param from - The client id the peer presented in its handshake. Use it to + * confirm the presented id is the one the metadata authorizes before + * returning parsed metadata. */ validate: ValidateHandshake; } diff --git a/transport/server.ts b/transport/server.ts index c672ffcc..8bb015b5 100644 --- a/transport/server.ts +++ b/transport/server.ts @@ -297,6 +297,7 @@ export abstract class ServerTransport< parsedMetadataOrFailureCode = await this.handshakeExtensions.validate( msg.payload.metadata, previousParsedMetadata, + msg.from, ); } catch (err) { this.rejectHandshakeRequest( diff --git a/transport/sessionStateMachine/SessionConnected.ts b/transport/sessionStateMachine/SessionConnected.ts index 41da1d31..e4875532 100644 --- a/transport/sessionStateMachine/SessionConnected.ts +++ b/transport/sessionStateMachine/SessionConnected.ts @@ -192,6 +192,15 @@ export class SessionConnected< const parsedMsg = parsedMsgRes.value; + // messages must originate from this session's peer + if (parsedMsg.from !== this.to) { + this.listeners.onInvalidMessage( + `received message with 'from' (${parsedMsg.from}) that does not match the session peer (${this.to})`, + ); + + return; + } + // check message ordering here if (parsedMsg.seq !== this.ack) { if (parsedMsg.seq < this.ack) { diff --git a/transport/sessionStateMachine/stateMachine.test.ts b/transport/sessionStateMachine/stateMachine.test.ts index 34a05d71..e2e55206 100644 --- a/transport/sessionStateMachine/stateMachine.test.ts +++ b/transport/sessionStateMachine/stateMachine.test.ts @@ -1967,11 +1967,17 @@ describe('session state machine', () => { expect(onConnectionClosed).not.toHaveBeenCalled(); expect(onConnectionErrored).not.toHaveBeenCalled(); - const encodeResult = session.encodeMsg( - payloadToTransportMessage('hello'), + // an incoming frame carries the peer's id in `from` + session.conn.emitData( + session.options.codec.toBuffer({ + id: 'msgid', + from: session.to, + to: session.from, + seq: 0, + ack: 0, + ...payloadToTransportMessage('hello'), + }), ); - assert(encodeResult.ok); - session.conn.emitData(encodeResult.value.data); await waitFor(async () => { expect(onMessage).toHaveBeenCalledTimes(1); @@ -2021,8 +2027,8 @@ describe('session state machine', () => { conn.onData( session.options.codec.toBuffer({ id: 'msgid', - to: 'SERVER', - from: 'client', + to: session.from, + from: session.to, seq: 0, ack: 0, streamId: 'heartbeat', @@ -2048,8 +2054,8 @@ describe('session state machine', () => { conn.onData( session.options.codec.toBuffer({ id: 'msgid', - to: 'SERVER', - from: 'client', + to: session.from, + from: session.to, seq: 0, ack: 0, streamId: 'heartbeat', @@ -2062,5 +2068,28 @@ describe('session state machine', () => { expect(sessionHandle.onMessage).not.toHaveBeenCalled(); }); + + test('rejects a message whose from does not match the session peer', async () => { + const sessionHandle = await createSessionConnected(); + const session = sessionHandle.session; + const conn = session.conn; + + // a frame whose `from` isn't this session's peer is rejected + conn.onData( + session.options.codec.toBuffer({ + id: 'msgid', + to: session.from, + from: 'someone-else', + seq: 0, + ack: 0, + streamId: 'stream', + controlFlags: 0, + payload: { type: 'ACK' }, + }), + ); + + expect(sessionHandle.onInvalidMessage).toHaveBeenCalledTimes(1); + expect(sessionHandle.onMessage).not.toHaveBeenCalled(); + }); }); }); diff --git a/transport/transport.test.ts b/transport/transport.test.ts index c0ab69f5..2af14770 100644 --- a/transport/transport.test.ts +++ b/transport/transport.test.ts @@ -1762,6 +1762,7 @@ describe.each(testMatrix())( discarded: 'discarded', }, undefined, + clientTransport.clientId, ); const session = serverTransport.sessions.get(clientTransport.clientId); @@ -1791,6 +1792,7 @@ describe.each(testMatrix())( { kept: 'kept', }, + clientTransport.clientId, ); await testFinishesCleanly({