diff --git a/docs/1.docs/50.tasks.md b/docs/1.docs/50.tasks.md index 11b0f4fa04..c1556b13ef 100644 --- a/docs/1.docs/50.tasks.md +++ b/docs/1.docs/50.tasks.md @@ -51,11 +51,18 @@ The `defineTask` helper accepts an object with the following properties: - **`meta`** (optional): An object with optional `name` and `description` string fields used for display in the dev server and CLI. - **`run`** (required): A function that receives a [`TaskEvent`](#taskevent) and returns (or resolves to) an object with an optional `result` property. +- **`concurrency`** (optional): Controls how concurrent calls are handled. Defaults to `{ mode: "dedupe" }`. ```ts interface Task { meta?: { name?: string; description?: string }; run(event: TaskEvent): { result?: RT } | Promise<{ result?: RT }>; + concurrency?: + | { mode: "parallel" } + | { + mode: "dedupe" | "serial"; + key?: (event: TaskEvent) => string; + }; } ``` @@ -276,7 +283,24 @@ The `--payload` flag accepts a JSON string that will be parsed and passed to the ### Concurrency -Each task can have **one running instance**. Calling a task of same name multiple times in parallel, results in calling it once and all callers will get the same return value. +By default, each task can have **one running instance for the same payload**. Calling a task of the same name and payload multiple times in parallel runs it once, and all callers receive the same return value. -> [!NOTE] -> Nitro tasks can be running multiple times and in parallel. +You can customize this behavior with the `concurrency` option: + +- **`dedupe`**: Coalesces concurrent calls to the same task and key into one execution. All callers wait for the same result. This is the default mode. +- **`parallel`**: Allows every call to run as an independent task instance. +- **`serial`**: Queues concurrent calls to the same task and key so they run one after another. + +For `dedupe` and `serial`, you can provide a `key` function to derive the execution key from the task event. If no key function is provided, Nitro uses the hash of the task payload. + +```ts [tasks/report/generate.ts] +export default defineTask({ + concurrency: { + mode: "serial", + key: ({ payload }) => payload.tenant, + }, + async run({ payload }) { + return { result: await generateTenantReport(payload.tenant) }; + }, +}); +``` diff --git a/src/runtime/internal/task.ts b/src/runtime/internal/task.ts index f119cefd3d..322e73eca2 100644 --- a/src/runtime/internal/task.ts +++ b/src/runtime/internal/task.ts @@ -1,6 +1,14 @@ import { Cron } from "croner"; import { HTTPError } from "h3"; -import type { Task, TaskContext, TaskEvent, TaskPayload, TaskResult } from "nitro/types"; +import { hash } from "ohash"; +import type { + Task, + TaskContext, + TaskConcurrency, + TaskEvent, + TaskPayload, + TaskResult, +} from "nitro/types"; import { scheduledTasks, tasks } from "#nitro/virtual/tasks"; /** @experimental */ @@ -13,17 +21,14 @@ export function defineTask(def: Task): Task { return def; } -const __runningTasks__: { [name: string]: ReturnType["run"]> } = {}; +const __runningTasks__ = new Map>(); +const __serialQueues__ = new Map>(); /** @experimental */ export async function runTask( name: string, { payload = {}, context = {} }: { payload?: TaskPayload; context?: TaskContext } = {} ): Promise> { - if (__runningTasks__[name]) { - return __runningTasks__[name]; - } - if (!(name in tasks)) { throw new HTTPError({ message: `Task \`${name}\` is not available!`, @@ -40,13 +45,24 @@ export async function runTask( const handler = (await tasks[name].resolve!()) as Task; const taskEvent: TaskEvent = { name, payload, context }; - __runningTasks__[name] = handler.run(taskEvent); + const concurrency: TaskConcurrency = handler.concurrency ?? { mode: "dedupe" }; - try { - const res = await __runningTasks__[name]; - return res; - } finally { - delete __runningTasks__[name]; + switch (concurrency.mode) { + case "parallel": { + return _callTask(handler, taskEvent); + } + case "dedupe": { + const key = _getTaskConcurrencyKey(concurrency, taskEvent); + return _runTaskOnce(key, () => _callTask(handler, taskEvent)); + } + case "serial": { + const key = _getTaskConcurrencyKey(concurrency, taskEvent); + return _runTaskSerially(key, () => _callTask(handler, taskEvent)); + } + default: { + const mode = (concurrency as { mode: string }).mode; + throw new Error(`Task \`${name}\` has an invalid concurrency mode: "${mode}"`); + } } } @@ -92,3 +108,52 @@ export function runCronTasks( ): Promise { return Promise.all(getCronTasks(cron).map((name) => runTask(name, ctx))); } + +async function _callTask(handler: Task, taskEvent: TaskEvent): Promise> { + return await handler.run(taskEvent); +} + +function _getTaskConcurrencyKey( + concurrency: Exclude, + taskEvent: TaskEvent +): string { + const key = concurrency.key ? concurrency.key(taskEvent) : hash(taskEvent.payload); + return `${taskEvent.name}:${key}`; +} + +function _runTaskOnce( + key: string, + run: () => Promise> +): Promise> { + const running = __runningTasks__.get(key); + if (running) { + return running as Promise>; + } + + const promise = run().finally(() => { + if (__runningTasks__.get(key) === promise) { + __runningTasks__.delete(key); + } + }); + __runningTasks__.set(key, promise); + + return promise; +} + +function _runTaskSerially( + key: string, + run: () => Promise> +): Promise> { + const previous = __serialQueues__.get(key) ?? Promise.resolve(); + const promise = previous.then(run); + const queue = promise + .catch(() => {}) + .then(() => { + if (__serialQueues__.get(key) === queue) { + __serialQueues__.delete(key); + } + }); + __serialQueues__.set(key, queue); + + return promise; +} diff --git a/src/types/runtime/task.ts b/src/types/runtime/task.ts index 2d997f746b..d439c1f362 100644 --- a/src/types/runtime/task.ts +++ b/src/types/runtime/task.ts @@ -26,10 +26,35 @@ export interface TaskResult { result?: RT; } +/** + * Controls how concurrent calls to the same task are handled. + * + * - `"parallel"`: Allow multiple instances of the same task to run concurrently. + * - `"dedupe"`: Coalesce concurrent calls with the same key into a single execution. + * All callers await the same promise and receive the same result. (default) + * - `"serial"`: Queue concurrent calls with the same key so they run one after another. + * + * @experimental + * @default { mode: "dedupe" } + */ +export type TaskConcurrency = + | { mode: "parallel" } + | { + mode: "dedupe" | "serial"; + /** + * Derives the dedupe or serial queue key from the task event. + * If omitted, the task payload hash is used. + * + * @default (event) => hash(event.payload) + */ + key?: (event: TaskEvent) => string; + }; + /** @experimental */ export interface Task { meta?: TaskMeta; run(event: TaskEvent): MaybePromise<{ result?: RT }>; + concurrency?: TaskConcurrency; } /** @experimental */ diff --git a/test/unit/task-concurrency.test.ts b/test/unit/task-concurrency.test.ts new file mode 100644 index 0000000000..b1d8a670e8 --- /dev/null +++ b/test/unit/task-concurrency.test.ts @@ -0,0 +1,268 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Task, TaskEvent } from "../../src/types/runtime/task.ts"; +import { runTask } from "../../src/runtime/internal/task.ts"; + +type VirtualTask = { + meta: NonNullable; + resolve?: () => Promise; +}; + +const mockTasks: Record = {}; + +vi.mock("#nitro/virtual/tasks", () => ({ + get tasks() { + return mockTasks; + }, + scheduledTasks: false, +})); + +describe("task concurrency", () => { + beforeEach(() => { + for (const key of Object.keys(mockTasks)) { + delete mockTasks[key]; + } + }); + + it("dedupes concurrent calls by task name and payload by default", async () => { + let calls = 0; + registerTask("default", { + run: vi.fn(async () => { + calls += 1; + return { result: calls }; + }), + }); + + const results = await Promise.all([ + runTask("default", { payload: { id: 1 } }), + runTask("default", { payload: { id: 1 } }), + runTask("default", { payload: { id: 2 } }), + ]); + const next = await runTask("default"); + + expect(results.map((result) => result.result)).toEqual([1, 1, 2]); + expect(next.result).toBe(3); + }); + + it("dedupes concurrent calls with the same custom key", async () => { + const run = vi.fn(async (event: TaskEvent) => ({ + result: event.payload.userId, + })); + registerTask("by-user", { + run, + concurrency: { + mode: "dedupe", + key: (event) => String(event.payload.userId), + }, + }); + + const results = await Promise.all([ + runTask("by-user", { payload: { userId: "a" } }), + runTask("by-user", { payload: { userId: "a" } }), + runTask("by-user", { payload: { userId: "b" } }), + ]); + + expect(run).toHaveBeenCalledTimes(2); + expect(results.map((result) => result.result)).toEqual(["a", "a", "b"]); + }); + + it("scopes custom keys by task name", async () => { + const concurrency = { + mode: "dedupe" as const, + key: () => "shared", + }; + const runA = vi.fn(async () => ({ result: "a" })); + const runB = vi.fn(async () => ({ result: "b" })); + + registerTask("task-a", { run: runA, concurrency }); + registerTask("task-b", { run: runB, concurrency }); + + const [a, b] = await Promise.all([runTask("task-a"), runTask("task-b")]); + + expect(runA).toHaveBeenCalledTimes(1); + expect(runB).toHaveBeenCalledTimes(1); + expect([a.result, b.result]).toEqual(["a", "b"]); + }); + + it("passes the full task event to custom key functions", async () => { + const key = vi.fn((event: TaskEvent) => String((event.context as { tag: string }).tag)); + registerTask("event-key", { + run: vi.fn(async () => ({ result: "ok" })), + concurrency: { mode: "dedupe", key }, + }); + + await runTask("event-key", { + payload: { id: 1 }, + context: { tag: "alpha" }, + }); + + expect(key).toHaveBeenCalledWith({ + name: "event-key", + payload: { id: 1 }, + context: { tag: "alpha" }, + }); + }); + + it("cleans up deduped calls after rejection", async () => { + let attempts = 0; + registerTask("flaky", { + run: vi.fn(async () => { + attempts += 1; + if (attempts === 1) { + throw new Error("transient"); + } + return { result: "ok" }; + }), + }); + + const rejected = await Promise.allSettled(callMany("flaky", 3)); + const retry = await runTask("flaky"); + + expect(rejected.every((result) => result.status === "rejected")).toBe(true); + expect(attempts).toBe(2); + expect(retry.result).toBe("ok"); + }); + + it("runs parallel tasks independently", async () => { + let calls = 0; + registerTask("parallel", { + run: vi.fn(async () => { + calls += 1; + return { result: calls }; + }), + concurrency: { mode: "parallel" }, + }); + + const results = await runMany("parallel", 3); + + expect(new Set(results.map((result) => result.result))).toEqual(new Set([1, 2, 3])); + }); + + it("serializes calls with the same key", async () => { + const firstRun = withResolvers(); + const events: string[] = []; + + registerTask("serial", { + run: vi.fn(async () => { + events.push("start"); + await firstRun.promise; + events.push("end"); + return { result: "ok" }; + }), + concurrency: { + mode: "serial", + key: () => "same", + }, + }); + + const first = runTask("serial"); + const second = runTask("serial"); + + await new Promise((resolve) => setTimeout(resolve, 0)); + expect(events).toEqual(["start"]); + + firstRun.resolve(); + await Promise.all([first, second]); + + expect(events).toEqual(["start", "end", "start", "end"]); + }); + + it("does not block serial calls with different keys", async () => { + const blockedRun = withResolvers(); + const events: string[] = []; + + registerTask("serial-by-key", { + run: vi.fn(async (event: TaskEvent) => { + const key = String(event.payload.key); + events.push(`start:${key}`); + if (key === "x") { + await blockedRun.promise; + } + events.push(`end:${key}`); + return { result: key }; + }), + concurrency: { + mode: "serial", + key: (event) => String(event.payload.key), + }, + }); + + const x = runTask("serial-by-key", { payload: { key: "x" } }); + const y = await runTask("serial-by-key", { payload: { key: "y" } }); + + expect(y.result).toBe("y"); + expect(events).toEqual(["start:x", "start:y", "end:y"]); + + blockedRun.resolve(); + await x; + + expect(events).toEqual(["start:x", "start:y", "end:y", "end:x"]); + }); + + it("continues a serial queue after rejection", async () => { + let attempts = 0; + registerTask("serial-flaky", { + run: vi.fn(async () => { + attempts += 1; + if (attempts === 1) { + throw new Error("first failed"); + } + return { result: "second ok" }; + }), + concurrency: { mode: "serial" }, + }); + + const results = await Promise.allSettled(callMany("serial-flaky", 2)); + + expect(results[0]).toMatchObject({ status: "rejected" }); + expect(results[1]).toMatchObject({ + status: "fulfilled", + value: { result: "second ok" }, + }); + }); + + it("throws for unknown concurrency modes", async () => { + const run = vi.fn(async () => ({ result: "should not run" })); + registerTask("invalid-mode", { + run, + concurrency: { mode: "batched" } as unknown as Task["concurrency"], + }); + + await expect(runTask("invalid-mode")).rejects.toThrow( + 'Task `invalid-mode` has an invalid concurrency mode: "batched"' + ); + expect(run).not.toHaveBeenCalled(); + }); + + it("throws for unknown or unresolved tasks", async () => { + mockTasks["no-handler"] = { meta: {} }; + + await expect(runTask("missing")).rejects.toThrow("Task `missing` is not available!"); + await expect(runTask("no-handler")).rejects.toThrow("Task `no-handler` is not implemented!"); + }); +}); + +function registerTask(name: string, task: Task) { + mockTasks[name] = { + meta: task.meta ?? {}, + resolve: () => Promise.resolve(task), + }; +} + +function runMany(name: string, count: number) { + return Promise.all(callMany(name, count)); +} + +function callMany(name: string, count: number) { + return Array.from({ length: count }, () => runTask(name)); +} + +// TODO: replace with Promise.withResolvers when targeting ES2024 +function withResolvers() { + let resolve!: (value: T | PromiseLike) => void; + let reject!: (reason?: unknown) => void; + const promise = new Promise((resolvePromise, rejectPromise) => { + resolve = resolvePromise; + reject = rejectPromise; + }); + return { promise, resolve, reject }; +}