Skip to content

Fix MLite microbatch loss and forward-only output contracts#68

Open
ISEEKYAN wants to merge 4 commits into
mainfrom
mlite-dapo-loss-micro-forward-only
Open

Fix MLite microbatch loss and forward-only output contracts#68
ISEEKYAN wants to merge 4 commits into
mainfrom
mlite-dapo-loss-micro-forward-only

Conversation

@ISEEKYAN

@ISEEKYAN ISEEKYAN commented Jun 28, 2026

Copy link
Copy Markdown
Owner

Summary

  • Mirror VERL’s Megatron loss-reduction hook: pass logical_loss * num_microbatches to the schedule while retaining MLite’s standard schedule-side microbatch averaging.
  • Keep backward loss separate from per-microbatch reporting, preserving every original loss, model output, Metric accumulator, and plain metric through the reduction store.
  • Preserve loss-context propagation across PP/VPP and PP1 forward-only token log probabilities with optional entropy.

Why

VERL PPO/SFT losses are already contributions normalized against the logical global batch. Megatron does not change its runtime API for this case: its VERL postprocess hook compensates for the schedule’s fixed microbatch averaging and reports the unscaled reduction payload separately. MLite now follows the same contract, keeping connector-specific normalization out of the public runtime interface.

Scope

This PR contains only MLite runtime/connector code and focused tests. It does not include launch scripts, training configurations, or changes to the external VERL repository.

Validation

  • Focused pytest: 59 passed (test_loss_microbatch_contract, test_ops_data_trainstep_unit, test_runtime_backend_unit, test_bridge_backend, and test_mlite_engine_forward_only).
  • Slurm validation job 13202007: COMPLETED, exit code 0:0.
  • Local focused pytest: 17 passed.
  • Ruff checks, Python compile checks, and git diff --check passed.
  • Commit history, filenames, and full diff passed mechanical internal-identifier scans; the branch contains one commit on current main.

Mirror VERL’s Megatron loss-reduction hook so schedules retain standard microbatch averaging while logical-batch PPO gradients and per-micro reporting remain correct. Preserve loss context propagation, all-micro metric aggregation, and PP1 forward-only outputs.
@ISEEKYAN ISEEKYAN force-pushed the mlite-dapo-loss-micro-forward-only branch from 8a5d864 to 5d2f0c9 Compare June 28, 2026 06:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant