Skip to content

feat(loss): add --loss-aggregation pg_loss modes#1350

Closed
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes
Closed

feat(loss): add --loss-aggregation pg_loss modes#1350
EazyReal wants to merge 1 commit into
radixark:mainfrom
EazyReal:upstream-pr/loss-aggregation-modes

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 16, 2026

Copy link
Copy Markdown

Port of THUDM/slime#2090 for miles.

Summary:

  • add --loss-aggregation {sample_mean,prompt_mean,token_mean,constant} for pg_loss
  • keep sample_mean as the default and reconcile legacy --calculate-per-token-loss with token_mean
  • sync the current slime prompt_mean behavior: rollout conversion emits prompt_mask_sums, the reducer scales by n_samples_per_prompt, and the final scalar is the direct mean over prompt groups
  • validate prompt_mean configs with global_batch_size % n_samples_per_prompt == 0, so each prompt group stays whole within a train step
  • add constant-divisor aggregation for Dr.GRPO while keeping non-pg metrics on the default reducer
  • include docs and focused CPU coverage for reducer math, argument validation, and prompt_mean rollout metadata

Validation:

  • uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py -q -> 27 passed
  • uv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -> passed
  • uv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -> passed
  • git diff --check upstream/main -> passed

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified --loss-aggregation configuration to support multiple loss aggregation modes (sample_mean, prompt_mean, token_mean, and constant) across the training pipeline, including support for Context Parallelism (CP). Key changes include updating the CP-aware reducer to accept custom denominators, resolving loss aggregation modes dynamically, precomputing step-level prompt group denominators for DAPO (prompt_mean), and adding validation and comprehensive unit tests. The review feedback suggests adding input length validation checks in both cp_utils.py and train_data_conversion.py to prevent silent mismatches, ensuring sample_denoms are on the correct device, and asserting that sample.group_index is not None during prompt-mean calculations to avoid silent grouping errors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +117 to +118
if sample_denoms is None:
sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential device mismatch errors or slow host-to-device copies during loss computation, ensure that the provided sample_denoms are always converted to PyTorch tensors on the same device as their corresponding loss_masks using torch.as_tensor. Additionally, validate that the list lengths match up front to prevent silent mismatches or broadcast failures.

    if sample_denoms is not None and len(sample_denoms) != len(loss_masks):
        raise ValueError(f"Length mismatch: sample_denoms ({len(sample_denoms)}) and loss_masks ({len(loss_masks)}) must have the same length.")
    if sample_denoms is None:
        sample_denoms = [loss_mask.sum() for loss_mask in loss_masks]
    else:
        sample_denoms = [
            torch.as_tensor(denom, dtype=torch.float32, device=loss_mask.device)
            for denom, loss_mask in zip(sample_denoms, loss_masks)
        ]
References
  1. When processing lists of sequence-level or token-level tensors in RL or distillation pipelines, validate that the list lengths match up front to prevent silent mismatches or broadcast failures.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing back here. sample_denoms are already guaranteed on the correct device at reduction time: in production they are built in data.py:get_rollout_data via torch.tensor(d, dtype=torch.float32, device=torch.cuda.current_device()) — the same call used for loss_masks — before any reducer runs; the constant path builds them with device=loss_mask.device; and the sample_denoms=None default uses loss_mask.sum() (already on-device). So no H2D copy happens in the hot loop, and adding a per-sample torch.as_tensor(..., device=...) here would only mask a genuine device bug rather than surface it. On the length check: the reducer's zip(..., strict=True) already raises ValueError on any length mismatch with zero extra code, so an explicit len() guard is redundant.

Comment on lines +72 to +76
if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean":
group_mask_totals: dict[int, int] = {}
for sample, loss_mask in zip(samples, loss_masks, strict=True):
group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask)
train_data["sample_denoms"] = [group_mask_totals[sample.group_index] for sample in samples]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If sample.group_index is None or missing, all such samples will be silently grouped together under group_mask_totals[None], leading to incorrect denominator calculations and corrupted training loss. Adding an explicit assertion prevents this silent failure. Additionally, validate that the list lengths of samples and loss_masks match up front to prevent silent mismatches.

Suggested change
if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean":
group_mask_totals: dict[int, int] = {}
for sample, loss_mask in zip(samples, loss_masks, strict=True):
group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask)
train_data["sample_denoms"] = [group_mask_totals[sample.group_index] for sample in samples]
if len(samples) != len(loss_masks):
raise ValueError(f"Length mismatch: samples ({len(samples)}) and loss_masks ({len(loss_masks)}) must have the same length.")
if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean":
group_mask_totals: dict[int, int] = {}
for sample, loss_mask in zip(samples, loss_masks):
assert sample.group_index is not None, "sample.group_index must not be None when using prompt_mean"
group_mask_totals[sample.group_index] = group_mask_totals.get(sample.group_index, 0) + sum(loss_mask)
train_data["sample_denoms"] = [group_mask_totals[sample.group_index] for sample in samples]
References
  1. When processing lists of sequence-level or token-level tensors in RL or distillation pipelines, validate that the list lengths match up front to prevent silent mismatches or broadcast failures.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — accepted. Sample.group_index is typed int | None, so a missing value would silently merge unrelated prompts under group_mask_totals[None] and corrupt the denominator. Added a fail-loud raise ValueError("--loss-aggregation prompt_mean requires every Sample.group_index to be set.") inside the prompt_mean loop (used raise rather than assert so it survives python -O), with a load-bearing test. The length-mismatch part is already covered by the loop's zip(..., strict=True), so I did not add a separate len() check.

@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from b175b59 to d55c05a Compare June 16, 2026 07:05
@EazyReal EazyReal marked this pull request as draft June 16, 2026 07:12
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from d55c05a to 8027362 Compare June 16, 2026 07:53
@EazyReal EazyReal marked this pull request as ready for review June 16, 2026 17:10
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from 8027362 to 228307f Compare June 17, 2026 08:17
@EazyReal EazyReal marked this pull request as draft June 17, 2026 08:36
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from 228307f to 24a6f76 Compare June 17, 2026 08:39
@EazyReal EazyReal marked this pull request as ready for review June 17, 2026 08:54
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from 24a6f76 to ed42c2f Compare June 27, 2026 02:11
@EazyReal EazyReal changed the title Add --loss-aggregation: sample_mean / prompt_mean / token_mean / constant feat(loss): add --loss-aggregation pg_loss modes Jun 27, 2026
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from ed42c2f to 53ddaa3 Compare June 27, 2026 02:33
Add first-class pg_loss aggregation modes: sample_mean, prompt_mean, token_mean, and constant. The default remains sample_mean; --calculate-per-token-loss is reconciled as the legacy spelling of token_mean.

This syncs the current THUDM/slime#2090 behavior for prompt_mean: rollout conversion emits prompt_mask_sums, the reducer scales by n_samples_per_prompt, and the final scalar is the direct mean over prompt groups. Startup validation requires global_batch_size to be a multiple of n_samples_per_prompt so prompt groups stay whole within a train step.

constant requires --loss-aggregation-divisor and remains incompatible with the per-token path. The custom pg_loss reducer hook keeps precedence, and non-pg metrics keep the default sample-mean reducer.

Validation: uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py -q

Validation: uv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py

Validation: uv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes branch from 53ddaa3 to d0be4cc Compare June 27, 2026 04:25
@EazyReal

EazyReal commented Jun 27, 2026

Copy link
Copy Markdown
Author

Hi @yueming-yuan, resurfacing this for visibility since the source slime PR has moved forward.

I synced the Miles port with the current THUDM/slime#2090 behavior: prompt_mean now uses prompt_mask_sums, scales the reducer by n_samples_per_prompt, and validates that global_batch_size is a multiple of n_samples_per_prompt, so the final scalar is the direct mean over prompt groups. The branch is rebased as one commit and the focused pytest/ruff/black checks are listed in the PR body.

Would appreciate your review when you have a chance.

@EazyReal

Copy link
Copy Markdown
Author

Closing in favor of a fresh PR from upstream-pr/loss-aggregation-modes-v2 for visibility. The new branch carries the same cleaned, slime-synced commit.

@EazyReal EazyReal closed this Jun 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant