diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 1ad6cd7957..bf6bd676f2 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -854,15 +854,15 @@ def train( torch.distributed.all_reduce(values, group=tracker.get("reduce_group")) if tracker.get("avg_group") is not None: torch.distributed.all_reduce(values, group=tracker["avg_group"], op=torch.distributed.ReduceOp.AVG) - # here we assume only one mtp layer - mtp_losses = (tracker["values"] * mtp_loss_scale).item() + # Multi-head MTP: tracker["values"] is [num_mtp_layers]; aggregate below. + mtp_losses = tracker["values"] * mtp_loss_scale MTPLossLoggingHelper.clean_loss_in_tracker() # CI check: verify MTP loss is within expected bounds if args.ci_test: from slime.backends.megatron_utils.ci_utils import check_mtp_loss - check_mtp_loss(mtp_losses) + check_mtp_loss(mtp_losses.sum().item()) # per train step log. if ( @@ -879,7 +879,9 @@ def train( } log_dict[f"train/{role_tag}grad_norm"] = grad_norm if args.enable_mtp_training: - log_dict[f"train/{role_tag}mtp_loss"] = mtp_losses + for _i in range(mtp_losses.shape[0]): + log_dict[f"train/{role_tag}mtp_{_i + 1}_loss"] = mtp_losses[_i].item() + log_dict[f"train/{role_tag}mtp_loss"] = mtp_losses.sum().item() for param_group_id, param_group in enumerate(optimizer.param_groups): log_dict[f"train/{role_tag}lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group)