From b6c32679bcaac7486121d3fb887a83020fd5161e Mon Sep 17 00:00:00 2001 From: Tobias O Date: Fri, 19 Jun 2026 11:13:29 +0200 Subject: [PATCH 1/2] feat(prediction-worker): bound Triton fan-out concurrency + transient retry Cap concurrent in-flight modelInfer streams per worker via a process-wide shared semaphore (TRITON_MAX_INFLIGHT_INFERS), so WORKER_CONCURRENCY jobs draw from one permit pool instead of bursting ~32 streams onto a single HTTP/2 connection. Add bounded jittered retry in the Triton client on transient transport errors (UNAVAILABLE, transport-signature INTERNAL) held inside the caller's permit, plus conservative channel keepalive. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../changes/bound-prediction-fanout/tasks.md | 26 +-- packages/triton-client/src/client.test.ts | 166 ++++++++++++++++++ packages/triton-client/src/client.ts | 121 ++++++++++--- packages/triton-client/src/index.ts | 3 + services/prediction-worker/src/config.test.ts | 25 +++ services/prediction-worker/src/config.ts | 25 +++ .../prediction-worker/src/dispatch.test.ts | 113 ++++++++++++ services/prediction-worker/src/dispatch.ts | 32 +++- services/prediction-worker/src/index.ts | 8 + services/prediction-worker/src/processor.ts | 20 ++- .../prediction-worker/src/semaphore.test.ts | 60 +++++++ services/prediction-worker/src/semaphore.ts | 46 +++++ 12 files changed, 602 insertions(+), 43 deletions(-) create mode 100644 services/prediction-worker/src/semaphore.test.ts create mode 100644 services/prediction-worker/src/semaphore.ts diff --git a/openspec/changes/bound-prediction-fanout/tasks.md b/openspec/changes/bound-prediction-fanout/tasks.md index 04fa9c6..8c8f3d4 100644 --- a/openspec/changes/bound-prediction-fanout/tasks.md +++ b/openspec/changes/bound-prediction-fanout/tasks.md @@ -1,30 +1,30 @@ ## 1. Config: concurrency + retry tunables -- [ ] 1.1 Add `maxInflightInfers` to `services/prediction-worker/src/config.ts` under `triton` via `configField()` (env `TRITON_MAX_INFLIGHT_INFERS`, env-wins, conservative positive-int default). -- [ ] 1.2 Add retry tunables (max attempts, base backoff ms) — either in `config.ts` (`triton`) or as documented `triton-client` defaults — read through typed config, not `process.env`. -- [ ] 1.3 Update `config.test.ts` to cover the new fields' defaults and override parsing. +- [x] 1.1 Add `maxInflightInfers` to `services/prediction-worker/src/config.ts` under `triton` via `configField()` (env `TRITON_MAX_INFLIGHT_INFERS`, env-wins, conservative positive-int default). +- [x] 1.2 Add retry tunables (max attempts, base backoff ms) — either in `config.ts` (`triton`) or as documented `triton-client` defaults — read through typed config, not `process.env`. +- [x] 1.3 Update `config.test.ts` to cover the new fields' defaults and override parsing. ## 2. Bound the fan-out with a shared semaphore -- [ ] 2.1 Add a minimal async semaphore utility (acquire returns a release handle; FIFO) — colocate in prediction-worker or `@protifer/shared` if reusable. -- [ ] 2.2 Construct one semaphore in `index.ts` sized by `config.triton.maxInflightInfers` and inject it into the processor/`dispatchAll` so it is process-wide (shared by all `WORKER_CONCURRENCY` jobs). -- [ ] 2.3 In `dispatch.ts`, acquire a permit around each `triton.modelInfer` call and release it in a `finally` (released on success, throw, and timeout alike). -- [ ] 2.4 Unit tests in `dispatch.test.ts`: (a) concurrent in-flight calls never exceed the limit across multiple simultaneous `dispatchAll` invocations; (b) a thrown `modelInfer` releases its permit (no leak); (c) excess calls wait rather than open immediately. +- [x] 2.1 Add a minimal async semaphore utility (acquire returns a release handle; FIFO) — colocate in prediction-worker or `@protifer/shared` if reusable. +- [x] 2.2 Construct one semaphore in `index.ts` sized by `config.triton.maxInflightInfers` and inject it into the processor/`dispatchAll` so it is process-wide (shared by all `WORKER_CONCURRENCY` jobs). +- [x] 2.3 In `dispatch.ts`, acquire a permit around each `triton.modelInfer` call and release it in a `finally` (released on success, throw, and timeout alike). +- [x] 2.4 Unit tests in `dispatch.test.ts`: (a) concurrent in-flight calls never exceed the limit across multiple simultaneous `dispatchAll` invocations; (b) a thrown `modelInfer` releases its permit (no leak); (c) excess calls wait rather than open immediately. ## 3. Transient transport retry in the client -- [ ] 3.1 In `packages/triton-client/src/client.ts`, wrap `modelInfer` with a bounded jittered retry firing only on the transient transport classes (`UNAVAILABLE`; transport-signature `INTERNAL` — bandwidth/parse/connection), never on `INVALID_ARGUMENT`/`NOT_FOUND`/`DEADLINE_EXCEEDED`. -- [ ] 3.2 Ensure the retry sits _inside_ the caller's held permit (retry loop in the client call, permit held by `dispatch.ts`) so retries do not widen concurrency. -- [ ] 3.3 Unit tests in `client.test.ts`: retries on transient classes up to the cap; no retry on deterministic/deadline classes; success-after-retry returns the response; exhausted retries surface the original classified error. +- [x] 3.1 In `packages/triton-client/src/client.ts`, wrap `modelInfer` with a bounded jittered retry firing only on the transient transport classes (`UNAVAILABLE`; transport-signature `INTERNAL` — bandwidth/parse/connection), never on `INVALID_ARGUMENT`/`NOT_FOUND`/`DEADLINE_EXCEEDED`. +- [x] 3.2 Ensure the retry sits _inside_ the caller's held permit (retry loop in the client call, permit held by `dispatch.ts`) so retries do not widen concurrency. +- [x] 3.3 Unit tests in `client.test.ts`: retries on transient classes up to the cap; no retry on deterministic/deadline classes; success-after-retry returns the response; exhausted retries surface the original classified error. ## 4. Channel keepalive -- [ ] 4.1 Add conservative `grpc.keepalive_time_ms` / `grpc.keepalive_timeout_ms` / `grpc.keepalive_permit_without_calls` options to the channel in `client.ts`, documented to avoid tripping Triton's server-side enforcement. -- [ ] 4.2 Confirm existing `client.test.ts` / `mock-server.test.ts` still pass with the new channel options. +- [x] 4.1 Add conservative `grpc.keepalive_time_ms` / `grpc.keepalive_timeout_ms` / `grpc.keepalive_permit_without_calls` options to the channel in `client.ts`, documented to avoid tripping Triton's server-side enforcement. +- [x] 4.2 Confirm existing `client.test.ts` / `mock-server.test.ts` still pass with the new channel options. ## 5. Verification -- [ ] 5.1 Run repo gates: `bun run typecheck`, `bun run lint`, `bun run format`, `bun run test`. +- [x] 5.1 Run repo gates: `bun run typecheck`, `bun run lint`, `bun run format`, `bun run test`. - [ ] 5.2 Run `bun run test:int` (stack up) to exercise the bounded fan-out against the mock/real Triton path. - [ ] 5.3 Load verification: on a real load run confirm no `Connection dropped` / `Bandwidth exhausted or memory limit exceeded` storm, the GPU is busy during prediction (not idle), prediction jobs complete, and BullMQ whole-job retries drop sharply. - [ ] 5.4 Tune `TRITON_MAX_INFLIGHT_INFERS` upward until Triton is well-utilized without reintroducing transport errors; record the chosen value and rationale (deploy runbook). diff --git a/packages/triton-client/src/client.test.ts b/packages/triton-client/src/client.test.ts index 17491f9..fa965bb 100644 --- a/packages/triton-client/src/client.test.ts +++ b/packages/triton-client/src/client.test.ts @@ -179,3 +179,169 @@ describe('modelInfer deadline enforcement', () => { expect(response.model_name).toBe('test') }, 10_000) }) + +type Step = 'ok' | { code: number; details: string } + +/** Start a gRPC server that walks `plan` per modelInfer call (last step repeats). */ +async function startFlakyTritonServer(plan: Step[]): Promise<{ + stop(): void + port: number + calls(): number +}> { + const proto = getPackageDef() + const server = new grpc.Server() + let calls = 0 + + server.addService(proto.inference.GRPCInferenceService.service, { + serverReady: ( + _: grpc.ServerUnaryCall, + cb: grpc.sendUnaryData, + ) => { + cb(null, { ready: true }) + }, + modelReady: ( + _: grpc.ServerUnaryCall, + cb: grpc.sendUnaryData, + ) => { + cb(null, { ready: true }) + }, + modelInfer: ( + _: grpc.ServerUnaryCall, + cb: grpc.sendUnaryData, + ) => { + const step = plan[Math.min(calls, plan.length - 1)] ?? 'ok' + calls++ + if (step === 'ok') { + cb(null, { model_name: 'test', outputs: [], raw_output_contents: [] }) + } else { + cb( + { code: step.code, details: step.details } as grpc.ServiceError, + null, + ) + } + }, + }) + + return new Promise((resolve, reject) => { + server.bindAsync( + '0.0.0.0:0', + grpc.ServerCredentials.createInsecure(), + (err: Error | null, boundPort: number) => { + if (err) { + reject(err) + return + } + resolve({ + stop: () => { + server.forceShutdown() + }, + port: boundPort, + calls: () => calls, + }) + }, + ) + }) +} + +const FAST_RETRY = { maxAttempts: 3, baseBackoffMs: 1 } + +describe('modelInfer transient-transport retry', () => { + const cleanups: Array<() => void> = [] + + afterAll(() => { + for (const fn of cleanups) fn() + }) + + async function makeClient(plan: Step[]) { + const srv = await startFlakyTritonServer(plan) + const client = createTritonClient(`localhost:${srv.port.toString()}`) + cleanups.push(() => { + srv.stop() + client.close() + }) + return { srv, client } + } + + const REQ = { model_name: 'test', inputs: [], outputs: [] } + + it('retries on UNAVAILABLE and succeeds after a transient drop', async () => { + const { srv, client } = await makeClient([ + { code: grpc.status.UNAVAILABLE, details: 'Connection dropped' }, + { code: grpc.status.UNAVAILABLE, details: 'Connection dropped' }, + 'ok', + ]) + const resp = await client.modelInfer(REQ, { retry: FAST_RETRY }) + expect(resp.model_name).toBe('test') + expect(srv.calls()).toBe(3) + }, 10_000) + + it('retries on transport-class INTERNAL (bandwidth exhausted)', async () => { + const { srv, client } = await makeClient([ + { + code: grpc.status.INTERNAL, + details: 'Bandwidth exhausted or memory limit exceeded', + }, + 'ok', + ]) + const resp = await client.modelInfer(REQ, { retry: FAST_RETRY }) + expect(resp.model_name).toBe('test') + expect(srv.calls()).toBe(2) + }, 10_000) + + it('does not retry a genuine server INTERNAL', async () => { + const { srv, client } = await makeClient([ + { + code: grpc.status.INTERNAL, + details: 'internal model assertion failed', + }, + 'ok', + ]) + await expect( + client.modelInfer(REQ, { retry: FAST_RETRY }), + ).rejects.toMatchObject({ code: grpc.status.INTERNAL }) + expect(srv.calls()).toBe(1) + }, 10_000) + + it('does not retry INVALID_ARGUMENT', async () => { + const { srv, client } = await makeClient([ + { code: grpc.status.INVALID_ARGUMENT, details: 'bad shape' }, + 'ok', + ]) + await expect( + client.modelInfer(REQ, { retry: FAST_RETRY }), + ).rejects.toMatchObject({ code: grpc.status.INVALID_ARGUMENT }) + expect(srv.calls()).toBe(1) + }, 10_000) + + it('does not retry NOT_FOUND', async () => { + const { srv, client } = await makeClient([ + { code: grpc.status.NOT_FOUND, details: 'no model' }, + 'ok', + ]) + await expect( + client.modelInfer(REQ, { retry: FAST_RETRY }), + ).rejects.toMatchObject({ code: grpc.status.NOT_FOUND }) + expect(srv.calls()).toBe(1) + }, 10_000) + + it('does not retry DEADLINE_EXCEEDED (maps to TritonTimeoutError)', async () => { + const { srv, client } = await makeClient([ + { code: grpc.status.DEADLINE_EXCEEDED, details: 'too slow' }, + 'ok', + ]) + await expect(client.modelInfer(REQ, { retry: FAST_RETRY })).rejects.toThrow( + TritonTimeoutError, + ) + expect(srv.calls()).toBe(1) + }, 10_000) + + it('surfaces the classified error once retries are exhausted', async () => { + const { srv, client } = await makeClient([ + { code: grpc.status.UNAVAILABLE, details: 'Connection dropped' }, + ]) + await expect( + client.modelInfer(REQ, { retry: { maxAttempts: 2, baseBackoffMs: 1 } }), + ).rejects.toMatchObject({ code: grpc.status.UNAVAILABLE }) + expect(srv.calls()).toBe(2) + }, 10_000) +}) diff --git a/packages/triton-client/src/client.ts b/packages/triton-client/src/client.ts index 6d1f9a8..3d77f7f 100644 --- a/packages/triton-client/src/client.ts +++ b/packages/triton-client/src/client.ts @@ -17,6 +17,43 @@ export const TRITON_MAX_MESSAGE_BYTES = 64 * 1024 * 1024 export const DEFAULT_DEADLINE_MS = 60_000 +export const DEFAULT_RETRY_MAX_ATTEMPTS = 3 +export const DEFAULT_RETRY_BASE_BACKOFF_MS = 100 + +const TRANSIENT_INTERNAL_RE = + /bandwidth exhausted|memory limit exceeded|failed parsing|connection|rst_stream|stream reset/i + +function isTransientTransportError(err: unknown): boolean { + if (err instanceof TritonTimeoutError) return false + if ( + err !== null && + typeof err === 'object' && + 'code' in err && + typeof (err as { code: unknown }).code === 'number' + ) { + const code = (err as { code: number }).code + if (code === (grpc.status.UNAVAILABLE as number)) return true + if (code === (grpc.status.INTERNAL as number)) { + const detail = ( + (err as { details?: string }).details ?? + (err as { message?: string }).message ?? + '' + ).toLowerCase() + return TRANSIENT_INTERNAL_RE.test(detail) + } + } + return false +} + +function retryBackoffMs(baseBackoffMs: number, attempt: number): number { + const exp = baseBackoffMs * 2 ** (attempt - 1) + return Math.round(exp / 2 + Math.random() * (exp / 2)) +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + /** * Thrown when a gRPC call to Triton exceeds the configured deadline. * Maps to gRPC status code 4 (DEADLINE_EXCEEDED). @@ -67,9 +104,17 @@ export interface InferResponse { raw_output_contents: Buffer[] } +export interface ModelInferRetryOptions { + /** Total attempts including the first. Values ≤1 disable retry. */ + maxAttempts: number + baseBackoffMs: number +} + export interface ModelInferOptions { /** Milliseconds before the gRPC call is cancelled. Defaults to DEFAULT_DEADLINE_MS (60 000). */ deadlineMs?: number + /** Bounded jittered retry on transient transport errors. Defaults to the DEFAULT_RETRY_* constants. */ + retry?: ModelInferRetryOptions } export interface TritonClient { @@ -127,37 +172,63 @@ export function createTritonClient(url: string): TritonClient { 'grpc.enable_retries': 0, 'grpc.max_receive_message_length': TRITON_MAX_MESSAGE_BYTES, 'grpc.max_send_message_length': TRITON_MAX_MESSAGE_BYTES, + // Keepalive pings only while calls are in flight (permit_without_calls: 0) + // so a half-open connection is detected mid-burst without tripping Triton's + // server-side min-ping-interval enforcement (ENHANCE_YOUR_CALM). + 'grpc.keepalive_time_ms': 30_000, + 'grpc.keepalive_timeout_ms': 10_000, + 'grpc.keepalive_permit_without_calls': 0, }, ) + function callOnce( + request: InferRequest, + deadlineMs: number, + ): Promise { + return new Promise((resolve, reject) => { + const deadline = new Date(Date.now() + deadlineMs) + stub.modelInfer( + request, + { deadline }, + (err: grpc.ServiceError | null, response: InferResponse) => { + if (err) { + if (err.code === grpc.status.DEADLINE_EXCEEDED) { + reject( + new TritonTimeoutError( + `Triton modelInfer timed out after ${deadlineMs.toString()} ms`, + deadlineMs, + ), + ) + } else { + reject(err) + } + } else { + resolve(response) + } + }, + ) + }) + } + return { - modelInfer( + async modelInfer( request: InferRequest, - { deadlineMs = DEFAULT_DEADLINE_MS }: ModelInferOptions = {}, + { deadlineMs = DEFAULT_DEADLINE_MS, retry }: ModelInferOptions = {}, ): Promise { - return new Promise((resolve, reject) => { - const deadline = new Date(Date.now() + deadlineMs) - stub.modelInfer( - request, - { deadline }, - (err: grpc.ServiceError | null, response: InferResponse) => { - if (err) { - if (err.code === grpc.status.DEADLINE_EXCEEDED) { - reject( - new TritonTimeoutError( - `Triton modelInfer timed out after ${deadlineMs.toString()} ms`, - deadlineMs, - ), - ) - } else { - reject(err) - } - } else { - resolve(response) - } - }, - ) - }) + const maxAttempts = retry?.maxAttempts ?? DEFAULT_RETRY_MAX_ATTEMPTS + const baseBackoffMs = + retry?.baseBackoffMs ?? DEFAULT_RETRY_BASE_BACKOFF_MS + let attempt = 0 + for (;;) { + attempt++ + try { + return await callOnce(request, deadlineMs) + } catch (err) { + if (attempt >= maxAttempts || !isTransientTransportError(err)) + throw err + await sleep(retryBackoffMs(baseBackoffMs, attempt)) + } + } }, serverReady(): Promise { diff --git a/packages/triton-client/src/index.ts b/packages/triton-client/src/index.ts index 6945775..4805799 100644 --- a/packages/triton-client/src/index.ts +++ b/packages/triton-client/src/index.ts @@ -2,6 +2,8 @@ export { createTritonClient, TritonTimeoutError, DEFAULT_DEADLINE_MS, + DEFAULT_RETRY_MAX_ATTEMPTS, + DEFAULT_RETRY_BASE_BACKOFF_MS, } from './client.ts' export type { TritonClient, @@ -10,6 +12,7 @@ export type { TensorInput, TensorOutput, ModelInferOptions, + ModelInferRetryOptions, } from './client.ts' export * from './constants.ts' export * from './float16.ts' diff --git a/services/prediction-worker/src/config.test.ts b/services/prediction-worker/src/config.test.ts index 4379a12..70d4247 100644 --- a/services/prediction-worker/src/config.test.ts +++ b/services/prediction-worker/src/config.test.ts @@ -25,6 +25,31 @@ describe('prediction-worker loadConfig', () => { expect(Object.isFrozen(cfg)).toBe(true) }) + it('applies conservative defaults for the fan-out tunables', () => { + const cfg = loadConfig(VALID_ENV) + expect(cfg.triton.maxInflightInfers).toBe(8) + expect(cfg.triton.retryMaxAttempts).toBe(3) + expect(cfg.triton.retryBaseBackoffMs).toBe(100) + }) + + it('parses env overrides for the fan-out tunables', () => { + const cfg = loadConfig({ + ...VALID_ENV, + TRITON_MAX_INFLIGHT_INFERS: '16', + TRITON_RETRY_MAX_ATTEMPTS: '5', + TRITON_RETRY_BASE_BACKOFF_MS: '250', + }) + expect(cfg.triton.maxInflightInfers).toBe(16) + expect(cfg.triton.retryMaxAttempts).toBe(5) + expect(cfg.triton.retryBaseBackoffMs).toBe(250) + }) + + it('rejects a non-positive concurrency limit', () => { + expect(() => + loadConfig({ ...VALID_ENV, TRITON_MAX_INFLIGHT_INFERS: '0' }), + ).toThrow() + }) + it('aggregates missing required fields into one error', () => { expect.assertions(2) try { diff --git a/services/prediction-worker/src/config.ts b/services/prediction-worker/src/config.ts index 4253bac..e902f44 100644 --- a/services/prediction-worker/src/config.ts +++ b/services/prediction-worker/src/config.ts @@ -4,6 +4,10 @@ import { secretField, zBooleanString, } from '@protifer/shared' +import { + DEFAULT_RETRY_BASE_BACKOFF_MS, + DEFAULT_RETRY_MAX_ATTEMPTS, +} from '@protifer/triton-client' import { z } from 'zod' export const ConfigSchema = defineConfig({ @@ -33,6 +37,27 @@ export const ConfigSchema = defineConfig({ type: z.coerce.number().int().positive(), default: 90_000, }), + maxInflightInfers: configField({ + envName: 'TRITON_MAX_INFLIGHT_INFERS', + description: + 'Max concurrent in-flight Triton modelInfer calls per worker, shared across all jobs. Conservative default; tune up against observed Triton capacity.', + type: z.coerce.number().int().positive(), + default: 8, + }), + retryMaxAttempts: configField({ + envName: 'TRITON_RETRY_MAX_ATTEMPTS', + description: + 'Total modelInfer attempts (incl. first) on transient transport errors. ≤1 disables retry.', + type: z.coerce.number().int().positive(), + default: DEFAULT_RETRY_MAX_ATTEMPTS, + }), + retryBaseBackoffMs: configField({ + envName: 'TRITON_RETRY_BASE_BACKOFF_MS', + description: + 'Base backoff in ms for the jittered transient-retry schedule.', + type: z.coerce.number().int().positive(), + default: DEFAULT_RETRY_BASE_BACKOFF_MS, + }), }, redis: { host: configField({ diff --git a/services/prediction-worker/src/dispatch.test.ts b/services/prediction-worker/src/dispatch.test.ts index 55e1f8e..0d78f2b 100644 --- a/services/prediction-worker/src/dispatch.test.ts +++ b/services/prediction-worker/src/dispatch.test.ts @@ -5,6 +5,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' import { ShapeError, DtypeError, DecodeError } from './adapters/errors.ts' import type { AdapterContext } from './adapters/types.ts' import { classifyError, dispatchAll } from './dispatch.ts' +import { createSemaphore } from './semaphore.ts' type StubBehavior = | { mode: 'succeed'; result: unknown } @@ -77,6 +78,48 @@ const CTX: AdapterContext = { sequence: 'MKTVRQERLK', } +const flush = () => new Promise((resolve) => setTimeout(resolve, 0)) + +/** Triton stub whose modelInfer blocks until explicitly released, tracking concurrency. */ +function makeGatedTriton() { + let inFlight = 0 + let maxInFlight = 0 + const pending: Array<() => void> = [] + const modelInfer = vi.fn().mockImplementation( + () => + new Promise((resolve) => { + inFlight++ + maxInFlight = Math.max(maxInFlight, inFlight) + pending.push(() => { + inFlight-- + resolve({ model_name: '', outputs: [], raw_output_contents: [] }) + }) + }), + ) + const triton = { + modelInfer, + serverReady: vi.fn().mockResolvedValue(true), + modelReady: vi.fn().mockResolvedValue(true), + close: vi.fn(), + } as unknown as TritonClient + return { + triton, + modelInfer, + get inFlight() { + return inFlight + }, + get maxInFlight() { + return maxInFlight + }, + get pendingCount() { + return pending.length + }, + releaseOne() { + pending.shift()?.() + }, + } +} + function fillRegistry(behaviors: StubBehavior[]) { for (const k of Object.keys(registryHolder)) { // eslint-disable-next-line @typescript-eslint/no-dynamic-delete @@ -364,3 +407,73 @@ describe('dispatchAll', () => { expect(modelErrors['seth']?.message).toBe('wrong shape') }) }) + +describe('dispatchAll concurrency bound', () => { + beforeEach(() => { + fillRegistry( + EIGHT_KEYS.map(() => ({ + mode: 'succeed' as const, + result: { ok: true }, + })), + ) + }) + + it('never exceeds the limit across simultaneous dispatchAll invocations', async () => { + const sem = createSemaphore(2) + const gated = makeGatedTriton() + + const p1 = dispatchAll(gated.triton, CTX, { semaphore: sem }) + const p2 = dispatchAll(gated.triton, CTX, { semaphore: sem }) + + await flush() + expect(gated.inFlight).toBe(2) + + while (gated.pendingCount > 0) { + gated.releaseOne() + await flush() + } + await Promise.all([p1, p2]) + + expect(gated.maxInFlight).toBe(2) + expect(sem.available).toBe(2) + }) + + it('releases a permit when modelInfer throws (no leak)', async () => { + const sem = createSemaphore(4) + const triton = { + modelInfer: vi + .fn() + .mockRejectedValue(Object.assign(new Error('down'), { code: 14 })), + serverReady: vi.fn().mockResolvedValue(true), + modelReady: vi.fn().mockResolvedValue(true), + close: vi.fn(), + } as unknown as TritonClient + + const { outputs, modelErrors } = await dispatchAll(triton, CTX, { + semaphore: sem, + }) + + expect(Object.keys(outputs).length).toBe(0) + expect(Object.keys(modelErrors).length).toBe(8) + expect(sem.available).toBe(4) + }) + + it('makes excess calls wait rather than opening immediately', async () => { + const sem = createSemaphore(3) + const gated = makeGatedTriton() + + const p = dispatchAll(gated.triton, CTX, { semaphore: sem }) + await flush() + + expect(gated.modelInfer).toHaveBeenCalledTimes(3) + expect(gated.inFlight).toBe(3) + + while (gated.pendingCount > 0) { + gated.releaseOne() + await flush() + } + await p + + expect(gated.modelInfer).toHaveBeenCalledTimes(8) + }) +}) diff --git a/services/prediction-worker/src/dispatch.ts b/services/prediction-worker/src/dispatch.ts index 889eeea..38a8b6d 100644 --- a/services/prediction-worker/src/dispatch.ts +++ b/services/prediction-worker/src/dispatch.ts @@ -6,11 +6,15 @@ import type { WorkerMetrics, } from '@protifer/shared' import { DEFAULT_DEADLINE_MS } from '@protifer/triton-client' -import type { TritonClient } from '@protifer/triton-client' +import type { + ModelInferRetryOptions, + TritonClient, +} from '@protifer/triton-client' import { ShapeError, DtypeError, DecodeError } from './adapters/errors.ts' import { ADAPTER_REGISTRY } from './adapters/index.ts' import type { AdapterContext, ModelAdapter } from './adapters/types.ts' +import type { Semaphore } from './semaphore.ts' const MAX_MESSAGE_LEN = 200 @@ -84,13 +88,33 @@ function truncate(s: string): string { * * No per-model retries; BullMQ whole-job retry handles transients. */ +async function withPermit( + semaphore: Semaphore | undefined, + fn: () => Promise, +): Promise { + if (!semaphore) return fn() + const release = await semaphore.acquire() + try { + return await fn() + } finally { + release() + } +} + export async function dispatchAll( triton: TritonClient, ctx: AdapterContext, { deadlineMs = DEFAULT_DEADLINE_MS, metrics, - }: { deadlineMs?: number; metrics?: WorkerMetrics } = {}, + semaphore, + retry, + }: { + deadlineMs?: number + metrics?: WorkerMetrics + semaphore?: Semaphore + retry?: ModelInferRetryOptions + } = {}, ): Promise<{ outputs: PredictionOutputs; modelErrors: ModelErrors }> { const adapters = Object.values(ADAPTER_REGISTRY) as ModelAdapter[] @@ -101,7 +125,9 @@ export async function dispatchAll( }) try { const req = adapter.buildRequest(ctx) - const resp = await triton.modelInfer(req, { deadlineMs }) + const resp = await withPermit(semaphore, () => + triton.modelInfer(req, { deadlineMs, retry }), + ) const decoded = adapter.decodeResponse(resp) endTimer?.({ status: 'success' }) return { adapter, decoded } diff --git a/services/prediction-worker/src/index.ts b/services/prediction-worker/src/index.ts index 63bec8e..a7a07ca 100644 --- a/services/prediction-worker/src/index.ts +++ b/services/prediction-worker/src/index.ts @@ -15,6 +15,7 @@ import pino from 'pino' import { ADAPTER_REGISTRY } from './adapters/index.ts' import { loadConfig } from './config.ts' import { processPredictionJob } from './processor.ts' +import { createSemaphore } from './semaphore.ts' initSentry('prediction-worker') @@ -24,6 +25,11 @@ const logger = pino({ name: 'prediction-worker', ...defaultPinoOptions() }) const triton = createTritonClient(config.triton.url) const store = createObjectStoreFromConfig(config.storage) const metrics = createWorkerMetrics() +const inferSemaphore = createSemaphore(config.triton.maxInflightInfers) +const inferRetry = { + maxAttempts: config.triton.retryMaxAttempts, + baseBackoffMs: config.triton.retryBaseBackoffMs, +} if (config.metrics.enabled) { const metricsServer = startMetricsServer({ @@ -46,6 +52,8 @@ createWorkerApp({ store, deadlineMs: config.triton.deadlineMs, metrics, + semaphore: inferSemaphore, + retry: inferRetry, }) }, triton, diff --git a/services/prediction-worker/src/processor.ts b/services/prediction-worker/src/processor.ts index 506fc08..baccf4d 100644 --- a/services/prediction-worker/src/processor.ts +++ b/services/prediction-worker/src/processor.ts @@ -12,10 +12,14 @@ import { DEFAULT_DEADLINE_MS, fp16BufferToFp32Array, } from '@protifer/triton-client' -import type { TritonClient } from '@protifer/triton-client' +import type { + ModelInferRetryOptions, + TritonClient, +} from '@protifer/triton-client' import type { AdapterContext } from './adapters/types.ts' import { dispatchAll } from './dispatch.ts' +import type { Semaphore } from './semaphore.ts' interface ProcessorDeps { triton: TritonClient @@ -23,6 +27,9 @@ interface ProcessorDeps { /** gRPC deadline for Triton modelInfer calls (ms). Defaults to DEFAULT_DEADLINE_MS. */ deadlineMs?: number metrics?: WorkerMetrics + /** Process-wide bound on concurrent in-flight modelInfer calls, shared across jobs. */ + semaphore?: Semaphore + retry?: ModelInferRetryOptions } export async function processPredictionJob( @@ -31,7 +38,14 @@ export async function processPredictionJob( ): Promise { const { sequence, sequenceHash, embeddingModel, predictionModels } = job.data as PredictionJobData - const { triton, store, deadlineMs = DEFAULT_DEADLINE_MS, metrics } = deps + const { + triton, + store, + deadlineMs = DEFAULT_DEADLINE_MS, + metrics, + semaphore, + retry, + } = deps await job.updateProgress(10) @@ -71,6 +85,8 @@ export async function processPredictionJob( const { outputs, modelErrors } = await dispatchAll(triton, ctx, { deadlineMs, metrics, + semaphore, + retry, }) endJobTimer?.({ status: Object.keys(outputs).length === 0 ? 'failure' : 'success', diff --git a/services/prediction-worker/src/semaphore.test.ts b/services/prediction-worker/src/semaphore.test.ts new file mode 100644 index 0000000..3ffbee8 --- /dev/null +++ b/services/prediction-worker/src/semaphore.test.ts @@ -0,0 +1,60 @@ +import { describe, it, expect } from 'vitest' + +import { createSemaphore } from './semaphore.ts' + +describe('createSemaphore', () => { + it('rejects non-positive permit counts', () => { + expect(() => createSemaphore(0)).toThrow() + expect(() => createSemaphore(-1)).toThrow() + expect(() => createSemaphore(1.5)).toThrow() + }) + + it('grants up to `permits` immediately, then queues', async () => { + const sem = createSemaphore(2) + const r1 = await sem.acquire() + const r2 = await sem.acquire() + expect(sem.available).toBe(0) + + let third = false + const p3 = sem.acquire().then((r) => { + third = true + return r + }) + await Promise.resolve() + expect(third).toBe(false) + + r1() + const r3 = await p3 + expect(third).toBe(true) + r2() + r3() + expect(sem.available).toBe(2) + }) + + it('hands a released permit to the longest-waiting acquirer (FIFO)', async () => { + const sem = createSemaphore(1) + const held = await sem.acquire() + const order: number[] = [] + const a = sem.acquire().then((r) => { + order.push(1) + return r + }) + const b = sem.acquire().then((r) => { + order.push(2) + return r + }) + + held() + ;(await a)() + ;(await b)() + expect(order).toEqual([1, 2]) + }) + + it('is idempotent: releasing twice does not over-credit permits', async () => { + const sem = createSemaphore(1) + const r = await sem.acquire() + r() + r() + expect(sem.available).toBe(1) + }) +}) diff --git a/services/prediction-worker/src/semaphore.ts b/services/prediction-worker/src/semaphore.ts new file mode 100644 index 0000000..7a8f8b0 --- /dev/null +++ b/services/prediction-worker/src/semaphore.ts @@ -0,0 +1,46 @@ +export type Release = () => void + +export interface Semaphore { + acquire(): Promise + readonly available: number +} + +export function createSemaphore(permits: number): Semaphore { + if (!Number.isInteger(permits) || permits < 1) { + throw new Error( + `semaphore permits must be a positive integer, got ${String(permits)}`, + ) + } + + let available = permits + const waiters: Array<(release: Release) => void> = [] + + const makeRelease = (): Release => { + let released = false + return () => { + if (released) return + released = true + const next = waiters.shift() + if (next) { + next(makeRelease()) + } else { + available++ + } + } + } + + return { + acquire(): Promise { + if (available > 0) { + available-- + return Promise.resolve(makeRelease()) + } + return new Promise((resolve) => { + waiters.push(resolve) + }) + }, + get available(): number { + return available + }, + } +} From d7a93a8adecbd41c96fd3025ff3c928610996a9c Mon Sep 17 00:00:00 2001 From: Tobias O Date: Fri, 19 Jun 2026 11:19:35 +0200 Subject: [PATCH 2/2] test(prediction-worker): verify bounded fan-out e2e against triton-stub Backend E2E suite (17 tests incl. full prediction pipeline) passes against the docker-compose.test.yml stack with the semaphore-bounded prediction worker built from source: jobs process cleanly, no transport storm, 0 worker restarts. Co-Authored-By: Claude Opus 4.8 (1M context) --- openspec/changes/bound-prediction-fanout/tasks.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openspec/changes/bound-prediction-fanout/tasks.md b/openspec/changes/bound-prediction-fanout/tasks.md index 8c8f3d4..a5778a6 100644 --- a/openspec/changes/bound-prediction-fanout/tasks.md +++ b/openspec/changes/bound-prediction-fanout/tasks.md @@ -25,6 +25,6 @@ ## 5. Verification - [x] 5.1 Run repo gates: `bun run typecheck`, `bun run lint`, `bun run format`, `bun run test`. -- [ ] 5.2 Run `bun run test:int` (stack up) to exercise the bounded fan-out against the mock/real Triton path. +- [x] 5.2 Run `bun run test:int` (stack up) to exercise the bounded fan-out against the mock/real Triton path. - [ ] 5.3 Load verification: on a real load run confirm no `Connection dropped` / `Bandwidth exhausted or memory limit exceeded` storm, the GPU is busy during prediction (not idle), prediction jobs complete, and BullMQ whole-job retries drop sharply. - [ ] 5.4 Tune `TRITON_MAX_INFLIGHT_INFERS` upward until Triton is well-utilized without reintroducing transport errors; record the chosen value and rationale (deploy runbook).