From 5edfdb5ae60163287ca645e3377555652c7f0717 Mon Sep 17 00:00:00 2001 From: guanyouhe Date: Fri, 12 Jun 2026 16:49:57 +0800 Subject: [PATCH] fix lora moe --- mbridge/peft/canonical_lora.py | 22 +++- mbridge/peft/lora.py | 209 +++++++++++++++++++++++++++++++-- mbridge/peft/lora_layers.py | 85 ++++++++++++++ 3 files changed, 307 insertions(+), 9 deletions(-) diff --git a/mbridge/peft/canonical_lora.py b/mbridge/peft/canonical_lora.py index cfd8cce..155c78c 100644 --- a/mbridge/peft/canonical_lora.py +++ b/mbridge/peft/canonical_lora.py @@ -7,7 +7,8 @@ import torch from mbridge.peft.adapter_wrapper import AdapterWrapper from mbridge.peft.base import PEFT -from mbridge.peft.lora_layers import LinearAdapter, LoRALinear, LoRATopKRouter +from mbridge.peft.lora_layers import (LinearAdapter, LoRAGroupedLinear, + LoRALinear, LoRATopKRouter) from mbridge.peft.module_matcher import ModuleMatcher from mbridge.peft.utils import (ParallelLinearAdapter, get_adapter_attributes_from_linear, @@ -294,7 +295,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s """ # Skip already transformed modules - if isinstance(m, (LinearAdapter, LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)): + if isinstance(m, (LinearAdapter, LoRALinear, LoRAGroupedLinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)): return m if (ans := self.match(m, name, prefix)) is not None: @@ -341,6 +342,23 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s base_linear_is_parallel=attrs.base_linear_is_parallel, ) + # Per-expert LoRA: each expert gets its own adapter + num_gemms = getattr(m, "num_gemms", 0) + if is_expert and num_gemms > 0: + logger.info( + f"Adding per-expert lora to: {full_name} " + f"(num_local_experts={num_gemms})" + ) + adapters = nn.ModuleList() + for i in range(num_gemms): + adapters.append( + ParallelLinearAdapter( + attrs.in_features, attrs.out_features, + **{**adapter_kwargs, "base_linear_name": f"{full_name}.expert{i}"}, + ) + ) + return LoRAGroupedLinear(m, adapters) + if name == "linear_fc1" and _should_treat_linear_fc1_as_unfused(full_name): logger.info(f"Adding lora to: {full_name} (treating unsupported canonical linear_fc1 as unfused)") adapter = ParallelLinearAdapter(attrs.in_features, attrs.out_features, **adapter_kwargs) diff --git a/mbridge/peft/lora.py b/mbridge/peft/lora.py index ef89b8e..1c5e97b 100644 --- a/mbridge/peft/lora.py +++ b/mbridge/peft/lora.py @@ -14,9 +14,10 @@ # These are only imported here (not in canonical_lora → avoids circular deps). from mbridge.peft.canonical_lora import (LoRALinearSplitFC1UpGate, LoRALinearSplitQKV) -from mbridge.peft.lora_layers import (LinearAdapter, LoRALinear, - LoRATopKRouter, TEFusedLoRALinear, - TELinearAdapter, patch_linear_module) +from mbridge.peft.lora_layers import (LinearAdapter, LoRAGroupedLinear, + LoRALinear, LoRATopKRouter, + TEFusedLoRALinear, TELinearAdapter, + patch_linear_module) from mbridge.peft.module_matcher import ModuleMatcher from mbridge.peft.utils import (ParallelLinearAdapter, get_adapter_attributes_from_linear, @@ -95,7 +96,7 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio nn.Module: The modified module with LoRA applied, or the original module if not a target. """ # Skip already transformed modules - adapter_types = (LinearAdapter, LoRALinear, LoRATopKRouter) + adapter_types = (LinearAdapter, LoRAGroupedLinear, LoRALinear, LoRATopKRouter) adapter_types = adapter_types + (TELinearAdapter,) if isinstance(module, adapter_types): return module @@ -142,6 +143,38 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio is_expert = is_expert_linear(full_name) attrs = get_adapter_attributes_from_linear(module, is_expert=is_expert) + # Per-expert LoRA for TEGroupedLinear: each expert gets its own adapter + num_gemms = getattr(module, "num_gemms", 0) + if is_expert and num_gemms > 0: + logging.info( + f"Adding per-expert lora to: {full_name} " + f"(num_local_experts={num_gemms})" + ) + adapters = nn.ModuleList() + for i in range(num_gemms): + adapters.append( + ParallelLinearAdapter( + attrs.in_features, + attrs.out_features, + self.dim, + base_linear_name=f"{full_name}.expert{i}", + activation="identity", + column_init_method=self.lora_A_init_method, + row_init_method=self.lora_B_init_method, + input_is_parallel=attrs.input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(module, "config", None), + alpha=self.alpha, + is_expert=is_expert, + a2a_experimental=self.a2a_experimental, + disable_tensor_parallel_comm=attrs.disable_tensor_parallel_comm, + disable_sequence_parallel_comm=attrs.disable_sequence_parallel_comm, + base_linear_is_parallel=attrs.base_linear_is_parallel, + ) + ) + return LoRAGroupedLinear(module, adapters) + enable_op_fuser = ( hasattr(module, "config") and getattr(module.config, "use_transformer_engine_op_fuser", False) @@ -335,6 +368,51 @@ def mcore_adapter_name_to_hf(mcore_name: str, bridge=None) -> str: """ import re + # --- Per-expert LoRA adapter path --- + # e.g. …experts.linear_fc1.adapter.3.linear_in.weight + m_expert = re.match( + r"(.+)\.adapter\.(\d+)\.linear_(in|out)\.weight$", + mcore_name, + ) + if m_expert is not None: + base_module_path = m_expert.group(1) + global_expert_id = int(m_expert.group(2)) + adapter_type = m_expert.group(3) # "in" or "out" + lora_suffix = _MCORE_TO_HF_LORA_SUFFIX[f"linear_{adapter_type}"] + + if bridge is not None: + # The bridge maps the BASE weight (without expert index) to fused + # HF keys. Query using the base name, then append expert index. + mcore_weight_name = f"{base_module_path}.weight" + try: + hf_names = bridge._weight_name_mapping_mcore_to_hf( + mcore_weight_name + ) + except (KeyError, NotImplementedError): + # Fallback for bridges without this mapping + hf_names = None + + if hf_names is not None: + if len(hf_names) == 1: + hf_name = hf_names[0] + if hf_name.endswith(".weight"): + hf_base = hf_name.rsplit(".", 1)[0] + else: + # Fused 3D format (no .weight suffix): keep full name + hf_base = hf_name + else: + hf_base = _combine_hf_module_names(hf_names) + return ( + f"base_model.model.{hf_base}." + f"{global_expert_id}.{lora_suffix}.weight" + ) + + # Fallback: use expert index in name directly + return ( + f"base_model.model.{base_module_path}." + f"{global_expert_id}.{lora_suffix}.weight" + ) + # --- CanonicalLoRA nested adapter path --- # e.g. …linear_qkv.adapter.adapter_q.linear_in.weight m_canonical = re.match( @@ -408,12 +486,21 @@ def infer_hf_target_modules(adapter_state: Dict[str, torch.Tensor]) -> list: Keys look like ``...layers.0.self_attn.qkv_proj.lora_A.weight``. The module name (``qkv_proj``) is 3 dots from the end. + Per-expert keys look like ``...gate_up_proj.3.lora_A.weight`` + where the module name is 4 dots from the end. """ modules = set() for key in adapter_state: parts = key.rsplit(".", 3) if len(parts) >= 4: - modules.add(parts[-3]) + candidate = parts[-3] + if candidate.isdigit(): + # Per-expert: module name is one level further up + sub_parts = parts[0].rsplit(".", 1) + if len(sub_parts) == 2: + modules.add(sub_parts[1]) + else: + modules.add(candidate) return sorted(modules) @@ -494,6 +581,73 @@ def gather_lora_state_dict(models, bridge=None) -> Dict[str, torch.Tensor]: adapter_state[hf_key] = lin_out_w.cpu() continue + # --- Per-expert LoRA (LoRAGroupedLinear) --- + if isinstance(module, LoRAGroupedLinear): + num_gemms = len(module.adapter) + ep_rank = parallel_state.get_expert_model_parallel_rank() + ep_size = parallel_state.get_expert_model_parallel_world_size() + num_local_experts = num_gemms + ep_group = parallel_state.get_expert_model_parallel_group() + + for i in range(num_local_experts): + adapter_i = module.adapter[i] + lin_in_w = adapter_i.linear_in.weight.data + lin_out_w = adapter_i.linear_out.weight.data + + if ep_size > 1: + # Gather this expert's adapter from all EP ranks + in_list = [ + torch.empty_like(lin_in_w) for _ in range(ep_size) + ] + out_list = [ + torch.empty_like(lin_out_w) for _ in range(ep_size) + ] + dist.all_gather( + in_list, lin_in_w.contiguous(), group=ep_group + ) + dist.all_gather( + out_list, lin_out_w.contiguous(), group=ep_group + ) + # Each EP rank's i-th local expert corresponds to + # global expert (ep_r * num_local_experts + i) + for ep_r in range(ep_size): + global_expert_id = ep_r * num_local_experts + i + expert_name = ( + f"{name}.adapter.{global_expert_id}" + f".linear_in.weight" + ) + hf_key = mcore_adapter_name_to_hf( + expert_name, bridge=bridge + ) + adapter_state[hf_key] = in_list[ep_r].cpu() + + expert_name = ( + f"{name}.adapter.{global_expert_id}" + f".linear_out.weight" + ) + hf_key = mcore_adapter_name_to_hf( + expert_name, bridge=bridge + ) + adapter_state[hf_key] = out_list[ep_r].cpu() + else: + global_expert_id = i + expert_name = ( + f"{name}.adapter.{global_expert_id}.linear_in.weight" + ) + hf_key = mcore_adapter_name_to_hf( + expert_name, bridge=bridge + ) + adapter_state[hf_key] = lin_in_w.cpu() + + expert_name = ( + f"{name}.adapter.{global_expert_id}.linear_out.weight" + ) + hf_key = mcore_adapter_name_to_hf( + expert_name, bridge=bridge + ) + adapter_state[hf_key] = lin_out_w.cpu() + continue + # --- LinearAdapter (nn.Linear with LoRA, for ViT/projector) --- if isinstance(module, LinearAdapter): lin_in_w = module.linear_in.weight.data.cpu() @@ -845,6 +999,29 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio module.to_wrap.weight.data = base_weight + lora_delta return module + # --- Per-expert LoRA (LoRAGroupedLinear) --- + if isinstance(module, LoRAGroupedLinear): + num_gemms = len(module.adapter) + single_grouped = getattr(module.to_wrap, "single_grouped_weight", False) + for i in range(num_gemms): + if single_grouped: + w = module.to_wrap.weight + base_device = w.device + else: + w = getattr(module.to_wrap, f"weight{i}") + base_device = w.device + adapter_i = module.adapter[i] + alpha = adapter_i.alpha + dim = adapter_i.dim + lin_in_w = adapter_i.linear_in.weight.data.to(base_device) + lin_out_w = adapter_i.linear_out.weight.data.to(base_device) + delta_i = (alpha / dim) * (lin_out_w @ lin_in_w) + if single_grouped: + w.data[i] += delta_i + else: + w.data += delta_i + return module + # --- Standard LoRA --- if not isinstance(module, LoRALinear): return module @@ -923,7 +1100,10 @@ def lora_merged(models): module_swaps = [] linear_adapter_backups = [] - _ADAPTER_WRAPPER_TYPES = (LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter) + _ADAPTER_WRAPPER_TYPES = ( + LoRALinear, LoRAGroupedLinear, LoRALinearSplitQKV, + LoRALinearSplitFC1UpGate, LoRATopKRouter, + ) for model_chunk in models: # Collect (parent, attr_name, lora_module) before modifying structure @@ -949,12 +1129,27 @@ def lora_merged(models): w = lora_module.to_wrap.weight _backup_tensor_to_cpu(w) weight_restore_list.append(w) + elif isinstance(lora_module, LoRAGroupedLinear): + # Per-expert LoRA: backup each expert weight separately + num_gemms = len(lora_module.adapter) + single_grouped = getattr( + lora_module.to_wrap, "single_grouped_weight", False + ) + if single_grouped: + w = lora_module.to_wrap.weight + _backup_tensor_to_cpu(w) + weight_restore_list.append(w) + else: + for i in range(num_gemms): + w = getattr(lora_module.to_wrap, f"weight{i}") + _backup_tensor_to_cpu(w) + weight_restore_list.append(w) elif hasattr(lora_module.to_wrap, "weight"): w = lora_module.to_wrap.weight _backup_tensor_to_cpu(w) weight_restore_list.append(w) else: - # TE Grouped Linear: backup all weights + # TE Grouped Linear (legacy): backup all weights for i in range(lora_module.to_wrap.num_gemms): w = getattr(lora_module.to_wrap, f"weight{i}") _backup_tensor_to_cpu(w) diff --git a/mbridge/peft/lora_layers.py b/mbridge/peft/lora_layers.py index 52cb668..965f2b7 100644 --- a/mbridge/peft/lora_layers.py +++ b/mbridge/peft/lora_layers.py @@ -48,6 +48,91 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Ten return linear_output + adapter_output, bias +class LoRAGroupedLinear(AdapterWrapper): + """Per-expert LoRA wrapper for TEGroupedLinear. + + Each local expert has its own independent (A_i, B_i) adapter pair. + The adapter attribute is an nn.ModuleList of per-expert adapters. + """ + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + pass + # Delegate unknown attributes to the wrapped TEGroupedLinear + # (needed for single_grouped_weight, num_gemms, in_features, etc.) + to_wrap = self.__dict__.get("to_wrap") or self._modules.get("to_wrap") + if to_wrap is not None: + return getattr(to_wrap, name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def forward( + self, x: torch.Tensor, tokens_per_expert: "list[int]", *args: Any, **kwargs: Any + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + linear_output = self.to_wrap(x, tokens_per_expert, *args, **kwargs) + if isinstance(linear_output, tuple): + if len(linear_output) == 2: + linear_output, bias = linear_output + else: + linear_output, bias = linear_output[0], None + else: + bias = None + + if not self._adapter_enabled: + return linear_output, bias + + adapter_outputs = [] + offset = 0 + for i, n_tokens in enumerate(tokens_per_expert): + if n_tokens == 0: + adapter_outputs.append( + x.new_zeros(0, linear_output.shape[-1]) + ) + continue + expert_input = x[offset:offset + n_tokens] + expert_adapter_out = self.adapter[i](expert_input.contiguous()) + adapter_outputs.append(expert_adapter_out) + offset += n_tokens + + if adapter_outputs: + adapter_cat = torch.cat(adapter_outputs, dim=0) + linear_output = linear_output + adapter_cat.reshape(linear_output.shape) + + return linear_output, bias + + def backward_dw(self): + """Delegate backward_dw to the wrapped TEGroupedLinear.""" + if hasattr(self.to_wrap, "backward_dw"): + self.to_wrap.backward_dw() + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = {} + self.to_wrap.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + self.adapter.state_dict( + destination=destination, prefix=f"{prefix}adapter.", keep_vars=keep_vars + ) + return destination + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + sharded_state_dict = {} + sharded_state_dict.update( + self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata) + ) + for i, adapter_i in enumerate(self.adapter): + sharded_state_dict.update( + adapter_i.sharded_state_dict( + f"{prefix}adapter.{i}.", sharded_offsets, metadata + ) + ) + return sharded_state_dict + + class LoRATopKRouter(AdapterWrapper): """Adapter wrapper that applies LoRA to router gating logits."""