Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions mbridge/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,17 +649,19 @@ def gather_lora_state_dict(models, bridge=None) -> Dict[str, torch.Tensor]:
continue

# --- LinearAdapter (nn.Linear with LoRA, for ViT/projector) ---
# CPU init + average_gradients_across_tp_domain guarantees all
# TP ranks hold identical weights, so no broadcast is needed.
if isinstance(module, LinearAdapter):
lin_in_w = module.linear_in.weight.data.cpu()
lin_out_w = module.linear_out.weight.data.cpu()
lin_in_w = module.linear_in.weight.data
lin_out_w = module.linear_out.weight.data
hf_key = mcore_adapter_name_to_hf(
f"{name}.linear_in.weight", bridge=bridge,
)
adapter_state[hf_key] = lin_in_w
adapter_state[hf_key] = lin_in_w.cpu()
hf_key = mcore_adapter_name_to_hf(
f"{name}.linear_out.weight", bridge=bridge,
)
adapter_state[hf_key] = lin_out_w
adapter_state[hf_key] = lin_out_w.cpu()
continue

# --- Standard LoRA ---
Expand Down Expand Up @@ -1029,7 +1031,59 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
# Detect stride for strided ColumnParallelLinear (gated MLP)
stride = getattr(module.to_wrap, 'stride', 1)

if hasattr(module.to_wrap, "weight"):
# Detect interleaved TP layout for linear-attention in_proj.
# The in_proj ColumnParallelLinear has a non-standard TP layout
# where each component (wq, wk, wv, wz, wb, wa) is sharded
# independently and concatenated per-rank, rather than a simple
# contiguous row split. We detect this by checking if the full
# output dimension matches the in_proj formula:
# in_proj_dim = k_dim*2 + v_dim*2 + num_v_heads*2
config = getattr(module.to_wrap, 'config', None)
has_interleaved_in_proj = False
if config is not None and tp_size > 1:
lnkh = getattr(config, 'linear_num_key_heads', None)
lkhd = getattr(config, 'linear_key_head_dim', None)
lnvh = getattr(config, 'linear_num_value_heads', None)
lvhd = getattr(config, 'linear_value_head_dim', None)
if all(v is not None for v in (lnkh, lkhd, lnvh, lvhd)):
in_proj_dim = (
lnkh * lkhd * 2 + lnvh * lvhd * 2 + lnvh * 2
)
full_out = module.to_wrap.weight.shape[0] * tp_size
if full_out == in_proj_dim:
has_interleaved_in_proj = True

if has_interleaved_in_proj:
# Compute the full (un-sharded) delta using _compute_sub_delta,
# then split by component sizes and reassemble in the
# interleaved per-TP-rank layout.
base_device = module.to_wrap.weight.device
full_delta = self._compute_sub_delta(
module.adapter.linear_out,
module.adapter.linear_in,
module.adapter.alpha,
module.adapter.dim,
base_device,
)
# Component sizes (full, not per-TP-rank)
k_dim = config.linear_num_key_heads * config.linear_key_head_dim
v_dim = config.linear_num_value_heads * config.linear_value_head_dim
num_v_heads = config.linear_num_value_heads
split_shape = [k_dim, k_dim, v_dim, v_dim, num_v_heads, num_v_heads]
# Split full delta into 6 components: [wq, wk, wv, wz, wb, wa]
delta_parts = full_delta.split(split_shape, dim=0)
assert len(delta_parts) == 6, (
f"in_proj expected 6 components, got {len(delta_parts)}"
)
# Split each component by TP size and reassemble per-rank
per_rank_delta = torch.cat(
[part.chunk(tp_size, dim=0)[tp_rank] for part in delta_parts],
dim=0,
)
module.to_wrap.weight.data = (
module.to_wrap.weight.data + per_rank_delta.to(base_device)
)
elif hasattr(module.to_wrap, "weight"):
base_device = module.to_wrap.weight.device
merged_weight = self.merge(
module.to_wrap.weight,
Expand Down Expand Up @@ -1168,6 +1222,8 @@ def lora_merged(models):
continue
_backup_tensor_to_cpu(la_module.weight)
scale = la_module.scale
# CPU init + average_gradients_across_tp_domain guarantees all
# TP ranks hold identical weights, so no broadcast is needed.
delta = la_module.linear_out.weight.data @ la_module.linear_in.weight.data
la_module.weight.data.add_(scale * delta)
# Hide LoRA sub-modules so they don't appear in named_parameters()
Expand Down
57 changes: 53 additions & 4 deletions mbridge/peft/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,21 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
class LoRATopKRouter(AdapterWrapper):
"""Adapter wrapper that applies LoRA to router gating logits."""

def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None:
super().__init__(to_wrap, adapter)
# Mirror base router weight behavior (router.py line 84):
# setattr(self.weight, 'sequence_parallel', self.config.sequence_parallel)
# The router adapter skips the SP gather for efficiency (each TP rank
# only routes its local seq/TP tokens). When SP is enabled, the
# adapter's ColumnParallelLinear weights receive gradients from local
# tokens only. Mark them so finalize_model_grads will SUM-allreduce
# across TP ranks, producing the correct full gradient.
seq_parallel = to_wrap.config.sequence_parallel
for sub in (adapter,) if not isinstance(adapter, nn.ModuleList) else adapter:
for p in sub.parameters():
if p.requires_grad:
setattr(p, "sequence_parallel", seq_parallel)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
"""Forward pass that adds LoRA delta to router logits before routing."""
self.to_wrap._maintain_float32_expert_bias()
Expand Down Expand Up @@ -260,13 +275,30 @@ def _init_adapter(
out_features = obj.out_features
dtype = lora_dtype or obj.weight.dtype

obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device)
obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device)
# Initialize on CPU first to ensure TP-independent initialization,
# then move to target device. GPU initialization uses per-rank
# model-parallel seeds, producing different LoRA-A weights on
# different TP ranks, which leads to divergent training and
# incorrect checkpoint export when TP > 1.
obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device="cpu")
obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device="cpu")
if lora_A_init_method == "xavier":
torch.nn.init.xavier_uniform_(obj.linear_in.weight.data)
else:
nn.init.kaiming_uniform_(obj.linear_in.weight.data, a=math.sqrt(5))
obj.linear_out.weight.data.fill_(0)
# Move to target device after CPU initialization
obj.linear_in = obj.linear_in.to(device=device)
obj.linear_out = obj.linear_out.to(device=device)

# Mark LoRA parameters for TP gradient averaging.
# TELinearAdapter wraps TE layers that are replicated across TP
# ranks (not sharded like ColumnParallelLinear/RowParallelLinear).
# Without this, _allreduce_non_tensor_model_parallel_grads skips
# these parameters and their gradients diverge across TP ranks.
setattr(obj.linear_in.weight, "average_gradients_across_tp_domain", True)
setattr(obj.linear_out.weight, "average_gradients_across_tp_domain", True)

if dropout > 0.0:
obj.dropout = nn.Dropout(p=dropout)
else:
Expand Down Expand Up @@ -664,13 +696,30 @@ def _init_adapter(
out_features = obj.out_features
dtype = lora_dtype or obj.weight.dtype

obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device)
obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device)
# Initialize on CPU first to ensure TP-independent initialization,
# then move to target device. GPU initialization uses per-rank
# model-parallel seeds, producing different LoRA-A weights on
# different TP ranks, which leads to divergent training and
# incorrect checkpoint export when TP > 1.
obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device="cpu")
obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device="cpu")
if lora_A_init_method == "xavier":
torch.nn.init.xavier_uniform_(obj.linear_in.weight.data)
else:
nn.init.kaiming_uniform_(obj.linear_in.weight.data, a=math.sqrt(5))
obj.linear_out.weight.data.fill_(0)
# Move to target device after CPU initialization
obj.linear_in = obj.linear_in.to(device=device)
obj.linear_out = obj.linear_out.to(device=device)

# Mark LoRA parameters for TP gradient averaging.
# LinearAdapter wraps nn.Linear layers that are replicated across TP
# ranks (not sharded like ColumnParallelLinear/RowParallelLinear).
# Without this, _allreduce_non_tensor_model_parallel_grads skips
# these parameters and their gradients diverge across TP ranks.
setattr(obj.linear_in.weight, "average_gradients_across_tp_domain", True)
setattr(obj.linear_out.weight, "average_gradients_across_tp_domain", True)

if dropout > 0.0:
obj.dropout = nn.Dropout(p=dropout)
else:
Expand Down
Loading