diff --git a/mbridge/peft/lora.py b/mbridge/peft/lora.py index 1c5e97b..a8cc51f 100644 --- a/mbridge/peft/lora.py +++ b/mbridge/peft/lora.py @@ -649,17 +649,19 @@ def gather_lora_state_dict(models, bridge=None) -> Dict[str, torch.Tensor]: continue # --- LinearAdapter (nn.Linear with LoRA, for ViT/projector) --- + # CPU init + average_gradients_across_tp_domain guarantees all + # TP ranks hold identical weights, so no broadcast is needed. if isinstance(module, LinearAdapter): - lin_in_w = module.linear_in.weight.data.cpu() - lin_out_w = module.linear_out.weight.data.cpu() + lin_in_w = module.linear_in.weight.data + lin_out_w = module.linear_out.weight.data hf_key = mcore_adapter_name_to_hf( f"{name}.linear_in.weight", bridge=bridge, ) - adapter_state[hf_key] = lin_in_w + adapter_state[hf_key] = lin_in_w.cpu() hf_key = mcore_adapter_name_to_hf( f"{name}.linear_out.weight", bridge=bridge, ) - adapter_state[hf_key] = lin_out_w + adapter_state[hf_key] = lin_out_w.cpu() continue # --- Standard LoRA --- @@ -1029,7 +1031,59 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio # Detect stride for strided ColumnParallelLinear (gated MLP) stride = getattr(module.to_wrap, 'stride', 1) - if hasattr(module.to_wrap, "weight"): + # Detect interleaved TP layout for linear-attention in_proj. + # The in_proj ColumnParallelLinear has a non-standard TP layout + # where each component (wq, wk, wv, wz, wb, wa) is sharded + # independently and concatenated per-rank, rather than a simple + # contiguous row split. We detect this by checking if the full + # output dimension matches the in_proj formula: + # in_proj_dim = k_dim*2 + v_dim*2 + num_v_heads*2 + config = getattr(module.to_wrap, 'config', None) + has_interleaved_in_proj = False + if config is not None and tp_size > 1: + lnkh = getattr(config, 'linear_num_key_heads', None) + lkhd = getattr(config, 'linear_key_head_dim', None) + lnvh = getattr(config, 'linear_num_value_heads', None) + lvhd = getattr(config, 'linear_value_head_dim', None) + if all(v is not None for v in (lnkh, lkhd, lnvh, lvhd)): + in_proj_dim = ( + lnkh * lkhd * 2 + lnvh * lvhd * 2 + lnvh * 2 + ) + full_out = module.to_wrap.weight.shape[0] * tp_size + if full_out == in_proj_dim: + has_interleaved_in_proj = True + + if has_interleaved_in_proj: + # Compute the full (un-sharded) delta using _compute_sub_delta, + # then split by component sizes and reassemble in the + # interleaved per-TP-rank layout. + base_device = module.to_wrap.weight.device + full_delta = self._compute_sub_delta( + module.adapter.linear_out, + module.adapter.linear_in, + module.adapter.alpha, + module.adapter.dim, + base_device, + ) + # Component sizes (full, not per-TP-rank) + k_dim = config.linear_num_key_heads * config.linear_key_head_dim + v_dim = config.linear_num_value_heads * config.linear_value_head_dim + num_v_heads = config.linear_num_value_heads + split_shape = [k_dim, k_dim, v_dim, v_dim, num_v_heads, num_v_heads] + # Split full delta into 6 components: [wq, wk, wv, wz, wb, wa] + delta_parts = full_delta.split(split_shape, dim=0) + assert len(delta_parts) == 6, ( + f"in_proj expected 6 components, got {len(delta_parts)}" + ) + # Split each component by TP size and reassemble per-rank + per_rank_delta = torch.cat( + [part.chunk(tp_size, dim=0)[tp_rank] for part in delta_parts], + dim=0, + ) + module.to_wrap.weight.data = ( + module.to_wrap.weight.data + per_rank_delta.to(base_device) + ) + elif hasattr(module.to_wrap, "weight"): base_device = module.to_wrap.weight.device merged_weight = self.merge( module.to_wrap.weight, @@ -1168,6 +1222,8 @@ def lora_merged(models): continue _backup_tensor_to_cpu(la_module.weight) scale = la_module.scale + # CPU init + average_gradients_across_tp_domain guarantees all + # TP ranks hold identical weights, so no broadcast is needed. delta = la_module.linear_out.weight.data @ la_module.linear_in.weight.data la_module.weight.data.add_(scale * delta) # Hide LoRA sub-modules so they don't appear in named_parameters() diff --git a/mbridge/peft/lora_layers.py b/mbridge/peft/lora_layers.py index 965f2b7..930e2f2 100644 --- a/mbridge/peft/lora_layers.py +++ b/mbridge/peft/lora_layers.py @@ -136,6 +136,21 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): class LoRATopKRouter(AdapterWrapper): """Adapter wrapper that applies LoRA to router gating logits.""" + def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None: + super().__init__(to_wrap, adapter) + # Mirror base router weight behavior (router.py line 84): + # setattr(self.weight, 'sequence_parallel', self.config.sequence_parallel) + # The router adapter skips the SP gather for efficiency (each TP rank + # only routes its local seq/TP tokens). When SP is enabled, the + # adapter's ColumnParallelLinear weights receive gradients from local + # tokens only. Mark them so finalize_model_grads will SUM-allreduce + # across TP ranks, producing the correct full gradient. + seq_parallel = to_wrap.config.sequence_parallel + for sub in (adapter,) if not isinstance(adapter, nn.ModuleList) else adapter: + for p in sub.parameters(): + if p.requires_grad: + setattr(p, "sequence_parallel", seq_parallel) + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): """Forward pass that adds LoRA delta to router logits before routing.""" self.to_wrap._maintain_float32_expert_bias() @@ -260,13 +275,30 @@ def _init_adapter( out_features = obj.out_features dtype = lora_dtype or obj.weight.dtype - obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) - obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) + # Initialize on CPU first to ensure TP-independent initialization, + # then move to target device. GPU initialization uses per-rank + # model-parallel seeds, producing different LoRA-A weights on + # different TP ranks, which leads to divergent training and + # incorrect checkpoint export when TP > 1. + obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device="cpu") + obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device="cpu") if lora_A_init_method == "xavier": torch.nn.init.xavier_uniform_(obj.linear_in.weight.data) else: nn.init.kaiming_uniform_(obj.linear_in.weight.data, a=math.sqrt(5)) obj.linear_out.weight.data.fill_(0) + # Move to target device after CPU initialization + obj.linear_in = obj.linear_in.to(device=device) + obj.linear_out = obj.linear_out.to(device=device) + + # Mark LoRA parameters for TP gradient averaging. + # TELinearAdapter wraps TE layers that are replicated across TP + # ranks (not sharded like ColumnParallelLinear/RowParallelLinear). + # Without this, _allreduce_non_tensor_model_parallel_grads skips + # these parameters and their gradients diverge across TP ranks. + setattr(obj.linear_in.weight, "average_gradients_across_tp_domain", True) + setattr(obj.linear_out.weight, "average_gradients_across_tp_domain", True) + if dropout > 0.0: obj.dropout = nn.Dropout(p=dropout) else: @@ -664,13 +696,30 @@ def _init_adapter( out_features = obj.out_features dtype = lora_dtype or obj.weight.dtype - obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) - obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) + # Initialize on CPU first to ensure TP-independent initialization, + # then move to target device. GPU initialization uses per-rank + # model-parallel seeds, producing different LoRA-A weights on + # different TP ranks, which leads to divergent training and + # incorrect checkpoint export when TP > 1. + obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device="cpu") + obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device="cpu") if lora_A_init_method == "xavier": torch.nn.init.xavier_uniform_(obj.linear_in.weight.data) else: nn.init.kaiming_uniform_(obj.linear_in.weight.data, a=math.sqrt(5)) obj.linear_out.weight.data.fill_(0) + # Move to target device after CPU initialization + obj.linear_in = obj.linear_in.to(device=device) + obj.linear_out = obj.linear_out.to(device=device) + + # Mark LoRA parameters for TP gradient averaging. + # LinearAdapter wraps nn.Linear layers that are replicated across TP + # ranks (not sharded like ColumnParallelLinear/RowParallelLinear). + # Without this, _allreduce_non_tensor_model_parallel_grads skips + # these parameters and their gradients diverge across TP ranks. + setattr(obj.linear_in.weight, "average_gradients_across_tp_domain", True) + setattr(obj.linear_out.weight, "average_gradients_across_tp_domain", True) + if dropout > 0.0: obj.dropout = nn.Dropout(p=dropout) else: