From 10701527aca7ca201e29adccf087fa69b45d6b4f Mon Sep 17 00:00:00 2001 From: xiaotaoliu Date: Wed, 20 May 2026 14:55:09 +0800 Subject: [PATCH] fix: qwen3.6 moe mtp fused expert layout --- example/qwen3_5/test_mtp_logits.py | 14 ++++---- mbridge/models/qwen3_5/base_bridge.py | 34 ++++++++++++++++---- mbridge/models/qwen3_5/qwen3_5_safetensor.py | 5 +++ 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/example/qwen3_5/test_mtp_logits.py b/example/qwen3_5/test_mtp_logits.py index 55056ce..7dcd0fa 100644 --- a/example/qwen3_5/test_mtp_logits.py +++ b/example/qwen3_5/test_mtp_logits.py @@ -33,18 +33,18 @@ # Test result on Qwen3.5-35B-A3B model # ================================================================ # MTP HEAD 0 — logits[i] predicts token[i+2] -# Top-1 accuracy : 43.08% (28/65 valid positions) +# Top-1 accuracy : 66.15% (43/65 valid positions) # ================================================================ -# --- Spot-check: first 12 positions --- -# pos= 0 | pred= 314 ' of' | gt= 314 ' of' | ✓ +# --- Spot-check: first 12 positions --- +# pos= 0 | pred= 11 ',' | gt= 314 ' of' | ✗ # pos= 1 | pred= 279 ' the' | gt= 9338 ' France' | ✗ # pos= 2 | pred= 369 ' is' | gt= 369 ' is' | ✓ # pos= 3 | pred= 11751 ' Paris' | gt= 11751 ' Paris' | ✓ # pos= 4 | pred= 13 '.' | gt= 13 '.' | ✓ # pos= 5 | pred= 198 '\n' | gt= 561 ' The' | ✗ # pos= 6 | pred= 6511 ' capital' | gt= 242476 ' Eiff' | ✗ -# pos= 7 | pred= 684 'so' | gt= 300 'el' | ✗ +# pos= 7 | pred= 300 'el' | gt= 300 'el' | ✓ # pos= 8 | pred= 21262 ' Tower' | gt= 21262 ' Tower' | ✓ # pos= 9 | pred= 369 ' is' | gt= 557 ' was' | ✗ # pos= 10 | pred= 5617 ' built' | gt= 5617 ' built' | ✓ @@ -52,14 +52,14 @@ # ================================================================ # MAIN HEAD — logits[i] predicts token[i+1] -# Top-1 accuracy : 68.18% (45/66 valid positions) +# Top-1 accuracy : 69.70% (46/66 valid positions) # ================================================================ # ================================================================ # SUMMARY # ================================================================ -# [PASS] MTP head 0 (predicts token[i+2]): top-1 acc = 43.08% (28/65) -# [PASS] Main head (predicts token[i+1]): top-1 acc = 68.18% (45/66) +# [PASS] MTP head 0 (predicts token[i+2]): top-1 acc = 66.15% (43/65) +# [PASS] Main head (predicts token[i+1]): top-1 acc = 69.70% (46/66) # ================================================================ diff --git a/mbridge/models/qwen3_5/base_bridge.py b/mbridge/models/qwen3_5/base_bridge.py index f64084d..a3e4139 100644 --- a/mbridge/models/qwen3_5/base_bridge.py +++ b/mbridge/models/qwen3_5/base_bridge.py @@ -21,6 +21,8 @@ class Qwen3_5VlBaseBridge(VLMBridge): + mtp_fused_experts: bool = False + def _handle_hf_config(self): self.hf_text_config = getattr(self.hf_config, "text_config", self.hf_config) self.hf_vision_config = getattr(self.hf_config, "vision_config", self.hf_config) @@ -134,10 +136,13 @@ def _get_mcore_config_by_name(self, mcore_weights_name: str): return self.config def _get_safetensor_io(self, weights_path: str): - # TODO: MTP layers are not handled yet - return Qwen3_5SafeTensorIO( - self._get_actual_hf_path(weights_path), ignore_mtp=False + mtp_num_layers = getattr(self.config, "mtp_num_layers", None) + + sio = Qwen3_5SafeTensorIO( + self._get_actual_hf_path(weights_path), ignore_mtp=(mtp_num_layers is None) ) + self.mtp_fused_experts = sio.mtp_fused_experts + return sio def _weight_name_mapping_mcore_local_to_global( self, model: torch.nn.Module, consider_ep: bool = True @@ -305,6 +310,15 @@ def _weight_name_mapping_visual(self, name: str) -> list[str]: ], } + MTP_FUSED_EXPERTS_MAPPING = { + "language_model.mtp.layers.0.transformer_layer.mlp.experts.linear_fc1.weight{expert_index}": [ + "mtp.layers.0.mlp.experts.gate_up_proj", + ], + "language_model.mtp.layers.0.transformer_layer.mlp.experts.linear_fc2.weight{expert_index}": [ + "mtp.layers.0.mlp.experts.down_proj", + ], + } + def _convert_mtp_param(self, name: str) -> list[str]: assert self.config.mtp_num_layers == 1, "only support one mtp layer for now" @@ -319,10 +333,13 @@ def _convert_mtp_param(self, name: str) -> list[str]: # e.g. language_model.mtp.layers.0.transformer_layer.mlp.experts.linear_fc1.weight3 # -> key = "...linear_fc1.weight{expert_index}", expert_index = 3 if ".mlp.experts.linear_fc" in name: - # split off the numeric expert_index suffix after ".weight" prefix, expert_index_str = name.split(".weight", 1) expert_index = int(expert_index_str) key = prefix + ".weight{expert_index}" + + if self.mtp_fused_experts: + return self.MTP_FUSED_EXPERTS_MAPPING[key] + mapping_names = self._MTP_MAPPING[key] return [x.format(expert_index=expert_index) for x in mapping_names] @@ -405,7 +422,12 @@ def _weight_to_hf_format( return [hf_names[0]], [mcore_weights[: self.vocab_size]] - if "mtp" in mcore_weights_name: + is_mtp_fused_expert = ( + "mtp" in mcore_weights_name + and ".mlp.experts.linear_fc" in mcore_weights_name + and self.mtp_fused_experts + ) + if "mtp" in mcore_weights_name and not is_mtp_fused_expert: return [hf_names[0]], [mcore_weights] # moe @@ -591,7 +613,7 @@ def _weight_to_mcore_format( # moe if ".mlp.experts.linear_fc" in mcore_weights_name: - if "mtp" in mcore_weights_name: + if "mtp" in mcore_weights_name and not self.mtp_fused_experts: return hf_weights[0] # get export index local_experts_idx = int(mcore_weights_name.split(".weight")[-1]) diff --git a/mbridge/models/qwen3_5/qwen3_5_safetensor.py b/mbridge/models/qwen3_5/qwen3_5_safetensor.py index e33b5f0..f25655f 100644 --- a/mbridge/models/qwen3_5/qwen3_5_safetensor.py +++ b/mbridge/models/qwen3_5/qwen3_5_safetensor.py @@ -50,3 +50,8 @@ def __init__(self, hf_dir: str, ignore_mtp: bool = False): self.index[key] = filename self.hf_dir = hf_dir + + has_mtp = any(k.startswith("mtp.") for k in self.index) + self.mtp_fused_experts = ( + has_mtp and "mtp.layers.0.mlp.experts.gate_up_proj" in self.index + )