Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions mbridge/peft/canonical_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from mbridge.peft.adapter_wrapper import AdapterWrapper
from mbridge.peft.base import PEFT
from mbridge.peft.lora_layers import LinearAdapter, LoRALinear, LoRATopKRouter
from mbridge.peft.lora_layers import (LinearAdapter, LoRAGroupedLinear,
LoRALinear, LoRATopKRouter)
from mbridge.peft.module_matcher import ModuleMatcher
from mbridge.peft.utils import (ParallelLinearAdapter,
get_adapter_attributes_from_linear,
Expand Down Expand Up @@ -294,7 +295,7 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s
"""

# Skip already transformed modules
if isinstance(m, (LinearAdapter, LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)):
if isinstance(m, (LinearAdapter, LoRALinear, LoRAGroupedLinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)):
return m

if (ans := self.match(m, name, prefix)) is not None:
Expand Down Expand Up @@ -341,6 +342,23 @@ def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[s
base_linear_is_parallel=attrs.base_linear_is_parallel,
)

# Per-expert LoRA: each expert gets its own adapter
num_gemms = getattr(m, "num_gemms", 0)
if is_expert and num_gemms > 0:
logger.info(
f"Adding per-expert lora to: {full_name} "
f"(num_local_experts={num_gemms})"
)
adapters = nn.ModuleList()
for i in range(num_gemms):
adapters.append(
ParallelLinearAdapter(
attrs.in_features, attrs.out_features,
**{**adapter_kwargs, "base_linear_name": f"{full_name}.expert{i}"},
)
)
return LoRAGroupedLinear(m, adapters)

if name == "linear_fc1" and _should_treat_linear_fc1_as_unfused(full_name):
logger.info(f"Adding lora to: {full_name} (treating unsupported canonical linear_fc1 as unfused)")
adapter = ParallelLinearAdapter(attrs.in_features, attrs.out_features, **adapter_kwargs)
Expand Down
209 changes: 202 additions & 7 deletions mbridge/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# These are only imported here (not in canonical_lora → avoids circular deps).
from mbridge.peft.canonical_lora import (LoRALinearSplitFC1UpGate,
LoRALinearSplitQKV)
from mbridge.peft.lora_layers import (LinearAdapter, LoRALinear,
LoRATopKRouter, TEFusedLoRALinear,
TELinearAdapter, patch_linear_module)
from mbridge.peft.lora_layers import (LinearAdapter, LoRAGroupedLinear,
LoRALinear, LoRATopKRouter,
TEFusedLoRALinear, TELinearAdapter,
patch_linear_module)
from mbridge.peft.module_matcher import ModuleMatcher
from mbridge.peft.utils import (ParallelLinearAdapter,
get_adapter_attributes_from_linear,
Expand Down Expand Up @@ -95,7 +96,7 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
nn.Module: The modified module with LoRA applied, or the original module if not a target.
"""
# Skip already transformed modules
adapter_types = (LinearAdapter, LoRALinear, LoRATopKRouter)
adapter_types = (LinearAdapter, LoRAGroupedLinear, LoRALinear, LoRATopKRouter)
adapter_types = adapter_types + (TELinearAdapter,)
if isinstance(module, adapter_types):
return module
Expand Down Expand Up @@ -142,6 +143,38 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
is_expert = is_expert_linear(full_name)
attrs = get_adapter_attributes_from_linear(module, is_expert=is_expert)

# Per-expert LoRA for TEGroupedLinear: each expert gets its own adapter
num_gemms = getattr(module, "num_gemms", 0)
if is_expert and num_gemms > 0:
logging.info(
f"Adding per-expert lora to: {full_name} "
f"(num_local_experts={num_gemms})"
)
adapters = nn.ModuleList()
for i in range(num_gemms):
adapters.append(
ParallelLinearAdapter(
attrs.in_features,
attrs.out_features,
self.dim,
base_linear_name=f"{full_name}.expert{i}",
activation="identity",
column_init_method=self.lora_A_init_method,
row_init_method=self.lora_B_init_method,
input_is_parallel=attrs.input_is_parallel,
dropout=self.dropout,
dropout_position=self.dropout_position,
model_parallel_config=getattr(module, "config", None),
alpha=self.alpha,
is_expert=is_expert,
a2a_experimental=self.a2a_experimental,
disable_tensor_parallel_comm=attrs.disable_tensor_parallel_comm,
disable_sequence_parallel_comm=attrs.disable_sequence_parallel_comm,
base_linear_is_parallel=attrs.base_linear_is_parallel,
)
)
return LoRAGroupedLinear(module, adapters)

enable_op_fuser = (
hasattr(module, "config")
and getattr(module.config, "use_transformer_engine_op_fuser", False)
Expand Down Expand Up @@ -335,6 +368,51 @@ def mcore_adapter_name_to_hf(mcore_name: str, bridge=None) -> str:
"""
import re

# --- Per-expert LoRA adapter path ---
# e.g. …experts.linear_fc1.adapter.3.linear_in.weight
m_expert = re.match(
r"(.+)\.adapter\.(\d+)\.linear_(in|out)\.weight$",
mcore_name,
)
if m_expert is not None:
base_module_path = m_expert.group(1)
global_expert_id = int(m_expert.group(2))
adapter_type = m_expert.group(3) # "in" or "out"
lora_suffix = _MCORE_TO_HF_LORA_SUFFIX[f"linear_{adapter_type}"]

if bridge is not None:
# The bridge maps the BASE weight (without expert index) to fused
# HF keys. Query using the base name, then append expert index.
mcore_weight_name = f"{base_module_path}.weight"
try:
hf_names = bridge._weight_name_mapping_mcore_to_hf(
mcore_weight_name
)
except (KeyError, NotImplementedError):
# Fallback for bridges without this mapping
hf_names = None

if hf_names is not None:
if len(hf_names) == 1:
hf_name = hf_names[0]
if hf_name.endswith(".weight"):
hf_base = hf_name.rsplit(".", 1)[0]
else:
# Fused 3D format (no .weight suffix): keep full name
hf_base = hf_name
else:
hf_base = _combine_hf_module_names(hf_names)
return (
f"base_model.model.{hf_base}."
f"{global_expert_id}.{lora_suffix}.weight"
)

# Fallback: use expert index in name directly
return (
f"base_model.model.{base_module_path}."
f"{global_expert_id}.{lora_suffix}.weight"
)

# --- CanonicalLoRA nested adapter path ---
# e.g. …linear_qkv.adapter.adapter_q.linear_in.weight
m_canonical = re.match(
Expand Down Expand Up @@ -408,12 +486,21 @@ def infer_hf_target_modules(adapter_state: Dict[str, torch.Tensor]) -> list:

Keys look like ``...layers.0.self_attn.qkv_proj.lora_A.weight``.
The module name (``qkv_proj``) is 3 dots from the end.
Per-expert keys look like ``...gate_up_proj.3.lora_A.weight``
where the module name is 4 dots from the end.
"""
modules = set()
for key in adapter_state:
parts = key.rsplit(".", 3)
if len(parts) >= 4:
modules.add(parts[-3])
candidate = parts[-3]
if candidate.isdigit():
# Per-expert: module name is one level further up
sub_parts = parts[0].rsplit(".", 1)
if len(sub_parts) == 2:
modules.add(sub_parts[1])
else:
modules.add(candidate)
return sorted(modules)


Expand Down Expand Up @@ -494,6 +581,73 @@ def gather_lora_state_dict(models, bridge=None) -> Dict[str, torch.Tensor]:
adapter_state[hf_key] = lin_out_w.cpu()
continue

# --- Per-expert LoRA (LoRAGroupedLinear) ---
if isinstance(module, LoRAGroupedLinear):
num_gemms = len(module.adapter)
ep_rank = parallel_state.get_expert_model_parallel_rank()
ep_size = parallel_state.get_expert_model_parallel_world_size()
num_local_experts = num_gemms
ep_group = parallel_state.get_expert_model_parallel_group()

for i in range(num_local_experts):
adapter_i = module.adapter[i]
lin_in_w = adapter_i.linear_in.weight.data
lin_out_w = adapter_i.linear_out.weight.data

if ep_size > 1:
# Gather this expert's adapter from all EP ranks
in_list = [
torch.empty_like(lin_in_w) for _ in range(ep_size)
]
out_list = [
torch.empty_like(lin_out_w) for _ in range(ep_size)
]
dist.all_gather(
in_list, lin_in_w.contiguous(), group=ep_group
)
dist.all_gather(
out_list, lin_out_w.contiguous(), group=ep_group
)
# Each EP rank's i-th local expert corresponds to
# global expert (ep_r * num_local_experts + i)
for ep_r in range(ep_size):
global_expert_id = ep_r * num_local_experts + i
expert_name = (
f"{name}.adapter.{global_expert_id}"
f".linear_in.weight"
)
hf_key = mcore_adapter_name_to_hf(
expert_name, bridge=bridge
)
adapter_state[hf_key] = in_list[ep_r].cpu()

expert_name = (
f"{name}.adapter.{global_expert_id}"
f".linear_out.weight"
)
hf_key = mcore_adapter_name_to_hf(
expert_name, bridge=bridge
)
adapter_state[hf_key] = out_list[ep_r].cpu()
else:
global_expert_id = i
expert_name = (
f"{name}.adapter.{global_expert_id}.linear_in.weight"
)
hf_key = mcore_adapter_name_to_hf(
expert_name, bridge=bridge
)
adapter_state[hf_key] = lin_in_w.cpu()

expert_name = (
f"{name}.adapter.{global_expert_id}.linear_out.weight"
)
hf_key = mcore_adapter_name_to_hf(
expert_name, bridge=bridge
)
adapter_state[hf_key] = lin_out_w.cpu()
continue

# --- LinearAdapter (nn.Linear with LoRA, for ViT/projector) ---
if isinstance(module, LinearAdapter):
lin_in_w = module.linear_in.weight.data.cpu()
Expand Down Expand Up @@ -845,6 +999,29 @@ def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optio
module.to_wrap.weight.data = base_weight + lora_delta
return module

# --- Per-expert LoRA (LoRAGroupedLinear) ---
if isinstance(module, LoRAGroupedLinear):
num_gemms = len(module.adapter)
single_grouped = getattr(module.to_wrap, "single_grouped_weight", False)
for i in range(num_gemms):
if single_grouped:
w = module.to_wrap.weight
base_device = w.device
else:
w = getattr(module.to_wrap, f"weight{i}")
base_device = w.device
adapter_i = module.adapter[i]
alpha = adapter_i.alpha
dim = adapter_i.dim
lin_in_w = adapter_i.linear_in.weight.data.to(base_device)
lin_out_w = adapter_i.linear_out.weight.data.to(base_device)
delta_i = (alpha / dim) * (lin_out_w @ lin_in_w)
if single_grouped:
w.data[i] += delta_i
else:
w.data += delta_i
return module

# --- Standard LoRA ---
if not isinstance(module, LoRALinear):
return module
Expand Down Expand Up @@ -923,7 +1100,10 @@ def lora_merged(models):
module_swaps = []
linear_adapter_backups = []

_ADAPTER_WRAPPER_TYPES = (LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)
_ADAPTER_WRAPPER_TYPES = (
LoRALinear, LoRAGroupedLinear, LoRALinearSplitQKV,
LoRALinearSplitFC1UpGate, LoRATopKRouter,
)

for model_chunk in models:
# Collect (parent, attr_name, lora_module) before modifying structure
Expand All @@ -949,12 +1129,27 @@ def lora_merged(models):
w = lora_module.to_wrap.weight
_backup_tensor_to_cpu(w)
weight_restore_list.append(w)
elif isinstance(lora_module, LoRAGroupedLinear):
# Per-expert LoRA: backup each expert weight separately
num_gemms = len(lora_module.adapter)
single_grouped = getattr(
lora_module.to_wrap, "single_grouped_weight", False
)
if single_grouped:
w = lora_module.to_wrap.weight
_backup_tensor_to_cpu(w)
weight_restore_list.append(w)
else:
for i in range(num_gemms):
w = getattr(lora_module.to_wrap, f"weight{i}")
_backup_tensor_to_cpu(w)
weight_restore_list.append(w)
elif hasattr(lora_module.to_wrap, "weight"):
w = lora_module.to_wrap.weight
_backup_tensor_to_cpu(w)
weight_restore_list.append(w)
else:
# TE Grouped Linear: backup all weights
# TE Grouped Linear (legacy): backup all weights
for i in range(lora_module.to_wrap.num_gemms):
w = getattr(lora_module.to_wrap, f"weight{i}")
_backup_tensor_to_cpu(w)
Expand Down
Loading
Loading