feat(loss): support pg_loss aggregation modes#1
Conversation
📝 WalkthroughWalkthroughAdds configurable loss-aggregation modes and prompt-group metadata flow, updates loss reducers and logging, extends nightly CI scheduling, centralizes bridge runtime propagation, and changes DeepSeek-V4 FP8 precision handling. ChangesLoss Aggregation Feature
Nightly CI scheduling
Runtime propagation and FP8 handling
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
@coderabbitai full review |
✅ Action performedFull review finished. |
98588bd to
4f1a9af
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/user-guide/customization.md`:
- Around line 222-224: Update the customization docs wording for the loss
aggregation options so it matches the actual validation in the argument parser.
In the section describing `--loss-aggregation-divisor`, change the current
“ignored for the other modes” phrasing to say it is only valid with `constant`
and will be rejected otherwise, referencing the behavior enforced by
`arguments.py` and the `--loss-aggregation`/`--loss-aggregation-divisor`
options.
In `@miles/backends/training_utils/cp_utils.py`:
- Around line 99-100: `prompt_mean` still uses microbatch-local denominators in
`get_pg_loss_reducer`, which inflates losses when a prompt group is split across
microbatches. Update the `get_sum_of_sample_mean(...)` call in that path to pass
`batch["prompt_mask_sums"]` through the new `sample_denoms` argument instead of
rebuilding from `pg_loss_masks`, using the existing `sample_denoms` hook in
`cp_utils.py`. Add a regression test covering `micro_batch_size <
n_samples_per_prompt` to verify prompt-level normalization stays correct across
microbatch splits.
In `@miles/backends/training_utils/loss_hub/losses.py`:
- Around line 323-330: The shared reducer passed into get_pg_loss_reducer is
letting calculate_per_token_loss affect metric aggregation too. Update the loss
path around get_pg_loss_reducer and the caller in loss.py so pg_loss uses a
dedicated per-token reducer when token_mean is enabled, while pg_clipfrac,
ppo_kl, entropy_loss, and kl_loss keep using a fixed sample-mean reducer. Use
the existing symbols get_pg_loss_reducer, default_reducer, and
calculate_per_token_loss to separate the pg-loss behavior from the metric
reducers.
- Around line 100-121: The prompt_mean reducer in losses.py is recomputing
prompt denominators from the local pg_loss_masks, so it normalizes by per-rank
totals instead of the intended per-step prompt totals. Update the prompt_mean
branch in the reducer setup to consume the step-level prompt_mask_sums
propagated by convert_samples_to_train_data() (or otherwise reconstruct them
after any mask rewrite) rather than calling _prompt_group_mask_sums on the local
shard. Keep the existing get_sum_of_sample_mean and prompt_mean_reducer flow,
but ensure the denominator source matches global prompt grouping semantics
across DP ranks.
In `@miles/utils/arguments.py`:
- Around line 2406-2407: Re-run loss-aggregation validation after the
`--custom-config-path` override in `arguments.py` so YAML-loaded values are
checked too. The current `_validate_loss_aggregation_args(args)` call before the
config merge in the argument parsing flow lets `loss_aggregation` values like
`token_mean`, `constant`, or `prompt_mean` bypass their required checks. Move or
repeat `_validate_loss_aggregation_args(args)` after the YAML values are
applied, keeping the existing validation helper and the argument parsing path
consistent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro Plus
Run ID: e5bf12ea-e4d3-4949-b3df-0ed26a236ea9
📒 Files selected for processing (12)
docs/user-guide/cli-reference.mddocs/user-guide/customization.mdmiles/backends/experimental/fsdp_utils/actor.pymiles/backends/megatron_utils/model.pymiles/backends/training_utils/cp_utils.pymiles/backends/training_utils/data.pymiles/backends/training_utils/loss.pymiles/backends/training_utils/loss_hub/losses.pymiles/ray/rollout/train_data_conversion.pymiles/utils/arguments.pytests/fast/backends/training_utils/loss/test_loss_snapshot.pytests/fast/backends/training_utils/test_loss_aggregation.py
67e0503 to
fe7d2bf
Compare
…ixark#1503) Co-authored-by: Shi Dong <shi.dong@radixark.ai>
…dixark#1505) Signed-off-by: zhihaow6 <zhihaow6@illinois.edu>
…xark#1509) Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
fe7d2bf to
8d994ad
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@miles/backends/training_utils/loss.py`:
- Around line 239-254: The token-normalized outputs in the loss path are using a
per-sample clamped token count instead of the true global active-token count,
which causes zero-mask samples to inflate the denominator. Update the `loss.py`
training loss flow around `num_tokens`, `_build_train_log_dict`, and the
returned normalizer so it uses the actual sum of active tokens across the batch
without clamping each sample to 1, matching cases where `sample.remove_sample`
zeros the mask. Make sure the `token_mean`/per-metric normalizers and the
returned `num_tokens` value all come from the same true global count.
In `@miles/ray/rollout/train_data_conversion.py`:
- Around line 187-194: `convert_samples_to_train_data()` should fail fast when
`args.loss_aggregation == "prompt_mean"` but prompt-group metadata is missing
instead of silently falling back to sample-wise partitioning. Update the branch
around `_prompt_group_partitions` to explicitly require `prompt_group_indices`
(and the related prompt metadata produced by `convert_samples_to_train_data()` /
`custom_convert_samples_to_train_data_func`), and raise a clear error if those
fields are absent; keep the existing `args.balance_data` path only for
non-`prompt_mean` modes.
In `@scripts/run_deepseek_v4.py`:
- Around line 581-582: The TE precision config is being written with
save_to_temp_file() on the launcher node and then passed through execute_train()
as a local path, so worker nodes may not be able to read it. Update the logic
around the --te-precision-config-file handling in run_deepseek_v4.py so the YAML
is placed on shared storage or otherwise distributed to every trainer node
before launch. Use the existing args.extra_args / misc_args path-building flow
and the relevant execute_train() invocation to ensure all ranks can resolve the
same config file path.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro Plus
Run ID: 7a3dd468-bdd0-4695-8ff9-899016093974
📒 Files selected for processing (30)
.github/workflows/pr-test.ymldocs/ci/00-stage.mddocs/ci/01-label.mddocs/user-guide/cli-reference.mddocs/user-guide/customization.mdmiles/backends/experimental/fsdp_utils/actor.pymiles/backends/megatron_utils/model.pymiles/backends/megatron_utils/model_provider.pymiles/backends/training_utils/cp_utils.pymiles/backends/training_utils/data.pymiles/backends/training_utils/log_utils.pymiles/backends/training_utils/loss.pymiles/backends/training_utils/loss_hub/losses.pymiles/backends/training_utils/loss_hub/math_utils.pymiles/ray/rollout/train_data_conversion.pymiles/utils/arguments.pyscripts/run_deepseek_v4.pytests/ci/test/test_ci_register.pytests/ci/test/test_labels.pytests/ci/test/test_log_groups.pytests/ci/test/test_run_suite.pytests/fast/backends/training_utils/loss/test_loss_snapshot.pytests/fast/backends/training_utils/test_loss_aggregation.pytests/fast/backends/training_utils/test_ppo_ratio_numerics.pytests/manual/__init__.pytests/manual/models/__init__.pytests/manual/models/deepseek_v4/__init__.pytests/manual/models/deepseek_v4/test_v4_tilelang_indexer.pytests/manual/models/deepseek_v4/test_v4_tilelang_sparse_mla.pytools/convert_hf_to_fp8.py
✅ Files skipped from review due to trivial changes (4)
- docs/ci/00-stage.md
- tests/ci/test/test_run_suite.py
- tools/convert_hf_to_fp8.py
- docs/user-guide/cli-reference.md
🚧 Files skipped from review as they are similar to previous changes (8)
- tests/fast/backends/training_utils/loss/test_loss_snapshot.py
- miles/backends/experimental/fsdp_utils/actor.py
- miles/backends/megatron_utils/model.py
- miles/backends/training_utils/data.py
- docs/user-guide/customization.md
- miles/utils/arguments.py
- miles/backends/training_utils/cp_utils.py
- miles/backends/training_utils/loss_hub/losses.py
| log_dict = _build_train_log_dict( | ||
| log, | ||
| num_samples=num_samples, | ||
| num_tokens=num_tokens, | ||
| device=logits.device, | ||
| calculate_per_token_loss=args.calculate_per_token_loss, | ||
| ) | ||
|
|
||
| return ( | ||
| loss, | ||
| torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device), | ||
| { | ||
| "keys": list(log.keys()), | ||
| "values": torch.tensor( | ||
| [ | ||
| num_samples if not args.calculate_per_token_loss else num_tokens, | ||
| ] | ||
| + list(log.values()), | ||
| device=logits.device, | ||
| ), | ||
| }, | ||
| ( | ||
| num_tokens.to(device=logits.device) | ||
| if args.calculate_per_token_loss | ||
| else torch.tensor(1, device=logits.device) | ||
| ), | ||
| log_dict, |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Use the true global active-token count for token-normalized outputs.
These paths now propagate num_tokens into both the returned normalizer and per-metric log normalizers, but num_tokens is still computed as sum(clamp_min(loss_mask.sum(), 1)). A fully masked sample is reachable here (sample.remove_sample zeroes the mask in miles/ray/rollout/train_data_conversion.py, Lines 53-55), so token_mean will overcount the denominator and under-report losses whenever a batch contains zero-loss samples.
Proposed fix
- num_tokens = sum([torch.clamp_min(loss_mask.sum(), 1) for loss_mask in batch["loss_masks"]])
+ num_tokens = sum(loss_mask.sum() for loss_mask in batch["loss_masks"])
+ token_normalizer = torch.clamp_min(num_tokens, 1)
@@
log_dict = _build_train_log_dict(
log,
num_samples=num_samples,
- num_tokens=num_tokens,
+ num_tokens=token_normalizer,
device=logits.device,
calculate_per_token_loss=args.calculate_per_token_loss,
)
@@
loss,
(
- num_tokens.to(device=logits.device)
+ token_normalizer.to(device=logits.device)
if args.calculate_per_token_loss
else torch.tensor(1, device=logits.device)
),
log_dict,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| log_dict = _build_train_log_dict( | |
| log, | |
| num_samples=num_samples, | |
| num_tokens=num_tokens, | |
| device=logits.device, | |
| calculate_per_token_loss=args.calculate_per_token_loss, | |
| ) | |
| return ( | |
| loss, | |
| torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device), | |
| { | |
| "keys": list(log.keys()), | |
| "values": torch.tensor( | |
| [ | |
| num_samples if not args.calculate_per_token_loss else num_tokens, | |
| ] | |
| + list(log.values()), | |
| device=logits.device, | |
| ), | |
| }, | |
| ( | |
| num_tokens.to(device=logits.device) | |
| if args.calculate_per_token_loss | |
| else torch.tensor(1, device=logits.device) | |
| ), | |
| log_dict, | |
| log_dict = _build_train_log_dict( | |
| log, | |
| num_samples=num_samples, | |
| num_tokens=token_normalizer, | |
| device=logits.device, | |
| calculate_per_token_loss=args.calculate_per_token_loss, | |
| ) | |
| return ( | |
| loss, | |
| ( | |
| token_normalizer.to(device=logits.device) | |
| if args.calculate_per_token_loss | |
| else torch.tensor(1, device=logits.device) | |
| ), | |
| log_dict, |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@miles/backends/training_utils/loss.py` around lines 239 - 254, The
token-normalized outputs in the loss path are using a per-sample clamped token
count instead of the true global active-token count, which causes zero-mask
samples to inflate the denominator. Update the `loss.py` training loss flow
around `num_tokens`, `_build_train_log_dict`, and the returned normalizer so it
uses the actual sum of active tokens across the batch without clamping each
sample to 1, matching cases where `sample.remove_sample` zeros the mask. Make
sure the `token_mean`/per-metric normalizers and the returned `num_tokens` value
all come from the same true global count.
| if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data: | ||
| partitions = _prompt_group_partitions( | ||
| data["prompt_group_indices"], | ||
| total_lengths, | ||
| dp_size, | ||
| balance_data=args.balance_data, | ||
| ) | ||
| elif args.balance_data: |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Fail fast if prompt_mean metadata is missing.
convert_samples_to_train_data() can return early via custom_convert_samples_to_train_data_func, so prompt_group_indices / prompt_mask_sums are not guaranteed to exist. This branch currently falls back to sample-wise partitioning instead of rejecting --loss-aggregation prompt_mean, which breaks the new prompt-group contract or defers the failure to a much less obvious downstream error.
Proposed fix
- if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data:
+ if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean":
+ missing = [key for key in ("prompt_group_indices", "prompt_mask_sums") if key not in data]
+ if missing:
+ raise ValueError(
+ "--loss-aggregation prompt_mean requires train_data to include "
+ f"{', '.join(missing)}. Custom converters must propagate prompt-group metadata."
+ )
partitions = _prompt_group_partitions(
data["prompt_group_indices"],
total_lengths,
dp_size,
balance_data=args.balance_data,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data: | |
| partitions = _prompt_group_partitions( | |
| data["prompt_group_indices"], | |
| total_lengths, | |
| dp_size, | |
| balance_data=args.balance_data, | |
| ) | |
| elif args.balance_data: | |
| if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean": | |
| missing = [key for key in ("prompt_group_indices", "prompt_mask_sums") if key not in data] | |
| if missing: | |
| raise ValueError( | |
| "--loss-aggregation prompt_mean requires train_data to include " | |
| f"{', '.join(missing)}. Custom converters must propagate prompt-group metadata." | |
| ) | |
| partitions = _prompt_group_partitions( | |
| data["prompt_group_indices"], | |
| total_lengths, | |
| dp_size, | |
| balance_data=args.balance_data, | |
| ) | |
| elif args.balance_data: |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@miles/ray/rollout/train_data_conversion.py` around lines 187 - 194,
`convert_samples_to_train_data()` should fail fast when `args.loss_aggregation
== "prompt_mean"` but prompt-group metadata is missing instead of silently
falling back to sample-wise partitioning. Update the branch around
`_prompt_group_partitions` to explicitly require `prompt_group_indices` (and the
related prompt metadata produced by `convert_samples_to_train_data()` /
`custom_convert_samples_to_train_data_func`), and raise a clear error if those
fields are absent; keep the existing `args.balance_data` path only for
non-`prompt_mean` modes.
| if "--te-precision-config-file" not in args.extra_args: | ||
| misc_args += f"--te-precision-config-file " f"{U.save_to_temp_file(_DSV4_TE_PRECISION_CONFIG, 'yaml')} " |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Inspect whether execute_train propagates locally generated TE precision config files to all nodes.
rg -n -C4 '\bdef execute_train\b|te-precision-config-file|save_to_temp_file|exec_command_all_ray_node|train_args' \
scripts/run_deepseek_v4.py miles/utils/external_utils/command_utils.py miles/utils/misc.pyRepository: EazyReal/miles
Length of output: 7783
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Read the relevant launcher code to see how train_args and runtime env are applied.
sed -n '97,230p' miles/utils/external_utils/command_utils.py
# Inspect the helper that creates the temp YAML.
sed -n '250,270p' miles/utils/external_utils/command_utils.py
# Inspect the caller around the injected TE config path.
sed -n '560,610p' scripts/run_deepseek_v4.pyRepository: EazyReal/miles
Length of output: 7958
scripts/run_deepseek_v4.py:581-582 — make the TE config reachable on every trainer node
save_to_temp_file() writes the YAML under /tmp on the submitter node, and execute_train() only forwards that path in train_args; nothing copies or re-materializes it on worker nodes. Multi-node fp8 runs can fail when non-head ranks try to open --te-precision-config-file. Move it to shared storage or broadcast it before launch.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@scripts/run_deepseek_v4.py` around lines 581 - 582, The TE precision config
is being written with save_to_temp_file() on the launcher node and then passed
through execute_train() as a local path, so worker nodes may not be able to read
it. Update the logic around the --te-precision-config-file handling in
run_deepseek_v4.py so the YAML is placed on shared storage or otherwise
distributed to every trainer node before launch. Use the existing
args.extra_args / misc_args path-building flow and the relevant execute_train()
invocation to ensure all ranks can resolve the same config file path.
Description
Adds
pg_lossaggregation modes to Miles, following the behavior in THUDM/slime#2090 and the related AReaL implementation in areal-project/AReaL#1443.This adds
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}while keepingsample_meanas the default.--calculate-per-token-lossremains the legacy spelling fortoken_mean, and custom pg-loss reducers still take precedence.The built-in modes are scoped to
pg_loss; metrics such aspg_clipfrac,ppo_kl,entropy_loss, andkl_losskeep the existing sample-mean reducer.sample_meanprompt_meantoken_meanconstant--loss-aggregation-divisorprompt_meanfollows the current slime behavior: rollout conversion emitsprompt_mask_sums, the reducer scales byn_samples_per_prompt, and argument validation requiresglobal_batch_size % n_samples_per_prompt == 0so prompt groups stay whole within a train step.constantimplements the Dr.GRPO-style fixed denominator by dividing the masked token-loss sum by--loss-aggregation-divisorbefore the standard step average.References
Validation
uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -quv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pyuv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pygit diff --check upstream/mainSummary by CodeRabbit
pg_lossaggregation:--loss-aggregation {sample_mean,token_mean,prompt_mean,constant}and--loss-aggregation-divisor.prompt_mean) during data preparation and training.