From 1a0827f314c9274ff4833521e8ba36a2c3204094 Mon Sep 17 00:00:00 2001 From: "LU, Xinbo" Date: Sat, 27 Jun 2026 00:38:12 +0800 Subject: [PATCH] Fix Qwen3 MoE FSDP weight sync for vLLM rollout Transformers 5 stores Qwen-style MoE expert weights as packed 3D `mlp.experts.gate_up_proj` and `mlp.experts.down_proj` tensors. During live FSDP-to-vLLM rollout weight sync, those packed keys were sent directly, but vLLM's Qwen3 MoE reload path expects the original per-expert checkpoint keys. Expand packed MoE expert tensors during FSDP parameter streaming so vLLM receives per-expert `gate_proj`, `up_proj`, and `down_proj` weights. Dense models and non-packed tensors continue to pass through unchanged. --- verl/utils/model.py | 29 ++++++++++++++++++++ verl/workers/engine/fsdp/transformer_impl.py | 21 +++++++------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/verl/utils/model.py b/verl/utils/model.py index 9b6a430f7ab..65637954b19 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -19,6 +19,7 @@ import os import re import warnings +from collections.abc import Iterable from dataclasses import dataclass from typing import Optional @@ -262,6 +263,34 @@ def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedMo return original_weights +def iter_vllm_compatible_moe_params(name: str, tensor: torch.Tensor) -> Iterable[tuple[str, torch.Tensor]]: + """Expand Transformers 5 packed MoE expert tensors to vLLM checkpoint keys. + + Transformers 5 stores Qwen-style MoE experts as packed 3D parameters: + ``mlp.experts.gate_up_proj`` with shape + ``[num_experts, 2 * intermediate_size, hidden_size]`` and + ``mlp.experts.down_proj`` with shape + ``[num_experts, hidden_size, intermediate_size]``. vLLM's Qwen MoE reload + path still accepts the original per-expert checkpoint keys during live + weight sync, so stream those keys without materializing a full dict. + """ + if name.endswith(".mlp.experts.gate_up_proj") and tensor.dim() == 3: + gate, up = tensor.chunk(2, dim=1) + base = name.removesuffix(".gate_up_proj") + for expert_id in range(tensor.size(0)): + yield f"{base}.{expert_id}.gate_proj.weight", gate[expert_id].contiguous() + yield f"{base}.{expert_id}.up_proj.weight", up[expert_id].contiguous() + return + + if name.endswith(".mlp.experts.down_proj") and tensor.dim() == 3: + base = name.removesuffix(".down_proj") + for expert_id in range(tensor.size(0)): + yield f"{base}.{expert_id}.down_proj.weight", tensor[expert_id].contiguous() + return + + yield name, tensor + + def check_exclude_modules(config, key: str) -> bool: """ A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config. diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index bc01062b826..b52ff1575ea 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -59,7 +59,7 @@ offload_fsdp_optimizer, replace_lora_wrapper, ) -from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs +from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs, iter_vllm_compatible_moe_params from verl.utils.py_functional import convert_to_regular_types from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import ( @@ -858,15 +858,16 @@ def get_per_tensor_param(self, layered_summon=False, base_sync_done=False, **kwa else: device = get_device_id() # used when fsdp2 set cpu_offload_policy # TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate - per_tensor_param = ( - ( - name, - param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True) - if isinstance(param, DTensor) - else param, - ) - for name, param in params.items() - ) + def param_generator(): + for name, param in params.items(): + tensor = ( + param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True) + if isinstance(param, DTensor) + else param + ) + yield from iter_vllm_compatible_moe_params(name, tensor) + + per_tensor_param = param_generator() if self._qat_enabled: from verl.utils.qat.quantizer import QATQuantizer