feat(loss): support pg_loss aggregation modes#1498
Conversation
|
Hi @yueming-yuan, reopening this fresh PR for visibility after syncing the loss aggregation behavior with THUDM/slime#2090. This adds Please review when you have a chance. |
There was a problem hiding this comment.
Code Review
This pull request introduces the --loss-aggregation command-line option to support multiple aggregation modes for pg_loss (sample_mean, prompt_mean, token_mean, and constant), along with validation logic, documentation, and a comprehensive test suite. The review feedback suggests appending new parameters to the end of get_sum_of_sample_mean to preserve backward compatibility for positional arguments, avoiding redundant computation of sample_denoms when constant_divisor is active, and optimizing the GPU tensor conversion of prompt_mask_sums by avoiding a loop over individual scalar tensors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
d0be4cc to
5bba0f9
Compare
fd42cd9 to
c64cbaa
Compare
4f1a9af to
f73fc1f
Compare
67e0503 to
fe7d2bf
Compare
fe7d2bf to
8d994ad
Compare
|
Quick follow-up after rebasing onto current |
Description
Adds built-in
pg_lossaggregation modes to Miles. Related implementations for comparison:This adds
--loss-aggregation {sample_mean,prompt_mean,token_mean,constant}while keepingsample_meanas the default.--calculate-per-token-lossremains the legacy spelling fortoken_mean, and custom pg-loss reducers still take precedence.sample_meanprompt_meantoken_meanconstant--loss-aggregation-divisorThe implementation keeps denominator ownership explicit:
token_meanpolicy loss rejects nonzero entropy or KL-loss coefficients that would mix token-normalizedpg_losswith sample-normalized auxiliary loss terms.pg_lossuses the modified mask for loss reduction while mismatch metrics stay on the original mask.prompt_meanDP splitting keeps prompt groups whole on each DP shard and rejects train steps whose prompt-group count cannot be distributed evenly.global_batch_sizebefore checking aggregation constraints.Validation
uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic --with psutil pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -quv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pyuv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/log_utils.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.pygit diff --check upstream/mainpython3 train.py --helpis not runnable in my local environment becausesglangis not installed; the parser path for the new flags is covered by the focused tests.