Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import re
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional

Expand Down Expand Up @@ -262,6 +263,34 @@ def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedMo
return original_weights


def iter_vllm_compatible_moe_params(name: str, tensor: torch.Tensor) -> Iterable[tuple[str, torch.Tensor]]:
"""Expand Transformers 5 packed MoE expert tensors to vLLM checkpoint keys.

Transformers 5 stores Qwen-style MoE experts as packed 3D parameters:
``mlp.experts.gate_up_proj`` with shape
``[num_experts, 2 * intermediate_size, hidden_size]`` and
``mlp.experts.down_proj`` with shape
``[num_experts, hidden_size, intermediate_size]``. vLLM's Qwen MoE reload
path still accepts the original per-expert checkpoint keys during live
weight sync, so stream those keys without materializing a full dict.
"""
if name.endswith(".mlp.experts.gate_up_proj") and tensor.dim() == 3:
gate, up = tensor.chunk(2, dim=1)
base = name.removesuffix(".gate_up_proj")
for expert_id in range(tensor.size(0)):
yield f"{base}.{expert_id}.gate_proj.weight", gate[expert_id].contiguous()
yield f"{base}.{expert_id}.up_proj.weight", up[expert_id].contiguous()
return

if name.endswith(".mlp.experts.down_proj") and tensor.dim() == 3:
base = name.removesuffix(".down_proj")
for expert_id in range(tensor.size(0)):
yield f"{base}.{expert_id}.down_proj.weight", tensor[expert_id].contiguous()
return

yield name, tensor
Comment on lines +266 to +291

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Robustness and Performance Improvements

  1. Robustness of Parameter Name Matching: The current implementation checks name.endswith(".mlp.experts.gate_up_proj"). Depending on how the model is loaded or wrapped, the parameter name in the state dict might have a .weight suffix (e.g., ...mlp.experts.gate_up_proj.weight). Stripping .weight first makes the matching much more robust.
  2. Redundant .contiguous() Calls: Since tensor is gathered from FSDP or is a model parameter, it is contiguous. Slicing it along the first dimension (e.g., gate[expert_id]) produces a slice that is also contiguous because the remaining dimensions have contiguous strides. Therefore, calling .contiguous() is redundant and can be omitted to avoid unnecessary overhead.
def iter_vllm_compatible_moe_params(name: str, tensor: torch.Tensor) -> Iterable[tuple[str, torch.Tensor]]:
    """Expand Transformers 5 packed MoE expert tensors to vLLM checkpoint keys.

    Transformers 5 stores Qwen-style MoE experts as packed 3D parameters:
    mlp.experts.gate_up_proj with shape
    [num_experts, 2 * intermediate_size, hidden_size] and
    mlp.experts.down_proj with shape
    [num_experts, hidden_size, intermediate_size]. vLLM's Qwen MoE reload
    path still accepts the original per-expert checkpoint keys during live
    weight sync, so stream those keys without materializing a full dict.
    """
    name_stripped = name.removesuffix(".weight")

    if name_stripped.endswith(".mlp.experts.gate_up_proj") and tensor.dim() == 3:
        gate, up = tensor.chunk(2, dim=1)
        base = name_stripped.removesuffix(".gate_up_proj")
        for expert_id in range(tensor.size(0)):
            yield f"{base}.{expert_id}.gate_proj.weight", gate[expert_id]
            yield f"{base}.{expert_id}.up_proj.weight", up[expert_id]
        return

    if name_stripped.endswith(".mlp.experts.down_proj") and tensor.dim() == 3:
        base = name_stripped.removesuffix(".down_proj")
        for expert_id in range(tensor.size(0)):
            yield f"{base}.{expert_id}.down_proj.weight", tensor[expert_id]
        return

    yield name, tensor



def check_exclude_modules(config, key: str) -> bool:
"""
A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config.
Expand Down
21 changes: 11 additions & 10 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
offload_fsdp_optimizer,
replace_lora_wrapper,
)
from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs
from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs, iter_vllm_compatible_moe_params
from verl.utils.py_functional import convert_to_regular_types
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.ulysses import (
Expand Down Expand Up @@ -858,15 +858,16 @@ def get_per_tensor_param(self, layered_summon=False, base_sync_done=False, **kwa
else:
device = get_device_id() # used when fsdp2 set cpu_offload_policy
# TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate
per_tensor_param = (
(
name,
param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True)
if isinstance(param, DTensor)
else param,
)
for name, param in params.items()
)
def param_generator():
for name, param in params.items():
tensor = (
param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True)
if isinstance(param, DTensor)
else param
)
yield from iter_vllm_compatible_moe_params(name, tensor)

per_tensor_param = param_generator()

if self._qat_enabled:
from verl.utils.qat.quantizer import QATQuantizer
Expand Down