diff --git a/packages/cli/src/cli.ts b/packages/cli/src/cli.ts index e215dcd..da74f3f 100644 --- a/packages/cli/src/cli.ts +++ b/packages/cli/src/cli.ts @@ -20,6 +20,7 @@ import { type Target2DDemoOptions, } from "@ignitionrl/examples"; import { + createDqnLearner, createLinearPolicySearchLearner, createTabularQLearner, } from "@ignitionrl/learning"; @@ -3460,6 +3461,17 @@ async function runKnownCheckpointInference( }); } + if (sourceRun.envId === GridWorld.id && sourceRun.algorithm === "dqn") { + return runCheckpointPolicyExperiment({ + ...common, + env: GridWorld, + learner: createDqnLearner(), + policyOptions: { + explore: false, + }, + }); + } + if (sourceRun.envId === Target2D.id && sourceRun.algorithm === "tabular-q-learning") { return runCheckpointPolicyExperiment({ ...common, @@ -3471,6 +3483,17 @@ async function runKnownCheckpointInference( }); } + if (sourceRun.envId === Target2D.id && sourceRun.algorithm === "dqn") { + return runCheckpointPolicyExperiment({ + ...common, + env: Target2D, + learner: createDqnLearner(), + policyOptions: { + explore: false, + }, + }); + } + if (sourceRun.envId === DroneTarget.id && sourceRun.algorithm === "linear-policy-search") { return runCheckpointPolicyExperiment({ ...common, diff --git a/packages/learning/README.md b/packages/learning/README.md index a3dbcc4..cbf9782 100644 --- a/packages/learning/README.md +++ b/packages/learning/README.md @@ -2,9 +2,10 @@ Baseline learning helpers for IgnitionRL. -This package intentionally starts small. It does not implement PPO, DQN, SAC, tensors or native backends. The first learners are deterministic baselines for validating environment loops, transitions, metrics and checkpoint plumbing: +This package intentionally starts small. It does not implement PPO, SAC, native tensor backends or high-throughput training. The first learners validate environment loops, transitions, metrics and checkpoint plumbing: - tabular Q-learning for discrete environments; +- DQN for small deterministic discrete debug environments; - linear policy search for bounded continuous environments. ## Tabular Q-Learning @@ -27,6 +28,32 @@ const action = result.learner.selectAction([0, 0, 3, 3], { explore: false }) console.log(action) ``` +## DQN + +```ts +import { trainDqn } from "@ignitionrl/learning" +import { GridWorld } from "@ignitionrl/examples" + +const result = await trainDqn(GridWorld, { + episodes: 200, + maxSteps: 50, + seed: 42, + learnerOptions: { + hiddenSize: 32, + epsilonStart: 0.4, + epsilonEnd: 0.05, + }, +}) + +const action = result.learner.selectAction( + GridWorld.createRunner({ seed: 42 }).getObservation(), + { explore: false } +) +console.log(action) +``` + +`DqnLearner` is the first neural learner path. It is intentionally compact: a seeded one-hidden-layer MLP, replay buffer, epsilon decay and target network for discrete toy/debug environments. It is not the final high-throughput backend, but it proves that neural learners can train, emit stable metrics and checkpoint behind the same environment contract. + ## Linear Policy Search ```ts @@ -82,6 +109,7 @@ const metrics = learnerMetricSpecsForAlgorithm("tabular-q-learning") Default baseline configs are: - `tabular-q-learning`: `learningRate: 0.2`, `discount: 0.95`, `epsilon: 0.1`, `initialQ: 0`, `observationPrecision: 2`, `seed: 0`; +- `dqn`: `learningRate: 0.01`, `discount: 0.95`, `epsilonStart: 0.4`, `epsilonEnd: 0.05`, `epsilonDecaySteps: 500`, `batchSize: 16`, `replayCapacity: 5000`, `minReplaySize: 16`, `trainEverySteps: 1`, `targetUpdateInterval: 50`, `hiddenSize: 32`, `gradientClip: 5`, `weightScale: 0.05`, `seed: 0`; - `linear-policy-search`: `sigma: 0.2`, `actionNoise: 0.03`, `initialWeightScale: 0.05`, `populationSize: 6`, `eliteCount: 2`, `seed: 0`. Default neural adapter cadences are: @@ -112,7 +140,7 @@ Checkpoints are JSON-serializable and include: - action shape and bounds, or discrete action values; - learner config; - metrics; -- learned Q-table or policy weights. +- learned Q-table, DQN network weights or policy weights. `linear-policy-search` keeps v1 checkpoint loading backward-compatible when newer diagnostic metrics are missing. Loading normalizes those fields before inference, so older demo artifacts can still be replayed through the current learner. @@ -146,7 +174,7 @@ Built-in support profiles are intentionally conservative: Unsupported combinations fail before a run starts with an algorithm-specific error. Custom neural adapters must declare their supported action spaces explicitly. -Current `TabularQLearner` and `LinearPolicySearchLearner` remain direct TypeScript learners. Future DQN/PPO/SAC implementations can sit behind this contract whether the backend is TypeScript, Rust/Burn, Rust/Candle or another native process. Environment authors still only implement `defineEnvironment()`. +Current `TabularQLearner`, `DqnLearner` and `LinearPolicySearchLearner` remain direct TypeScript learners. Future PPO/SAC implementations, and heavier DQN backends, can sit behind this contract whether the backend is TypeScript, Rust/Burn, Rust/Candle or another native process. Environment authors still only implement `defineEnvironment()`. ## Scope diff --git a/packages/learning/src/dqn.ts b/packages/learning/src/dqn.ts new file mode 100644 index 0000000..952ffd9 --- /dev/null +++ b/packages/learning/src/dqn.ts @@ -0,0 +1,880 @@ +import { + createSeededRng, + validateActionSpec, + type DiscreteActionSpec, + type DiscreteActionValue, + type EnvironmentSpec, + type JsonObject, + type Learner, + type LearnerConfig, + type Metrics, + type Policy, + type PolicyActContext, + type Seed, + type SeededRng, + type Transition, +} from "@ignitionrl/core"; +import { + assertNeuralCheckpointEnvelope, + createNeuralCheckpointEnvelope, + defineNeuralLearnerAdapterContract, + type NeuralCheckpointEnvelope, +} from "./neural-adapter.js"; + +export const DQN_ALGORITHM = "dqn"; +export const DQN_CHECKPOINT_PAYLOAD_VERSION = 1; + +export type DqnConfig = { + readonly seed: Seed; + readonly learningRate: number; + readonly discount: number; + readonly epsilonStart: number; + readonly epsilonEnd: number; + readonly epsilonDecaySteps: number; + readonly batchSize: number; + readonly replayCapacity: number; + readonly minReplaySize: number; + readonly trainEverySteps: number; + readonly targetUpdateInterval: number; + readonly hiddenSize: number; + readonly gradientClip: number; + readonly weightScale: number; +}; + +export type DqnOptions = Partial & { + readonly metadata?: JsonObject; +}; + +export type DqnMetrics = Readonly> & { + readonly transitions: number; + readonly episodes: number; + readonly updates: number; + readonly epsilon: number; + readonly explorationRate: number; + readonly exploratoryActions: number; + readonly greedyActions: number; + readonly replayBufferSize: number; + readonly tdLoss: number; + readonly meanAbsTdError: number; + readonly meanQ: number; + readonly lastEpisodeReward: number; + readonly lastEpisodeLength: number; +}; + +export type DqnNetwork = { + readonly inputSize: number; + readonly hiddenSize: number; + readonly outputSize: number; + readonly hiddenWeights: readonly (readonly number[])[]; + readonly outputWeights: readonly (readonly number[])[]; +}; + +export type DqnPolicyOptions = { + readonly explore?: boolean; +}; + +export type DqnCheckpoint = NeuralCheckpointEnvelope; + +type MutableDqnNetwork = { + inputSize: number; + hiddenSize: number; + outputSize: number; + hiddenWeights: number[][]; + outputWeights: number[][]; +}; + +type DqnReplayTransition = { + readonly observation: readonly number[]; + readonly action: TAction; + readonly actionIndex: number; + readonly reward: number; + readonly nextObservation: readonly number[]; + readonly done: boolean; + readonly terminated: boolean; + readonly truncated: boolean; +}; + +type ForwardPass = { + readonly observation: readonly number[]; + readonly hiddenRaw: readonly number[]; + readonly hidden: readonly number[]; + readonly qValues: readonly number[]; +}; + +export const DEFAULT_DQN_CONFIG: DqnConfig = Object.freeze({ + seed: 0, + learningRate: 0.01, + discount: 0.95, + epsilonStart: 0.4, + epsilonEnd: 0.05, + epsilonDecaySteps: 500, + batchSize: 16, + replayCapacity: 5_000, + minReplaySize: 16, + trainEverySteps: 1, + targetUpdateInterval: 50, + hiddenSize: 32, + gradientClip: 5, + weightScale: 0.05, +}); + +export class DqnLearner< + TAction extends DiscreteActionValue = DiscreteActionValue, +> implements Learner { + readonly name = DQN_ALGORITHM; + + private config: DqnConfig; + private metadata: JsonObject | undefined; + private rng: SeededRng; + private envSpec: EnvironmentSpec | undefined; + private actions: TAction[] = []; + private network: MutableDqnNetwork | undefined; + private targetNetwork: MutableDqnNetwork | undefined; + private replay: DqnReplayTransition[] = []; + private transitions = 0; + private episodes = 0; + private updates = 0; + private exploratoryActions = 0; + private greedyActions = 0; + private currentEpisodeReward = 0; + private currentEpisodeLength = 0; + private lastEpisodeReward = 0; + private lastEpisodeLength = 0; + private tdLoss = 0; + private meanAbsTdError = 0; + private meanQ = 0; + + constructor(options: DqnOptions = {}) { + this.config = defineDqnConfig(options); + this.metadata = options.metadata; + this.rng = createSeededRng(this.config.seed); + } + + async init(spec: EnvironmentSpec, config: LearnerConfig = {}): Promise { + const nextConfig = defineDqnConfig({ + ...this.config, + ...config, + }); + + validateActionSpec(spec.actions); + if (spec.actions.type !== "discrete") { + throw new Error("[IgnitionRL] DqnLearner only supports discrete action spaces."); + } + + const observationSize = spec.observation.shape[0]; + if (!Number.isInteger(observationSize) || observationSize <= 0) { + throw new Error("[IgnitionRL] DqnLearner requires a positive vector observation size."); + } + + this.config = nextConfig; + this.rng = createSeededRng(this.config.seed); + this.envSpec = spec as EnvironmentSpec; + this.actions = discreteActionsFromSpec(spec.actions) as TAction[]; + this.network = createNetwork( + observationSize, + this.config.hiddenSize, + this.actions.length, + this.rng, + this.config.weightScale, + ); + this.targetNetwork = cloneNetwork(this.network); + this.replay = []; + this.transitions = 0; + this.episodes = 0; + this.updates = 0; + this.exploratoryActions = 0; + this.greedyActions = 0; + this.currentEpisodeReward = 0; + this.currentEpisodeLength = 0; + this.lastEpisodeReward = 0; + this.lastEpisodeLength = 0; + this.tdLoss = 0; + this.meanAbsTdError = 0; + this.meanQ = 0; + } + + async act( + input: Float32Array | PolicyActContext, + ): Promise { + return this.selectAction(observationFromInput(input), { explore: true }); + } + + async observe(transition: Transition): Promise { + this.assertInitialized(); + + const record: DqnReplayTransition = { + observation: this.assertObservation(transition.observation), + action: transition.action, + actionIndex: this.actionIndexFromTransition(transition), + reward: finiteNumber(transition.reward, "transition.reward"), + nextObservation: this.assertObservation(transition.nextObservation), + done: transition.done, + terminated: transition.terminated, + truncated: transition.truncated, + }; + + this.pushReplay(record); + this.transitions += 1; + this.currentEpisodeReward += record.reward; + this.currentEpisodeLength += 1; + + if ( + this.replay.length >= this.config.minReplaySize + && this.transitions % this.config.trainEverySteps === 0 + ) { + this.trainBatch(); + } + + if (record.done) { + this.episodes += 1; + this.lastEpisodeReward = this.currentEpisodeReward; + this.lastEpisodeLength = this.currentEpisodeLength; + this.currentEpisodeReward = 0; + this.currentEpisodeLength = 0; + } + } + + async update(): Promise { + return this.getMetrics(); + } + + async save(path: string): Promise { + const { writeFile } = await import("node:fs/promises"); + + await writeFile(path, `${JSON.stringify(this.toCheckpoint(), null, 2)}\n`, "utf8"); + } + + async load(path: string): Promise { + const { readFile } = await import("node:fs/promises"); + const raw = await readFile(path, "utf8"); + + this.loadCheckpoint(JSON.parse(raw) as unknown); + } + + policy(options: DqnPolicyOptions = {}): Policy { + return { + name: this.name, + act: (context) => this.selectAction(context.observation, { + explore: options.explore ?? false, + }), + }; + } + + selectAction( + observation: readonly number[] | Float32Array, + options: DqnPolicyOptions = {}, + ): TAction { + const network = this.requireNetwork(); + const values = this.assertObservation(observation); + const explore = options.explore ?? false; + const epsilon = this.currentEpsilon(); + + if (explore && this.rng.bool(epsilon)) { + this.exploratoryActions += 1; + return this.rng.pick(this.actions); + } + + this.greedyActions += 1; + + const qValues = forward(network, values).qValues; + const bestValue = Math.max(...qValues); + const bestIndexes = qValues.flatMap((value, index) => + value === bestValue ? [index] : []); + const index = explore ? this.rng.pick(bestIndexes) : (bestIndexes[0] as number); + + return this.actions[index] as TAction; + } + + qValuesForObservation(observation: readonly number[] | Float32Array): readonly number[] { + return [...forward(this.requireNetwork(), this.assertObservation(observation)).qValues]; + } + + getMetrics(): DqnMetrics { + const totalActions = this.exploratoryActions + this.greedyActions; + + return { + transitions: this.transitions, + episodes: this.episodes, + updates: this.updates, + epsilon: this.currentEpsilon(), + explorationRate: totalActions === 0 ? 0 : this.exploratoryActions / totalActions, + exploratoryActions: this.exploratoryActions, + greedyActions: this.greedyActions, + replayBufferSize: this.replay.length, + tdLoss: this.tdLoss, + meanAbsTdError: this.meanAbsTdError, + meanQ: this.meanQ, + lastEpisodeReward: this.lastEpisodeReward, + lastEpisodeLength: this.lastEpisodeLength, + }; + } + + getConfig(): DqnConfig { + return { ...this.config }; + } + + toCheckpoint(): DqnCheckpoint { + const spec = this.requireInitializedSpec(); + const network = this.requireNetwork(); + const targetNetwork = this.requireTargetNetwork(); + const contract = defineNeuralLearnerAdapterContract(spec, { + algorithm: DQN_ALGORITHM, + updateCadence: { + type: "step", + everySteps: this.config.trainEverySteps, + warmupSteps: this.config.minReplaySize, + batchSize: this.config.batchSize, + }, + }); + + return createNeuralCheckpointEnvelope(contract, { + payloadVersion: DQN_CHECKPOINT_PAYLOAD_VERSION, + actions: [...this.actions], + config: { ...this.config }, + network: networkToJson(network), + targetNetwork: networkToJson(targetNetwork), + }, { + metrics: this.getMetrics(), + ...(this.metadata !== undefined ? { metadata: this.metadata } : {}), + }); + } + + loadCheckpoint(checkpoint: unknown, spec?: EnvironmentSpec): void { + const envelope = normalizeDqnCheckpoint(checkpoint); + const payload = envelope.payload; + const actions = actionsFromPayload(payload.actions); + const config = defineDqnConfig( + recordFromPayload(payload.config, "payload.config") as Partial, + ); + const network = networkFromPayload(payload.network, "payload.network"); + const targetNetwork = networkFromPayload(payload.targetNetwork, "payload.targetNetwork"); + const resolvedSpec = spec ?? specFromCheckpoint(envelope, actions); + const contract = defineNeuralLearnerAdapterContract(resolvedSpec, { + algorithm: DQN_ALGORITHM, + updateCadence: envelope.updateCadence, + }); + + assertNeuralCheckpointEnvelope(envelope, contract); + if (resolvedSpec.actions.type !== "discrete") { + throw new Error("[IgnitionRL] DQN checkpoints require a discrete action spec."); + } + + const expectedActions = discreteActionsFromSpec(resolvedSpec.actions); + if (!sameDiscreteActions(actions, expectedActions)) { + throw new Error("[IgnitionRL] DQN checkpoint actions do not match the environment action spec."); + } + + if ( + network.inputSize !== resolvedSpec.observation.shape[0] + || targetNetwork.inputSize !== resolvedSpec.observation.shape[0] + || network.outputSize !== actions.length + || targetNetwork.outputSize !== actions.length + ) { + throw new Error("[IgnitionRL] DQN checkpoint network shape does not match the environment spec."); + } + + this.config = config; + this.rng = createSeededRng(this.config.seed); + this.envSpec = resolvedSpec as EnvironmentSpec; + this.actions = [...actions] as TAction[]; + this.network = network; + this.targetNetwork = targetNetwork; + this.replay = []; + this.restoreMetrics(envelope.metrics); + this.metadata = envelope.metadata; + } + + private trainBatch(): void { + const network = this.requireNetwork(); + const targetNetwork = this.requireTargetNetwork(); + const batchSize = Math.min(this.config.batchSize, this.replay.length); + let totalLoss = 0; + let totalAbsTdError = 0; + let totalQ = 0; + + for (let index = 0; index < batchSize; index += 1) { + const transition = this.rng.pick(this.replay); + const pass = forward(network, transition.observation); + const prediction = pass.qValues[transition.actionIndex] as number; + const nextQ = transition.done + ? 0 + : Math.max(...forward(targetNetwork, transition.nextObservation).qValues); + const target = transition.reward + this.config.discount * nextQ; + const tdError = prediction - target; + + updateNetwork(network, pass, transition.actionIndex, tdError, this.config); + totalLoss += tdError ** 2; + totalAbsTdError += Math.abs(tdError); + totalQ += prediction; + } + + this.updates += 1; + this.tdLoss = totalLoss / batchSize; + this.meanAbsTdError = totalAbsTdError / batchSize; + this.meanQ = totalQ / batchSize; + + if (this.updates % this.config.targetUpdateInterval === 0) { + this.targetNetwork = cloneNetwork(network); + } + } + + private pushReplay(transition: DqnReplayTransition): void { + this.replay.push(transition); + + while (this.replay.length > this.config.replayCapacity) { + this.replay.shift(); + } + } + + private currentEpsilon(): number { + const progress = Math.min(1, this.transitions / this.config.epsilonDecaySteps); + + return this.config.epsilonStart + + (this.config.epsilonEnd - this.config.epsilonStart) * progress; + } + + private assertObservation(observation: readonly number[] | Float32Array): number[] { + const spec = this.requireInitializedSpec(); + const values = Array.from(observation); + + if (values.length !== spec.observation.shape[0]) { + throw new Error( + `[IgnitionRL] Observation length must be ${spec.observation.shape[0]}, got ${values.length}.`, + ); + } + + values.forEach((value, index) => finiteNumber(value, `observation[${index}]`)); + + return values; + } + + private actionIndexFromTransition(transition: Transition): number { + if (typeof transition.encodedAction !== "number") { + throw new Error("[IgnitionRL] DqnLearner requires scalar encoded discrete actions."); + } + + const index = transition.encodedAction; + if (!Number.isInteger(index) || index < 0 || index >= this.actions.length) { + throw new Error(`[IgnitionRL] Encoded action index is out of range: ${index}.`); + } + + const expected = this.actions[index]; + if (!Object.is(expected, transition.action)) { + throw new Error("[IgnitionRL] Transition action does not match its encoded action index."); + } + + return index; + } + + private restoreMetrics(metrics: Readonly>): void { + this.transitions = integerMetric(metrics.transitions, "metrics.transitions"); + this.episodes = integerMetric(metrics.episodes, "metrics.episodes"); + this.updates = integerMetric(metrics.updates, "metrics.updates"); + this.exploratoryActions = integerMetric( + metrics.exploratoryActions, + "metrics.exploratoryActions", + ); + this.greedyActions = integerMetric(metrics.greedyActions, "metrics.greedyActions"); + this.tdLoss = finiteNumber(metrics.tdLoss ?? 0, "metrics.tdLoss"); + this.meanAbsTdError = finiteNumber( + metrics.meanAbsTdError ?? 0, + "metrics.meanAbsTdError", + ); + this.meanQ = finiteNumber(metrics.meanQ ?? 0, "metrics.meanQ"); + this.lastEpisodeReward = finiteNumber( + metrics.lastEpisodeReward ?? 0, + "metrics.lastEpisodeReward", + ); + this.lastEpisodeLength = integerMetric( + metrics.lastEpisodeLength ?? 0, + "metrics.lastEpisodeLength", + ); + this.currentEpisodeReward = 0; + this.currentEpisodeLength = 0; + } + + private assertInitialized(): void { + this.requireInitializedSpec(); + this.requireNetwork(); + this.requireTargetNetwork(); + } + + private requireInitializedSpec(): EnvironmentSpec { + if (this.envSpec === undefined) { + throw new Error("[IgnitionRL] DqnLearner must be initialized before use."); + } + + return this.envSpec; + } + + private requireNetwork(): MutableDqnNetwork { + if (this.network === undefined) { + throw new Error("[IgnitionRL] DqnLearner must be initialized before use."); + } + + return this.network; + } + + private requireTargetNetwork(): MutableDqnNetwork { + if (this.targetNetwork === undefined) { + throw new Error("[IgnitionRL] DqnLearner must be initialized before use."); + } + + return this.targetNetwork; + } +} + +export function createDqnLearner< + TAction extends DiscreteActionValue = DiscreteActionValue, +>(options?: DqnOptions): DqnLearner { + return new DqnLearner(options); +} + +export function defineDqnConfig(options: Partial = {}): DqnConfig { + const config = { + seed: options.seed ?? DEFAULT_DQN_CONFIG.seed, + learningRate: options.learningRate ?? DEFAULT_DQN_CONFIG.learningRate, + discount: options.discount ?? DEFAULT_DQN_CONFIG.discount, + epsilonStart: options.epsilonStart ?? DEFAULT_DQN_CONFIG.epsilonStart, + epsilonEnd: options.epsilonEnd ?? DEFAULT_DQN_CONFIG.epsilonEnd, + epsilonDecaySteps: options.epsilonDecaySteps ?? DEFAULT_DQN_CONFIG.epsilonDecaySteps, + batchSize: options.batchSize ?? DEFAULT_DQN_CONFIG.batchSize, + replayCapacity: options.replayCapacity ?? DEFAULT_DQN_CONFIG.replayCapacity, + minReplaySize: options.minReplaySize ?? DEFAULT_DQN_CONFIG.minReplaySize, + trainEverySteps: options.trainEverySteps ?? DEFAULT_DQN_CONFIG.trainEverySteps, + targetUpdateInterval: options.targetUpdateInterval + ?? DEFAULT_DQN_CONFIG.targetUpdateInterval, + hiddenSize: options.hiddenSize ?? DEFAULT_DQN_CONFIG.hiddenSize, + gradientClip: options.gradientClip ?? DEFAULT_DQN_CONFIG.gradientClip, + weightScale: options.weightScale ?? DEFAULT_DQN_CONFIG.weightScale, + }; + + if (typeof config.seed !== "number" && typeof config.seed !== "string") { + throw new Error("[IgnitionRL] config.seed must be a number or string."); + } + + assertPositiveNumber(config.learningRate, "config.learningRate"); + assertUnitInterval(config.discount, "config.discount"); + assertUnitInterval(config.epsilonStart, "config.epsilonStart"); + assertUnitInterval(config.epsilonEnd, "config.epsilonEnd"); + assertPositiveInteger(config.epsilonDecaySteps, "config.epsilonDecaySteps"); + assertPositiveInteger(config.batchSize, "config.batchSize"); + assertPositiveInteger(config.replayCapacity, "config.replayCapacity"); + assertPositiveInteger(config.minReplaySize, "config.minReplaySize"); + assertPositiveInteger(config.trainEverySteps, "config.trainEverySteps"); + assertPositiveInteger(config.targetUpdateInterval, "config.targetUpdateInterval"); + assertPositiveInteger(config.hiddenSize, "config.hiddenSize"); + assertPositiveNumber(config.gradientClip, "config.gradientClip"); + assertPositiveNumber(config.weightScale, "config.weightScale"); + + if (config.epsilonEnd > config.epsilonStart) { + throw new Error("[IgnitionRL] config.epsilonEnd cannot exceed config.epsilonStart."); + } + + if (config.minReplaySize > config.replayCapacity) { + throw new Error("[IgnitionRL] config.minReplaySize cannot exceed config.replayCapacity."); + } + + return config; +} + +export function assertDqnCheckpoint(value: unknown): asserts value is DqnCheckpoint { + normalizeDqnCheckpoint(value); +} + +function normalizeDqnCheckpoint(value: unknown): DqnCheckpoint { + assertNeuralCheckpointEnvelope(value); + + if (value.algorithm !== DQN_ALGORITHM) { + throw new Error(`[IgnitionRL] Unsupported DQN checkpoint algorithm: ${value.algorithm}.`); + } + + const payload = value.payload; + if (payload.payloadVersion !== DQN_CHECKPOINT_PAYLOAD_VERSION) { + throw new Error( + `[IgnitionRL] Unsupported DQN checkpoint payload version: ${String(payload.payloadVersion)}.`, + ); + } + + actionsFromPayload(payload.actions); + defineDqnConfig(recordFromPayload(payload.config, "payload.config") as Partial); + networkFromPayload(payload.network, "payload.network"); + networkFromPayload(payload.targetNetwork, "payload.targetNetwork"); + + return value; +} + +function createNetwork( + inputSize: number, + hiddenSize: number, + outputSize: number, + rng: SeededRng, + weightScale: number, +): MutableDqnNetwork { + return { + inputSize, + hiddenSize, + outputSize, + hiddenWeights: Array.from({ length: hiddenSize }, () => + Array.from({ length: inputSize + 1 }, () => rng.float(-weightScale, weightScale))), + outputWeights: Array.from({ length: outputSize }, () => + Array.from({ length: hiddenSize + 1 }, () => rng.float(-weightScale, weightScale))), + }; +} + +function cloneNetwork(network: MutableDqnNetwork): MutableDqnNetwork { + return { + inputSize: network.inputSize, + hiddenSize: network.hiddenSize, + outputSize: network.outputSize, + hiddenWeights: network.hiddenWeights.map((row) => [...row]), + outputWeights: network.outputWeights.map((row) => [...row]), + }; +} + +function networkToJson(network: MutableDqnNetwork): JsonObject { + return { + inputSize: network.inputSize, + hiddenSize: network.hiddenSize, + outputSize: network.outputSize, + hiddenWeights: network.hiddenWeights.map((row) => [...row]), + outputWeights: network.outputWeights.map((row) => [...row]), + }; +} + +function networkFromPayload(value: unknown, label: string): MutableDqnNetwork { + const record = recordFromPayload(value, label); + const inputSize = integerMetric(record.inputSize, `${label}.inputSize`); + const hiddenSize = integerMetric(record.hiddenSize, `${label}.hiddenSize`); + const outputSize = integerMetric(record.outputSize, `${label}.outputSize`); + const hiddenWeights = matrixFromPayload(record.hiddenWeights, `${label}.hiddenWeights`); + const outputWeights = matrixFromPayload(record.outputWeights, `${label}.outputWeights`); + + if (hiddenWeights.length !== hiddenSize) { + throw new Error(`[IgnitionRL] ${label}.hiddenWeights must have ${hiddenSize} rows.`); + } + + for (const row of hiddenWeights) { + if (row.length !== inputSize + 1) { + throw new Error(`[IgnitionRL] ${label}.hiddenWeights row width is invalid.`); + } + } + + if (outputWeights.length !== outputSize) { + throw new Error(`[IgnitionRL] ${label}.outputWeights must have ${outputSize} rows.`); + } + + for (const row of outputWeights) { + if (row.length !== hiddenSize + 1) { + throw new Error(`[IgnitionRL] ${label}.outputWeights row width is invalid.`); + } + } + + return { + inputSize, + hiddenSize, + outputSize, + hiddenWeights, + outputWeights, + }; +} + +function forward( + network: MutableDqnNetwork, + observation: readonly number[], +): ForwardPass { + const hiddenRaw = network.hiddenWeights.map((row) => + dotWithBias(row, observation)); + const hidden = hiddenRaw.map(leakyRelu); + const qValues = network.outputWeights.map((row) => + dotWithBias(row, hidden)); + + return { + observation, + hiddenRaw, + hidden, + qValues, + }; +} + +function updateNetwork( + network: MutableDqnNetwork, + pass: ForwardPass, + actionIndex: number, + tdError: number, + config: DqnConfig, +): void { + const outputRow = network.outputWeights[actionIndex] as number[]; + const outputBeforeUpdate = [...outputRow]; + const gradient = clamp(tdError, -config.gradientClip, config.gradientClip); + const outputBiasIndex = network.hiddenSize; + + for (let hiddenIndex = 0; hiddenIndex < network.hiddenSize; hiddenIndex += 1) { + outputRow[hiddenIndex] = (outputRow[hiddenIndex] as number) + - config.learningRate * gradient * (pass.hidden[hiddenIndex] as number); + } + outputRow[outputBiasIndex] = (outputRow[outputBiasIndex] as number) + - config.learningRate * gradient; + + for (let hiddenIndex = 0; hiddenIndex < network.hiddenSize; hiddenIndex += 1) { + const hiddenDerivative = leakyReluDerivative(pass.hiddenRaw[hiddenIndex] as number); + const hiddenGradient = clamp( + gradient * (outputBeforeUpdate[hiddenIndex] as number) * hiddenDerivative, + -config.gradientClip, + config.gradientClip, + ); + const hiddenRow = network.hiddenWeights[hiddenIndex] as number[]; + const hiddenBiasIndex = network.inputSize; + + for (let inputIndex = 0; inputIndex < network.inputSize; inputIndex += 1) { + hiddenRow[inputIndex] = (hiddenRow[inputIndex] as number) + - config.learningRate * hiddenGradient * (pass.observation[inputIndex] as number); + } + hiddenRow[hiddenBiasIndex] = (hiddenRow[hiddenBiasIndex] as number) + - config.learningRate * hiddenGradient; + } +} + +function dotWithBias(weights: readonly number[], inputs: readonly number[]): number { + let value = weights[inputs.length] as number; + + for (let index = 0; index < inputs.length; index += 1) { + value += (weights[index] as number) * (inputs[index] as number); + } + + return value; +} + +function leakyRelu(value: number): number { + return value >= 0 ? value : 0.01 * value; +} + +function leakyReluDerivative(value: number): number { + return value >= 0 ? 1 : 0.01; +} + +function discreteActionsFromSpec(spec: DiscreteActionSpec): DiscreteActionValue[] { + if ("values" in spec && spec.values !== undefined) { + return [...spec.values]; + } + + return Array.from({ length: spec.n }, (_, index) => index); +} + +function specFromCheckpoint( + checkpoint: DqnCheckpoint, + actions: readonly DiscreteActionValue[], +): EnvironmentSpec { + return { + id: checkpoint.envId, + observation: { + type: "vector", + shape: checkpoint.observation.shape, + dtype: "float32", + }, + actions: { + type: "discrete", + values: actions, + }, + }; +} + +function observationFromInput( + input: Float32Array | PolicyActContext, +): readonly number[] | Float32Array { + return input instanceof Float32Array ? input : input.observation; +} + +function sameDiscreteActions( + left: readonly DiscreteActionValue[], + right: readonly DiscreteActionValue[], +): boolean { + return left.length === right.length && left.every((value, index) => Object.is(value, right[index])); +} + +function actionsFromPayload(value: unknown): DiscreteActionValue[] { + if (!Array.isArray(value) || value.length === 0) { + throw new Error("[IgnitionRL] DQN checkpoint payload.actions must be a non-empty array."); + } + + return value.map((action) => { + if ( + typeof action !== "string" + && typeof action !== "number" + && typeof action !== "boolean" + ) { + throw new Error("[IgnitionRL] DQN checkpoint payload.actions must contain discrete values."); + } + + return action; + }); +} + +function matrixFromPayload(value: unknown, label: string): number[][] { + if (!Array.isArray(value) || value.length === 0) { + throw new Error(`[IgnitionRL] ${label} must be a non-empty matrix.`); + } + + return value.map((row, rowIndex) => { + if (!Array.isArray(row) || row.length === 0) { + throw new Error(`[IgnitionRL] ${label}[${rowIndex}] must be a non-empty array.`); + } + + return row.map((entry, columnIndex) => + finiteNumber(entry, `${label}[${rowIndex}][${columnIndex}]`)); + }); +} + +function recordFromPayload(value: unknown, label: string): Record { + if (typeof value !== "object" || value === null || Array.isArray(value)) { + throw new Error(`[IgnitionRL] ${label} must be an object.`); + } + + return value as Record; +} + +function integerMetric(value: unknown, label: string): number { + const number = finiteNumber(value, label); + + if (!Number.isInteger(number) || number < 0) { + throw new Error(`[IgnitionRL] ${label} must be a non-negative integer.`); + } + + return number; +} + +function assertPositiveInteger(value: unknown, label: string): asserts value is number { + const number = finiteNumber(value, label); + + if (!Number.isInteger(number) || number <= 0) { + throw new Error(`[IgnitionRL] ${label} must be a positive integer.`); + } +} + +function assertPositiveNumber(value: unknown, label: string): asserts value is number { + const number = finiteNumber(value, label); + + if (number <= 0) { + throw new Error(`[IgnitionRL] ${label} must be positive.`); + } +} + +function assertUnitInterval(value: unknown, label: string): asserts value is number { + const number = finiteNumber(value, label); + + if (number < 0 || number > 1) { + throw new Error(`[IgnitionRL] ${label} must be in [0, 1].`); + } +} + +function finiteNumber(value: unknown, label: string): number { + if (typeof value !== "number" || !Number.isFinite(value)) { + throw new Error(`[IgnitionRL] ${label} must be finite.`); + } + + return value; +} + +function clamp(value: number, min: number, max: number): number { + return Math.max(min, Math.min(max, value)); +} diff --git a/packages/learning/src/index.ts b/packages/learning/src/index.ts index 8561f0c..3f159ff 100644 --- a/packages/learning/src/index.ts +++ b/packages/learning/src/index.ts @@ -38,6 +38,21 @@ export { type SelectActionOptions, type TabularQOptions, } from "./tabular-q.js"; +export { + DEFAULT_DQN_CONFIG, + DQN_ALGORITHM, + DQN_CHECKPOINT_PAYLOAD_VERSION, + DqnLearner, + assertDqnCheckpoint, + createDqnLearner, + defineDqnConfig, + type DqnCheckpoint, + type DqnConfig, + type DqnMetrics, + type DqnNetwork, + type DqnOptions, + type DqnPolicyOptions, +} from "./dqn.js"; export { EPISODE_METRIC_SPECS, LINEAR_POLICY_SEARCH_METRIC_SPECS, @@ -86,8 +101,11 @@ export { } from "./neural-adapter.js"; export { trainLinearPolicySearch, + trainDqn, trainTabularQ, trainingSummaryToJson, + type TrainDqnOptions, + type TrainDqnResult, type TrainEpisodeSummary, type TrainLinearPolicySearchOptions, type TrainLinearPolicySearchResult, diff --git a/packages/learning/src/neural-adapter.ts b/packages/learning/src/neural-adapter.ts index 6d9b356..697623f 100644 --- a/packages/learning/src/neural-adapter.ts +++ b/packages/learning/src/neural-adapter.ts @@ -363,6 +363,7 @@ export function defaultMetricSpecs(algorithm: NeuralLearnerAlgorithm): readonly ...common, { name: "tdLoss", scope: "update", direction: "minimize", reducer: "mean" }, { name: "epsilon", scope: "update", direction: "none", reducer: "last" }, + { name: "explorationRate", scope: "run", direction: "none", reducer: "mean" }, { name: "replayBufferSize", scope: "update", direction: "none", reducer: "last" }, ]; } diff --git a/packages/learning/src/train.ts b/packages/learning/src/train.ts index 6666d52..e72bbbf 100644 --- a/packages/learning/src/train.ts +++ b/packages/learning/src/train.ts @@ -8,6 +8,12 @@ import type { Seed, Transition, } from "@ignitionrl/core"; +import { + createDqnLearner, + type DqnLearner, + type DqnMetrics, + type DqnOptions, +} from "./dqn.js"; import { createLinearPolicySearchLearner, type LinearPolicySearchLearner, @@ -38,6 +44,16 @@ export type TrainTabularQOptions = { readonly onEpisode?: (summary: TrainEpisodeSummary) => void | Promise; }; +export type TrainDqnOptions = { + readonly learner?: DqnLearner; + readonly learnerOptions?: DqnOptions; + readonly episodes?: number; + readonly maxSteps?: number; + readonly seed?: Seed; + readonly runId?: string; + readonly onEpisode?: (summary: TrainEpisodeSummary) => void | Promise; +}; + export type TrainEpisodeSummary = { readonly episode: number; readonly totalReward: number; @@ -54,6 +70,12 @@ export type TrainTabularQResult = { readonly metrics: TabularQMetrics; }; +export type TrainDqnResult = { + readonly learner: DqnLearner; + readonly episodes: readonly TrainEpisodeSummary[]; + readonly metrics: DqnMetrics; +}; + export type TrainLinearPolicySearchOptions = { readonly learner?: LinearPolicySearchLearner; readonly learnerOptions?: LinearPolicySearchOptions; @@ -135,6 +157,73 @@ export async function trainTabularQ< }; } +export async function trainDqn< + TState, + const TActions extends DiscreteActionSpec, +>( + env: DefinedEnvironment, + options: TrainDqnOptions> = {}, +): Promise>> { + const episodes = options.episodes ?? 100; + + if (!Number.isInteger(episodes) || episodes <= 0) { + throw new Error("[IgnitionRL] episodes must be a positive integer."); + } + + const seed = options.seed ?? 0; + const learner = options.learner ?? createDqnLearner< + DiscreteEnvAction + >(options.learnerOptions); + const spec = env.getSpec({ seed }); + + await learner.init(spec, {}); + + const runnerOptions: EnvironmentRunnerOptions = { + seed, + ...(options.runId !== undefined ? { runId: options.runId } : {}), + ...(options.maxSteps !== undefined ? { maxSteps: options.maxSteps } : {}), + }; + const runner = env.createRunner(runnerOptions); + const summaries: TrainEpisodeSummary[] = []; + + for (let episode = 0; episode < episodes; episode += 1) { + if (episode > 0) { + runner.reset({ seed: `${String(seed)}:${episode}` }); + } + + while (!runner.isDone) { + const action = await learner.act(new Float32Array(runner.getObservation())); + const step = runner.step(action, options.maxSteps === undefined ? {} : { maxSteps: options.maxSteps }); + + await learner.observe( + transitionFromStep(step) as Transition>, + ); + } + + await learner.update(); + + const trace = runner.getTrace(); + const summary: TrainEpisodeSummary = { + episode, + totalReward: trace.summary.totalReward, + length: trace.summary.length, + success: trace.summary.success, + terminated: trace.summary.terminated, + truncated: trace.summary.truncated, + ...(trace.summary.reason !== undefined ? { reason: trace.summary.reason } : {}), + }; + + summaries.push(summary); + await options.onEpisode?.(summary); + } + + return { + learner, + episodes: summaries, + metrics: learner.getMetrics(), + }; +} + export function trainingSummaryToJson(summary: TrainEpisodeSummary): JsonObject { return { episode: summary.episode, diff --git a/packages/learning/test/config-metrics.test.ts b/packages/learning/test/config-metrics.test.ts index b4f7560..2a4bfec 100644 --- a/packages/learning/test/config-metrics.test.ts +++ b/packages/learning/test/config-metrics.test.ts @@ -97,6 +97,7 @@ describe("learner metric catalogs", () => { expect(metricNames).toContain("success"); expect(metricNames).toContain("learner.tdLoss"); expect(metricNames).toContain("learner.epsilon"); + expect(metricNames).toContain("learner.explorationRate"); expect(new Set(metricNames).size).toBe(metricNames.length); }); }); diff --git a/packages/learning/test/dqn.test.ts b/packages/learning/test/dqn.test.ts new file mode 100644 index 0000000..1e18934 --- /dev/null +++ b/packages/learning/test/dqn.test.ts @@ -0,0 +1,188 @@ +import { describe, expect, test } from "bun:test"; +import { mkdtemp, rm } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { + defineEnvironment, + reward, +} from "@ignitionrl/core"; +import { + assertDqnCheckpoint, + createDqnLearner, + defineDqnConfig, + trainDqn, + transitionFromStep, +} from "../src/index.js"; + +type LineWorldAction = "left" | "right"; + +const LineWorld = defineEnvironment({ + id: "DqnLineWorld-v0", + metadata: { + maxSteps: 5, + }, + + createInitialState: () => ({ + position: 0, + steps: 0, + }), + + observations: ({ state }) => [state.position], + + actions: { + type: "discrete", + values: ["left", "right"], + }, + + step: ({ state, action }) => ({ + position: Math.max(0, Math.min(2, state.position + (action === "right" ? 1 : -1))), + steps: state.steps + 1, + }), + + reward: ({ state, nextState, action }) => + reward() + .add("progress", nextState.position - state.position) + .add("target_reached", nextState.position >= 2, 4) + .add("left_penalty", action === "left", -0.2) + .add("step_penalty", true, -0.01), + + done: ({ nextState }) => + nextState.position >= 2 + ? { done: true, reason: "target_reached", success: true } + : nextState.steps >= 5, +}); + +describe("DqnLearner", () => { + test("trains a discrete neural learner through the public environment contract", async () => { + const result = await trainDqn(LineWorld, { + episodes: 100, + maxSteps: 5, + seed: 7, + learnerOptions: { + seed: 11, + learningRate: 0.03, + epsilonStart: 0.4, + epsilonEnd: 0.02, + epsilonDecaySteps: 180, + hiddenSize: 16, + batchSize: 8, + minReplaySize: 8, + targetUpdateInterval: 20, + }, + }); + + expect(result.episodes).toHaveLength(100); + expect(result.metrics.episodes).toBe(100); + expect(result.metrics.transitions).toBeGreaterThan(0); + expect(result.metrics.updates).toBeGreaterThan(0); + expect(result.metrics.epsilon).toBeLessThan(0.35); + expect(result.metrics.explorationRate).toBeGreaterThan(0); + expect(result.metrics.tdLoss).toBeGreaterThanOrEqual(0); + expect(result.learner.selectAction([0], { explore: false })).toBe("right"); + expect(result.learner.selectAction([1], { explore: false })).toBe("right"); + }); + + test("updates Q values from transition replay records", async () => { + const learner = createDqnLearner({ + seed: 17, + learningRate: 0.04, + epsilonStart: 0, + epsilonEnd: 0, + hiddenSize: 8, + batchSize: 2, + minReplaySize: 2, + }); + const runner = LineWorld.createRunner({ seed: 1 }); + + await learner.init(LineWorld.getSpec(), {}); + + const first = runner.step("right"); + await learner.observe(transitionFromStep(first)); + const second = runner.step("right"); + await learner.observe(transitionFromStep(second)); + + expect(learner.getMetrics()).toMatchObject({ + transitions: 2, + episodes: 1, + replayBufferSize: 2, + }); + expect(learner.getMetrics().updates).toBeGreaterThan(0); + expect(learner.qValuesForObservation([1])).toHaveLength(2); + }); + + test("serializes, restores and validates neural checkpoint envelopes", async () => { + const result = await trainDqn(LineWorld, { + episodes: 80, + maxSteps: 5, + seed: 3, + learnerOptions: { + seed: 19, + learningRate: 0.03, + epsilonStart: 0.35, + epsilonEnd: 0.01, + epsilonDecaySteps: 140, + hiddenSize: 14, + batchSize: 6, + minReplaySize: 6, + }, + }); + const checkpoint = result.learner.toCheckpoint(); + + assertDqnCheckpoint(checkpoint); + + const restored = createDqnLearner(); + restored.loadCheckpoint(checkpoint, LineWorld.getSpec()); + + expect(restored.getMetrics().transitions).toBe(result.metrics.transitions); + expect(restored.toCheckpoint().envId).toBe("DqnLineWorld-v0"); + expect(restored.selectAction([0], { explore: false })).toBe("right"); + }); + + test("saves and loads checkpoint JSON from disk", async () => { + const result = await trainDqn(LineWorld, { + episodes: 80, + maxSteps: 5, + seed: 5, + learnerOptions: { + seed: 23, + learningRate: 0.03, + epsilonStart: 0.35, + epsilonEnd: 0.01, + epsilonDecaySteps: 140, + batchSize: 4, + minReplaySize: 4, + hiddenSize: 12, + }, + }); + const dir = await mkdtemp(join(tmpdir(), "ignitionrl-dqn-")); + const path = join(dir, "line-world.dqn.json"); + + try { + await result.learner.save(path); + + const loaded = createDqnLearner(); + await loaded.load(path); + + expect(loaded.toCheckpoint().algorithm).toBe("dqn"); + expect(loaded.selectAction([0], { explore: false })).toBe("right"); + } finally { + await rm(dir, { recursive: true, force: true }); + } + }); + + test("rejects unsupported action spaces and invalid hyperparameters", async () => { + const learner = createDqnLearner(); + + await expect( + learner.init({ + id: "Continuous-v0", + observation: { type: "vector", shape: [1], dtype: "float32" }, + actions: { type: "continuous", shape: [1] }, + }, {}), + ).rejects.toThrow("only supports discrete"); + + expect(() => defineDqnConfig({ epsilonEnd: 0.5, epsilonStart: 0.1 })) + .toThrow("epsilonEnd"); + expect(() => learner.loadCheckpoint({})).toThrow("checkpoint"); + }); +}); diff --git a/packages/learning/test/type-inference.ts b/packages/learning/test/type-inference.ts index 8c03b74..3fc514e 100644 --- a/packages/learning/test/type-inference.ts +++ b/packages/learning/test/type-inference.ts @@ -3,11 +3,14 @@ import { createLinearPolicySearchLearner, createNeuralCheckpointEnvelope, createTabularQLearner, + defineDqnConfig, defineTabularQConfig, defineNeuralLearnerAdapterContract, learnerMetricSpecsForAlgorithm, + trainDqn, trainLinearPolicySearch, trainTabularQ, + type DqnMetrics, type LinearPolicySearchCheckpoint, type NeuralLearnerAdapterContract, type TabularQCheckpoint, @@ -60,6 +63,10 @@ const tabularConfig = defineTabularQConfig({ epsilon: 0.2, observationPrecision: 1, }); +const dqnConfig = defineDqnConfig({ + epsilonStart: 0.3, + epsilonEnd: 0.05, +}); const tabularMetricNames = learnerMetricSpecsForAlgorithm("tabular-q-learning") .map((metric) => metric.name); const checkpoint: TabularQCheckpoint = { @@ -139,12 +146,26 @@ async function smoke(): Promise { acceptsTinyAction(action); const acceptsNeuralCheckpointEnv = (_envId: "Tiny-v0") => undefined; const acceptsTabularConfigSeed = (_seed: number | string) => undefined; + const acceptsDqnMetrics = (_metrics: DqnMetrics) => undefined; acceptsNeuralCheckpointEnv(neuralCheckpoint.envId as "Tiny-v0"); acceptsTabularConfigSeed(tabularConfig.seed); + acceptsTabularConfigSeed(dqnConfig.seed); tabularMetricNames.includes("learner.transitions"); learner.loadCheckpoint(checkpoint, TinyEnv.getSpec()); + const dqnTrained = await trainDqn(TinyEnv, { + episodes: 2, + learnerOptions: { + batchSize: 1, + minReplaySize: 1, + epsilonStart: 0.1, + epsilonEnd: 0, + }, + }); + acceptsTinyAction(dqnTrained.learner.selectAction([0], { explore: false })); + acceptsDqnMetrics(dqnTrained.metrics); + await continuousLearner.init(TinyContinuousEnv.getSpec(), {}); await trainLinearPolicySearch(TinyContinuousEnv, { learner: continuousLearner, diff --git a/scripts/ci-smoke.mjs b/scripts/ci-smoke.mjs index 366f6ab..fcad0c3 100644 --- a/scripts/ci-smoke.mjs +++ b/scripts/ci-smoke.mjs @@ -2,9 +2,19 @@ import { spawnSync } from "node:child_process"; import { mkdtempSync, rmSync } from "node:fs"; import { tmpdir } from "node:os"; import { join } from "node:path"; +import { randomPolicy } from "../packages/core/dist/index.js"; +import { GridWorld } from "../packages/examples/dist/grid-world.js"; +import { createDqnLearner } from "../packages/learning/dist/index.js"; +import { + createIgnitionProject, + runCheckpointPolicyExperiment, + runLearnerExperiment, + runPolicyExperiment, +} from "../packages/sdk/dist/index.js"; const results = [ runGridWorldSmoke(), + await runDqnSmoke(), runDroneTargetSmoke(), ]; @@ -706,6 +716,122 @@ function runDroneTargetSmoke() { }); } +async function runDqnSmoke() { + return withProjectDirAsync("ignitionrl-dqn-smoke-", async (projectDir) => { + const store = await createIgnitionProject(projectDir, { + id: "dqn-smoke", + name: "DQN Smoke", + environments: [{ + id: GridWorld.id, + packageName: "@ignitionrl/examples", + exportName: "GridWorld", + }], + }); + const random = await runPolicyExperiment({ + store, + env: GridWorld, + policy: randomPolicy(GridWorld.actions, { + seed: "dqn-smoke:random-policy", + }), + algorithm: "random", + runId: "grid-world-dqn-random", + episodes: 5, + maxSteps: 50, + seed: "dqn-smoke:random-env", + }); + const learner = createDqnLearner({ + seed: "dqn-smoke:learner", + learningRate: 0.01, + epsilonStart: 0.55, + epsilonEnd: 0.05, + epsilonDecaySteps: 1_200, + hiddenSize: 48, + batchSize: 32, + minReplaySize: 32, + targetUpdateInterval: 80, + }); + const trained = await runLearnerExperiment({ + store, + env: GridWorld, + learner, + algorithm: "dqn", + runId: "grid-world-dqn", + episodes: 180, + maxSteps: 50, + seed: "dqn-smoke:learner-env", + checkpoint: { + id: "final", + metadata: { + kind: "dqn", + }, + serialize: (trainedLearner) => trainedLearner.toCheckpoint(), + }, + }); + const inference = await runCheckpointPolicyExperiment({ + store, + env: GridWorld, + learner: createDqnLearner(), + sourceRunId: trained.run.id, + checkpointId: "final", + algorithm: "dqn-inference", + runId: "grid-world-dqn-inference", + episodes: 5, + maxSteps: 50, + seed: "dqn-smoke:inference-env", + policyOptions: { + explore: false, + }, + }); + const report = await store.createProjectReport({ + generatedAt: "ci-smoke", + includeMetrics: true, + includeTraceReferences: true, + includeCheckpoints: true, + scoreBy: "summary.totalReward", + }); + const metrics = await store.readMetrics("grid-world-dqn"); + const checkpoints = await store.listCheckpoints("grid-world-dqn"); + const checkpoint = await store.readCheckpoint("grid-world-dqn", "final"); + + assert( + inference.summary.successRate > random.summary.successRate, + "DQN inference did not improve success rate over random GridWorld", + ); + assert( + inference.summary.totalReward > random.summary.totalReward, + "DQN inference did not improve total reward over random GridWorld", + ); + assert( + metrics.some((record) => + record.values["learner.epsilon"] !== undefined + && record.values["learner.explorationRate"] !== undefined + && record.values["learner.tdLoss"] !== undefined + ), + "DQN run did not persist epsilon, exploration and loss metrics", + ); + assert(checkpoints.length === 1, "DQN checkpoint was not listed"); + assert(checkpoint.algorithm === "dqn", "DQN checkpoint payload used the wrong algorithm"); + + return { + name: "GridWorldDQN", + runs: report.runs.map((run) => run.run.id), + doctorChecks: 0, + evalChecks: 2, + environments: 1, + environmentRuns: report.runs.length, + historyRows: report.runs.length, + studioSelectedRun: inference.run.id, + runDetailArtifacts: 0, + artifacts: 0, + episodes: inference.episodes.length, + metricPoints: metrics.length, + rewardTerms: 0, + checkpoints: checkpoints.length, + checkpointSource: trained.run.id, + }; + }); +} + function withProjectDir(prefix, fn) { const projectDir = mkdtempSync(join(tmpdir(), prefix)); @@ -716,6 +842,16 @@ function withProjectDir(prefix, fn) { } } +async function withProjectDirAsync(prefix, fn) { + const projectDir = mkdtempSync(join(tmpdir(), prefix)); + + try { + return await fn(projectDir); + } finally { + rmSync(projectDir, { recursive: true, force: true }); + } +} + function runCli(args) { const result = spawnSync("bun", ["packages/cli/src/index.ts", ...args], { cwd: new URL("..", import.meta.url),