Skip to content

feat(loss): support pg_loss aggregation modes#1

Open
EazyReal wants to merge 9 commits into
mainfrom
upstream-pr/loss-aggregation-modes-v2
Open

feat(loss): support pg_loss aggregation modes#1
EazyReal wants to merge 9 commits into
mainfrom
upstream-pr/loss-aggregation-modes-v2

Conversation

@EazyReal

@EazyReal EazyReal commented Jun 27, 2026

Copy link
Copy Markdown
Owner

Description

Adds pg_loss aggregation modes to Miles, following the behavior in THUDM/slime#2090 and the related AReaL implementation in areal-project/AReaL#1443.

This adds --loss-aggregation {sample_mean,prompt_mean,token_mean,constant} while keeping sample_mean as the default. --calculate-per-token-loss remains the legacy spelling for token_mean, and custom pg-loss reducers still take precedence.

The built-in modes are scoped to pg_loss; metrics such as pg_clipfrac, ppo_kl, entropy_loss, and kl_loss keep the existing sample-mean reducer.

Mode Denominator
sample_mean per-sample active-token count
prompt_mean per-prompt-group active-token count, then mean over prompt groups
token_mean global active-token count
constant fixed --loss-aggregation-divisor

prompt_mean follows the current slime behavior: rollout conversion emits prompt_mask_sums, the reducer scales by n_samples_per_prompt, and argument validation requires global_batch_size % n_samples_per_prompt == 0 so prompt groups stay whole within a train step.

constant implements the Dr.GRPO-style fixed denominator by dividing the masked token-loss sum by --loss-aggregation-divisor before the standard step average.

References

Validation

  • uv run --with pytest --with torch --with numpy --with httpx --with pyyaml --with ray --with huggingface_hub --with transformers --with pydantic pytest --confcutdir=tests/fast/backends/training_utils tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py -q
  • uv run --with ruff ruff check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • uv run --with black black --check miles/backends/training_utils/cp_utils.py miles/backends/training_utils/data.py miles/backends/training_utils/loss.py miles/backends/training_utils/loss_hub/losses.py miles/ray/rollout/train_data_conversion.py miles/utils/arguments.py miles/backends/megatron_utils/model.py miles/backends/experimental/fsdp_utils/actor.py tests/fast/backends/training_utils/test_loss_aggregation.py tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • git diff --check upstream/main

Summary by CodeRabbit

  • New Features
    • Added new training CLI controls for pg_loss aggregation: --loss-aggregation {sample_mean,token_mean,prompt_mean,constant} and --loss-aggregation-divisor.
    • Introduced prompt-group-aware aggregation support (prompt_mean) during data preparation and training.
  • Bug Fixes
    • Improved compatibility and validation between legacy per-token loss settings and the new aggregation modes.
    • Refined loss aggregation/reduction behavior to keep normalization and logging consistent across modes.
  • Documentation
    • Updated the CLI reference and customization guide with detailed mode semantics, constraints, and legacy behavior.
  • Tests / CI
    • Expanded automated coverage for aggregation correctness, validation, DP partitioning, and numerical edge cases.
    • Updated CI workflow behavior for nightly scheduled runs and label-based gating.

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds configurable loss-aggregation modes and prompt-group metadata flow, updates loss reducers and logging, extends nightly CI scheduling, centralizes bridge runtime propagation, and changes DeepSeek-V4 FP8 precision handling.

Changes

Loss Aggregation Feature

Layer / File(s) Summary
CLI, validation, and prompt-group data
miles/utils/arguments.py, miles/ray/rollout/train_data_conversion.py, miles/backends/training_utils/data.py, miles/backends/experimental/fsdp_utils/actor.py, miles/backends/megatron_utils/model.py
Registers aggregation flags, validates them against legacy loss settings, and propagates prompt-group fields through conversion, rollout data, and batch loading.
Reducers and logging
miles/backends/training_utils/cp_utils.py, miles/backends/training_utils/loss.py, miles/backends/training_utils/loss_hub/losses.py, miles/backends/training_utils/log_utils.py
get_sum_of_sample_mean gains cached denominators and constant-divisor support, loss_function updates reducer wiring and logging construction, get_pg_loss_reducer selects PG-loss aggregation modes, and aggregate_train_losses handles normalizers explicitly.
Numerical stability updates
miles/backends/training_utils/loss_hub/math_utils.py
Safe log-ratio helpers are added and used in ESS, low-var KL, and PPO policy-loss ratio calculations.
Tests and documentation
docs/user-guide/cli-reference.md, docs/user-guide/customization.md, tests/fast/backends/training_utils/*
Docs describe the new aggregation modes and prompt-group fields, and tests cover reducer behavior, validation, logging contracts, data conversion, and PPO numerics.

Nightly CI scheduling

Layer / File(s) Summary
Workflow triggers and gating
.github/workflows/pr-test.yml
pr-test.yml adds a nightly schedule and widens job and stage gating, label handling, and continue-on-error behavior for scheduled and nightly runs.
CI docs and tests
docs/ci/00-stage.md, docs/ci/01-label.md, tests/ci/test/test_ci_register.py, tests/ci/test/test_run_suite.py
The CI docs explain nightly/full-suite behavior, and the CI suite tests update repository-root discovery and expected test-path names.

Runtime propagation and FP8 handling

Layer / File(s) Summary
Bridge runtime propagation
miles/backends/megatron_utils/model_provider.py
_apply_bridge_runtime_config copies bridge-mode runtime settings onto the Megatron provider and replaces the inline assignments in the bridge path.
FP8 precision and conversion
scripts/run_deepseek_v4.py, tools/convert_hf_to_fp8.py
The DeepSeek-V4 script adds a TE precision config for FP8 training, and the HF-to-FP8 conversion tool excludes additional parameter key prefixes from quantization.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 I hop through prompts and nightly light,
With loss modes twinkling left and right.
The bridge now hums, the FP8 sprinkles gleam,
And CI dreams run by moonbeam stream.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 14.16% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly states the main change: adding pg_loss aggregation modes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch upstream-pr/loss-aggregation-modes-v2

Comment @coderabbitai help to get the list of available commands.

@EazyReal

Copy link
Copy Markdown
Owner Author

@coderabbitai full review

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown
✅ Action performed

Full review finished.

@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch 2 times, most recently from 98588bd to 4f1a9af Compare June 27, 2026 20:12

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/user-guide/customization.md`:
- Around line 222-224: Update the customization docs wording for the loss
aggregation options so it matches the actual validation in the argument parser.
In the section describing `--loss-aggregation-divisor`, change the current
“ignored for the other modes” phrasing to say it is only valid with `constant`
and will be rejected otherwise, referencing the behavior enforced by
`arguments.py` and the `--loss-aggregation`/`--loss-aggregation-divisor`
options.

In `@miles/backends/training_utils/cp_utils.py`:
- Around line 99-100: `prompt_mean` still uses microbatch-local denominators in
`get_pg_loss_reducer`, which inflates losses when a prompt group is split across
microbatches. Update the `get_sum_of_sample_mean(...)` call in that path to pass
`batch["prompt_mask_sums"]` through the new `sample_denoms` argument instead of
rebuilding from `pg_loss_masks`, using the existing `sample_denoms` hook in
`cp_utils.py`. Add a regression test covering `micro_batch_size <
n_samples_per_prompt` to verify prompt-level normalization stays correct across
microbatch splits.

In `@miles/backends/training_utils/loss_hub/losses.py`:
- Around line 323-330: The shared reducer passed into get_pg_loss_reducer is
letting calculate_per_token_loss affect metric aggregation too. Update the loss
path around get_pg_loss_reducer and the caller in loss.py so pg_loss uses a
dedicated per-token reducer when token_mean is enabled, while pg_clipfrac,
ppo_kl, entropy_loss, and kl_loss keep using a fixed sample-mean reducer. Use
the existing symbols get_pg_loss_reducer, default_reducer, and
calculate_per_token_loss to separate the pg-loss behavior from the metric
reducers.
- Around line 100-121: The prompt_mean reducer in losses.py is recomputing
prompt denominators from the local pg_loss_masks, so it normalizes by per-rank
totals instead of the intended per-step prompt totals. Update the prompt_mean
branch in the reducer setup to consume the step-level prompt_mask_sums
propagated by convert_samples_to_train_data() (or otherwise reconstruct them
after any mask rewrite) rather than calling _prompt_group_mask_sums on the local
shard. Keep the existing get_sum_of_sample_mean and prompt_mean_reducer flow,
but ensure the denominator source matches global prompt grouping semantics
across DP ranks.

In `@miles/utils/arguments.py`:
- Around line 2406-2407: Re-run loss-aggregation validation after the
`--custom-config-path` override in `arguments.py` so YAML-loaded values are
checked too. The current `_validate_loss_aggregation_args(args)` call before the
config merge in the argument parsing flow lets `loss_aggregation` values like
`token_mean`, `constant`, or `prompt_mean` bypass their required checks. Move or
repeat `_validate_loss_aggregation_args(args)` after the YAML values are
applied, keeping the existing validation helper and the argument parsing path
consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: e5bf12ea-e4d3-4949-b3df-0ed26a236ea9

📥 Commits

Reviewing files that changed from the base of the PR and between 39a1580 and 98588bd.

📒 Files selected for processing (12)
  • docs/user-guide/cli-reference.md
  • docs/user-guide/customization.md
  • miles/backends/experimental/fsdp_utils/actor.py
  • miles/backends/megatron_utils/model.py
  • miles/backends/training_utils/cp_utils.py
  • miles/backends/training_utils/data.py
  • miles/backends/training_utils/loss.py
  • miles/backends/training_utils/loss_hub/losses.py
  • miles/ray/rollout/train_data_conversion.py
  • miles/utils/arguments.py
  • tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • tests/fast/backends/training_utils/test_loss_aggregation.py

Comment thread docs/user-guide/customization.md
Comment thread miles/backends/training_utils/cp_utils.py
Comment thread miles/backends/training_utils/loss_hub/losses.py
Comment thread miles/backends/training_utils/loss_hub/losses.py
Comment thread miles/utils/arguments.py
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch 2 times, most recently from 67e0503 to fe7d2bf Compare June 27, 2026 22:48
@EazyReal EazyReal force-pushed the upstream-pr/loss-aggregation-modes-v2 branch from fe7d2bf to 8d994ad Compare June 30, 2026 08:38

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@miles/backends/training_utils/loss.py`:
- Around line 239-254: The token-normalized outputs in the loss path are using a
per-sample clamped token count instead of the true global active-token count,
which causes zero-mask samples to inflate the denominator. Update the `loss.py`
training loss flow around `num_tokens`, `_build_train_log_dict`, and the
returned normalizer so it uses the actual sum of active tokens across the batch
without clamping each sample to 1, matching cases where `sample.remove_sample`
zeros the mask. Make sure the `token_mean`/per-metric normalizers and the
returned `num_tokens` value all come from the same true global count.

In `@miles/ray/rollout/train_data_conversion.py`:
- Around line 187-194: `convert_samples_to_train_data()` should fail fast when
`args.loss_aggregation == "prompt_mean"` but prompt-group metadata is missing
instead of silently falling back to sample-wise partitioning. Update the branch
around `_prompt_group_partitions` to explicitly require `prompt_group_indices`
(and the related prompt metadata produced by `convert_samples_to_train_data()` /
`custom_convert_samples_to_train_data_func`), and raise a clear error if those
fields are absent; keep the existing `args.balance_data` path only for
non-`prompt_mean` modes.

In `@scripts/run_deepseek_v4.py`:
- Around line 581-582: The TE precision config is being written with
save_to_temp_file() on the launcher node and then passed through execute_train()
as a local path, so worker nodes may not be able to read it. Update the logic
around the --te-precision-config-file handling in run_deepseek_v4.py so the YAML
is placed on shared storage or otherwise distributed to every trainer node
before launch. Use the existing args.extra_args / misc_args path-building flow
and the relevant execute_train() invocation to ensure all ranks can resolve the
same config file path.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: 7a3dd468-bdd0-4695-8ff9-899016093974

📥 Commits

Reviewing files that changed from the base of the PR and between 98588bd and 8d994ad.

📒 Files selected for processing (30)
  • .github/workflows/pr-test.yml
  • docs/ci/00-stage.md
  • docs/ci/01-label.md
  • docs/user-guide/cli-reference.md
  • docs/user-guide/customization.md
  • miles/backends/experimental/fsdp_utils/actor.py
  • miles/backends/megatron_utils/model.py
  • miles/backends/megatron_utils/model_provider.py
  • miles/backends/training_utils/cp_utils.py
  • miles/backends/training_utils/data.py
  • miles/backends/training_utils/log_utils.py
  • miles/backends/training_utils/loss.py
  • miles/backends/training_utils/loss_hub/losses.py
  • miles/backends/training_utils/loss_hub/math_utils.py
  • miles/ray/rollout/train_data_conversion.py
  • miles/utils/arguments.py
  • scripts/run_deepseek_v4.py
  • tests/ci/test/test_ci_register.py
  • tests/ci/test/test_labels.py
  • tests/ci/test/test_log_groups.py
  • tests/ci/test/test_run_suite.py
  • tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • tests/fast/backends/training_utils/test_loss_aggregation.py
  • tests/fast/backends/training_utils/test_ppo_ratio_numerics.py
  • tests/manual/__init__.py
  • tests/manual/models/__init__.py
  • tests/manual/models/deepseek_v4/__init__.py
  • tests/manual/models/deepseek_v4/test_v4_tilelang_indexer.py
  • tests/manual/models/deepseek_v4/test_v4_tilelang_sparse_mla.py
  • tools/convert_hf_to_fp8.py
✅ Files skipped from review due to trivial changes (4)
  • docs/ci/00-stage.md
  • tests/ci/test/test_run_suite.py
  • tools/convert_hf_to_fp8.py
  • docs/user-guide/cli-reference.md
🚧 Files skipped from review as they are similar to previous changes (8)
  • tests/fast/backends/training_utils/loss/test_loss_snapshot.py
  • miles/backends/experimental/fsdp_utils/actor.py
  • miles/backends/megatron_utils/model.py
  • miles/backends/training_utils/data.py
  • docs/user-guide/customization.md
  • miles/utils/arguments.py
  • miles/backends/training_utils/cp_utils.py
  • miles/backends/training_utils/loss_hub/losses.py

Comment on lines +239 to +254
log_dict = _build_train_log_dict(
log,
num_samples=num_samples,
num_tokens=num_tokens,
device=logits.device,
calculate_per_token_loss=args.calculate_per_token_loss,
)

return (
loss,
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device),
{
"keys": list(log.keys()),
"values": torch.tensor(
[
num_samples if not args.calculate_per_token_loss else num_tokens,
]
+ list(log.values()),
device=logits.device,
),
},
(
num_tokens.to(device=logits.device)
if args.calculate_per_token_loss
else torch.tensor(1, device=logits.device)
),
log_dict,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Use the true global active-token count for token-normalized outputs.

These paths now propagate num_tokens into both the returned normalizer and per-metric log normalizers, but num_tokens is still computed as sum(clamp_min(loss_mask.sum(), 1)). A fully masked sample is reachable here (sample.remove_sample zeroes the mask in miles/ray/rollout/train_data_conversion.py, Lines 53-55), so token_mean will overcount the denominator and under-report losses whenever a batch contains zero-loss samples.

Proposed fix
-    num_tokens = sum([torch.clamp_min(loss_mask.sum(), 1) for loss_mask in batch["loss_masks"]])
+    num_tokens = sum(loss_mask.sum() for loss_mask in batch["loss_masks"])
+    token_normalizer = torch.clamp_min(num_tokens, 1)
@@
     log_dict = _build_train_log_dict(
         log,
         num_samples=num_samples,
-        num_tokens=num_tokens,
+        num_tokens=token_normalizer,
         device=logits.device,
         calculate_per_token_loss=args.calculate_per_token_loss,
     )
@@
         loss,
         (
-            num_tokens.to(device=logits.device)
+            token_normalizer.to(device=logits.device)
             if args.calculate_per_token_loss
             else torch.tensor(1, device=logits.device)
         ),
         log_dict,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
log_dict = _build_train_log_dict(
log,
num_samples=num_samples,
num_tokens=num_tokens,
device=logits.device,
calculate_per_token_loss=args.calculate_per_token_loss,
)
return (
loss,
torch.tensor(num_tokens if args.calculate_per_token_loss else 1, device=logits.device),
{
"keys": list(log.keys()),
"values": torch.tensor(
[
num_samples if not args.calculate_per_token_loss else num_tokens,
]
+ list(log.values()),
device=logits.device,
),
},
(
num_tokens.to(device=logits.device)
if args.calculate_per_token_loss
else torch.tensor(1, device=logits.device)
),
log_dict,
log_dict = _build_train_log_dict(
log,
num_samples=num_samples,
num_tokens=token_normalizer,
device=logits.device,
calculate_per_token_loss=args.calculate_per_token_loss,
)
return (
loss,
(
token_normalizer.to(device=logits.device)
if args.calculate_per_token_loss
else torch.tensor(1, device=logits.device)
),
log_dict,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@miles/backends/training_utils/loss.py` around lines 239 - 254, The
token-normalized outputs in the loss path are using a per-sample clamped token
count instead of the true global active-token count, which causes zero-mask
samples to inflate the denominator. Update the `loss.py` training loss flow
around `num_tokens`, `_build_train_log_dict`, and the returned normalizer so it
uses the actual sum of active tokens across the batch without clamping each
sample to 1, matching cases where `sample.remove_sample` zeros the mask. Make
sure the `token_mean`/per-metric normalizers and the returned `num_tokens` value
all come from the same true global count.

Comment on lines +187 to +194
if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data:
partitions = _prompt_group_partitions(
data["prompt_group_indices"],
total_lengths,
dp_size,
balance_data=args.balance_data,
)
elif args.balance_data:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Fail fast if prompt_mean metadata is missing.

convert_samples_to_train_data() can return early via custom_convert_samples_to_train_data_func, so prompt_group_indices / prompt_mask_sums are not guaranteed to exist. This branch currently falls back to sample-wise partitioning instead of rejecting --loss-aggregation prompt_mean, which breaks the new prompt-group contract or defers the failure to a much less obvious downstream error.

Proposed fix
-    if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data:
+    if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean":
+        missing = [key for key in ("prompt_group_indices", "prompt_mask_sums") if key not in data]
+        if missing:
+            raise ValueError(
+                "--loss-aggregation prompt_mean requires train_data to include "
+                f"{', '.join(missing)}. Custom converters must propagate prompt-group metadata."
+            )
         partitions = _prompt_group_partitions(
             data["prompt_group_indices"],
             total_lengths,
             dp_size,
             balance_data=args.balance_data,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean" and "prompt_group_indices" in data:
partitions = _prompt_group_partitions(
data["prompt_group_indices"],
total_lengths,
dp_size,
balance_data=args.balance_data,
)
elif args.balance_data:
if getattr(args, "loss_aggregation", "sample_mean") == "prompt_mean":
missing = [key for key in ("prompt_group_indices", "prompt_mask_sums") if key not in data]
if missing:
raise ValueError(
"--loss-aggregation prompt_mean requires train_data to include "
f"{', '.join(missing)}. Custom converters must propagate prompt-group metadata."
)
partitions = _prompt_group_partitions(
data["prompt_group_indices"],
total_lengths,
dp_size,
balance_data=args.balance_data,
)
elif args.balance_data:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@miles/ray/rollout/train_data_conversion.py` around lines 187 - 194,
`convert_samples_to_train_data()` should fail fast when `args.loss_aggregation
== "prompt_mean"` but prompt-group metadata is missing instead of silently
falling back to sample-wise partitioning. Update the branch around
`_prompt_group_partitions` to explicitly require `prompt_group_indices` (and the
related prompt metadata produced by `convert_samples_to_train_data()` /
`custom_convert_samples_to_train_data_func`), and raise a clear error if those
fields are absent; keep the existing `args.balance_data` path only for
non-`prompt_mean` modes.

Comment on lines +581 to +582
if "--te-precision-config-file" not in args.extra_args:
misc_args += f"--te-precision-config-file " f"{U.save_to_temp_file(_DSV4_TE_PRECISION_CONFIG, 'yaml')} "

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Inspect whether execute_train propagates locally generated TE precision config files to all nodes.
rg -n -C4 '\bdef execute_train\b|te-precision-config-file|save_to_temp_file|exec_command_all_ray_node|train_args' \
  scripts/run_deepseek_v4.py miles/utils/external_utils/command_utils.py miles/utils/misc.py

Repository: EazyReal/miles

Length of output: 7783


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read the relevant launcher code to see how train_args and runtime env are applied.
sed -n '97,230p' miles/utils/external_utils/command_utils.py

# Inspect the helper that creates the temp YAML.
sed -n '250,270p' miles/utils/external_utils/command_utils.py

# Inspect the caller around the injected TE config path.
sed -n '560,610p' scripts/run_deepseek_v4.py

Repository: EazyReal/miles

Length of output: 7958


scripts/run_deepseek_v4.py:581-582 — make the TE config reachable on every trainer node

save_to_temp_file() writes the YAML under /tmp on the submitter node, and execute_train() only forwards that path in train_args; nothing copies or re-materializes it on worker nodes. Multi-node fp8 runs can fail when non-head ranks try to open --te-precision-config-file. Move it to shared storage or broadcast it before launch.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/run_deepseek_v4.py` around lines 581 - 582, The TE precision config
is being written with save_to_temp_file() on the launcher node and then passed
through execute_train() as a local path, so worker nodes may not be able to read
it. Update the logic around the --te-precision-config-file handling in
run_deepseek_v4.py so the YAML is placed on shared storage or otherwise
distributed to every trainer node before launch. Use the existing
args.extra_args / misc_args path-building flow and the relevant execute_train()
invocation to ensure all ranks can resolve the same config file path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants