Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions scripts/ci-smoke.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const results = [
runGridWorldSmoke(),
await runDqnSmoke(),
runDroneTargetSmoke(),
runTrainingRegressionSmoke(),
];

console.log(
Expand Down Expand Up @@ -832,6 +833,142 @@ async function runDqnSmoke() {
});
}

function runTrainingRegressionSmoke() {
return withProjectDir("ignitionrl-training-regression-smoke-", (projectDir) => {
const gridProjectDir = join(projectDir, "grid-world.ignitionrl");
const droneProjectDir = join(projectDir, "drone-target.ignitionrl");

runCli(["init", "grid-world", gridProjectDir, "--json"]);
const gridTrain = runCli([
"train",
gridProjectDir,
"GridWorld-v0",
"--learner",
"tabular-q",
"--episodes",
"12",
"--max-steps",
"20",
"--seed",
"ci-training-regression-grid",
"--run-id",
"grid-world-training-regression",
"--checkpoint-id",
"final",
"--json",
]);
const gridMetrics = runCli([
"metrics",
gridProjectDir,
"totalReward",
"--run-id",
"grid-world-training-regression",
"--json",
]);
const gridMetricPoints = metricPointCount(
gridMetrics,
"grid-world-training-regression",
"totalReward",
);

assertTrainingRun(gridTrain, {
environmentId: "GridWorld-v0",
learner: "tabular-q",
runId: "grid-world-training-regression",
});
assertTrainingMetricListed(gridTrain, "learner.transitions");
assertTrainingMetricAtLeast(
gridTrain,
"summary.successRate",
gridTrain.summary.successRate,
0.8,
);
assertTrainingMetricAtLeast(
gridTrain,
"summary.totalReward",
gridTrain.summary.totalReward,
100,
);
assertTrainingMetricAtLeast(gridTrain, "metricPoints.totalReward", gridMetricPoints, 1);

runCli(["init", "drone-target", droneProjectDir, "--json"]);
const droneTrain = runCli([
"train",
droneProjectDir,
"DroneTarget-v0",
"--learner",
"linear-policy-search",
"--episodes",
"2",
"--max-steps",
"12",
"--seed",
"ci-training-regression-drone",
"--run-id",
"drone-target-training-regression",
"--checkpoint-id",
"final",
"--json",
]);
const droneMetrics = runCli([
"metrics",
droneProjectDir,
"learner.bestReward",
"--run-id",
"drone-target-training-regression",
"--json",
]);
const droneMetricPoints = metricPointCount(
droneMetrics,
"drone-target-training-regression",
"learner.bestReward",
);

assertTrainingRun(droneTrain, {
environmentId: "DroneTarget-v0",
learner: "linear-policy-search",
runId: "drone-target-training-regression",
});
assertTrainingMetricListed(droneTrain, "learner.bestReward");
assertTrainingMetricAtLeast(
droneTrain,
"summary.successRate",
droneTrain.summary.successRate,
0.5,
);
assertTrainingMetricAtLeast(
droneTrain,
"summary.bestReward",
droneTrain.summary.bestReward,
20,
);
assertTrainingMetricAtLeast(
droneTrain,
"metricPoints.learner.bestReward",
droneMetricPoints,
1,
);

return {
name: "TrainingRegression",
runs: [gridTrain.runId, droneTrain.runId],
doctorChecks: 0,
evalChecks: 5,
environments: 2,
environmentRuns: 2,
historyRows: 0,
studioSelectedRun: droneTrain.runId,
runDetailArtifacts: 0,
artifacts: 0,
episodes: gridTrain.summary.episodes + droneTrain.summary.episodes,
metricPoints: gridMetricPoints + droneMetricPoints,
rewardTerms: 0,
checkpoints: 2,
checkpointSource: `${gridTrain.runId}, ${droneTrain.runId}`,
};
});
}

function withProjectDir(prefix, fn) {
const projectDir = mkdtempSync(join(tmpdir(), prefix));

Expand Down Expand Up @@ -879,3 +1016,79 @@ function assert(condition, message) {
throw new Error(message);
}
}

function assertTrainingRun(result, expected) {
assert(
result.command === "train",
"training regression command returned an unexpected payload",
);
assert(
result.environmentId === expected.environmentId,
`training regression selected wrong environment: run=${result.runId} expected=${expected.environmentId} actual=${result.environmentId}`,
);
assert(
result.learner === expected.learner,
`training regression selected wrong learner: run=${result.runId} expected=${expected.learner} actual=${result.learner}`,
);
assert(
result.runId === expected.runId,
`training regression selected wrong run: expected=${expected.runId} actual=${result.runId}`,
);
assert(
result.checkpoint?.id === "final",
`training regression did not persist final checkpoint: run=${result.runId}`,
);
}

function assertTrainingMetricListed(result, metric) {
if (!result.metricNames.includes(metric)) {
throw new Error(
[
"Training regression failed:",
`run=${result.runId}`,
`environment=${result.environmentId}`,
`learner=${result.learner}`,
`metric=${metric}`,
"reason=missing_metric_name",
].join(" "),
);
}
}

function assertTrainingMetricAtLeast(result, metric, actual, expected) {
if (typeof actual !== "number" || !Number.isFinite(actual) || actual < expected) {
throw new Error(
[
"Training regression failed:",
`run=${result.runId}`,
`environment=${result.environmentId}`,
`learner=${result.learner}`,
`metric=${metric}`,
`actual=${String(actual)}`,
`expected>=${String(expected)}`,
].join(" "),
);
}
}

function metricPointCount(result, runId, metric) {
assert(
result.command === "metrics",
"training regression metrics command returned an unexpected payload",
);

const run = result.runs.find((entry) => entry.id === runId);

if (run === undefined) {
throw new Error(
[
"Training regression failed:",
`run=${runId}`,
`metric=${metric}`,
"reason=missing_metric_run",
].join(" "),
);
}

return run.points.length;
}
Loading