feat(loss): add --loss-aggregation pg_loss modes#1350
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a unified --loss-aggregation configuration to support multiple loss aggregation modes (sample_mean, prompt_mean, token_mean, and constant) across the training pipeline, including support for Context Parallelism (CP). Key changes include updating the CP-aware reducer to accept custom denominators, resolving loss aggregation modes dynamically, precomputing step-level prompt group denominators for DAPO (prompt_mean), and adding validation and comprehensive unit tests. The review feedback suggests adding input length validation checks in both cp_utils.py and train_data_conversion.py to prevent silent mismatches, ensuring sample_denoms are on the correct device, and asserting that sample.group_index is not None during prompt-mean calculations to avoid silent grouping errors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if sample_denoms is None: | ||
| sample_denoms = [loss_mask.sum() for loss_mask in loss_masks] |
There was a problem hiding this comment.
To prevent potential device mismatch errors or slow host-to-device copies during loss computation, ensure that the provided sample_denoms are always converted to PyTorch tensors on the same device as their corresponding loss_masks using torch.as_tensor. Additionally, validate that the list lengths match up front to prevent silent mismatches or broadcast failures.
if sample_denoms is not None and len(sample_denoms) != len(loss_masks):
raise ValueError(f"Length mismatch: sample_denoms ({len(sample_denoms)}) and loss_masks ({len(loss_masks)}) must have the same length.")
if sample_denoms is None:
sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]
else:
sample_denoms = [
torch.as_tensor(denom, dtype=torch.float32, device=loss_mask.device)
for denom, loss_mask in zip(sample_denoms, loss_masks)
]References
- When processing lists of sequence-level or token-level tensors in RL or distillation pipelines, validate that the list lengths match up front to prevent silent mismatches or broadcast failures.
There was a problem hiding this comment.
Pushing back here. sample_denoms are already guaranteed on the correct device at reduction time: in production they are built in data.py:get_rollout_data via torch.tensor(d, dtype=torch.float32, device=torch.cuda.current_device()) — the same call used for loss_masks — before any reducer runs; the constant path builds them with device=loss_mask.device; and the sample_denoms=None default uses loss_mask.sum() (already on-device). So no H2D copy happens in the hot loop, and adding a per-sample torch.as_tensor(..., device=...) here would only mask a genuine device bug rather than surface it. On the length check: the reducer's zip(..., strict=True) already raises ValueError on any length mismatch with zero extra code, so an explicit len() guard is redundant.
| if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean": | ||
| group_mask_totals: dict[int, int] = {} | ||
| for sample, loss_mask in zip(samples, loss_masks, strict=True): | ||
| group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask) | ||
| train_data["sample_denoms"] = [group_mask_totals[sample.group_index] for sample in samples] |
There was a problem hiding this comment.
If sample.group_index is None or missing, all such samples will be silently grouped together under group_mask_totals[None], leading to incorrect denominator calculations and corrupted training loss. Adding an explicit assertion prevents this silent failure. Additionally, validate that the list lengths of samples and loss_masks match up front to prevent silent mismatches.
| if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean": | |
| group_mask_totals: dict[int, int] = {} | |
| for sample, loss_mask in zip(samples, loss_masks, strict=True): | |
| group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask) | |
| train_data["sample_denoms"] = [group_mask_totals[sample.group_index] for sample in samples] | |
| if len(samples) != len(loss_masks): | |
| raise ValueError(f"Length mismatch: samples ({len(samples)}) and loss_masks ({len(loss_masks)}) must have the same length.") | |
| if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean": | |
| group_mask_totals: dict[int, int] = {} | |
| for sample, loss_mask in zip(samples, loss_masks): | |
| assert sample.group_index is not None, "sample.group_index must not be None when using prompt_mean" | |
| group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask) | |
| train_data["sample_denoms"] = [group_mask_totals[sample.group_index] for sample in samples] |
References
- When processing lists of sequence-level or token-level tensors in RL or distillation pipelines, validate that the list lengths match up front to prevent silent mismatches or broadcast failures.
There was a problem hiding this comment.
Good catch — accepted. Sample.group_index is typed int | None, so a missing value would silently merge unrelated prompts under group_mask_totals[None] and corrupt the denominator. Added a fail-loud raise ValueError("--loss-aggregation prompt_mean requires every Sample.group_index to be set.") inside the prompt_mean loop (used raise rather than assert so it survives python -O), with a load-bearing test. The length-mismatch part is already covered by the loop's zip(..., strict=True), so I did not add a separate len() check.
b175b59 to
d55c05a
Compare
d55c05a to
8027362
Compare
8027362 to
228307f
Compare
228307f to
24a6f76
Compare
24a6f76 to
ed42c2f
Compare
ed42c2f to
53ddaa3
Compare
Add first-class pg_loss aggregation modes: sample_mean, prompt_mean, token_mean, and constant. The default remains sample_mean; --calculate-per-token-loss is reconciled as the legacy spelling of token_mean. This syncs the current THUDM/slime#2090 behavior for prompt_mean: rollout conversion emits prompt_mask_sums, the reducer scales by n_samples_per_prompt, and the final scalar is the direct mean over prompt groups. Startup validation requires global_batch_size to be a multiple of n_samples_per_prompt so prompt groups stay whole within a train step. constant requires --loss-aggregation-divisor and remains incompatible with the per-token path. The custom pg_loss reducer hook keeps precedence, and non-pg metrics keep the default sample-mean reducer. Validation: uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py -q Validation: uv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py Validation: uv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
53ddaa3 to
d0be4cc
Compare
|
Hi @yueming-yuan, resurfacing this for visibility since the source slime PR has moved forward. I synced the Miles port with the current THUDM/slime#2090 behavior: Would appreciate your review when you have a chance. |
|
Closing in favor of a fresh PR from |
Port of THUDM/slime#2090 for miles.
Summary:
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}for pg_losssample_meanas the default and reconcile legacy--calculate-per-token-losswithtoken_meanprompt_meanbehavior: rollout conversion emitsprompt_mask_sums, the reducer scales byn_samples_per_prompt, and the final scalar is the direct mean over prompt groupsprompt_meanconfigs withglobal_batch_size % n_samples_per_prompt == 0, so each prompt group stays whole within a train stepValidation:
uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py -q-> 27 passeduv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py-> passeduv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py-> passedgit diff --check upstream/main-> passed