diff --git a/README.md b/README.md index 459068e..99819f1 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ bun run demo:drone The `init:*` commands create template projects without running experiments. The `demo:*` commands create timestamped projects under `.ignitionrl/demos/` and write export artifacts under each project `exports/` directory. Existing projects can be inspected without rerunning experiments through `bun run --cwd packages/cli start inspect --json`. +Registered environments can be trained without the demo shortcut through `bun run --cwd packages/cli start train --learner tabular-q --episodes 30 --checkpoint-id final --json`, which persists traces, metrics and a learner checkpoint. Persisted runs can be compared without rerunning experiments through `bun run --cwd packages/cli start compare --score-by summary.totalReward --json`. Metric series can be inspected through `bun run --cwd packages/cli start metrics totalReward --json`. The offline Studio bootstrap also embeds metric chart summaries for persisted metrics, so the future metrics panel can render chart selectors and latest values immediately. diff --git a/packages/cli/README.md b/packages/cli/README.md index f3978d0..ffceb3b 100644 --- a/packages/cli/README.md +++ b/packages/cli/README.md @@ -24,6 +24,12 @@ bun run --cwd packages/cli start demo drone-target ./drone-target-demo.ignitionr --learner-episodes 12 \ --inference-episodes 3 +bun run --cwd packages/cli start train ./grid-world-project.ignitionrl GridWorld-v0 \ + --learner tabular-q \ + --episodes 30 \ + --checkpoint-id final \ + --json + bun run --cwd packages/cli start inspect ./target-2d-demo.ignitionrl \ --score-by summary.totalReward \ --json @@ -124,6 +130,8 @@ If the project directory is omitted, the CLI creates a timestamped project under Each demo command runs the environment, writes traces, metrics and learner checkpoints when available, then exports JSON artifacts under the project `exports/` directory. The DroneTarget demo also replays the final continuous-control checkpoint as a separate inference run so the learned policy is directly comparable and replayable. +`train` opens an existing local project, verifies that the selected environment is registered, runs one learner, and writes traces, metrics and a checkpoint. Supported first-class learner selections are `tabular-q` and `dqn` for `GridWorld-v0`/`Target2D-v0`, and `linear-policy-search` for `DroneTarget-v0`. Use `--json` for automation and `--checkpoint-id` to choose the checkpoint name. + - project report; - Studio project view; - Studio environment view when `environment --export` is used; diff --git a/packages/cli/src/cli.ts b/packages/cli/src/cli.ts index da74f3f..b2e8a09 100644 --- a/packages/cli/src/cli.ts +++ b/packages/cli/src/cli.ts @@ -28,6 +28,7 @@ import { createIgnitionProject, openIgnitionProject, runCheckpointPolicyExperiment, + runLearnerExperiment, selectReplayFrame, type ActionInspectorDistributionEntry, type ComparedRun, @@ -44,6 +45,7 @@ import { type ProjectEnvironmentReference, type RunComparisonDirection, type RunEvaluationCheck, + type RunLearnerExperimentResult, type RunManifest, type RunPolicyExperimentResult, type RunSummary, @@ -128,6 +130,20 @@ export type InferCliResult = { readonly exports: readonly ExportedArtifactSummary[]; }; +export type TrainCliResult = { + readonly command: "train"; + readonly projectDir: string; + readonly projectId: string; + readonly projectName: string; + readonly environmentId: string; + readonly learner: TrainLearnerName; + readonly runId: string; + readonly algorithm: string; + readonly summary: RunSummary; + readonly checkpoint: CheckpointCliSummary; + readonly metricNames: readonly string[]; +}; + export type CheckpointCliSummary = { readonly id: string; readonly path: string; @@ -569,6 +585,7 @@ export type CliResult = | InitCliResult | InspectCliResult | ReplayCliResult + | TrainCliResult | RewardsCliResult | InferCliResult | CheckpointDetailCliResult @@ -614,6 +631,20 @@ type ParsedInferOptions = { readonly json: boolean; }; +type TrainLearnerName = "tabular-q" | "dqn" | "linear-policy-search"; + +type ParsedTrainOptions = { + readonly projectDir: string; + readonly environmentId: string; + readonly learner: TrainLearnerName; + readonly episodes?: number; + readonly maxSteps?: number; + readonly seed?: string | number; + readonly runId?: string; + readonly checkpointId: string; + readonly json: boolean; +}; + type ParsedCheckpointsOptions = { readonly projectDir: string; readonly runId?: string; @@ -790,35 +821,37 @@ export async function runCli( ? await handleReplay(rest, io) : command === "infer" ? await handleInfer(rest, io) - : command === "rewards" - ? await handleRewards(rest, io) - : command === "checkpoint" - ? await handleCheckpoint(rest, io) - : command === "checkpoints" - ? await handleCheckpoints(rest, io) - : command === "episodes" - ? await handleEpisodes(rest, io) - : command === "artifacts" - ? await handleArtifacts(rest, io) - : command === "environments" - ? await handleEnvironments(rest, io) - : command === "environment" - ? await handleEnvironment(rest, io) - : command === "studio" - ? await handleStudio(rest, io) - : command === "run" - ? await handleRun(rest, io) - : command === "eval" - ? await handleEval(rest, io) - : command === "compare" - ? await handleCompare(rest, io) - : command === "metrics" - ? await handleMetrics(rest, io) - : command === "history" - ? await handleHistory(rest, io) - : command === "doctor" - ? await handleDoctor(rest, io) - : undefined; + : command === "train" + ? await handleTrain(rest, io) + : command === "rewards" + ? await handleRewards(rest, io) + : command === "checkpoint" + ? await handleCheckpoint(rest, io) + : command === "checkpoints" + ? await handleCheckpoints(rest, io) + : command === "episodes" + ? await handleEpisodes(rest, io) + : command === "artifacts" + ? await handleArtifacts(rest, io) + : command === "environments" + ? await handleEnvironments(rest, io) + : command === "environment" + ? await handleEnvironment(rest, io) + : command === "studio" + ? await handleStudio(rest, io) + : command === "run" + ? await handleRun(rest, io) + : command === "eval" + ? await handleEval(rest, io) + : command === "compare" + ? await handleCompare(rest, io) + : command === "metrics" + ? await handleMetrics(rest, io) + : command === "history" + ? await handleHistory(rest, io) + : command === "doctor" + ? await handleDoctor(rest, io) + : undefined; if (result === undefined) { throw new CliError(`Unknown command: ${command}`); @@ -873,6 +906,18 @@ async function handleInfer( return runInferCommand(args); } +async function handleTrain( + args: readonly string[], + io: CliIo, +): Promise<{ readonly json: boolean; readonly output: TrainCliResult } | "help"> { + if (args.length === 0 || args[0] === "--help" || args[0] === "-h") { + io.stdout(trainHelpText()); + return "help"; + } + + return runTrainCommand(args); +} + async function handleRewards( args: readonly string[], io: CliIo, @@ -1241,6 +1286,43 @@ async function runInferCommand( }; } +async function runTrainCommand( + args: readonly string[], +): Promise<{ readonly json: boolean; readonly output: TrainCliResult }> { + const options = parseTrainOptions(args); + const projectDir = resolve(options.projectDir); + const project = await openIgnitionProject(projectDir); + const manifest = project.getProjectManifest(); + const environment = manifest.environments.find((entry) => + entry.id === options.environmentId); + + if (environment === undefined) { + throw new CliError( + `Environment ${options.environmentId} is not registered in this project.`, + ); + } + + const result = await runKnownEnvironmentTraining(project, options); + const analysis = await project.readRunAnalysis(result.run.id); + + return { + json: options.json, + output: { + command: "train", + projectDir, + projectId: manifest.id, + projectName: manifest.name, + environmentId: result.run.envId, + learner: options.learner, + runId: result.run.id, + algorithm: result.run.algorithm, + summary: result.summary, + checkpoint: checkpointSummary(result.checkpoint), + metricNames: analysis.metricNames, + }, + }; +} + async function runRewardsCommand( args: readonly string[], ): Promise<{ readonly json: boolean; readonly output: RewardsCliResult }> { @@ -1956,6 +2038,84 @@ function parseInferOptions(args: readonly string[]): ParsedInferOptions { }; } +function parseTrainOptions(args: readonly string[]): ParsedTrainOptions { + const positional: string[] = []; + const values: { + learner?: TrainLearnerName; + episodes?: number; + maxSteps?: number; + seed?: string | number; + runId?: string; + checkpointId: string; + json: boolean; + } = { + checkpointId: "final", + json: false, + }; + + for (let index = 0; index < args.length; index += 1) { + const arg = args[index]; + + if (arg === undefined) { + continue; + } + + if (arg === "--json") { + values.json = true; + continue; + } + + if (arg.startsWith("--")) { + const option = parseLongOption(arg, args[index + 1]); + index += option.consumedValue ? 1 : 0; + + if (option.name === "learner") values.learner = parseTrainLearner(option.value); + else if (option.name === "episodes") values.episodes = parsePositiveInteger(option.value, option.name); + else if (option.name === "max-steps") values.maxSteps = parsePositiveInteger(option.value, option.name); + else if (option.name === "seed") values.seed = option.value; + else if (option.name === "run-id") values.runId = option.value; + else if (option.name === "checkpoint-id") values.checkpointId = option.value; + else throw new CliError(`Unknown option: --${option.name}`); + + continue; + } + + positional.push(arg); + } + + if (positional[0] === undefined) { + throw new CliError("Missing projectDir for train command."); + } + + if (positional[1] === undefined) { + throw new CliError("Missing environmentId for train command."); + } + + if (positional.length > 2) { + throw new CliError(`Unexpected positional argument: ${positional[2]}`); + } + + if (values.learner === undefined) { + throw new CliError("Missing --learner for train command."); + } + + if (values.checkpointId.trim().length === 0) { + throw new CliError("Option --checkpoint-id must be a non-empty string."); + } + + return { + projectDir: positional[0], + environmentId: positional[1], + learner: values.learner, + ...(values.episodes !== undefined ? { episodes: values.episodes } : {}), + ...(values.maxSteps !== undefined ? { maxSteps: values.maxSteps } : {}), + ...(values.seed !== undefined ? { seed: values.seed } : {}), + ...(values.runId !== undefined ? { runId: values.runId } : {}), + checkpointId: values.checkpointId, + json: values.json, + }; +} + function parseCheckpointsOptions(args: readonly string[]): ParsedCheckpointsOptions { const positional: string[] = []; const values: { @@ -3014,6 +3174,16 @@ function parseDirection(value: string): RunComparisonDirection { throw new CliError("Option --direction must be asc or desc."); } +function parseTrainLearner(value: string): TrainLearnerName { + if (value === "tabular-q" || value === "dqn" || value === "linear-policy-search") { + return value; + } + + throw new CliError( + "Option --learner must be one of: tabular-q, dqn, linear-policy-search.", + ); +} + function parseArtifactKind(value: string): StoredExportedArtifactKind { if (ARTIFACT_KIND_NAMES.includes(value as StoredExportedArtifactKind)) { return value as StoredExportedArtifactKind; @@ -3511,6 +3681,168 @@ async function runKnownCheckpointInference( ); } +async function runKnownEnvironmentTraining( + project: LocalProjectStore, + options: ParsedTrainOptions, +): Promise { + if (options.environmentId === GridWorld.id) { + if (options.learner === "tabular-q") { + return requireTrainingCheckpoint(await runLearnerExperiment({ + ...trainingCommon(project, options, "tabular-q"), + env: GridWorld, + learner: createTabularQLearner({ + seed: options.seed ?? 0, + epsilon: 0.2, + learningRate: 0.5, + discount: 0.95, + observationPrecision: 0, + }), + algorithm: "tabular-q-learning", + })); + } + + if (options.learner === "dqn") { + return requireTrainingCheckpoint(await runLearnerExperiment({ + ...trainingCommon(project, options, "dqn"), + env: GridWorld, + learner: createDqnLearner({ + seed: options.seed ?? 0, + }), + algorithm: "dqn", + })); + } + + throw unsupportedTrainLearner(options.environmentId, options.learner, [ + "tabular-q", + "dqn", + ]); + } + + if (options.environmentId === Target2D.id) { + if (options.learner === "tabular-q") { + return requireTrainingCheckpoint(await runLearnerExperiment({ + ...trainingCommon(project, options, "tabular-q"), + env: Target2D, + learner: createTabularQLearner({ + seed: options.seed ?? 0, + epsilon: 0.25, + learningRate: 0.4, + discount: 0.95, + observationPrecision: 1, + }), + algorithm: "tabular-q-learning", + })); + } + + if (options.learner === "dqn") { + return requireTrainingCheckpoint(await runLearnerExperiment({ + ...trainingCommon(project, options, "dqn"), + env: Target2D, + learner: createDqnLearner({ + seed: options.seed ?? 0, + }), + algorithm: "dqn", + })); + } + + throw unsupportedTrainLearner(options.environmentId, options.learner, [ + "tabular-q", + "dqn", + ]); + } + + if (options.environmentId === DroneTarget.id) { + if (options.learner === "linear-policy-search") { + return requireTrainingCheckpoint(await runLearnerExperiment({ + ...trainingCommon(project, options, "linear-policy-search"), + env: DroneTarget, + learner: createLinearPolicySearchLearner({ + seed: options.seed ?? 0, + populationSize: 4, + eliteCount: 2, + sigma: 0.08, + actionNoise: 0.01, + initialWeightScale: 0, + initialWeightMap: { + throttle: { + "target.dy": 0.35, + "agent.velocity.y": -0.45, + }, + yaw: { + "target.right": 0.18, + }, + pitch: { + "target.forward": 0.25, + }, + roll: { + "target.right": 0.25, + }, + }, + }), + algorithm: "linear-policy-search", + })); + } + + throw unsupportedTrainLearner(options.environmentId, options.learner, [ + "linear-policy-search", + ]); + } + + throw new CliError( + `Unsupported registered environment for train command: ${options.environmentId}.`, + ); +} + +function trainingCommon( + project: LocalProjectStore, + options: ParsedTrainOptions, + checkpointKind: TrainLearnerName, +) { + return { + store: project, + ...(options.episodes !== undefined ? { episodes: options.episodes } : {}), + ...(options.maxSteps !== undefined ? { maxSteps: options.maxSteps } : {}), + ...(options.seed !== undefined ? { seed: options.seed } : {}), + ...(options.runId !== undefined ? { runId: options.runId } : {}), + metadata: { + source: "cli-train", + learner: options.learner, + }, + checkpoint: { + id: options.checkpointId, + metadata: { + kind: checkpointKind, + }, + serialize: (learner: { + readonly toCheckpoint: () => unknown; + }) => learner.toCheckpoint(), + }, + }; +} + +function requireTrainingCheckpoint( + result: RunLearnerExperimentResult, +): RunLearnerExperimentResult & { readonly checkpoint: StoredCheckpoint } { + if (result.checkpoint === undefined) { + throw new CliError("Training completed without writing a checkpoint."); + } + + return { + ...result, + checkpoint: result.checkpoint, + }; +} + +function unsupportedTrainLearner( + environmentId: string, + learner: TrainLearnerName, + supportedLearners: readonly TrainLearnerName[], +): CliError { + return new CliError( + `Learner ${learner} is not supported for ${environmentId}. Supported learners: ${supportedLearners.join(", ")}.`, + ); +} + function replayResultFromView( projectDir: string, runId: string, @@ -3741,6 +4073,10 @@ function formatDemoResult(result: CliResult): string { return formatInferResult(result); } + if (result.command === "train") { + return formatTrainResult(result); + } + if (result.command === "replay") { return formatReplayResult(result); } @@ -4122,6 +4458,27 @@ function formatInferResult(result: InferCliResult): string { return lines.join("\n"); } +function formatTrainResult(result: TrainCliResult): string { + const lines = [ + `IgnitionRL training complete: ${result.runId}`, + `Project: ${result.projectDir}`, + `Environment: ${result.environmentId}`, + `Learner: ${result.learner}`, + `Algorithm: ${result.algorithm}`, + `Episodes: ${result.summary.episodes}`, + `Total reward: ${formatNumber(result.summary.totalReward)}`, + `Success rate: ${formatNumber(result.summary.successRate)}`, + `Checkpoint: ${result.checkpoint.id} (${result.checkpoint.path})`, + "Metrics:", + ...result.metricNames.slice(0, 12).map((name) => `- ${name}`), + ...(result.metricNames.length > 12 + ? [`- ... ${result.metricNames.length - 12} more`] + : []), + ]; + + return lines.join("\n"); +} + function formatReplayResult(result: ReplayCliResult): string { const lines = [ `IgnitionRL replay: ${result.runId}`, @@ -4200,6 +4557,7 @@ function helpText(): string { " ignitionrl init [projectDir] [--json]", " ignitionrl demo [projectDir] [options]", " ignitionrl inspect [options]", + " ignitionrl train --learner [options]", " ignitionrl replay [options]", " ignitionrl infer [options]", " ignitionrl rewards [options]", @@ -4220,6 +4578,7 @@ function helpText(): string { "Run `ignitionrl init --help` for template options.", "Run `ignitionrl demo --help` for demo options.", "Run `ignitionrl inspect --help` for inspect options.", + "Run `ignitionrl train --help` for headless training options.", "Run `ignitionrl replay --help` for replay options.", "Run `ignitionrl infer --help` for checkpoint inference options.", "Run `ignitionrl rewards --help` for reward debugger options.", @@ -4298,6 +4657,19 @@ function inferHelpText(): string { ].join("\n"); } +function trainHelpText(): string { + return [ + "IgnitionRL train command", + "", + "Usage:", + " ignitionrl train --learner tabular-q [--episodes n] [--max-steps n] [--seed value] [--run-id id] [--checkpoint-id id] [--json]", + " ignitionrl train --learner dqn [--episodes n] [--max-steps n] [--seed value] [--run-id id] [--checkpoint-id id] [--json]", + " ignitionrl train --learner linear-policy-search [--episodes n] [--max-steps n] [--seed value] [--run-id id] [--checkpoint-id id] [--json]", + "", + "Opens an existing IgnitionRL project, verifies the environment is registered, runs the selected learner, and persists traces, metrics and a checkpoint.", + ].join("\n"); +} + function rewardsHelpText(): string { return [ "IgnitionRL rewards command", diff --git a/packages/cli/test/cli.test.ts b/packages/cli/test/cli.test.ts index 344202d..697c585 100644 --- a/packages/cli/test/cli.test.ts +++ b/packages/cli/test/cli.test.ts @@ -183,6 +183,90 @@ describe("IgnitionRL CLI", () => { }); }); + test("trains a registered environment headlessly and writes a checkpoint", async () => { + await withProjectDir(async (dir) => { + const init = createIo(); + const train = createIo(); + + await runCli([ + "init", + "grid-world", + dir, + "--json", + ], init.io); + + const code = await runCli([ + "train", + dir, + "GridWorld-v0", + "--learner", + "tabular-q", + "--episodes", + "4", + "--max-steps", + "8", + "--seed", + "cli-train", + "--run-id", + "grid-world-train", + "--checkpoint-id", + "trained", + "--json", + ], train.io); + const parsed = JSON.parse(train.stdout[0] ?? "{}") as { + readonly command: string; + readonly projectId: string; + readonly environmentId: string; + readonly learner: string; + readonly runId: string; + readonly algorithm: string; + readonly summary: { readonly episodes: number }; + readonly checkpoint: { + readonly id: string; + readonly algorithm: string; + readonly path: string; + }; + readonly metricNames: readonly string[]; + }; + const run = JSON.parse( + await readFile(join(dir, "runs", "grid-world-train", "run.json"), "utf8"), + ) as { + readonly status: string; + readonly config?: { readonly learnerConfig?: unknown }; + }; + const checkpoint = JSON.parse( + await readFile(join(dir, parsed.checkpoint.path), "utf8"), + ) as { + readonly algorithm: string; + }; + + expect(code).toBe(0); + expect(train.stderr).toEqual([]); + expect(parsed).toMatchObject({ + command: "train", + projectId: "grid-world-project", + environmentId: "GridWorld-v0", + learner: "tabular-q", + runId: "grid-world-train", + algorithm: "tabular-q-learning", + summary: { + episodes: 4, + }, + checkpoint: { + id: "trained", + algorithm: "tabular-q-learning", + }, + }); + expect(parsed.metricNames).toContain("learner.transitions"); + expect(run.status).toBe("completed"); + expect(run.config?.learnerConfig).toMatchObject({ + seed: "cli-train", + observationPrecision: 0, + }); + expect(checkpoint.algorithm).toBe("tabular-q-learning"); + }); + }); + test("inspects an existing project and refreshes exports", async () => { await withProjectDir(async (dir) => { const demo = createIo(); @@ -1746,6 +1830,36 @@ describe("IgnitionRL CLI", () => { expect(stderr[0]).toContain("--max-steps"); }); + test("reports invalid train arguments", async () => { + await withProjectDir(async (dir) => { + const init = createIo(); + const missingLearner = createIo(); + const unsupported = createIo(); + + await runCli(["init", "grid-world", dir, "--json"], init.io); + + const missingLearnerCode = await runCli([ + "train", + dir, + "GridWorld-v0", + ], missingLearner.io); + const unsupportedCode = await runCli([ + "train", + dir, + "GridWorld-v0", + "--learner", + "linear-policy-search", + ], unsupported.io); + + expect(missingLearnerCode).toBe(1); + expect(missingLearner.stdout).toEqual([]); + expect(missingLearner.stderr[0]).toContain("--learner"); + expect(unsupportedCode).toBe(1); + expect(unsupported.stdout).toEqual([]); + expect(unsupported.stderr[0]).toContain("not supported"); + }); + }); + test("reports invalid inspect direction", async () => { const { stdout, stderr, io } = createIo(); const code = await runCli(["inspect", "./missing.ignitionrl", "--direction", "sideways"], io);