Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions docs/spikes/ppo-drone-target-v0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# PPO Path for DroneTarget-v0

Status: accepted spike

Issue: #52

## Decision

Implement PPO as a TypeScript-first learner behind the existing neural adapter contract, but keep the checkpoint and backend descriptor native-ready from day one.

The next implementation should not start in Rust. DroneTarget-v0 still needs fast iteration on observation labels, action bounds, rollout metrics, Studio panels and checkpoint replay. TypeScript is the shortest path to validate the environment-facing contract. Native Burn/Candle becomes the right move after the TS learner proves the rollout shape, metrics and replay semantics.

## Rollout Batch

DroneTarget-v0 is bounded continuous control, so PPO should use the existing continuous action spec:

- observations: `Float32Array` vectors using the environment observation shape;
- actions: bounded continuous vectors matching `actions.shape`;
- policy distribution: diagonal Gaussian with means from the actor head and log standard deviations in config or trainable parameters;
- sampled actions: clamp or tanh-rescale to the environment low/high bounds;
- rollout fields: `observation`, `action`, `logProbability`, `reward`, `done`, `value`, `nextValue`, `advantage`, `return`, `episode`, `step`.

Batching should start single-process and sequential, then move to vectorized environments later:

- collect `rolloutSteps` transitions, default `2048`;
- reset on terminal episodes and continue filling the same rollout;
- compute generalized advantage estimation after collection;
- normalize advantages per rollout before update;
- emit per-rollout metrics before optimizer updates.

## Advantage Estimation

Use GAE(lambda):

```txt
delta_t = reward_t + gamma * value_{t+1} * (1 - done_t) - value_t
advantage_t = delta_t + gamma * lambda * (1 - done_t) * advantage_{t+1}
return_t = advantage_t + value_t
```

Initial defaults:

- `gamma: 0.99`;
- `gaeLambda: 0.95`;
- `rolloutSteps: 2048`;
- `epochs: 10`;
- `minibatches: 32`;
- `clipRatio: 0.2`;
- `valueLossCoefficient: 0.5`;
- `entropyCoefficient: 0.01`;
- `maxGradientNorm: 0.5`.

## Update Shape

The TypeScript PPO learner should use a small shared MLP trunk with separate actor and value heads:

- actor mean head: one output per continuous action dimension;
- actor log standard deviation: vector matching action size;
- value head: scalar;
- loss: clipped policy objective plus value loss minus entropy bonus;
- metrics: `policyLoss`, `valueLoss`, `entropy`, `approxKl`, `clipFraction`, `explainedVariance`, `rolloutRewardMean`, `rolloutLengthMean`, `successRate`.

The native backend should receive the same batch fields and return the same metric names. That keeps Studio and CI independent from whether PPO runs in TS, Burn, Candle or a remote trainer.

## Checkpoint Shape

Use `createNeuralCheckpointEnvelope()` with:

- `algorithm: "ppo"`;
- `actionSpace` copied from `defineNeuralLearnerAdapterContract(DroneTarget.getSpec(), { algorithm: "ppo" })`;
- `updateCadence.type: "rollout"`;
- payload version;
- config;
- actor/value network weights;
- action normalization bounds;
- observation normalization stats when added;
- optimizer state when needed;
- metrics snapshot.

Inference replay should load the envelope, rebuild the deterministic actor policy, and run with mean actions instead of sampling.

## Minimal Experiment

The script `scripts/spike-ppo-drone-target.mjs` collects deterministic DroneTarget-v0 rollout metrics for random and heuristic policies, then computes GAE-shaped rollout statistics with zero value estimates. It does not train PPO; it validates the metric payload and rollout accounting needed before implementation.

Run:

```sh
bun run build:packages
bun scripts/spike-ppo-drone-target.mjs
```

Useful output fields:

- `baselines.random.successRate`;
- `baselines.heuristic.successRate`;
- `rollout.steps`;
- `rollout.rewardMean`;
- `rollout.advantageStdDev`;
- `rollout.actionMean`;
- `rollout.actionStdDev`.

## Blockers Before Full PPO

- Choose a TS autograd/tensor dependency or implement a tiny local MLP optimizer first.
- Decide whether first PPO clamps actions or uses tanh-squashed Gaussian log-prob correction.
- Add observation normalization before comparing PPO quality across seeds.
- Add checkpoint replay for continuous neural policies in CLI/Studio.
- Define CI thresholds that are stable across machines without requiring long training runs.
165 changes: 165 additions & 0 deletions scripts/spike-ppo-drone-target.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import { randomPolicy } from "../packages/core/dist/index.js";
import {
DroneTarget,
droneTargetHeuristicAction,
} from "../packages/examples/dist/drone-target.js";

const EPISODES = 6;
const MAX_STEPS = 120;
const GAMMA = 0.99;
const GAE_LAMBDA = 0.95;

const result = await runSpike();

console.log(JSON.stringify(result, null, 2));

async function runSpike() {
const random = await collectPolicy(
randomPolicy(DroneTarget.actions, {
seed: "ppo-spike:random-policy",
name: "random",
}),
"random",
);
const heuristic = await collectPolicy({
name: "heuristic",
act: ({ state }) => droneTargetHeuristicAction(state),
}, "heuristic");
const rollout = rolloutMetrics(heuristic.traces);

return {
environment: DroneTarget.id,
recommendation: "typescript-first-hybrid-ready",
config: {
episodes: EPISODES,
maxSteps: MAX_STEPS,
gamma: GAMMA,
gaeLambda: GAE_LAMBDA,
},
baselines: {
random: random.summary,
heuristic: heuristic.summary,
},
rollout,
};
}

async function collectPolicy(policy, label) {
const runner = DroneTarget.createRunner({
seed: `ppo-spike:${label}:0`,
maxSteps: MAX_STEPS,
collectTrace: true,
});
const traces = [];

for (let episode = 0; episode < EPISODES; episode += 1) {
if (episode > 0) {
runner.reset({
seed: `ppo-spike:${label}:${episode}`,
episodeId: `${label}:episode:${episode}`,
});
}

traces.push(await runner.runEpisode(policy, {
reset: false,
maxSteps: MAX_STEPS,
}));
}

return {
traces,
summary: summarizeTraces(traces),
};
}

function summarizeTraces(traces) {
const rewards = traces.map((trace) => trace.summary.totalReward);
const lengths = traces.map((trace) => trace.summary.length);

return {
episodes: traces.length,
steps: sum(lengths),
rewardMean: mean(rewards),
rewardStdDev: standardDeviation(rewards),
lengthMean: mean(lengths),
successRate: traces.filter((trace) => trace.summary.success).length / traces.length,
};
}

function rolloutMetrics(traces) {
const steps = traces.flatMap((trace, episode) =>
trace.steps.map((step) => ({
episode,
reward: step.reward,
done: step.done,
action: step.action,
})));
const advantages = generalizedAdvantages(steps);
const rewards = steps.map((step) => step.reward);
const lengths = traces.map((trace) => trace.summary.length);
const actions = steps.map((step) => step.action);

return {
episodes: traces.length,
steps: steps.length,
rewardMean: mean(rewards),
rewardStdDev: standardDeviation(rewards),
lengthMean: mean(lengths),
lengthStdDev: standardDeviation(lengths),
successRate: traces.filter((trace) => trace.summary.success).length / traces.length,
advantageMean: mean(advantages),
advantageStdDev: standardDeviation(advantages),
returnMean: mean(advantages),
actionMean: vectorMean(actions),
actionStdDev: vectorStdDev(actions),
};
}

function generalizedAdvantages(steps) {
const advantages = Array.from({ length: steps.length }, () => 0);
let nextAdvantage = 0;

for (let index = steps.length - 1; index >= 0; index -= 1) {
const step = steps[index];
const mask = step.done ? 0 : 1;
const delta = step.reward;
const advantage = delta + GAMMA * GAE_LAMBDA * mask * nextAdvantage;

advantages[index] = advantage;
nextAdvantage = advantage;
}

return advantages;
}

function vectorMean(vectors) {
if (vectors.length === 0) return [];

return Array.from({ length: vectors[0].length }, (_, index) =>
mean(vectors.map((vector) => vector[index] ?? 0)));
}

function vectorStdDev(vectors) {
if (vectors.length === 0) return [];

const means = vectorMean(vectors);

return means.map((meanValue, index) =>
Math.sqrt(mean(vectors.map((vector) => ((vector[index] ?? 0) - meanValue) ** 2))));
}

function mean(values) {
return values.length === 0 ? 0 : sum(values) / values.length;
}

function standardDeviation(values) {
if (values.length === 0) return 0;

const meanValue = mean(values);

return Math.sqrt(mean(values.map((value) => (value - meanValue) ** 2)));
}

function sum(values) {
return values.reduce((total, value) => total + value, 0);
}
Loading