From 301dfe1c039e5d0c0814b93fce2968ec40906ec4 Mon Sep 17 00:00:00 2001 From: Dawei Date: Thu, 12 Jun 2025 17:09:42 -0700 Subject: [PATCH] Refactor ParsedMetadata into generic --- __tests__/cancellation.test.ts | 1 + __tests__/cleanup.test.ts | 4 +- __tests__/e2e.test.ts | 12 ++- __tests__/typescript-stress.test.ts | 29 ++++--- router/client.ts | 4 + router/context.ts | 124 ++++++++++++---------------- router/handshake.ts | 13 +-- router/index.ts | 2 +- router/procedures.ts | 103 +++++++++++++++++++---- router/server.ts | 66 ++++++++++----- router/services.ts | 84 +++++++++++++------ testUtil/fixtures/cleanup.ts | 9 +- testUtil/fixtures/mockTransport.ts | 30 +++++-- testUtil/fixtures/transports.ts | 38 ++++++--- transport/impls/ws/server.ts | 6 +- transport/server.ts | 15 ++-- transport/transport.test.ts | 44 ++++++---- 17 files changed, 385 insertions(+), 199 deletions(-) diff --git a/__tests__/cancellation.test.ts b/__tests__/cancellation.test.ts index 856dc8bd..97b4f796 100644 --- a/__tests__/cancellation.test.ts +++ b/__tests__/cancellation.test.ts @@ -29,6 +29,7 @@ function makeMockHandler( ) { return vi.fn< Procedure< + object, object, object, T, diff --git a/__tests__/cleanup.test.ts b/__tests__/cleanup.test.ts index 190244b2..497a6854 100644 --- a/__tests__/cleanup.test.ts +++ b/__tests__/cleanup.test.ts @@ -504,7 +504,7 @@ describe('request finishing triggers signal onabort', async () => { const clientTransport = getClientTransport('client'); const serverTransport = getServerTransport(); const handler = - vi.fn<(ctx: ProcedureHandlerContext) => void>(); + vi.fn<(ctx: ProcedureHandlerContext) => void>(); const serverId = serverTransport.clientId; const serviceName = 'service'; const procedureName = procedureType; @@ -523,7 +523,7 @@ describe('request finishing triggers signal onabort', async () => { async handler({ ctx, }: { - ctx: ProcedureHandlerContext; + ctx: ProcedureHandlerContext; }) { handler(ctx); diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index bd4e60d7..493ac4ab 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -931,6 +931,12 @@ describe.each(testMatrix())( const requestSchema = Type.Object({ data: Type.String(), }); + + interface ParsedMetadata { + data: string; + extra: number; + } + const clientTransport = getClientTransport( 'client', createClientHandshakeOptions(requestSchema, () => ({ data: 'foobar' })), @@ -949,7 +955,7 @@ describe.each(testMatrix())( }); const services = { - test: createServiceSchema().define({ + test: createServiceSchema().define({ getData: Procedure.rpc({ requestInit: Type.Object({}), responseData: Type.Object({ @@ -957,9 +963,7 @@ describe.each(testMatrix())( extra: Type.Number(), }), handler: async ({ ctx }) => { - // we haven't extended the interface - // eslint-disable-next-line @typescript-eslint/no-explicit-any - return Ok({ ...ctx.metadata } as { data: string; extra: number }); + return Ok({ ...ctx.metadata }); }, }), }), diff --git a/__tests__/typescript-stress.test.ts b/__tests__/typescript-stress.test.ts index 9752d5aa..c1a5cd7b 100644 --- a/__tests__/typescript-stress.test.ts +++ b/__tests__/typescript-stress.test.ts @@ -49,6 +49,7 @@ const fnBody = Procedure.rpc< { db: string; }, + object, typeof requestData, typeof responseData, typeof responseError @@ -429,18 +430,22 @@ describe('Handshake', () => { }, ); - createServer(mockTransportNetwork.getServerTransport(), services, { - handshakeOptions: createServerHandshakeOptions( - schema, - (metadata, _prev) => { - if (metadata.token !== '123') { - return false; - } - - return {}; - }, - ), - }); + createServer( + mockTransportNetwork.getServerTransport(), + services, + { + handshakeOptions: createServerHandshakeOptions( + schema, + (metadata, _prev) => { + if (metadata.token !== '123') { + return false; + } + + return {}; + }, + ), + }, + ); }); }); diff --git a/router/client.ts b/router/client.ts index 804518c8..ed861eb6 100644 --- a/router/client.ts +++ b/router/client.ts @@ -147,11 +147,15 @@ export type Client< // Context is a server-side implementation detail that doesn't affect the client interface // eslint-disable-next-line @typescript-eslint/no-explicit-any any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + any, Services > = InstantiatedServiceSchemaMap< // Context is a server-side implementation detail that doesn't affect the client interface // eslint-disable-next-line @typescript-eslint/no-explicit-any any, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + any, Services >, > = { diff --git a/router/context.ts b/router/context.ts index e97324af..922c6574 100644 --- a/router/context.ts +++ b/router/context.ts @@ -5,78 +5,60 @@ import { ErrResult } from './result'; import { CancelErrorSchema } from './errors'; import { Static } from '@sinclair/typebox'; -/** - * The parsed metadata schema for a service. This is the - * return value of the {@link ServerHandshakeOptions.validate} - * if the handshake extension is used. - * - * You should use declaration merging to extend this interface - * with the sanitized metadata. - * - * ```ts - * declare module '@replit/river' { - * interface ParsedMetadata { - * userId: number; - * } - * } - * ``` - */ -/* eslint-disable-next-line @typescript-eslint/no-empty-interface */ -export interface ParsedMetadata extends Record {} - /** * This is passed to every procedure handler and contains various context-level * information and utilities. */ -export type ProcedureHandlerContext = Context & { - /** - * State for this service as defined by the service definition. - */ - state: State; - /** - * The span for this procedure call. You can use this to add attributes, events, and - * links to the span. - */ - span: Span; - /** - * Metadata parsed on the server. See {@link ParsedMetadata} - */ - metadata: ParsedMetadata; - /** - * The ID of the session that sent this request. - */ - sessionId: SessionId; - /** - * The ID of the client that sent this request. There may be multiple sessions per client. - */ - from: TransportClientId; - /** - * This is used to cancel the procedure call from the handler and notify the client that the - * call was cancelled. - * - * Cancelling is not the same as closing procedure calls gracefully, please refer to - * the river documentation to understand the difference between the two concepts. - */ - cancel: (message?: string) => ErrResult>; - /** - * This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) - * triggered when the procedure invocation is done. This signal tracks the invocation/request finishing - * for _any_ reason, for example: - * - client explicit cancellation - * - procedure handler explicit cancellation via {@link cancel} - * - client session disconnect - * - server cancellation due to client invalid payload - * - invocation finishes cleanly, this depends on the type of the procedure (i.e. rpc handler return, or in a stream after the client-side has closed the request writable and the server-side has closed the response writable) - * - * You can use this to pass it on to asynchronous operations (such as fetch). - * - * You may also want to explicitly register callbacks on the - * ['abort' event](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/abort_event) - * as a way to cleanup after the request is finished. - * - * Note that (per standard AbortSignals) callbacks registered _after_ the procedure invocation - * is done are not triggered. In such cases, you can check the "aborted" property and cleanup - * immediately if needed. - */ - signal: AbortSignal; -}; +export type ProcedureHandlerContext = + Context & { + /** + * State for this service as defined by the service definition. + */ + state: State; + /** + * The span for this procedure call. You can use this to add attributes, events, and + * links to the span. + */ + span: Span; + /** + * Metadata parsed on the server. See {@link createServerHandshakeOptions} + */ + metadata: ParsedMetadata; + /** + * The ID of the session that sent this request. + */ + sessionId: SessionId; + /** + * The ID of the client that sent this request. There may be multiple sessions per client. + */ + from: TransportClientId; + /** + * This is used to cancel the procedure call from the handler and notify the client that the + * call was cancelled. + * + * Cancelling is not the same as closing procedure calls gracefully, please refer to + * the river documentation to understand the difference between the two concepts. + */ + cancel: (message?: string) => ErrResult>; + /** + * This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) + * triggered when the procedure invocation is done. This signal tracks the invocation/request finishing + * for _any_ reason, for example: + * - client explicit cancellation + * - procedure handler explicit cancellation via {@link cancel} + * - client session disconnect + * - server cancellation due to client invalid payload + * - invocation finishes cleanly, this depends on the type of the procedure (i.e. rpc handler return, or in a stream after the client-side has closed the request writable and the server-side has closed the response writable) + * + * You can use this to pass it on to asynchronous operations (such as fetch). + * + * You may also want to explicitly register callbacks on the + * ['abort' event](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/abort_event) + * as a way to cleanup after the request is finished. + * + * Note that (per standard AbortSignals) callbacks registered _after_ the procedure invocation + * is done are not triggered. In such cases, you can check the "aborted" property and cleanup + * immediately if needed. + */ + signal: AbortSignal; + }; diff --git a/router/handshake.ts b/router/handshake.ts index 636b6936..3e87458d 100644 --- a/router/handshake.ts +++ b/router/handshake.ts @@ -1,12 +1,11 @@ import { Static, TSchema } from '@sinclair/typebox'; -import { ParsedMetadata } from './context'; import { HandshakeErrorCustomHandlerFatalResponseCodes } from '../transport/message'; type ConstructHandshake = () => | Static | Promise>; -type ValidateHandshake = ( +type ValidateHandshake = ( metadata: Static, previousParsedMetadata?: ParsedMetadata, ) => @@ -34,6 +33,7 @@ export interface ClientHandshakeOptions< export interface ServerHandshakeOptions< MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, > { /** * Schema for the metadata that the server receives from the client @@ -52,7 +52,7 @@ export interface ServerHandshakeOptions< * @param isReconnect - Whether the client is reconnecting to the session, * or if this is a new session. */ - validate: ValidateHandshake; + validate: ValidateHandshake; } export function createClientHandshakeOptions< @@ -66,9 +66,10 @@ export function createClientHandshakeOptions< export function createServerHandshakeOptions< MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, >( schema: MetadataSchema, - validate: ValidateHandshake, -): ServerHandshakeOptions { - return { schema, validate: validate as ValidateHandshake }; + validate: ValidateHandshake, +): ServerHandshakeOptions { + return { schema, validate }; } diff --git a/router/index.ts b/router/index.ts index ece437cb..8a84a91b 100644 --- a/router/index.ts +++ b/router/index.ts @@ -49,7 +49,7 @@ export type { MiddlewareParam, MiddlewareContext, } from './server'; -export type { ParsedMetadata, ProcedureHandlerContext } from './context'; +export type { ProcedureHandlerContext } from './context'; export { Ok, Err } from './result'; export type { Result, diff --git a/router/procedures.ts b/router/procedures.ts index 2cc43bdd..33289da6 100644 --- a/router/procedures.ts +++ b/router/procedures.ts @@ -52,6 +52,7 @@ export type Cancellable = T | Static; export interface RpcProcedure< Context, State, + ParsedMetadata, RequestInit extends PayloadType, ResponseData extends PayloadType, ResponseErr extends ProcedureErrorSchemaType, @@ -62,7 +63,7 @@ export interface RpcProcedure< responseError: ResponseErr; description?: string; handler(param: { - ctx: ProcedureHandlerContext; + ctx: ProcedureHandlerContext; reqInit: Static; }): Promise, Cancellable>>>; } @@ -80,6 +81,7 @@ export interface RpcProcedure< export interface UploadProcedure< Context, State, + ParsedMetadata, RequestInit extends PayloadType, RequestData extends PayloadType, ResponseData extends PayloadType, @@ -92,7 +94,7 @@ export interface UploadProcedure< responseError: ResponseErr; description?: string; handler(param: { - ctx: ProcedureHandlerContext; + ctx: ProcedureHandlerContext; reqInit: Static; reqReadable: Readable< Static, @@ -112,6 +114,7 @@ export interface UploadProcedure< export interface SubscriptionProcedure< Context, State, + ParsedMetadata, RequestInit extends PayloadType, ResponseData extends PayloadType, ResponseErr extends ProcedureErrorSchemaType, @@ -122,7 +125,7 @@ export interface SubscriptionProcedure< responseError: ResponseErr; description?: string; handler(param: { - ctx: ProcedureHandlerContext; + ctx: ProcedureHandlerContext; reqInit: Static; resWritable: Writable< Result, Cancellable>> @@ -143,6 +146,7 @@ export interface SubscriptionProcedure< export interface StreamProcedure< Context, State, + ParsedMetadata, RequestInit extends PayloadType, RequestData extends PayloadType, ResponseData extends PayloadType, @@ -155,7 +159,7 @@ export interface StreamProcedure< responseError: ResponseErr; description?: string; handler(param: { - ctx: ProcedureHandlerContext; + ctx: ProcedureHandlerContext; reqInit: Static; reqReadable: Readable< Static, @@ -185,6 +189,7 @@ export interface StreamProcedure< export type Procedure< Context, State, + ParsedMetadata, Ty extends ValidProcType, RequestInit extends PayloadType, RequestData extends PayloadType | null, @@ -195,6 +200,7 @@ export type Procedure< ? UploadProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -204,6 +210,7 @@ export type Procedure< ? StreamProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -211,11 +218,19 @@ export type Procedure< > : never : Ty extends 'rpc' - ? RpcProcedure + ? RpcProcedure< + Context, + State, + ParsedMetadata, + RequestInit, + ResponseData, + ResponseErr + > : Ty extends 'subscription' ? SubscriptionProcedure< Context, State, + ParsedMetadata, RequestInit, ResponseData, ResponseErr @@ -228,9 +243,14 @@ export type Procedure< * @template State - The context state object. You can provide this to constrain * the type of procedures. */ -export type AnyProcedure = Procedure< +export type AnyProcedure< + Context = object, + State = object, + ParsedMetadata = object, +> = Procedure< Context, State, + ParsedMetadata, ValidProcType, PayloadType, PayloadType | null, @@ -244,10 +264,11 @@ export type AnyProcedure = Procedure< * @template State - The context state object. You can provide this to constrain * the type of procedures. */ -export type ProcedureMap = Record< - string, - AnyProcedure ->; +export type ProcedureMap< + Context = object, + State = object, + ParsedMetadata = object, +> = Record>; // typescript is funky so with these upcoming procedure constructors, the overloads // which handle the `init` case _must_ come first, otherwise the `init` property @@ -260,6 +281,7 @@ export type ProcedureMap = Record< function rpc< Context, State, + ParsedMetadata, RequestInit extends PayloadType, ResponseData extends PayloadType, >(def: { @@ -270,16 +292,27 @@ function rpc< handler: RpcProcedure< Context, State, + ParsedMetadata, RequestInit, ResponseData, TNever >['handler']; -}): Branded>; +}): Branded< + RpcProcedure< + Context, + State, + ParsedMetadata, + RequestInit, + ResponseData, + TNever + > +>; // signature: explicit errors function rpc< Context, State, + ParsedMetadata, RequestInit extends PayloadType, ResponseData extends PayloadType, ResponseErr extends ProcedureErrorSchemaType, @@ -291,12 +324,20 @@ function rpc< handler: RpcProcedure< Context, State, + ParsedMetadata, RequestInit, ResponseData, ResponseErr >['handler']; }): Branded< - RpcProcedure + RpcProcedure< + Context, + State, + ParsedMetadata, + RequestInit, + ResponseData, + ResponseErr + > >; // implementation @@ -312,6 +353,7 @@ function rpc({ responseError?: ProcedureErrorSchemaType; description?: string; handler: RpcProcedure< + object, object, object, PayloadType, @@ -336,6 +378,7 @@ function rpc({ function upload< Context, State, + ParsedMetadata, RequestInit extends PayloadType, RequestData extends PayloadType, ResponseData extends PayloadType, @@ -348,6 +391,7 @@ function upload< handler: UploadProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -357,6 +401,7 @@ function upload< UploadProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -368,6 +413,7 @@ function upload< function upload< Context, State, + ParsedMetadata, RequestInit extends PayloadType, RequestData extends PayloadType, ResponseData extends PayloadType, @@ -381,6 +427,7 @@ function upload< handler: UploadProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -390,6 +437,7 @@ function upload< UploadProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -412,6 +460,7 @@ function upload({ responseError?: ProcedureErrorSchemaType; description?: string; handler: UploadProcedure< + object, object, object, PayloadType, @@ -438,6 +487,7 @@ function upload({ function subscription< Context, State, + ParsedMetadata, RequestInit extends PayloadType, ResponseData extends PayloadType, >(def: { @@ -448,18 +498,27 @@ function subscription< handler: SubscriptionProcedure< Context, State, + ParsedMetadata, RequestInit, ResponseData, TNever >['handler']; }): Branded< - SubscriptionProcedure + SubscriptionProcedure< + Context, + State, + ParsedMetadata, + RequestInit, + ResponseData, + TNever + > >; // signature: explicit errors function subscription< Context, State, + ParsedMetadata, RequestInit extends PayloadType, ResponseData extends PayloadType, ResponseErr extends ProcedureErrorSchemaType, @@ -471,12 +530,20 @@ function subscription< handler: SubscriptionProcedure< Context, State, + ParsedMetadata, RequestInit, ResponseData, ResponseErr >['handler']; }): Branded< - SubscriptionProcedure + SubscriptionProcedure< + Context, + State, + ParsedMetadata, + RequestInit, + ResponseData, + ResponseErr + > >; // implementation @@ -492,6 +559,7 @@ function subscription({ responseError?: ProcedureErrorSchemaType; description?: string; handler: SubscriptionProcedure< + object, object, object, PayloadType, @@ -516,6 +584,7 @@ function subscription({ function stream< Context, State, + ParsedMetadata, RequestInit extends PayloadType, RequestData extends PayloadType, ResponseData extends PayloadType, @@ -528,6 +597,7 @@ function stream< handler: StreamProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -537,6 +607,7 @@ function stream< StreamProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -548,6 +619,7 @@ function stream< function stream< Context, State, + ParsedMetadata, RequestInit extends PayloadType, RequestData extends PayloadType, ResponseData extends PayloadType, @@ -561,6 +633,7 @@ function stream< handler: StreamProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -570,6 +643,7 @@ function stream< StreamProcedure< Context, State, + ParsedMetadata, RequestInit, RequestData, ResponseData, @@ -592,6 +666,7 @@ function stream({ responseError?: ProcedureErrorSchemaType; description?: string; handler: StreamProcedure< + object, object, object, PayloadType, diff --git a/router/server.ts b/router/server.ts index 786b34fc..87341e44 100644 --- a/router/server.ts +++ b/router/server.ts @@ -1,4 +1,4 @@ -import { Static } from '@sinclair/typebox'; +import { Static, TSchema } from '@sinclair/typebox'; import { PayloadType, AnyProcedure } from './procedures'; import { ReaderErrorSchema, @@ -28,7 +28,7 @@ import { ProtocolVersion, TransportClientId, } from '../transport/message'; -import { ProcedureHandlerContext, ParsedMetadata } from './context'; +import { ProcedureHandlerContext } from './context'; import { Logger } from '../logging/log'; import { Value } from '@sinclair/typebox/value'; import { Err, Result, Ok, ErrResult } from './result'; @@ -55,21 +55,22 @@ type StreamId = string; */ export interface Server< Context extends object, + ParsedMetadata extends object, Services extends AnyServiceSchemaMap, > { /** * Services defined for this server. */ - services: InstantiatedServiceSchemaMap; + services: InstantiatedServiceSchemaMap; /** * A set of stream ids that are currently open. */ - streams: Map; + streams: Map>; close: () => Promise; } -interface StreamInitProps { +interface StreamInitProps { // msg derived streamId: StreamId; procedureName: string; @@ -92,7 +93,7 @@ interface StreamInitProps { passInitAsDataForBackwardsCompat: boolean; } -interface ProcStream { +interface ProcStream { streamId: StreamId; from: TransportClientId; procedureName: string; @@ -105,10 +106,17 @@ interface ProcStream { class RiverServer< Context extends object, + MetadataSchema extends TSchema, + ParsedMetadata extends object, Services extends AnyServiceSchemaMap, -> implements Server +> implements Server { - private transport: ServerTransport; + private transport: ServerTransport< + Connection, + MetadataSchema, + ParsedMetadata + >; + private contextMap: Map; private log?: Logger; private middlewares: Array; @@ -123,15 +131,19 @@ class RiverServer< private serverCancelledStreams: Map>; private maxCancelledStreamTombstonesPerSession: number; - public streams: Map; - public services: InstantiatedServiceSchemaMap; + public streams: Map>; + public services: InstantiatedServiceSchemaMap< + Context, + ParsedMetadata, + Services + >; private unregisterTransportListeners: () => void; constructor( - transport: ServerTransport, + transport: ServerTransport, services: Services, - handshakeOptions?: ServerHandshakeOptions, + handshakeOptions?: ServerHandshakeOptions, extendedContext?: Context, maxCancelledStreamTombstonesPerSession = 200, middlewares: Array = [], @@ -141,6 +153,7 @@ class RiverServer< this.services = instances as InstantiatedServiceSchemaMap< Context, + ParsedMetadata, Services >; this.contextMap = new Map(); @@ -252,7 +265,10 @@ class RiverServer< this.transport.addEventListener('transportStatus', handleTransportStatus); } - private createNewProcStream(span: Span, props: StreamInitProps) { + private createNewProcStream( + span: Span, + props: StreamInitProps, + ) { const { streamId, initialSession, @@ -399,7 +415,7 @@ class RiverServer< }; const finishedController = new AbortController(); - const procStream: ProcStream = { + const procStream: ProcStream = { from: from, streamId, procedureName, @@ -567,7 +583,11 @@ class RiverServer< closeReadable(); } - const handlerContextWithSpan: ProcedureHandlerContext = { + const handlerContextWithSpan: ProcedureHandlerContext< + object, + object, + ParsedMetadata + > = { ...serviceContext, from: from, sessionId, @@ -706,7 +726,7 @@ class RiverServer< private validateNewProcStream( initMessage: OpaqueTransportMessage, - ): StreamInitProps | null { + ): StreamInitProps | null { // lifetime safety: this is a sync function so this session cant transition // to another state before we finish const session = this.transport.sessions.get(initMessage.from); @@ -914,7 +934,7 @@ class RiverServer< serviceName: initMessage.serviceName, tracingCtx: initMessage.tracing, initPayload: initMessage.payload, - sessionMetadata, + sessionMetadata: sessionMetadata, procedure, serviceContext, procClosesWithInit: isStreamCloseBackwardsCompat( @@ -1018,7 +1038,9 @@ function getStreamCloseBackwardsCompat(protocolVersion: ProtocolVersion) { } export interface MiddlewareContext - extends Readonly, 'cancel'>> { + extends Readonly< + Omit, 'cancel'> + > { readonly streamId: StreamId; readonly procedureName: string; readonly serviceName: string; @@ -1047,12 +1069,14 @@ export type Middleware = (param: MiddlewareParam) => void; */ export function createServer< Context extends object, + MetadataSchema extends TSchema, + ParsedMetadata extends object, Services extends AnyServiceSchemaMap, >( - transport: ServerTransport, + transport: ServerTransport, services: Services, providedServerOptions?: Partial<{ - handshakeOptions?: ServerHandshakeOptions; + handshakeOptions?: ServerHandshakeOptions; extendedContext?: Context; /** * Maximum number of cancelled streams to keep track of to avoid @@ -1064,7 +1088,7 @@ export function createServer< */ middlewares?: Array; }>, -): Server { +): Server { return new RiverServer( transport, services, diff --git a/router/services.ts b/router/services.ts index 6c898ca9..e70499df 100644 --- a/router/services.ts +++ b/router/services.ts @@ -20,7 +20,8 @@ import { export interface Service< Context, State extends object, - Procs extends ProcedureMap, + ParsedMetadata, + Procs extends ProcedureMap, > { readonly state: State; readonly procedures: Procs; @@ -30,22 +31,25 @@ export interface Service< /** * Represents any {@link Service} object. */ -export type AnyService = Service; +export type AnyService = Service; /** * Represents any {@link ServiceSchema} object. */ -export type AnyServiceSchema = InstanceType< - ReturnType> +export type AnyServiceSchema< + Context extends object = object, + ParsedMetadata extends object = object, +> = InstanceType< + ReturnType> >; /** * A dictionary of {@link ServiceSchema}s, where the key is the service name. */ -export type AnyServiceSchemaMap = Record< - string, - AnyServiceSchema ->; +export type AnyServiceSchemaMap< + Context extends object = object, + ParsedMetadata extends object = object, +> = Record>; // This has the secret sauce to keep go to definition working, the structure is // somewhat delicate, so be careful when modifying it. Would be nice to add a @@ -56,9 +60,10 @@ export type AnyServiceSchemaMap = Record< */ export type InstantiatedServiceSchemaMap< Context extends object, - T extends AnyServiceSchemaMap, + ParsedMetadata extends object, + T extends AnyServiceSchemaMap, > = { - [K in keyof T]: T[K] extends AnyServiceSchema + [K in keyof T]: T[K] extends AnyServiceSchema ? T[K] extends { initializeState: (ctx: Context) => infer S; procedures: infer P; @@ -66,7 +71,12 @@ export type InstantiatedServiceSchemaMap< ? Service< Context, S extends object ? S : object, - P extends ProcedureMap + ParsedMetadata, + P extends ProcedureMap< + Context, + S extends object ? S : object, + ParsedMetadata + > ? P : ProcedureMap > @@ -142,9 +152,9 @@ export type ProcType< * A list of procedures where every procedure is "branded", as-in the procedure * was created via the {@link Procedure} constructors. */ -type BrandedProcedureMap = Record< +type BrandedProcedureMap = Record< string, - Branded> + Branded> >; type MaybeDisposable = State & { @@ -308,10 +318,13 @@ export function serializeSchema( * * When defining procedures, always use the {@link Procedure} constructors to create them. */ -export function createServiceSchema() { +export function createServiceSchema< + Context extends object = object, + ParsedMetadata extends object = object, +>() { return class ServiceSchema< State extends object, - Procedures extends ProcedureMap, + Procedures extends ProcedureMap, > { /** * Factory function for creating a fresh state. @@ -427,7 +440,7 @@ export function createServiceSchema() { */ static define< State extends object, - Procedures extends BrandedProcedureMap, + Procedures extends BrandedProcedureMap, >( config: ServiceConfiguration, procedures: Procedures, @@ -458,7 +471,9 @@ export function createServiceSchema() { * }); */ - static define>( + static define< + Procedures extends BrandedProcedureMap, + >( procedures: Procedures, ): ServiceSchema< object, @@ -468,11 +483,11 @@ export function createServiceSchema() { static define( configOrProcedures: | ServiceConfiguration - | BrandedProcedureMap, - maybeProcedures?: BrandedProcedureMap, + | BrandedProcedureMap, + maybeProcedures?: BrandedProcedureMap, ): ServiceSchema { let config: ServiceConfiguration; - let procedures: BrandedProcedureMap; + let procedures: BrandedProcedureMap; if ( 'initializeState' in configOrProcedures && @@ -486,7 +501,11 @@ export function createServiceSchema() { procedures = maybeProcedures; } else { config = { initializeState: () => ({}) }; - procedures = configOrProcedures as BrandedProcedureMap; + procedures = configOrProcedures as BrandedProcedureMap< + Context, + object, + ParsedMetadata + >; } return new ServiceSchema(config, procedures); @@ -581,7 +600,9 @@ export function createServiceSchema() { * You probably don't need this, usually the River server will handle this * for you. */ - instantiate(extendedContext: Context): Service { + instantiate( + extendedContext: Context, + ): Service { const state = this.initializeState(extendedContext); const dispose = async () => { await state[Symbol.asyncDispose]?.(); @@ -620,7 +641,11 @@ export function getSerializedProcErrors( * @see {@link ServiceSchema.scaffold} */ // note that this isn't exported -class ServiceScaffold { +class ServiceScaffold< + Context extends object, + State extends object, + ParsedMetadata extends object, +> { /** * The configuration for this service. */ @@ -653,7 +678,9 @@ class ServiceScaffold { * * @param procedures - The procedures for this service. */ - procedures>(procedures: T): T { + procedures>( + procedures: T, + ): T { return procedures; } @@ -675,7 +702,12 @@ class ServiceScaffold { * }); * ``` */ - finalize>(procedures: T) { - return createServiceSchema().define(this.config, procedures); + finalize>( + procedures: T, + ) { + return createServiceSchema().define( + this.config, + procedures, + ); } } diff --git a/testUtil/fixtures/cleanup.ts b/testUtil/fixtures/cleanup.ts index af08e79a..a54e2818 100644 --- a/testUtil/fixtures/cleanup.ts +++ b/testUtil/fixtures/cleanup.ts @@ -85,7 +85,7 @@ export async function ensureTransportBuffersAreEventuallyEmpty( } export async function ensureServerIsClean( - s: Server, + s: Server, ) { return waitFor(() => expect( @@ -112,8 +112,11 @@ export async function testFinishesCleanly({ server, }: Partial<{ clientTransports: Array>; - serverTransport: ServerTransport; - server: Server; + // MetadataSchema and ParsedMetadata are not used in this test, + // so we can safely use any here + // eslint-disable-next-line @typescript-eslint/no-explicit-any + serverTransport: ServerTransport; + server: Server; }>) { // pre-close invariants // invariant check servers first as heartbeats are authoritative on their side diff --git a/testUtil/fixtures/mockTransport.ts b/testUtil/fixtures/mockTransport.ts index 613d01ba..715f474a 100644 --- a/testUtil/fixtures/mockTransport.ts +++ b/testUtil/fixtures/mockTransport.ts @@ -1,4 +1,4 @@ -import { TransportClientId } from '../../transport'; +import { Transport, TransportClientId } from '../../transport'; import { ClientTransport } from '../../transport/client'; import { Connection } from '../../transport/connection'; import { ServerTransport } from '../../transport/server'; @@ -8,6 +8,8 @@ import { TestSetupHelpers, TestTransportOptions } from './transports'; import { Duplex } from 'node:stream'; import { duplexPair } from '../duplex/duplexPair'; import { nanoid } from 'nanoid'; +import { TSchema } from '@sinclair/typebox'; +import { ServerHandshakeOptions } from '../../router/handshake'; export class InMemoryConnection extends Connection { conn: Duplex; @@ -69,7 +71,7 @@ export function createMockTransportNetwork( // conn id -> [client->server, server->client] const connections = new Observable>({}); - const transports: Array = []; + const transports: Array> = []; class MockClientTransport extends ClientTransport { async createNewOutgoingConnection( to: TransportClientId, @@ -94,7 +96,14 @@ export function createMockTransportNetwork( } } - class MockServerTransport extends ServerTransport { + class MockServerTransport< + MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, + > extends ServerTransport< + InMemoryConnection, + MetadataSchema, + ParsedMetadata + > { subscribeCleanup: () => void; constructor( @@ -136,8 +145,19 @@ export function createMockTransportNetwork( return clientTransport; }, - getServerTransport: (id = 'SERVER', handshakeOptions) => { - const serverTransport = new MockServerTransport(id, opts?.server); + getServerTransport: < + MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, + >( + id = 'SERVER', + handshakeOptions: + | ServerHandshakeOptions + | undefined, + ) => { + const serverTransport = new MockServerTransport< + MetadataSchema, + ParsedMetadata + >(id, opts?.server); if (handshakeOptions) { serverTransport.extendHandshake(handshakeOptions); } diff --git a/testUtil/fixtures/transports.ts b/testUtil/fixtures/transports.ts index 3578bcaa..500e5e1a 100644 --- a/testUtil/fixtures/transports.ts +++ b/testUtil/fixtures/transports.ts @@ -20,6 +20,7 @@ import { TransportClientId } from '../../transport/message'; import { ClientTransport } from '../../transport/client'; import { Connection } from '../../transport/connection'; import { ServerTransport } from '../../transport/server'; +import { TSchema } from '@sinclair/typebox'; export type ValidTransports = 'ws' | 'mock'; @@ -33,10 +34,13 @@ export interface TestSetupHelpers { id: TransportClientId, handshakeOptions?: ClientHandshakeOptions, ) => ClientTransport; - getServerTransport: ( + getServerTransport: < + MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, + >( id?: TransportClientId, - handshakeOptions?: ServerHandshakeOptions, - ) => ServerTransport; + handshakeOptions?: ServerHandshakeOptions, + ) => ServerTransport; simulatePhantomDisconnect: () => void; restartServer: () => Promise; cleanup: () => Promise | void; @@ -56,7 +60,8 @@ export const transports: Array = [ let wss = createWebSocketServer(server); const transports: Array< - WebSocketClientTransport | WebSocketServerTransport + // eslint-disable-next-line @typescript-eslint/no-explicit-any + WebSocketClientTransport | WebSocketServerTransport > = []; return { @@ -91,12 +96,19 @@ export const transports: Array = [ return clientTransport; }, - getServerTransport(id = 'SERVER', handshakeOptions) { - const serverTransport = new WebSocketServerTransport( - wss, - id, - opts?.server, - ); + getServerTransport: < + MetadataSchema extends TSchema, + ParsedMetadata extends object, + >( + id = 'SERVER', + handshakeOptions: + | ServerHandshakeOptions + | undefined, + ) => { + const serverTransport = new WebSocketServerTransport< + MetadataSchema, + ParsedMetadata + >(wss, id, opts?.server); serverTransport.bindLogger((msg, ctx, level) => { if (ctx?.tags?.includes('invariant-violation')) { @@ -113,7 +125,11 @@ export const transports: Array = [ transports.push(serverTransport); - return serverTransport as ServerTransport; + return serverTransport as ServerTransport< + Connection, + MetadataSchema, + ParsedMetadata + >; }, async restartServer() { for (const transport of transports) { diff --git a/transport/impls/ws/server.ts b/transport/impls/ws/server.ts index bb01a639..d0b397bc 100644 --- a/transport/impls/ws/server.ts +++ b/transport/impls/ws/server.ts @@ -5,6 +5,7 @@ import { WsLike } from './wslike'; import { ServerTransport } from '../../server'; import { ProvidedServerTransportOptions } from '../../options'; import { type IncomingMessage } from 'http'; +import { TSchema } from '@sinclair/typebox'; function cleanHeaders( headers: IncomingMessage['headers'], @@ -21,7 +22,10 @@ function cleanHeaders( return cleanedHeaders; } -export class WebSocketServerTransport extends ServerTransport { +export class WebSocketServerTransport< + MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, +> extends ServerTransport { wss: WebSocketServer; constructor( diff --git a/transport/server.ts b/transport/server.ts index 5a53b8a4..8a7f8554 100644 --- a/transport/server.ts +++ b/transport/server.ts @@ -1,5 +1,4 @@ import { SpanStatusCode } from '@opentelemetry/api'; -import { ParsedMetadata } from '../router/context'; import { ServerHandshakeOptions } from '../router/handshake'; import { ControlMessageHandshakeRequestSchema, @@ -19,7 +18,7 @@ import { } from './options'; import { DeleteSessionOptions, Transport } from './transport'; import { coerceErrorString } from './stringifyError'; -import { Static } from '@sinclair/typebox'; +import { Static, TSchema } from '@sinclair/typebox'; import { Value } from '@sinclair/typebox/value'; import { ProtocolError } from './events'; import { Connection } from './connection'; @@ -33,6 +32,8 @@ import { export abstract class ServerTransport< ConnType extends Connection, + MetadataSchema extends TSchema = TSchema, + ParsedMetadata extends object = object, > extends Transport { /** * The options for this transport. @@ -42,7 +43,7 @@ export abstract class ServerTransport< /** * Optional handshake options for the server. */ - handshakeExtensions?: ServerHandshakeOptions; + handshakeExtensions?: ServerHandshakeOptions; /** * A map of session handshake data for each session. @@ -68,7 +69,9 @@ export abstract class ServerTransport< }); } - extendHandshake(options: ServerHandshakeOptions) { + extendHandshake( + options: ServerHandshakeOptions, + ) { this.handshakeExtensions = options; } @@ -262,7 +265,7 @@ export abstract class ServerTransport< } // invariant: must pass custom validation if defined - let parsedMetadata: ParsedMetadata = {}; + let parsedMetadata: ParsedMetadata = {} as ParsedMetadata; if (this.handshakeExtensions) { if (!Value.Check(this.handshakeExtensions.schema, msg.payload.metadata)) { this.rejectHandshakeRequest( @@ -324,7 +327,7 @@ export abstract class ServerTransport< } // success! - parsedMetadata = parsedMetadataOrFailureCode; + parsedMetadata = parsedMetadataOrFailureCode as ParsedMetadata; } // 4 connect cases diff --git a/transport/transport.test.ts b/transport/transport.test.ts index 7850e8ef..9845da02 100644 --- a/transport/transport.test.ts +++ b/transport/transport.test.ts @@ -29,7 +29,6 @@ import { ProvidedClientTransportOptions, ProvidedTransportOptions, } from './options'; -import { ParsedMetadata } from '../router'; describe.each(testMatrix())( 'transport connection behaviour tests ($transport.name transport, $codec.name codec)', @@ -881,7 +880,7 @@ describe.each(testMatrix())( const get = vi.fn(); const parse = vi.fn(() => { - const promise = new Promise(() => { + const promise = new Promise(() => { // noop we never want this to return }); @@ -1492,10 +1491,13 @@ describe.each(testMatrix())( kept: Type.String(), discarded: Type.String(), }); + + interface Metadata { + kept: string; + } + const get = vi.fn(async () => ({ kept: 'kept', discarded: 'discarded' })); - const parse = vi.fn(async (metadata: unknown) => ({ - // @ts-expect-error - we haven't extended the global type here - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const parse = vi.fn(async (metadata: Metadata) => ({ kept: metadata.kept, })); @@ -1537,10 +1539,13 @@ describe.each(testMatrix())( const schema = Type.Object({ foo: Type.String(), }); + + interface Metadata { + foo: string; + } + const get = vi.fn(async () => ({ foo: false })); - const parse = vi.fn(async (metadata: unknown) => ({ - // @ts-expect-error - we haven't extended the global type here - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const parse = vi.fn(async (metadata: Metadata) => ({ foo: metadata.foo, })); @@ -1593,10 +1598,12 @@ describe.each(testMatrix())( foo: Type.Boolean(), }); + interface Metadata { + foo: boolean; + } + const get = vi.fn(async () => ({ foo: 123 })); - const parse = vi.fn(async (metadata: unknown) => ({ - // @ts-expect-error - we haven't extended the global type here - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const parse = vi.fn(async (metadata: Metadata) => ({ foo: metadata.foo, })); @@ -1661,11 +1668,16 @@ describe.each(testMatrix())( discarded: 'discarded', })); - const validate = vi.fn(async (metadata: unknown, _previous: unknown) => ({ - // @ts-expect-error - we haven't extended the global type here - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - kept: metadata.kept, - })); + interface Metadata { + kept: string; + discarded: string; + } + + const validate = vi.fn( + async (metadata: Metadata, _previous: unknown) => ({ + kept: metadata.kept, + }), + ); const serverTransport = getServerTransport('SERVER', { schema,