From 65db52e390ffc7f8dc6c9ed078bdbbfa6be35532 Mon Sep 17 00:00:00 2001 From: Ziyi <1034337098@qq.com> Date: Fri, 26 Jun 2026 05:43:42 +0000 Subject: [PATCH] fix(mtp): support multi-head MTP loss logging (mtp-num-layers > 1) slime's per-step MTP-loss logging called .item() on tracker["values"], which is a [num_mtp_layers] vector from Megatron's MTPLossLoggingHelper. This crashes for any multi-head MTP model (--mtp-num-layers > 1) under --enable-mtp-training with: RuntimeError: a Tensor with N elements cannot be converted to Scalar. Align with Megatron's track_mtp_metrics: keep the per-layer vector and log one scalar per layer plus a summed total. Co-Authored-By: Claude --- slime/backends/megatron_utils/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 1ad6cd7957..bf6bd676f2 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -854,15 +854,15 @@ def train( torch.distributed.all_reduce(values, group=tracker.get("reduce_group")) if tracker.get("avg_group") is not None: torch.distributed.all_reduce(values, group=tracker["avg_group"], op=torch.distributed.ReduceOp.AVG) - # here we assume only one mtp layer - mtp_losses = (tracker["values"] * mtp_loss_scale).item() + # Multi-head MTP: tracker["values"] is [num_mtp_layers]; aggregate below. + mtp_losses = tracker["values"] * mtp_loss_scale MTPLossLoggingHelper.clean_loss_in_tracker() # CI check: verify MTP loss is within expected bounds if args.ci_test: from slime.backends.megatron_utils.ci_utils import check_mtp_loss - check_mtp_loss(mtp_losses) + check_mtp_loss(mtp_losses.sum().item()) # per train step log. if ( @@ -879,7 +879,9 @@ def train( } log_dict[f"train/{role_tag}grad_norm"] = grad_norm if args.enable_mtp_training: - log_dict[f"train/{role_tag}mtp_loss"] = mtp_losses + for _i in range(mtp_losses.shape[0]): + log_dict[f"train/{role_tag}mtp_{_i + 1}_loss"] = mtp_losses[_i].item() + log_dict[f"train/{role_tag}mtp_loss"] = mtp_losses.sum().item() for param_group_id, param_group in enumerate(optimizer.param_groups): log_dict[f"train/{role_tag}lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group)