diff --git a/mbridge/peft/__init__.py b/mbridge/peft/__init__.py new file mode 100644 index 0000000..d2ee69a --- /dev/null +++ b/mbridge/peft/__init__.py @@ -0,0 +1,7 @@ +# Adapted from NVIDIA Megatron-Bridge + +from mbridge.peft.base import PEFT +from mbridge.peft.canonical_lora import CanonicalLoRA +from mbridge.peft.lora import (LoRA, LoRAMerge, gather_lora_state_dict, + infer_hf_target_modules, lora_merged, + mcore_adapter_name_to_hf) diff --git a/mbridge/peft/adapter_wrapper.py b/mbridge/peft/adapter_wrapper.py new file mode 100644 index 0000000..8493792 --- /dev/null +++ b/mbridge/peft/adapter_wrapper.py @@ -0,0 +1,203 @@ +# Adapted from NVIDIA Megatron-Bridge + +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from mbridge.peft.utils import ParallelLinearAdapter + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.mapping import ShardedStateDict + + +def _compute_mamba_dim_info(wrapped_module: nn.Module) -> Dict[str, int]: + """Compute Mamba dimension information from a wrapped module's config. + + This follows the same logic as mamba_mixer.py to derive local tensor parallel + dimensions from the TransformerConfig. + + Args: + wrapped_module: The wrapped module (typically a linear projection in a Mamba layer). + + Returns: + Dictionary containing d_inner_local_tp, ngroups_local_tp, d_state, and nheads_local_tp. + """ + config = wrapped_module.config + + # Get base dimensions from config + d_state = config.mamba_state_dim + headdim = config.mamba_head_dim + ngroups = config.mamba_num_groups + + # Compute nheads and d_inner + if config.mamba_num_heads is not None: + nheads = config.mamba_num_heads + d_inner = nheads * headdim + else: + d_inner = wrapped_module.d_inner + nheads = d_inner // headdim + + # Get tensor parallel size and compute local dimensions + tp_size = wrapped_module.tp_size + + return { + "d_inner_local_tp": d_inner // tp_size, + "ngroups_local_tp": ngroups // tp_size, + "d_state": d_state, + "nheads_local_tp": nheads // tp_size, + } + + +class AdapterWrapper(nn.Module): + """Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT). + + This class wraps a module and its associated adapter, providing methods for + managing the state dictionaries of both the main module and the adapter. It does not + implement the forward method, which must be implemented by concrete subclasses. + + Attributes: + to_wrap (nn.Module): The main module to be wrapped. + adapter (nn.Module): The adapter module to be applied. + + Note: + This class is abstract and cannot be instantiated directly. Subclasses must + implement the forward method. + + Example: + class LoRALinear(AdapterWrapper): + def __init__(self, to_wrap, adapter): + super().__init__(to_wrap, adapter) + + def forward(self, x): + return self.to_wrap(x) + self.adapter(x) + + main_module = nn.Linear(100, 100) + adapter = nn.Linear(100, 100) + parallel_adapter = LoRALinear(main_module, adapter) + """ + + def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None: + """Initialize the AdapterWrapper with a main module and adapter. + + Args: + to_wrap: The main module to be wrapped. + adapter: The adapter module to be applied. + """ + super(AdapterWrapper, self).__init__() + self.to_wrap = to_wrap + self.adapter = adapter + self._adapter_enabled = True + + def enable_adapter_layers(self) -> None: + """Enable the adapter layers, allowing them to contribute to the forward pass output.""" + self._adapter_enabled = True + + def disable_adapter_layers(self) -> None: + """Disable the adapter layers, making the forward pass return only the base module output.""" + self._adapter_enabled = False + + def base_linear_forward( + self, x: torch.Tensor, *args: Any, **kwargs: Any + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """Run the forward method of the linear module `to_wrap`. + + This method handles the complex return patterns of Megatron's linear layers, + which can return different combinations of outputs, biases, and layernorm outputs. + + The flow is: x -> [layernorm/identity] -> layernorm_output -> [linear] -> linear_output, bias + + Args: + x: Input tensor. + *args: Additional positional arguments for the wrapped module. + **kwargs: Additional keyword arguments for the wrapped module. + + Returns: + A tuple containing: + - linear_output: The output from the linear layer + - bias: The bias term (if present, otherwise None) + - layernorm_output: The output from layernorm (differs from x only for + LayerNormColumnParallelLinear, otherwise equals x) + + Note: + The wrapped module can return values in four different patterns: + 1. nothing: (out, None) + 2. return_bias: (out, bias) + 3. return_layernorm_output: ((out, ln_out), None) + 4. both: (out, bias, ln_out) + """ + linear_output = self.to_wrap(x, *args, **kwargs) + assert isinstance(linear_output, tuple), ( + f"{self.to_wrap} should return a tuple but instead returns {linear_output}" + ) + + bias = None + layernorm_output = x + + if len(linear_output) == 2: + linear_output, bias = linear_output + if isinstance(linear_output, tuple) and len(linear_output) == 2: + linear_output, layernorm_output = linear_output + elif len(linear_output) == 3: + linear_output, bias, layernorm_output = linear_output + + return linear_output, bias, layernorm_output + + def state_dict( + self, destination: Optional[Dict[str, Any]] = None, prefix: str = "", keep_vars: bool = False + ) -> Dict[str, Any]: + """Retrieve the state dictionary of the wrapped module and adapter. + + This method overrides the default state_dict behavior to include both + the main module's state and the adapter's state under a special 'adapter' prefix. + + Args: + destination: A dictionary to store the state. If None, a new + dictionary is created. Defaults to None. + prefix: A prefix added to parameter and buffer names. Defaults to ''. + keep_vars: If True, returns variables instead of tensor values. + Defaults to False. + + Returns: + The state dictionary containing both the main module and adapter states. + """ + if destination is None: + destination = {} + + # Get state dict of the main module + self.to_wrap.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + # Store adapter state dict under the "adapter" prefix in the destination dict + self.adapter.state_dict(destination=destination, prefix=f"{prefix}adapter.", keep_vars=keep_vars) + return destination + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[Dict[str, Any]] = None, + ) -> "ShardedStateDict": + """Retrieve the sharded state dictionary of the wrapped module and adapter. + + This method is used for distributed checkpointing, combining the sharded states + of both the main module and the adapter. + + Args: + prefix: A prefix added to parameter and buffer names. Defaults to ''. + sharded_offsets: Offsets for sharded parameters. Defaults to an empty tuple. + metadata: Additional metadata for the sharded state. Defaults to None. + + Returns: + The combined sharded state dictionary. + """ + adapter_sharded_state_dict_kwargs = {} + if isinstance(self.adapter, ParallelLinearAdapter) and "mixer.in_proj" in self.adapter.base_linear_name: + adapter_sharded_state_dict_kwargs["mamba_dim_info"] = _compute_mamba_dim_info(self.to_wrap) + + sharded_state_dict = {} + sharded_state_dict.update(self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata)) + sharded_state_dict.update( + self.adapter.sharded_state_dict( + f"{prefix}adapter.", sharded_offsets, metadata, **adapter_sharded_state_dict_kwargs + ) + ) + return sharded_state_dict diff --git a/mbridge/peft/base.py b/mbridge/peft/base.py new file mode 100644 index 0000000..10c7123 --- /dev/null +++ b/mbridge/peft/base.py @@ -0,0 +1,159 @@ +# Adapted from NVIDIA Megatron-Bridge + +import logging +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Optional, TypeVar, Union + +import torch +import torch.nn as nn +from mbridge.peft.recompute import maybe_enable_recompute_inputs_grad +from mbridge.peft.walk_utils import walk +from megatron.core.transformer.module import MegatronModule + +logger: logging.Logger = logging.getLogger(__name__) + +ModelType = TypeVar("ModelType", bound=Union[nn.Module, list[MegatronModule]]) + + +@dataclass +class PEFT(ABC): + """Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods. + + This class defines the interface for PEFT methods, which are used to fine-tune + large language models efficiently by modifying only a small subset of the model's + parameters. + """ + + # Runtime state that should not be serialized in checkpoints + params_to_save: set[str] = field(default_factory=set, init=False, repr=False) + + @abstractmethod + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: + """Transform a single module according to the PEFT method. + + This method is called for each module in the model during the PEFT application process. + It should be implemented by subclasses to define how individual modules are transformed + for the specific PEFT technique. + + Args: + module (nn.Module): The individual module to be transformed. + name (Optional[str]): The name of the module within the model structure. Defaults to None. + prefix (Optional[str]): A prefix to be added to the module name, typically used for + nested modules. Defaults to None. + + Returns: + nn.Module: The transformed module. This can be the original module with modifications, + a new module replacing the original, or the original module if no + transformation is needed for this specific module. + + Note: + This method is automatically called for each module in the model when the PEFT + instance is applied to the model using the __call__ method. + """ + raise NotImplementedError("The transform method should be implemented by subclasses.") + + def __call__(self, model: ModelType, training: bool = True) -> ModelType: + """Apply the PEFT method to the entire model. + + This method freezes the model parameters and walks through the model + structure, applying the transform method to each module. + + Args: + model: The model to be fine-tuned. Can be a single model or a list of model chunks + (for pipeline parallelism). + training (bool): Whether the model will be used for training. If False, + additional freezing may be applied. Defaults to True. + + Returns: + The same type as the input model, transformed with PEFT applied. + """ + self.freeze_model(model, training=training) + + self._walk_model(model, self.transform) + + if training: + maybe_enable_recompute_inputs_grad(model) + + if not training: + self.freeze_model(model, training=training) + + # Set model training mode appropriately + if isinstance(model, list): + for model_chunk in model: + model_chunk.train(mode=training) + else: + model.train(mode=training) + + return model + + def _walk_model(self, model: ModelType, func) -> None: + if isinstance(model, list): + for model_chunk in model: + walk(model_chunk, func) + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + walk(model.module, func) + else: + walk(model, func) + + def enable_adapter_layers(self, model: ModelType) -> None: + """Enable adapter layers for all PEFT-wrapped modules in the model.""" + + def enable(module: nn.Module) -> nn.Module: + method = getattr(module, "enable_adapter_layers", None) + if callable(method): + method() + return module + + self._walk_model(model, enable) + + def disable_adapter_layers(self, model: ModelType) -> None: + """Disable adapter layers for all PEFT-wrapped modules in the model.""" + + def disable(module: nn.Module) -> nn.Module: + method = getattr(module, "disable_adapter_layers", None) + if callable(method): + method() + return module + + self._walk_model(model, disable) + + @contextmanager + def disable_adapter(self, model: ModelType): + """ + Disables the adapter module. + """ + try: + self.disable_adapter_layers(model) + yield + finally: + self.enable_adapter_layers(model) + + def freeze_model(self, model: ModelType, training: bool = True) -> None: + """Apply a default freeze method to the model. + + This method freezes all the model parameters. This method can be overridden by subclasses to + implement custom freeze strategies (e.g. freeze only parts of the model) + + Args: + model: The model to be fine-tuned. + training (bool): Whether the model is being used for training. Affects training mode handling. + """ + + def freeze_parameters(module): + """Freeze all parameters in a module.""" + for param in module.parameters(recurse=False): + param.requires_grad = False + return module + + self._walk_model(model, freeze_parameters) + + if training: + if isinstance(model, list): + for model_chunk in model: + model_chunk.train(mode=True) + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + model.module.train(mode=True) + else: + model.train(mode=True) diff --git a/mbridge/peft/canonical_lora.py b/mbridge/peft/canonical_lora.py new file mode 100644 index 0000000..cfd8cce --- /dev/null +++ b/mbridge/peft/canonical_lora.py @@ -0,0 +1,391 @@ +# Adapted from NVIDIA Megatron-Bridge + +import logging +from dataclasses import dataclass, field +from typing import Any, List, Literal, Optional, Tuple + +import torch +from mbridge.peft.adapter_wrapper import AdapterWrapper +from mbridge.peft.base import PEFT +from mbridge.peft.lora_layers import LinearAdapter, LoRALinear, LoRATopKRouter +from mbridge.peft.module_matcher import ModuleMatcher +from mbridge.peft.utils import (ParallelLinearAdapter, + get_adapter_attributes_from_linear, + is_expert_linear) +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.transformer.moe.router import TopKRouter +from torch import nn + +logger = logging.getLogger(__name__) + + +def _should_treat_linear_fc1_as_unfused(full_name: str) -> bool: + """Return True when CanonicalLoRA should keep linear_fc1 as a single adapter.""" + + return full_name.startswith("vision_model.") or full_name.endswith(".mlp.experts.linear_fc1") + + +class ModuleDict(nn.ModuleDict): + """ + nn.ModuleDict with a sharded_state_dict implementation for checkpointing + """ + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> "ShardedStateDict": + """Retrieve the sharded state dictionary of the wrapped module and adapter. + + This method is used for distributed checkpointing, combining the sharded states + of both the main module and the adapter. + + Args: + prefix (str): A prefix added to parameter and buffer names. Defaults to ''. + sharded_offsets (Tuple[Tuple[int, int, int]]): Offsets for sharded parameters. + Defaults to an empty tuple. + metadata (Optional[dict]): Additional metadata for the sharded state. + Defaults to None. + + Returns: + ShardedStateDict: The combined sharded state dictionary. + """ + sharded_state_dict = {} + for key, layer in self.items(): + sharded_state_dict.update(layer.sharded_state_dict(f"{prefix}{key}.", sharded_offsets, metadata)) + return sharded_state_dict + + +class LoRALinearSplitQKV(AdapterWrapper): + """An adapter wrapper for `linear_qkv` where q, k, v are three separate adapters. + This module that adds the output of the adapters to the output of the wrapped module while taking care of shape. + + This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques + where the adapter's output is added to the main module's output. It extends the AdapterWrapper + class to provide a specific implementation of the forward method. + """ + + def _interleave_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + """Interleave QKV outputs to match Megatron's packed ordering.""" + + config = self.to_wrap.config + head_num = getattr(config, "num_attention_heads", None) + num_query_groups = getattr(config, "num_query_groups", None) + head_size = getattr(config, "kv_channels", None) + + if head_size is None: + hidden_size = getattr(config, "hidden_size", None) + if head_num is not None and hidden_size is not None: + head_size = hidden_size // head_num + elif num_query_groups: + if key.size(-1) % num_query_groups != 0: + raise ValueError("Key projection size must be divisible by num_query_groups.") + head_size = key.size(-1) // num_query_groups + elif head_num is not None: + if query.size(-1) % head_num != 0: + raise ValueError("Query projection size must be divisible by num_attention_heads.") + head_size = query.size(-1) // head_num + else: + raise ValueError( + "Cannot infer head size without kv_channels or hidden_size/num_attention_heads or num_query_groups." + ) + + if head_num is None: + if query.size(-1) % head_size != 0: + raise ValueError("Query projection size must be divisible by head_size.") + head_num = query.size(-1) // head_size + + if not num_query_groups: + if key.size(-1) % head_size != 0: + raise ValueError("Key projection size must be divisible by head_size.") + num_query_groups = key.size(-1) // head_size + + if head_num % num_query_groups != 0: + raise ValueError("num_attention_heads must be divisible by num_query_groups.") + + heads_per_group = head_num // num_query_groups + + leading_shape = query.shape[:-1] + query = query.reshape(-1, head_num, head_size) + key = key.reshape(-1, num_query_groups, head_size) + value = value.reshape(-1, num_query_groups, head_size) + + output_gate = getattr(config, "attention_output_gate", False) + + qkv_chunks = [] + for i in range(num_query_groups): + q_group = query[:, i * heads_per_group : (i + 1) * heads_per_group, :] + k_group = key[:, i : i + 1, :] + v_group = value[:, i : i + 1, :] + qkv_chunks.append(q_group) + if output_gate: + qkv_chunks.append(torch.zeros_like(q_group)) + qkv_chunks.append(k_group) + qkv_chunks.append(v_group) + + qkv = torch.cat(qkv_chunks, dim=1) + return qkv.reshape(*leading_shape, -1) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # pylint: disable=C0115,C0116 + linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs) + if not self._adapter_enabled: + return linear_output, bias + + # Only compute adapters that exist (not None) + query = self.adapter.adapter_q(layernorm_output) if self.adapter.adapter_q is not None else None + key = self.adapter.adapter_k(layernorm_output) if self.adapter.adapter_k is not None else None + value = self.adapter.adapter_v(layernorm_output) if self.adapter.adapter_v is not None else None + + if query is None and key is None and value is None: + return linear_output, bias + + # For canonical LoRA, if only a subset of Q/K/V adapters are present, + # fall back to a simpler concatenation (Q then K then V) which matches + # the Megatron QKV weight order for standard (non-GQA-interleaved) layouts. + if query is None: + adapter_output = torch.cat( + [t for t in [key, value] if t is not None], dim=-1, + ) + elif key is None and value is None: + adapter_output = query + elif key is None: + adapter_output = torch.cat([query, value], dim=-1) + elif value is None: + adapter_output = torch.cat([query, key], dim=-1) + else: + adapter_output = self._interleave_qkv(query, key, value) + + return linear_output + adapter_output, bias + + +class LoRALinearSplitFC1UpGate(AdapterWrapper): + """An adapter wrapper for `linear_fc1` where up_proj and gate_proj are two separate adapters. + This module that adds the output of the adapters to the output of the wrapped module while taking care of shape. + + This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques + where the adapter's output is added to the main module's output. It extends the AdapterWrapper + class to provide a specific implementation of the forward method. + """ + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # pylint: disable=C0115,C0116 + linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs) + if not self._adapter_enabled: + return linear_output, bias + adapter_parts = [] + if self.adapter.adapter_gate is not None: + adapter_parts.append(self.adapter.adapter_gate(layernorm_output)) + if self.adapter.adapter_up is not None: + adapter_parts.append(self.adapter.adapter_up(layernorm_output)) + if not adapter_parts: + return linear_output, bias + adapter_output = torch.cat(adapter_parts, dim=-1) + return linear_output + adapter_output, bias + + +@dataclass +class CanonicalLoRA(PEFT, ModuleMatcher): + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. + Canonical LoRA applies LoRA on Q, K, V projection matrices separately, as well as Up and Gate projection + matrices separately. This follows more closely with Huggingface's implementation of LoRA. + + Args: + target_modules (List[str], optional): A list of module names to apply LoRA to. + Defaults to all linear layers ['linear_q', 'linear_k', 'linear_v', 'linear_proj', + 'linear_fc1_up', 'linear_fc1_gate', 'linear_fc2']. + - 'linear_q', 'linear_k', 'linear_v': Apply LoRA to the linear layer used for query, key, and value + projections in self-attention. This is fused into one matrix in LoRA, but left as three + separate matrices in Canonical LoRA. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. + - 'linear_fc1_up', 'linear_fc1_gate': Apply LoRA to the Up proj and Gate proj layers. + These two together constitute the first fully-connected layer in MLP in LoRA. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. + Target modules can also contain wildcards. For example, you can specify + target_modules=['*.layers.0.*.linear_q', '*.layers.1.*.linear_q'] to add LoRA to only linear_q + on the first two layers. + exclude_modules (List[str], optional): A list of module names not to apply LoRA to. It will + match all nn.Linear & nn.Linear-adjacent modules whose name does not match any string in + exclude_modules. If used, will require target_modules to be empty list or None. + dim (int): Dimension of the low-rank projection space. Defaults to 32. + alpha (int): Weighting factor for the low-rank projection. Defaults to 32. + dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. + dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. + Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre'. + lora_A_init_method (str): Initialization method for LoRA A matrix. Defaults to "xavier". + lora_B_init_method (str): Initialization method for LoRA B matrix. Defaults to "zero". + """ + + target_modules: List[str] = field( + default_factory=lambda: [ + "linear_q", + "linear_k", + "linear_v", + "linear_proj", + "linear_fc1_up", + "linear_fc1_gate", + "linear_fc2", + ] + ) + dim: int = 32 + alpha: int = 32 + dropout: float = 0.0 + dropout_position: Literal["pre", "post"] = "pre" + lora_A_init_method: str = "xavier" + lora_B_init_method: str = "zero" + + def __post_init__(self) -> None: + """ + Initialize the canonical mapping and call the parent post_init. + + Construct a mapping from the target module as supported in LoRA() to the specific parts of the layer for which + adapter is applied. + + For example, if user specifies target_module = ['linear_q', 'linear_k', 'linear_proj', 'linear_fc1_up'], then + canonical_lora_mapping = { + "linear_qkv": {'linear_q', 'linear_k'}, + "linear_proj": {'linear_proj'}, # the value of this key does not matter + "linear_fc1": {'linear_fc1_up'}, + } + + If user specifies target_module = ['*.layers.0.*.linear_q', '*.layers.1.*.linear_q'], then + canonical_lora_mapping = { + "'*.layers.0.*.linear_qkv'": {'linear_q'}, + "'*.layers.1.*.linear_qkv'": {'linear_q'}, + } + + """ + for target in self.target_modules: + assert not target.endswith("linear_qkv"), ( + "Canonical LoRA does not support target 'linear_qkv'. Either use 'linear_qkv' with LoRA() or " + "use ['linear_q', 'linear_k', 'linear_v'] with Canonical LoRA" + ) + assert not target.endswith("linear_fc1"), ( + "Canonical LoRA does not support target 'linear_fc1'. Either use 'linear_fc1' with LoRA() or " + "use ['linear_fc1_up', 'linear_fc1_gate'] with Canonical LoRA" + ) + + if target.endswith("linear_q"): + self.canonical_mapping[target.replace("linear_q", "linear_qkv")].add("linear_q") + elif target.endswith("linear_k"): + self.canonical_mapping[target.replace("linear_k", "linear_qkv")].add("linear_k") + elif target.endswith("linear_v"): + self.canonical_mapping[target.replace("linear_v", "linear_qkv")].add("linear_v") + elif target.endswith("linear_fc1_up"): + self.canonical_mapping[target.replace("linear_fc1_up", "linear_fc1")].add("linear_fc1_up") + elif target.endswith("linear_fc1_gate"): + self.canonical_mapping[target.replace("linear_fc1_gate", "linear_fc1")].add("linear_fc1_gate") + else: + self.canonical_mapping[target].add(target) + + def transform(self, m: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: + """ + Applies LoRA to a specific module within the model architecture. + + Args: + m (nn.Module): The module to apply LoRA to. + name (Optional[str]): Name of the module (if applicable). Defaults to None. + prefix (Optional[str]): Prefix for the module name (if applicable). Defaults to None. + + Returns: + nn.Module: The modified module with LoRA applied, or the original module if not a target. + """ + + # Skip already transformed modules + if isinstance(m, (LinearAdapter, LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter)): + return m + + if (ans := self.match(m, name, prefix)) is not None: + (match, full_name) = ans + if isinstance(m, nn.Linear): + return LinearAdapter( + m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method + ) + + from megatron.core.tensor_parallel import ( + ColumnParallelLinear, RowParallelLinear, + ) + from mbridge.peft.utils import TECL, TERL + _supported_types = ( + ColumnParallelLinear, RowParallelLinear, TopKRouter, + ) + TECL + TERL + if not isinstance(m, _supported_types): + logger.warning( + f"CanonicalLoRA target pattern matched module '{full_name}' " + f"of type {type(m).__name__}, but this type is not supported " + f"for LoRA adaptation. Skipping." + ) + return m + + is_expert = is_expert_linear(full_name) + attrs = get_adapter_attributes_from_linear(m, is_expert=is_expert) + + adapter_kwargs = dict( + dim=self.dim, + base_linear_name=full_name, + activation="identity", + norm_type=None, + column_init_method=self.lora_A_init_method, + row_init_method=self.lora_B_init_method, + gather_output=False, + input_is_parallel=attrs.input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(m, "config", None), + alpha=self.alpha, + is_expert=is_expert, + disable_tensor_parallel_comm=attrs.disable_tensor_parallel_comm, + disable_sequence_parallel_comm=attrs.disable_sequence_parallel_comm, + base_linear_is_parallel=attrs.base_linear_is_parallel, + ) + + if name == "linear_fc1" and _should_treat_linear_fc1_as_unfused(full_name): + logger.info(f"Adding lora to: {full_name} (treating unsupported canonical linear_fc1 as unfused)") + adapter = ParallelLinearAdapter(attrs.in_features, attrs.out_features, **adapter_kwargs) + return LoRALinear(m, adapter) + + canonical_submodules = self.canonical_mapping[match] + logger.info(f"Adding lora to: {full_name} ({canonical_submodules})") + if name == "linear_qkv": + adapter_q, adapter_k, adapter_v = None, None, None + kv_out_features = m.config.kv_channels * m.config.num_query_groups + q_out_features = m.config.kv_channels * m.config.num_attention_heads + if "linear_q" in canonical_submodules: + adapter_q = ParallelLinearAdapter(attrs.in_features, q_out_features, **adapter_kwargs) + if "linear_k" in canonical_submodules: + adapter_k = ParallelLinearAdapter(attrs.in_features, kv_out_features, **adapter_kwargs) + if "linear_v" in canonical_submodules: + adapter_v = ParallelLinearAdapter(attrs.in_features, kv_out_features, **adapter_kwargs) + adapters = ModuleDict({"adapter_q": adapter_q, "adapter_k": adapter_k, "adapter_v": adapter_v}) + return LoRALinearSplitQKV(m, adapters) + + if name == "linear_fc1": + stride = getattr(m, 'stride', 1) + if stride <= 1: + # Non-GLU: single adapter with full out_features. + # When gated_linear_unit=False, linear_fc1 has no gate/up split + # and the canonical target_modules 'linear_fc1_up'/'linear_fc1_gate' + # are treated as targeting the sole weight matrix. + logger.info( + f"Adding lora to: {full_name} (non-gated, single adapter, canonical_submodules={canonical_submodules})" + ) + adapter = ParallelLinearAdapter(attrs.in_features, attrs.out_features, **adapter_kwargs) + return LoRALinear(m, adapter) + # GLU: split gate/up, each with out_features // 2 + adapter_up, adapter_gate = None, None + if "linear_fc1_up" in canonical_submodules: + adapter_up = ParallelLinearAdapter(attrs.in_features, attrs.out_features // 2, **adapter_kwargs) + if "linear_fc1_gate" in canonical_submodules: + adapter_gate = ParallelLinearAdapter(attrs.in_features, attrs.out_features // 2, **adapter_kwargs) + adapters = ModuleDict({"adapter_up": adapter_up, "adapter_gate": adapter_gate}) + return LoRALinearSplitFC1UpGate(m, adapters) + + adapter = ParallelLinearAdapter(attrs.in_features, attrs.out_features, **adapter_kwargs) + logger.info(f"Adding lora to: {full_name}") + if isinstance(m, TopKRouter): + return LoRATopKRouter(m, adapter) + return LoRALinear(m, adapter) + + return m diff --git a/mbridge/peft/lora.py b/mbridge/peft/lora.py new file mode 100644 index 0000000..ef89b8e --- /dev/null +++ b/mbridge/peft/lora.py @@ -0,0 +1,1000 @@ +# Adapted from NVIDIA Megatron-Bridge + +import logging +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import transformer_engine.pytorch as te +from mbridge.peft.base import PEFT +# Import canonical split-adapter wrappers for gather/merge support. +# These are only imported here (not in canonical_lora → avoids circular deps). +from mbridge.peft.canonical_lora import (LoRALinearSplitFC1UpGate, + LoRALinearSplitQKV) +from mbridge.peft.lora_layers import (LinearAdapter, LoRALinear, + LoRATopKRouter, TEFusedLoRALinear, + TELinearAdapter, patch_linear_module) +from mbridge.peft.module_matcher import ModuleMatcher +from mbridge.peft.utils import (ParallelLinearAdapter, + get_adapter_attributes_from_linear, + is_expert_linear) +from megatron.core import parallel_state +from megatron.core.tensor_parallel import (ColumnParallelLinear, + RowParallelLinear) +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.utils import unwrap_model + +logger = logging.getLogger(__name__) + +try: + import bitsandbytes + + HAVE_BNB = True +except ImportError: + HAVE_BNB = False + + +@dataclass +class LoRA(PEFT, ModuleMatcher): + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. + + LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of LoRA to specific modules within the model architecture. + + Args: + target_modules (List[str], optional): A list of module names to apply LoRA to. + Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections + in self-attention. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. + - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. + Target modules can also contain wildcards. For example, you can specify + target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv + on the first two layers. + exclude_modules (List[str], optional): A list of module names not to apply LoRa to. It will + match all nn.Linear & nn.Linear-adjacent modules whose name does not match any string in + exclude_modules. If used, will require target_modules to be empty list or None. + dim (int): Dimension of the low-rank projection space. Defaults to 32. + alpha (int): Weighting factor for the low-rank projection. Defaults to 32. + dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. + dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. + Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre'. + a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False. + lora_A_init_method (str): Initialization method for the low-rank matrix A. Defaults to "xavier". + lora_B_init_method (str): Initialization method for the low-rank matrix B. Defaults to "zero". + lora_dtype (torch.dtype): Parameter data type for LoRA weights. Default None (will use model's dtype). + """ + + target_modules: List[str] = field( + default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + ) + dim: int = 32 + alpha: int = 32 + dropout: float = 0.0 + dropout_position: Literal["pre", "post"] = "pre" + lora_A_init_method: str = "xavier" + lora_B_init_method: str = "zero" + a2a_experimental: bool = False + lora_dtype: torch.dtype = None + + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: + """ + Applies LoRA to a specific module within the model architecture. + + Args: + m (nn.Module): The module to apply LoRA to. + name (str, optional): Name of the module (if applicable). Defaults to None. + prefix (str, optional): Prefix for the module name (if applicable). Defaults to None. + + Returns: + nn.Module: The modified module with LoRA applied, or the original module if not a target. + """ + # Skip already transformed modules + adapter_types = (LinearAdapter, LoRALinear, LoRATopKRouter) + adapter_types = adapter_types + (TELinearAdapter,) + if isinstance(module, adapter_types): + return module + + if (ans := self.match(module, name, prefix)) is not None: + (match, full_name) = ans + if isinstance(module, nn.Linear) or (module.__class__ == te.Linear): + # Will use the `patch_linear_module` function if: + # - is FSDP v1 + # - is DTensor (has _local_tensor attribute) + # - has quant_state attribute + if hasattr(module.weight.data, "_local_tensor") or ( + HAVE_BNB + and getattr(module, "quant_state", None) is not None + and module.quant_state.__class__ == bitsandbytes.functional.QuantState + ): + lora_cls = patch_linear_module + elif module.__class__ == te.Linear: + lora_cls = TELinearAdapter + else: + lora_cls = LinearAdapter + + return lora_cls( + module, + dim=self.dim, + alpha=self.alpha, + dropout=self.dropout, + lora_A_init_method=self.lora_A_init_method, + lora_dtype=self.lora_dtype, + ) + + from mbridge.peft.utils import TECL, TERL + _supported_parallel_types = ( + ColumnParallelLinear, RowParallelLinear, TopKRouter, + ) + TECL + TERL + (te.Linear,) + if not isinstance(module, _supported_parallel_types): + logging.warning( + f"LoRA target pattern matched module '{full_name}' of type " + f"{type(module).__name__}, but this type is not supported for " + f"LoRA adaptation. Skipping." + ) + return module + + is_expert = is_expert_linear(full_name) + attrs = get_adapter_attributes_from_linear(module, is_expert=is_expert) + + enable_op_fuser = ( + hasattr(module, "config") + and getattr(module.config, "use_transformer_engine_op_fuser", False) + # TP not yet supported + and parallel_state.get_tensor_model_parallel_world_size() == 1 + ) + + logging.info(f"Adding lora to: {full_name}") + adapter = ParallelLinearAdapter( + attrs.in_features, + attrs.out_features, + self.dim, + base_linear_name=full_name, + activation="identity", + column_init_method=self.lora_A_init_method, + row_init_method=self.lora_B_init_method, + input_is_parallel=attrs.input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(module, "config", None), + alpha=self.alpha, + is_expert=is_expert, + a2a_experimental=self.a2a_experimental, + disable_tensor_parallel_comm=attrs.disable_tensor_parallel_comm, + disable_sequence_parallel_comm=attrs.disable_sequence_parallel_comm, + base_linear_is_parallel=attrs.base_linear_is_parallel, + ) + if isinstance(module, TopKRouter): + return LoRATopKRouter(module, adapter) + if enable_op_fuser: + return TEFusedLoRALinear(module, adapter) + else: + return LoRALinear(module, adapter) + return module + + +def _gather_parallel_weight(weight: torch.Tensor, module: nn.Module) -> torch.Tensor: + """Gather a TP-sharded weight tensor to its full (un-sharded) size. + + ColumnParallelLinear stores weight as ``(out/TP, in)`` — gather dim 0. + RowParallelLinear stores weight as ``(out, in/TP)`` — gather dim 1. + """ + tp_size = parallel_state.get_tensor_model_parallel_world_size() + if tp_size <= 1: + return weight + + tp_group = parallel_state.get_tensor_model_parallel_group() + gathered = [torch.empty_like(weight) for _ in range(tp_size)] + dist.all_gather(gathered, weight.contiguous(), group=tp_group) + + if isinstance(module, RowParallelLinear): + return torch.cat(gathered, dim=1) + else: + return torch.cat(gathered, dim=0) + + +def _deinterleave_gathered_lora_b( + gathered_b: torch.Tensor, stride: int, tp_size: int +) -> torch.Tensor: + """Permute a gathered LoRA-B (linear_out) from interleaved to sequential layout. + + When TP > 1 and the base layer has stride > 1 (e.g. SwiGLU linear_fc1), + each rank's B_local has rows that alternate between stride components + (gate and up for stride=2). After naive concatenation across TP ranks, + the layout is interleaved: + + [rank0_gate, rank0_up, rank1_gate, rank1_up, ...] + + This function permutes to sequential layout: + + [gate_all, up_all] + + which is the correct layout for HF export and for computing the full + delta matrix during merge. + """ + if stride <= 1 or tp_size <= 1: + return gathered_b + + total_rows = gathered_b.shape[0] + per_rank = total_rows // tp_size + per_stride_per_rank = per_rank // stride + + parts = [] + for s in range(stride): + for r in range(tp_size): + start = r * per_rank + s * per_stride_per_rank + end = start + per_stride_per_rank + parts.append(gathered_b[start:end]) + return torch.cat(parts, dim=0) + + +# --------------------------------------------------------------------------- +# Megatron-Core → HuggingFace PEFT name mapping +# --------------------------------------------------------------------------- + +_MCORE_TO_HF_LORA_SUFFIX = { + "linear_in": "lora_A", + "linear_out": "lora_B", +} + +# Mapping from CanonicalLoRA sub-adapter name → index into the bridge's +# fused weight-name list for the parent megatron module. +# e.g. bridge maps linear_qkv.weight → [q_proj.weight, k_proj.weight, v_proj.weight] +# → adapter_q → index 0, adapter_k → index 1, adapter_v → index 2. +_CANONICAL_ADAPTER_TO_HF_INDEX = { + "adapter_q": 0, + "adapter_k": 1, + "adapter_v": 2, + "adapter_gate": 0, + "adapter_up": 1, +} + +_CANONICAL_ADAPTER_TO_HF_SUFFIX = { + "adapter_q": "q_proj", + "adapter_k": "k_proj", + "adapter_v": "v_proj", + "adapter_gate": "gate_proj", + "adapter_up": "up_proj", +} + + +def _combine_hf_module_names(hf_weight_names: List[str]) -> str: + """Derive a single fused adapter module path from multiple HF weight names. + + For fused MCore layers (e.g. ``linear_qkv`` → q/k/v, ``linear_fc1`` → + gate/up), the bridge returns multiple HF weight names. This function + combines them into a single name suitable for the adapter key. + + Examples:: + + ["model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight"] + → "model.layers.0.self_attn.qkv_proj" + + ["model.layers.0.mlp.gate_proj.weight", + "model.layers.0.mlp.up_proj.weight"] + → "model.layers.0.mlp.gate_up_proj" + """ + import os + + bases = [n.rsplit(".", 1)[0] for n in hf_weight_names] + + common_prefix = os.path.commonprefix(bases) + if common_prefix and not common_prefix.endswith("."): + common_prefix = common_prefix[: common_prefix.rfind(".") + 1] + + suffixes = [b[len(common_prefix):] for b in bases] + + reversed_suffixes = [s[::-1] for s in suffixes] + common_suffix = os.path.commonprefix(reversed_suffixes)[::-1] + + strip_len = len(common_suffix) + unique_parts = [s[:-strip_len] if strip_len else s for s in suffixes] + + if all(len(p) <= 1 for p in unique_parts): + combined = "".join(unique_parts) + common_suffix + else: + combined = "_".join(unique_parts) + common_suffix + + return common_prefix + combined + + +def mcore_adapter_name_to_hf(mcore_name: str, bridge=None) -> str: + """Convert a Megatron-Core adapter parameter name to HF PEFT format. + + When *bridge* is provided the mapping is derived dynamically via + ``bridge._weight_name_mapping_mcore_to_hf``, which handles every + model architecture the bridge supports. Without a bridge the + function is a no-op passthrough (the name is prefixed with + ``base_model.model.`` only). + + Supports both standard LoRA (single adapter per fused layer) and + CanonicalLoRA (multiple sub-adapters per fused layer, e.g. + ``adapter.adapter_q.linear_in.weight``). + + Parameters + ---------- + mcore_name : str + Full Megatron-Core adapter parameter name, e.g. + ``decoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight`` + or ``decoder.layers.0.self_attention.linear_qkv.adapter.adapter_q.linear_in.weight`` + bridge : optional + An mbridge ``Bridge`` instance. + + Returns + ------- + str + HF PEFT parameter name, e.g. + ``base_model.model.model.layers.0.self_attn.qkv_proj.lora_A.weight`` + """ + import re + + # --- CanonicalLoRA nested adapter path --- + # e.g. …linear_qkv.adapter.adapter_q.linear_in.weight + m_canonical = re.match( + r"(.+)\.adapter\.(adapter_\w+)\.linear_(in|out)\.weight$", + mcore_name, + ) + if m_canonical is not None: + base_module_path = m_canonical.group(1) + sub_adapter = m_canonical.group(2) # e.g. "adapter_q" + adapter_type = m_canonical.group(3) # "in" or "out" + lora_suffix = _MCORE_TO_HF_LORA_SUFFIX[f"linear_{adapter_type}"] + + if bridge is not None: + mcore_weight_name = f"{base_module_path}.weight" + hf_names = bridge._weight_name_mapping_mcore_to_hf(mcore_weight_name) + idx = _CANONICAL_ADAPTER_TO_HF_INDEX.get(sub_adapter, 0) + if len(hf_names) > 1 and idx < len(hf_names): + # Each sub-adapter maps to a distinct HF weight (e.g. LLM q/k/v) + hf_base = hf_names[idx].rsplit(".", 1)[0] + elif len(hf_names) == 1: + # Fused target (e.g. ViT qkv): append sub-adapter suffix + # to produce unique keys per sub-adapter component. + fused_base = hf_names[0].rsplit(".", 1)[0] + sub_suffix = _CANONICAL_ADAPTER_TO_HF_SUFFIX.get( + sub_adapter, sub_adapter + ) + hf_base = f"{fused_base}.{sub_suffix}" + else: + fused_base = _combine_hf_module_names(hf_names) + hf_base = f"{fused_base}_{sub_adapter}" + return f"base_model.model.{hf_base}.{lora_suffix}.weight" + + return f"base_model.model.{base_module_path}.{sub_adapter}.{lora_suffix}.weight" + + # --- Standard LoRA adapter path --- + # e.g. …linear_qkv.adapter.linear_in.weight + m = re.match( + r"(.+)\.(adapter\.linear_(in|out)\.weight)$", + mcore_name, + ) + # --- LinearAdapter path (no .adapter. prefix) --- + # e.g. …vision_model.blocks.0.attn.proj.linear_in.weight + if m is None: + m = re.match( + r"(.+)\.(linear_(in|out)\.weight)$", + mcore_name, + ) + if m is None: + return f"base_model.model.{mcore_name}" + + base_module_path = m.group(1) + adapter_type = m.group(3) # "in" or "out" + lora_suffix = _MCORE_TO_HF_LORA_SUFFIX[f"linear_{adapter_type}"] + + if bridge is not None: + mcore_weight_name = f"{base_module_path}.weight" + hf_names = bridge._weight_name_mapping_mcore_to_hf(mcore_weight_name) + + if len(hf_names) == 1: + hf_base = hf_names[0].rsplit(".", 1)[0] + else: + hf_base = _combine_hf_module_names(hf_names) + + return f"base_model.model.{hf_base}.{lora_suffix}.weight" + + return f"base_model.model.{base_module_path}.{lora_suffix}.weight" + + +def infer_hf_target_modules(adapter_state: Dict[str, torch.Tensor]) -> list: + """Infer HF ``target_modules`` from adapter weight names. + + Keys look like ``...layers.0.self_attn.qkv_proj.lora_A.weight``. + The module name (``qkv_proj``) is 3 dots from the end. + """ + modules = set() + for key in adapter_state: + parts = key.rsplit(".", 3) + if len(parts) >= 4: + modules.add(parts[-3]) + return sorted(modules) + + +@torch.no_grad() +def gather_lora_state_dict(models, bridge=None) -> Dict[str, torch.Tensor]: + """Gather full (un-sharded) LoRA adapter weights in HF PEFT format. + + When TP > 1, the adapter's ``linear_in`` and ``linear_out`` are + parallel linear layers whose weights are sharded across TP ranks. + This function performs ``all_gather`` to reconstruct the full tensors + and converts parameter names to HF PEFT convention. + + Supports both standard LoRA (``LoRALinear``) and canonical LoRA + (``LoRALinearSplitQKV``, ``LoRALinearSplitFC1UpGate``). + + Parameters + ---------- + models : list[nn.Module] + Unwrapped model chunks (as returned by ``unwrap_model``). + bridge : optional + An mbridge ``Bridge`` instance. When provided the mcore → HF name + mapping is derived dynamically from the bridge, supporting any model + architecture. When *None*, adapter names are passed through with a + ``base_model.model.`` prefix. + + Returns + ------- + dict[str, torch.Tensor] + Mapping from HF PEFT parameter names + (e.g. ``base_model.model.model.layers.0.self_attn.qkv_proj.lora_A.weight``) + to full-size tensors on CPU. + """ + adapter_state: Dict[str, torch.Tensor] = {} + + for model_chunk in models: + for name, module in model_chunk.named_modules(): + # --- CanonicalLoRA split adapters (Q/K/V, gate/up) --- + if isinstance(module, (LoRALinearSplitQKV, LoRALinearSplitFC1UpGate)): + adapters_dict = module.adapter # ModuleDict + for sub_name, sub_adapter in adapters_dict.items(): + if sub_adapter is None: + continue + lin_in_w = _gather_parallel_weight( + sub_adapter.linear_in.weight.data, sub_adapter.linear_in, + ) + hf_key = mcore_adapter_name_to_hf( + f"{name}.adapter.{sub_name}.linear_in.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_in_w.cpu() + + lin_out_w = _gather_parallel_weight( + sub_adapter.linear_out.weight.data, sub_adapter.linear_out, + ) + hf_key = mcore_adapter_name_to_hf( + f"{name}.adapter.{sub_name}.linear_out.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_out_w.cpu() + + continue + + # --- LoRATopKRouter --- + if isinstance(module, LoRATopKRouter): + adapter = module.adapter + lin_in_w = _gather_parallel_weight( + adapter.linear_in.weight.data, adapter.linear_in, + ) + hf_key = mcore_adapter_name_to_hf( + f"{name}.adapter.linear_in.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_in_w.cpu() + + lin_out_w = _gather_parallel_weight( + adapter.linear_out.weight.data, adapter.linear_out, + ) + hf_key = mcore_adapter_name_to_hf( + f"{name}.adapter.linear_out.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_out_w.cpu() + continue + + # --- LinearAdapter (nn.Linear with LoRA, for ViT/projector) --- + if isinstance(module, LinearAdapter): + lin_in_w = module.linear_in.weight.data.cpu() + lin_out_w = module.linear_out.weight.data.cpu() + hf_key = mcore_adapter_name_to_hf( + f"{name}.linear_in.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_in_w + hf_key = mcore_adapter_name_to_hf( + f"{name}.linear_out.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_out_w + continue + + # --- Standard LoRA --- + if not isinstance(module, LoRALinear): + continue + adapter = module.adapter + + lin_in_w = _gather_parallel_weight( + adapter.linear_in.weight.data, adapter.linear_in, + ) + hf_key = mcore_adapter_name_to_hf( + f"{name}.adapter.linear_in.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_in_w.cpu() + + lin_out_w = _gather_parallel_weight( + adapter.linear_out.weight.data, adapter.linear_out, + ) + # For strided base layers (e.g. SwiGLU FC1), the gathered B has + # interleaved layout; permute to sequential for correct HF export. + stride = getattr(module.to_wrap, 'stride', 1) + tp_size = parallel_state.get_tensor_model_parallel_world_size() + if stride > 1 and tp_size > 1: + lin_out_w = _deinterleave_gathered_lora_b( + lin_out_w, stride, tp_size, + ) + hf_key = mcore_adapter_name_to_hf( + f"{name}.adapter.linear_out.weight", bridge=bridge, + ) + adapter_state[hf_key] = lin_out_w.cpu() + + return adapter_state + + +class LoRAMerge(PEFT): + """ + Implements the LoRA weight merge for parameter-efficient fine-tuning. + """ + + @staticmethod + def _compute_sub_delta(linear_out, linear_in, alpha, dim, base_device): + """Compute the full (un-sharded) LoRA delta for a single sub-adapter. + + Gathers TP-sharded weights if TP > 1. + """ + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_group = parallel_state.get_tensor_model_parallel_group() + + lin_in_w = linear_in.weight.data.to(base_device) + lin_out_w = linear_out.weight.data.to(base_device) + + if tp_size == 1: + return (alpha / dim) * (lin_out_w @ lin_in_w) + + # Gather linear_in along dim 0 (ColumnParallel case) + lin_in_list = [torch.empty_like(lin_in_w) for _ in range(tp_size)] + dist.all_gather(lin_in_list, lin_in_w.contiguous(), group=tp_group) + lin_in_full = torch.cat(lin_in_list, dim=0) + + # Gather linear_out along dim 0 + lin_out_list = [torch.empty_like(lin_out_w) for _ in range(tp_size)] + dist.all_gather(lin_out_list, lin_out_w.contiguous(), group=tp_group) + lin_out_full = torch.cat(lin_out_list, dim=0) + + return (alpha / dim) * (lin_out_full @ lin_in_full) + + @staticmethod + def _interleave_qkv_full_delta(q_delta, k_delta, v_delta, config): + """Interleave Q, K, V full deltas into Megatron QKV packed weight order. + + The fused QKV layout (from Megatron) is: + for each head group i (without output gate): + [Q_heads_per_group, K_1_head, V_1_head] + for each head group i (with output gate): + [Q_heads_per_group, G_heads_per_group(zeros), K_1_head, V_1_head] + """ + head_num = config.num_attention_heads + num_query_groups = config.num_query_groups + head_size = config.kv_channels + heads_per_group = head_num // num_query_groups + output_gate = getattr(config, "attention_output_gate", False) + + q_reshaped = q_delta.reshape(head_num, head_size, -1) + k_reshaped = k_delta.reshape(num_query_groups, head_size, -1) + v_reshaped = v_delta.reshape(num_query_groups, head_size, -1) + + interleaved_parts = [] + for g in range(num_query_groups): + q_group = q_reshaped[g * heads_per_group: (g + 1) * heads_per_group] + k_group = k_reshaped[g: g + 1] + v_group = v_reshaped[g: g + 1] + interleaved_parts.append(q_group.reshape(-1, q_delta.shape[1])) + if output_gate: + interleaved_parts.append(torch.zeros_like(q_group.reshape(-1, q_delta.shape[1]))) + interleaved_parts.append(k_group.reshape(-1, q_delta.shape[1])) + interleaved_parts.append(v_group.reshape(-1, q_delta.shape[1])) + + return torch.cat(interleaved_parts, dim=0) + + def merge( + self, + base_weight: torch.Tensor, + linear_out: torch.Tensor, + linear_in: torch.Tensor, + alpha: int, + dim: int, + stride: int = 1, + ) -> torch.Tensor: + """ + Merges the LoRA adapter weights with the base model weights. + Handles tensor parallelism by gathering sharded dimensions. + + For ColumnParallelLinear (e.g., linear_qkv, linear_fc1): + - base_weight: (out_features/TP, in_features) + - linear_in: (dim/TP, in_features) <- Need to gather this + - linear_out: (out_features/TP, dim) + - Target: (out_features/TP, dim) @ (dim, in_features) = (out_features/TP, in_features) + + For RowParallelLinear (e.g., linear_proj, linear_fc2): + - base_weight: (out_features, in_features/TP) + - linear_in: (dim, in_features/TP) + - linear_out: (out_features/TP, dim) <- Need to gather this + - Target: (out_features, dim) @ (dim, in_features/TP) = (out_features, in_features/TP) + + For strided ColumnParallelLinear (gated MLP linear_fc1 with stride > 1): + The base weight has an interleaved layout across TP ranks (due to stride). + The adapter's linear_out is a *non-strided* ColumnParallelLinear, so its + TP sharding is a simple contiguous chunk — which does NOT match the + interleaved layout of the base weight. This function handles this by + gathering both linear_in and linear_out, computing the full delta, and + then interleaving the delta chunks to match the base weight's layout. + + Args: + base_weight (torch.Tensor): The base model weights. + linear_out (torch.Tensor): LoRA's B matrix. + linear_in (torch.Tensor): LoRA's A matrix. + alpha (int): Weighting factor for the low-rank projection. + dim (int): Dimension of the low-rank projection space. + stride (int): Stride of the base ColumnParallelLinear (default: 1). + Use stride=2 for gated MLP (linear_fc1 with GLU). + + Returns: + torch.Tensor: The merged weights. + """ + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + if tp_size == 1: + # No tensor parallelism, simple multiplication + lora_weight = alpha / dim * (linear_out @ linear_in) + return base_weight + lora_weight + + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Case 1: ColumnParallelLinear - linear_in is sharded on dim 0 + # linear_in: (dim/TP, in_features), linear_out: (out_features/TP, dim) + if linear_in.shape[0] * tp_size == dim and linear_out.shape[1] == dim: + # Gather linear_in along dimension 0 to get full dim + linear_in_list = [torch.empty_like(linear_in) for _ in range(tp_size)] + dist.all_gather(linear_in_list, linear_in, group=tp_group) + linear_in_full = torch.cat(linear_in_list, dim=0) + + # adapter linear_out is non-strided ColumnParallel (contiguous chunk); + # base weight may be strided (interleaved). For stride>1, we need to + # gather linear_out fully and interleave the delta for this TP rank. + if stride > 1: + # Gather linear_out across TP to get the full B matrix + linear_out_list = [torch.empty_like(linear_out) for _ in range(tp_size)] + dist.all_gather(linear_out_list, linear_out, group=tp_group) + linear_out_full = torch.cat(linear_out_list, dim=0) + + # The gathered B has interleaved layout because each rank's local + # B contains rows for ALL stride components (gate+up). Permute to + # sequential [gate_all, up_all] before computing the full delta. + linear_out_full = _deinterleave_gathered_lora_b( + linear_out_full, stride, tp_size, + ) + + # Full delta in sequential layout: [gate_delta, up_delta] + full_delta = alpha / dim * (linear_out_full @ linear_in_full) + out_features = full_delta.shape[0] + + # Split full_delta into stride parts (now correctly sequential) + stride_chunks = full_delta.chunk(stride, dim=0) + + # Each stride chunk is further split across TP ranks + tp_chunks_per_stride = [c.chunk(tp_size, dim=0) for c in stride_chunks] + + # For strided layout, this rank takes the tp_rank-th chunk from + # each stride part and concatenates them + lora_weight = torch.cat( + [chunks[tp_rank] for chunks in tp_chunks_per_stride], + dim=0, + ) + else: + # Non-strided: simple (out_features/TP, dim) @ (dim, in_features) + lora_weight = alpha / dim * (linear_out @ linear_in_full) + + # Case 2: RowParallelLinear - linear_out is sharded on dim 0 + # linear_in: (dim, in_features/TP), linear_out: (out_features/TP, dim) + elif linear_out.shape[0] * tp_size == base_weight.shape[0]: + # Gather linear_out along dimension 0 to get full out_features + linear_out_list = [torch.empty_like(linear_out) for _ in range(tp_size)] + dist.all_gather(linear_out_list, linear_out, group=tp_group) + linear_out_full = torch.cat(linear_out_list, dim=0) + + # Multiply: (out_features, dim) @ (dim, in_features/TP) + lora_weight = alpha / dim * (linear_out_full @ linear_in) + + else: + # Fallback: no gathering needed or already full-size + lora_weight = alpha / dim * (linear_out @ linear_in) + + return base_weight + lora_weight + + @torch.no_grad() + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: + """ + Merges the LoRA adapter with the base model weights. + + Supports standard LoRA (``LoRALinear``) and canonical LoRA + (``LoRALinearSplitQKV``, ``LoRALinearSplitFC1UpGate``). + + Args: + m (nn.Module): The module to apply LoRA merge to. + name (str, optional): Name of the module to merge. Defaults to None. + prefix (str, optional): Prefix for the module name. Defaults to None. + + Returns: + nn.Module: The modified module with the LoRA adapter merged into the base model weights. + """ + tp_size = parallel_state.get_tensor_model_parallel_world_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + # --- CanonicalLoRA: LoRALinearSplitQKV --- + if isinstance(module, LoRALinearSplitQKV): + base_device = module.to_wrap.weight.device + config = module.to_wrap.config + + q_delta = k_delta = v_delta = None + if module.adapter.adapter_q is not None: + q_delta = self._compute_sub_delta( + module.adapter.adapter_q.linear_out, module.adapter.adapter_q.linear_in, + module.adapter.adapter_q.alpha, module.adapter.adapter_q.dim, base_device, + ) + if module.adapter.adapter_k is not None: + k_delta = self._compute_sub_delta( + module.adapter.adapter_k.linear_out, module.adapter.adapter_k.linear_in, + module.adapter.adapter_k.alpha, module.adapter.adapter_k.dim, base_device, + ) + if module.adapter.adapter_v is not None: + v_delta = self._compute_sub_delta( + module.adapter.adapter_v.linear_out, module.adapter.adapter_v.linear_in, + module.adapter.adapter_v.alpha, module.adapter.adapter_v.dim, base_device, + ) + + # Interleave into fused Megatron QKV layout + if q_delta is not None and k_delta is not None and v_delta is not None: + full_qkv_delta = self._interleave_qkv_full_delta(q_delta, k_delta, v_delta, config) + else: + # Fallback: simple concatenation Q→K→V + parts = [d for d in [q_delta, k_delta, v_delta] if d is not None] + full_qkv_delta = torch.cat(parts, dim=0) + + # Take TP-rank's contiguous shard of the full fused delta + total_rows = full_qkv_delta.shape[0] + per_rank = total_rows // tp_size if tp_size > 1 else total_rows + start = tp_rank * per_rank + per_rank_delta = full_qkv_delta[start:start + per_rank] + + module.to_wrap.weight.data = module.to_wrap.weight.data + per_rank_delta.to(base_device) + return module + + # --- CanonicalLoRA: LoRALinearSplitFC1UpGate --- + if isinstance(module, LoRALinearSplitFC1UpGate): + base_device = module.to_wrap.weight.device + stride = getattr(module.to_wrap, 'stride', 1) + + gate_delta = up_delta = None + if module.adapter.adapter_gate is not None: + gate_delta = self._compute_sub_delta( + module.adapter.adapter_gate.linear_out, module.adapter.adapter_gate.linear_in, + module.adapter.adapter_gate.alpha, module.adapter.adapter_gate.dim, base_device, + ) + if module.adapter.adapter_up is not None: + up_delta = self._compute_sub_delta( + module.adapter.adapter_up.linear_out, module.adapter.adapter_up.linear_in, + module.adapter.adapter_up.alpha, module.adapter.adapter_up.dim, base_device, + ) + + # Stack gate + up → (2*ffn_hidden_size, in_features) + parts = [d for d in [gate_delta, up_delta] if d is not None] + full_fc1_delta = torch.cat(parts, dim=0) if len(parts) > 1 else parts[0] + + if tp_size > 1 and stride > 1: + # Apply stride interleaving for the fused gate/up layout + stride_chunks = full_fc1_delta.chunk(stride, dim=0) + tp_chunks_per_stride = [c.chunk(tp_size, dim=0) for c in stride_chunks] + per_rank_delta = torch.cat( + [chunks[tp_rank] for chunks in tp_chunks_per_stride], dim=0, + ) + else: + total_rows = full_fc1_delta.shape[0] + per_rank = total_rows // tp_size if tp_size > 1 else total_rows + start = tp_rank * per_rank + per_rank_delta = full_fc1_delta[start:start + per_rank] + + module.to_wrap.weight.data = module.to_wrap.weight.data + per_rank_delta.to(base_device) + return module + + # --- LoRATopKRouter --- + if isinstance(module, LoRATopKRouter): + base_weight = module.to_wrap.weight + base_device = base_weight.device + adapter = module.adapter + alpha = adapter.alpha + dim = adapter.dim + + lin_in_w = adapter.linear_in.weight.to(base_device) + lin_out_w = adapter.linear_out.weight.to(base_device) + + if tp_size > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + lin_in_list = [torch.empty_like(lin_in_w) for _ in range(tp_size)] + dist.all_gather(lin_in_list, lin_in_w, group=tp_group) + lin_in_full = torch.cat(lin_in_list, dim=0) + + lin_out_list = [torch.empty_like(lin_out_w) for _ in range(tp_size)] + dist.all_gather(lin_out_list, lin_out_w, group=tp_group) + lin_out_full = torch.cat(lin_out_list, dim=0) + else: + lin_in_full = lin_in_w + lin_out_full = lin_out_w + + lora_delta = alpha / dim * (lin_out_full @ lin_in_full) + module.to_wrap.weight.data = base_weight + lora_delta + return module + + # --- Standard LoRA --- + if not isinstance(module, LoRALinear): + return module + + # Detect stride for strided ColumnParallelLinear (gated MLP) + stride = getattr(module.to_wrap, 'stride', 1) + + if hasattr(module.to_wrap, "weight"): + base_device = module.to_wrap.weight.device + merged_weight = self.merge( + module.to_wrap.weight, + module.adapter.linear_out.weight.to(base_device), + module.adapter.linear_in.weight.to(base_device), + module.adapter.alpha, + module.adapter.dim, + stride=stride, + ) + module.to_wrap.weight.data = merged_weight + else: # TE Grouped Linear + for i in range(module.to_wrap.num_gemms): + base_device = getattr(module.to_wrap, f"weight{i}").device + merged_weight = self.merge( + getattr(module.to_wrap, f"weight{i}"), + module.adapter.linear_out.weight.to(base_device), + module.adapter.linear_in.weight.to(base_device), + module.adapter.alpha, + module.adapter.dim, + stride=stride, + ) + getattr(module.to_wrap, f"weight{i}").data = merged_weight + return module + + +@contextmanager +@torch.no_grad() +def _backup_tensor_to_cpu(tensor: torch.Tensor) -> None: + """Backup a GPU tensor to pinned CPU memory via a custom attribute. + + The pinned CPU buffer is stored as ``tensor.mbridge_cpu_data`` and reused + across successive checkpoint saves to avoid repeated allocation. + """ + if not hasattr(tensor, "mbridge_cpu_data"): + tensor.mbridge_cpu_data = torch.empty_like( + tensor.data, device="cpu", pin_memory=True + ) + tensor.mbridge_cpu_data.copy_(tensor.data, non_blocking=True) + + +def _restore_tensor_from_cpu(tensor: torch.Tensor) -> None: + """Restore a GPU tensor from its pinned CPU backup.""" + tensor.data.copy_(tensor.mbridge_cpu_data, non_blocking=True) + + +@contextmanager +def lora_merged(models): + """Context manager that temporarily merges LoRA into base weights. + + On enter: for each LoRA-wrapped module (``LoRALinear``, + ``LoRALinearSplitQKV``, ``LoRALinearSplitFC1UpGate``, ``LinearAdapter``, + ``LoRATopKRouter``): + (1) backs up the base weight to pinned CPU memory, (2) merges the LoRA + delta in-place on GPU, and (3) swaps the wrapper with its ``to_wrap`` in + the parent so that ``named_parameters()`` yields clean names. + On exit, restores weights from CPU backup (no floating-point drift). + + The CPU backup buffer is cached on each tensor as ``tensor.mbridge_cpu_data`` + (pinned memory) so that repeated saves reuse the same allocation. + + Parameters + ---------- + models : list[nn.Module] + Unwrapped model chunks (as returned by ``unwrap_model``). + """ + merger = LoRAMerge() + weight_restore_list = [] # tensors to restore from mbridge_cpu_data + module_swaps = [] + linear_adapter_backups = [] + + _ADAPTER_WRAPPER_TYPES = (LoRALinear, LoRALinearSplitQKV, LoRALinearSplitFC1UpGate, LoRATopKRouter) + + for model_chunk in models: + # Collect (parent, attr_name, lora_module) before modifying structure + swap_list = [] + linear_adapter_list = [] + all_modules = dict(model_chunk.named_modules()) + for name, module in all_modules.items(): + if isinstance(module, _ADAPTER_WRAPPER_TYPES): + parts = name.rsplit(".", 1) + if len(parts) == 2: + parent_name, attr_name = parts + parent = all_modules[parent_name] + else: + parent = model_chunk + attr_name = parts[0] + swap_list.append((parent, attr_name, module)) + elif isinstance(module, LinearAdapter): + linear_adapter_list.append((name, module)) + + for parent, attr_name, lora_module in swap_list: + # Backup the original base weight to pinned CPU memory + if isinstance(lora_module, LoRATopKRouter): + w = lora_module.to_wrap.weight + _backup_tensor_to_cpu(w) + weight_restore_list.append(w) + elif hasattr(lora_module.to_wrap, "weight"): + w = lora_module.to_wrap.weight + _backup_tensor_to_cpu(w) + weight_restore_list.append(w) + else: + # TE Grouped Linear: backup all weights + for i in range(lora_module.to_wrap.num_gemms): + w = getattr(lora_module.to_wrap, f"weight{i}") + _backup_tensor_to_cpu(w) + weight_restore_list.append(w) + + # Merge using LoRAMerge.transform() which handles all adapter types + merger.transform(lora_module) + + # Replace wrapper with to_wrap in parent so parameter names are clean + setattr(parent, attr_name, lora_module.to_wrap) + module_swaps.append((parent, attr_name, lora_module)) + + # Handle LinearAdapter (extends nn.Linear directly, no to_wrap) + for name, la_module in linear_adapter_list: + if not getattr(la_module, '_adapter_enabled', True): + continue + _backup_tensor_to_cpu(la_module.weight) + scale = la_module.scale + delta = la_module.linear_out.weight.data @ la_module.linear_in.weight.data + la_module.weight.data.add_(scale * delta) + # Hide LoRA sub-modules so they don't appear in named_parameters() + saved_in = la_module._modules.pop('linear_in') + saved_out = la_module._modules.pop('linear_out') + linear_adapter_backups.append((la_module, saved_in, saved_out)) + + # Ensure all async CPU copies finish before entering the save section + torch.cuda.current_stream().synchronize() + + try: + yield + finally: + # Restore wrapper modules in parent + for parent, attr_name, lora_module in module_swaps: + setattr(parent, attr_name, lora_module) + # Restore original weight data from pinned CPU backup + for w in weight_restore_list: + _restore_tensor_from_cpu(w) + # Restore LinearAdapter weights and sub-modules + for la_module, saved_in, saved_out in linear_adapter_backups: + _restore_tensor_from_cpu(la_module.weight) + la_module._modules['linear_in'] = saved_in + la_module._modules['linear_out'] = saved_out + torch.cuda.current_stream().synchronize() diff --git a/mbridge/peft/lora_layers.py b/mbridge/peft/lora_layers.py new file mode 100644 index 0000000..52cb668 --- /dev/null +++ b/mbridge/peft/lora_layers.py @@ -0,0 +1,693 @@ +# Adapted from NVIDIA Megatron-Bridge + +import math +from typing import Any, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import transformer_engine.pytorch as te +from mbridge.peft.adapter_wrapper import AdapterWrapper +from megatron.core.transformer.moe.moe_utils import apply_random_logits + +try: + import bitsandbytes + HAVE_BNB = True +except ImportError: + bitsandbytes = None + HAVE_BNB = False + + +class LoRALinear(AdapterWrapper): + """An adapter wrapper that adds the output of the adapter to the output of the wrapped module. + + This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques + where the adapter's output is added to the main module's output. It extends the AdapterWrapper + class to provide a specific implementation of the forward method. + """ + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass that combines the wrapped module output with the adapter output. + + Args: + x: Input tensor. + *args: Additional positional arguments for the wrapped module. + **kwargs: Additional keyword arguments for the wrapped module. + + Returns: + A tuple containing: + - Combined output (linear_output + adapter_output) if adapter is enabled, + otherwise just the linear_output + - Bias term (if present, otherwise None) + """ + # pylint: disable=C0115,C0116 + linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs) + if not self._adapter_enabled: + return linear_output, bias + adapter_output = self.adapter(layernorm_output.contiguous()) + adapter_output = adapter_output.reshape(linear_output.shape) + return linear_output + adapter_output, bias + + +class LoRATopKRouter(AdapterWrapper): + """Adapter wrapper that applies LoRA to router gating logits.""" + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + """Forward pass that adds LoRA delta to router logits before routing.""" + self.to_wrap._maintain_float32_expert_bias() + jittered_input = self.to_wrap.apply_input_jitter(x) + logits = self.to_wrap.gating(jittered_input) + if self._adapter_enabled: + adapter_output = self.adapter(jittered_input.contiguous()) + logits = logits + adapter_output.to(dtype=logits.dtype) + if self.to_wrap.config.moe_router_force_load_balancing: + logits = apply_random_logits(logits) + return self.to_wrap.routing(logits, *args, **kwargs) + + +class TELinearAdapter(te.Linear): + """ + TELinear + LoRA, maintains ckpts structure (i.e. Linear's weight/bias remain at the same FQN) + + The _init_adapter and forward methods provide the LoRA functionality. We want to be able to + use those inside LinearAdapter but also for monkey-patching modules, without repeating the + same code -> therefore those are decorated with @staticmethod. + + Args: + orig_linear: The linear module to augment. + dim: LoRA's dimension (in_features -> dim -> out_features). + alpha: LoRA's scaling alpha. + dropout: Dropout probability (default: 0.0). + dropout_position: Where to apply dropout relative to LoRA (choices: ['pre', 'post'], default='pre'). + lora_A_init_method: Initialization method for lora_A (choices: ['xavier', 'uniform']). + lora_dtype: Weight's dtype, by default will use orig_linear's but if they + are quantized weights (e.g. 4bit) needs to be specified explicitly. + """ + + def __init__( + self, + orig_linear: "te.Linear", + dim: int = 8, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: Literal["pre", "post"] = "pre", + lora_A_init_method: Literal["xavier", "uniform"] = "xavier", + lora_dtype: Optional[torch.dtype] = None, + ) -> None: + """Initialize TELinearAdapter by copying from original TELinear and adding LoRA components. + + Args: + orig_linear: The original TELinear module to adapt. + dim: LoRA rank dimension. + alpha: LoRA scaling factor. + dropout: Dropout probability. + dropout_position: When to apply dropout ('pre' or 'post' LoRA computation). + lora_A_init_method: Initialization method for LoRA matrix A. + lora_dtype: Data type for LoRA weights. + """ + assert orig_linear.__class__ == te.Linear + # TELinear has bias set to empty tensor + has_bias = orig_linear.bias is not None and orig_linear.bias.shape[0] != 0 + super(TELinearAdapter, self).__init__( + in_features=orig_linear.in_features, + out_features=orig_linear.out_features, + bias=has_bias, + device=orig_linear.weight.device, + params_dtype=orig_linear.weight.dtype, + ) + # copy weights + self.weight.data.copy_(orig_linear.weight.data) + if has_bias: + self.bias.data.copy_(orig_linear.bias.data) + # initialize the adapter + TELinearAdapter._init_adapter( + self, + dim=dim, + alpha=alpha, + dropout=dropout, + dropout_position=dropout_position, + lora_A_init_method=lora_A_init_method, + lora_dtype=lora_dtype, + ) + self._adapter_enabled = True + + def enable_adapter_layers(self) -> None: + """Enable the adapter layers, allowing them to contribute to the forward pass output.""" + self._adapter_enabled = True + + def disable_adapter_layers(self) -> None: + """Disable the adapter layers, making the forward pass return only the base module output.""" + self._adapter_enabled = False + + @torch.no_grad + @staticmethod + def _init_adapter( + obj: Union["TELinearAdapter", nn.Module], + dim: int = 8, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: Literal["pre", "post"] = "pre", + lora_A_init_method: Literal["xavier", "uniform"] = "xavier", + lora_dtype: Optional[torch.dtype] = None, + ) -> None: + """Add LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching). + + Args: + obj: Input module to adapt (LinearAdapter or nn.Module). + dim: LoRA's dimension (in_features -> dim -> out_features). + alpha: LoRA's scaling alpha. + dropout: Dropout probability (default: 0.0). + dropout_position: Where to apply dropout relative to LoRA (choices: ['pre', 'post'], default='pre'). + lora_A_init_method: Initialization method for lora_A (choices: ['xavier', 'uniform']). + lora_dtype: Weight's dtype, by default will use orig_linear's but if they + are quantized weights (e.g. 4bit) needs to be specified explicitly. + """ + obj.dim = dim + obj.alpha = alpha + obj.scale = alpha / dim + + # Freeze original weights + device = obj.weight.device + obj.weight.requires_grad = False + if obj.bias is not None: + obj.bias.requires_grad = False + + in_features = obj.in_features + out_features = obj.out_features + dtype = lora_dtype or obj.weight.dtype + + obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) + obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) + if lora_A_init_method == "xavier": + torch.nn.init.xavier_uniform_(obj.linear_in.weight.data) + else: + nn.init.kaiming_uniform_(obj.linear_in.weight.data, a=math.sqrt(5)) + obj.linear_out.weight.data.fill_(0) + if dropout > 0.0: + obj.dropout = nn.Dropout(p=dropout) + else: + obj.dropout = nn.Identity() + assert dropout_position in ["pre", "post"], dropout_position + obj.dropout_position = dropout_position + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass combining TELinear output with LoRA adaptation. + + Args: + x: Input tensor. + + Returns: + Combined output from original linear layer and LoRA adaptation. + """ + # pylint: disable=C0115,C0116 + res = super(TELinearAdapter, self).forward(x) + + if not self._adapter_enabled: + return res + + if self.dropout_position == "pre": + x = self.dropout(x) + # LoRA fwd is performed in original precision regardless of FP8 enabled + lora_res = self.linear_out(self.linear_in(x)) + lora_res = lora_res * self.scale + if self.dropout_position == "post": + lora_res = self.dropout(lora_res) + return res + lora_res + + +class TEFusedLoRALinear(LoRALinear): + """LoRA adapter wrapper using Transformer Engine operation fuser""" + + def __init__(self, to_wrap: nn.Module, adapter: nn.Module): + super().__init__(to_wrap, adapter) + self._fused_branches: Optional[tuple[te.ops.Sequential, te.ops.Sequential]] = None + + def _make_fused_branches(self) -> tuple[te.ops.Sequential, te.ops.Sequential]: + """Construct fused modules for main and LoRA branches""" + + # Extract layer size and tensor parallel config + kwargs = { + "in_features": self.to_wrap.weight.size(1), + "out_features": self.to_wrap.weight.size(0), + "tensor_parallel_mode": None, + "tensor_parallel_group": None, + "sequence_parallel": False, + } + # TODO: Restore once TP is supported + # tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size() + # if tensor_parallel_size > 1: + # kwargs["tensor_parallel_group"] = parallel_state.get_tensor_model_parallel_group() + # if isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear)): + # kwargs["tensor_parallel_mode"] = self.to_wrap.parallel_mode + # kwargs["sequence_parallel"] = self.to_wrap.sequence_parallel + # if kwargs["tensor_parallel_mode"] == "row": + # kwargs["in_features"] *= tensor_parallel_size + # elif kwargs["tensor_parallel_mode"] == "column": + # kwargs["out_features"] *= tensor_parallel_size + + # wgrad accumulation fusion + accumulate_into_main_grad = False + if isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear)): + accumulate_into_main_grad = self.to_wrap.fuse_wgrad_accumulation + kwargs["accumulate_into_main_grad"] = accumulate_into_main_grad + + # Construct fused branches + main_branch = self._make_main_branch(**kwargs) + lora_branch = self._make_lora_branch(**kwargs) + + # Get submodule forward hooks + forward_pre_hooks = [] + forward_post_hooks = [] + for submodule in self.modules(): + for hook in submodule._forward_pre_hooks.values(): + forward_pre_hooks.append((submodule, hook)) + for hook in submodule._forward_hooks.values(): + forward_post_hooks.append((submodule, hook)) + + # Attempt to emulate submodule forward hooks if needed + # Note: Assume hooks do not interact with submodule inputs + # or outputs since they are internal to the op fuser. + if forward_pre_hooks: + + def forward_pre_hook(module, *_) -> None: + for submodule, hook in forward_pre_hooks: + # Assume that hook does not interact with + # input + hook(submodule, None) + + main_branch.register_forward_pre_hook(forward_pre_hook) + if forward_post_hooks: + + def forward_post_hook(module, *_) -> None: + for submodule, hook in forward_post_hooks: + # Assume that hook does not interact with + # input or output + hook(submodule, None, None) + + lora_branch.register_forward_hook(forward_post_hook) + + return main_branch, lora_branch + + def _make_main_branch( + self, + *, + in_features: int, + out_features: int, + tensor_parallel_mode: Optional[str], + tensor_parallel_group: Optional[torch.distributed.ProcessGroup], + sequence_parallel: bool, + accumulate_into_main_grad: bool, + ) -> te.ops.Sequential: + """Construct fused module for main branch (norm + fork + linear)""" + + # Check wrapped linear class + if not isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear, torch.nn.Linear)): + raise ValueError(f"Unsupported class for wrapped linear ({self.to_wrap.__class__.__name__})") + + # Ops in main branch + main_branch = te.ops.Sequential() + + # Norm op + if isinstance(self.to_wrap, te.LayerNormLinear): + norm_type = self.to_wrap.normalization + kwargs = { + "eps": self.to_wrap.eps, + "device": "meta", + "dtype": self.to_wrap.layer_norm_weight.dtype, + "zero_centered_gamma": self.to_wrap.zero_centered_gamma, + } + op = None + if norm_type == "LayerNorm": + op = te.ops.LayerNorm(in_features, **kwargs) + op.weight = self.to_wrap.layer_norm_weight + op.bias = self.to_wrap.layer_norm_bias + elif norm_type == "RMSNorm": + op = te.ops.RMSNorm(in_features, **kwargs) + op.weight = self.to_wrap.layer_norm_weight + else: + raise ValueError(f"Unsupported normalization ({norm_type})") + main_branch.append(op) + main_branch.append(te.ops.Quantize(forward=True, backward=False)) + + # Fork to LoRA branch + # Note: GEMM with beta=1 in backward pass + main_branch.append(te.ops.MakeExtraOutput(in_place=True)) + + # Linear op + weight = self.to_wrap.weight + bias = self.to_wrap.bias + if isinstance(bias, torch.Tensor) and bias.numel() == 0: + bias = None + op = te.ops.Linear( + in_features, + out_features, + bias=bias is not None, + device="meta", + dtype=weight.dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=tensor_parallel_group, + sequence_parallel=sequence_parallel, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + op.weight = weight + op.bias = bias + main_branch.append(op) + + return main_branch + + def _make_lora_branch( + self, + *, + in_features: int, + out_features: int, + tensor_parallel_mode: Optional[str], + tensor_parallel_group: Optional[torch.distributed.ProcessGroup], + sequence_parallel: bool, + accumulate_into_main_grad: bool, + ) -> te.ops.Sequential: + """Construct fused module for LoRA branch (linear_in + linear_out + add)""" + + from mbridge.peft.utils import ParallelLinearAdapter + + # Extract params from LoRA adapter + linear_in_weight = None + linear_out_weight = None + lora_dim = None + dropout = 0 + dropout_position = None + scale = None + if isinstance(self.adapter, (LinearAdapter, TELinearAdapter)): + linear_in_weight = self.adapter.linear_in.weight + linear_out_weight = self.adapter.linear_out.weight + lora_dim = linear_out_weight.size(1) + dropout = getattr(self.adapter.dropout, "p", 0.0) + dropout_position = self.adapter.dropout_position + scale = self.adapter.scale + elif isinstance(self.adapter, ParallelLinearAdapter): + linear_in_weight = self.adapter.linear_in.weight + linear_out_weight = self.adapter.linear_out.weight + lora_dim = linear_out_weight.size(1) + dropout = getattr(self.adapter.dropout, "p", 0.0) + dropout_position = self.adapter.dropout_position + scale = self.adapter.alpha / self.adapter.dim + else: + raise ValueError(f"Unsupported class for LoRA adapter ({self.adapter.__class__.__name__})") + + # Ops in LoRA branch + lora_branch = te.ops.Sequential() + + # LoRA pre-processing + if dropout > 0 and dropout_position == "pre": + lora_branch.append(te.ops.Dropout(dropout)) + + # LoRA A linear op + op = te.ops.Linear( + in_features, + lora_dim, + bias=False, + device="meta", + dtype=linear_in_weight.dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=tensor_parallel_group, + sequence_parallel=sequence_parallel, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + op.weight = linear_in_weight + lora_branch.append(op) + + # LoRA B linear op + if tensor_parallel_mode == "column": + # All-gather along dim -1 + raise NotImplementedError("Column tensor parallelism is not yet supported") + op = te.ops.Linear( + lora_dim, + out_features, + bias=False, + device="meta", + dtype=linear_out_weight.dtype, + tensor_parallel_mode=None if tensor_parallel_mode is None else "column", + tensor_parallel_group=tensor_parallel_group, + sequence_parallel=False, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + op.weight = linear_out_weight + lora_branch.append(op) + + # LoRA post-processing + if scale != 1: + lora_branch.append(te.ops.ConstantScale(scale)) + if dropout > 0 and dropout_position == "post": + lora_branch.append(te.ops.Dropout(dropout)) + if tensor_parallel_mode == "row": + # All-gather along dim -1 + raise NotImplementedError("Row tensor parallelism is not yet supported") + + # Add with main branch + # Note: GEMM with beta=1 in forward pass + lora_branch.append(te.ops.AddExtraInput(in_place=True)) + + return lora_branch + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]: + # pylint: disable=C0115,C0116 + + # If adapter is disabled, fall back to base forward + if not self._adapter_enabled: + return super().forward(x) + + # Construct fused impl if needed + # Note: We initialize during the first forward pass in + # case the params are modified after the constructor. + # Note: The fused impl is stored in a tuple to avoid + # registering submodules. + if self._fused_branches is None: + self._fused_branches = self._make_fused_branches() + + # Apply fused impl + main_branch, lora_branch = self._fused_branches + linear_output, linear_input = main_branch(x) + with te.fp8_autocast(enabled=False): + out = lora_branch(linear_input, linear_output) + return out, None + + +class LinearAdapter(nn.Linear): + """ + Linear + LoRA, maintains ckpts structure (i.e. Linear's weight/bias remain at the same FQN) + + The _init_adapter and forward methods provide the LoRA functionality. We want to be able to + use those inside LinearAdapter but also for monkey-patching modules, without repeating the + same code -> therefore those are decorated with @staticmethod. + + Args: + orig_linear: The linear module to augment. + dim: LoRA's dimension (in_features -> dim -> out_features). + alpha: LoRA's scaling alpha. + dropout: Dropout probability (default: 0.0). + dropout_position: Where to apply dropout relative to LoRA (choices: ['pre', 'post'], default='pre'). + lora_A_init_method: Initialization method for lora_A (choices: ['xavier', 'uniform']). + lora_dtype: Weight's dtype, by default will use orig_linear's but if they + are quantized weights (e.g. 4bit) needs to be specified explicitly. + """ + + def __init__( + self, + orig_linear: nn.Linear, + dim: int = 8, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: Literal["pre", "post"] = "pre", + lora_A_init_method: Literal["xavier", "uniform"] = "xavier", + lora_dtype: Optional[torch.dtype] = None, + ) -> None: + """Initialize LinearAdapter by copying from original Linear and adding LoRA components. + + Args: + orig_linear: The original Linear module to adapt. + dim: LoRA rank dimension. + alpha: LoRA scaling factor. + dropout: Dropout probability. + dropout_position: When to apply dropout ('pre' or 'post' LoRA computation). + lora_A_init_method: Initialization method for LoRA matrix A. + lora_dtype: Data type for LoRA weights. + """ + assert isinstance(orig_linear, nn.Linear) + super(LinearAdapter, self).__init__( + in_features=orig_linear.in_features, + out_features=orig_linear.out_features, + bias=orig_linear.bias is not None, + device=orig_linear.weight.device, + dtype=orig_linear.weight.dtype, + ) + # copy weights + self.weight.data.copy_(orig_linear.weight.data) + if orig_linear.bias is not None: + self.bias.data.copy_(orig_linear.bias.data) + # initialize the adapter + LinearAdapter._init_adapter( + self, + dim=dim, + alpha=alpha, + dropout=dropout, + dropout_position=dropout_position, + lora_A_init_method=lora_A_init_method, + lora_dtype=lora_dtype, + ) + self._adapter_enabled = True + + def enable_adapter_layers(self) -> None: + """Enable the adapter layers, allowing them to contribute to the forward pass output.""" + self._adapter_enabled = True + + def disable_adapter_layers(self) -> None: + """Disable the adapter layers, making the forward pass return only the base module output.""" + self._adapter_enabled = False + + @torch.no_grad + @staticmethod + def _init_adapter( + obj: Union["LinearAdapter", nn.Module], + dim: int = 8, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: Literal["pre", "post"] = "pre", + lora_A_init_method: Literal["xavier", "uniform"] = "xavier", + lora_dtype: Optional[torch.dtype] = None, + ) -> None: + """Add LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching). + + Args: + obj: Input module to adapt (LinearAdapter or nn.Module). + dim: LoRA's dimension (in_features -> dim -> out_features). + alpha: LoRA's scaling alpha. + dropout: Dropout probability (default: 0.0). + dropout_position: Where to apply dropout relative to LoRA (choices: ['pre', 'post'], default='pre'). + lora_A_init_method: Initialization method for lora_A (choices: ['xavier', 'uniform']). + lora_dtype: Weight's dtype, by default will use orig_linear's but if they + are quantized weights (e.g. 4bit) needs to be specified explicitly. + """ + obj.dim = dim + obj.alpha = alpha + obj.scale = alpha / dim + + # Freeze original weights + device = obj.weight.device + obj.weight.requires_grad = False + if obj.bias is not None: + obj.bias.requires_grad = False + + in_features = obj.in_features + out_features = obj.out_features + dtype = lora_dtype or obj.weight.dtype + + obj.linear_in = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) + obj.linear_out = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) + if lora_A_init_method == "xavier": + torch.nn.init.xavier_uniform_(obj.linear_in.weight.data) + else: + nn.init.kaiming_uniform_(obj.linear_in.weight.data, a=math.sqrt(5)) + obj.linear_out.weight.data.fill_(0) + if dropout > 0.0: + obj.dropout = nn.Dropout(p=dropout) + else: + obj.dropout = nn.Identity() + assert dropout_position in ["pre", "post"], dropout_position + obj.dropout_position = dropout_position + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass combining Linear output with LoRA adaptation. + + Args: + x: Input tensor. + + Returns: + Combined output from original linear layer and LoRA adaptation. + """ + # pylint: disable=C0115,C0116 + # If LinearAdapter is used to monkey-patch a nn.Linear module, we want to use nn.Linear's + # forward in the case where it uses quantized weights. We store a reference to nn.Linear's + # forward in `super_fwd` attribute. If the attribute does not exist we do the usual linear. + if (fwd := getattr(self, "super_fwd", None)) is not None: + assert fwd != self.forward + res = fwd(x) + else: + res = torch.nn.functional.linear(x, self.weight, self.bias) + + if not self._adapter_enabled: + return res + + if self.dropout_position == "pre": + x = self.dropout(x) + lora_res = self.linear_out(self.linear_in(x)) + lora_res = lora_res * self.scale + if self.dropout_position == "post": + lora_res = self.dropout(lora_res) + return res + lora_res + + +def patch_linear_module( + orig_linear: Union[nn.Linear, "te.Linear"], + dim: int = 8, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: Literal["pre", "post"] = "pre", + lora_A_init_method: Literal["xavier", "uniform"] = "xavier", + lora_dtype: Optional[torch.dtype] = None, +) -> Union[nn.Linear, "te.Linear"]: + """Monkey-patch a nn.Linear or te.Linear to be a LinearAdapter. + + This function replaces a nn.Linear with a LinearAdapter without copying weights, + making it suitable for cases where the original module was initialized with meta device. + + The orig_linear might not contain valid weights, for example, the given orig_linear was + initialized within a context-manager that uses a "meta" device. Therefore, we cannot copy + the weight/bias from the orig_linear to the LinearAdapter, since those have not been allocated. + + To circumvent this scenario, LinearAdapter's additional functionality (_init_adapter, _forward) + is based on static functions, so that we can use them for patching or when allocating a + new LinearAdapter object. + + Args: + orig_linear: The module we add adapter to. + dim: LoRA dimension. Defaults to 8. + alpha: LoRA alpha scale. Defaults to 32. + dropout: Dropout probability. Defaults to 0.0. + dropout_position: Location to apply dropout wrt LoRA. + Defaults to 'pre' (choices: 'pre', 'post'). + lora_A_init_method: LoRA_A initialization method. Defaults to 'xavier'. + lora_dtype: LoRA weights' dtype. By default will use orig_linear's dtype + but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must + specify the dtype manually. Defaults to None. + + Returns: + The monkey-patched (nn.Linear + LoRA) nn.Module. + + Raises: + NotImplementedError: If orig_linear is not nn.Linear or te.Linear. + AssertionError: If orig_linear already has super_fwd attribute. + """ + assert isinstance(orig_linear, nn.Linear) or (orig_linear.__class__ == te.Linear) + assert not hasattr(orig_linear, "super_fwd"), orig_linear.super_fwd + + if isinstance(orig_linear, nn.Linear): + LinearAdapter._init_adapter(orig_linear, dim, alpha, dropout, dropout_position, lora_A_init_method, lora_dtype) + cls = orig_linear.__class__ + new_cls = type("PatchedLinearAdapter", (LinearAdapter, cls), {}) + elif orig_linear.__class__ == te.Linear: + TELinearAdapter._init_adapter( + orig_linear, dim, alpha, dropout, dropout_position, lora_A_init_method, lora_dtype + ) + cls = orig_linear.__class__ + new_cls = type("PatchedTELinearAdapter", (TELinearAdapter, cls), {}) + else: + raise NotImplementedError("Expected isinstance(orig_linear, (nn.Linear, te.Linear))") + + # If the model uses quantized weights, we want to use orig_linear's forward + if ( + HAVE_BNB + and getattr(orig_linear, "quant_state", None) is not None + and orig_linear.quant_state.__class__ == bitsandbytes.functional.QuantState + ): + orig_linear.super_fwd = orig_linear.forward + + orig_linear.__class__ = new_cls + return orig_linear diff --git a/mbridge/peft/module_matcher.py b/mbridge/peft/module_matcher.py new file mode 100644 index 0000000..cd48a40 --- /dev/null +++ b/mbridge/peft/module_matcher.py @@ -0,0 +1,105 @@ +# Adapted from NVIDIA Megatron-Bridge + +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set + +from mbridge.peft.utils import wildcard_match +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, TELayerNormColumnParallelLinear, + TERowParallelLinear) +from megatron.core.tensor_parallel import (ColumnParallelLinear, + RowParallelLinear) +from torch import nn + + +@dataclass +class ModuleMatcher: + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. + + LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of LoRA to specific modules within the model architecture. + + Args: + target_modules (List[str], optional): A list of module names to apply LoRA to. + Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections + in self-attention. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. + - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. + Target modules can also contain wildcards. For example, you can specify + target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv + on the first two layers. + """ + + target_modules: List[str] = field( + default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + ) + exclude_modules: List[str] = field(default_factory=list) + canonical_mapping: Dict[str, Set] = field(default_factory=lambda: defaultdict(set)) + + def match( + self, m: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None + ) -> Optional[tuple[str, str]]: + """ + Determines whether a given module matches specified target patterns. + + This function checks if the provided module `m` should be included based on predefined + mapping rules (`canonical_mapping`, `target_modules`, and `exclude_modules`). It returns + the matching pattern if a match is found; otherwise, it returns `None`. + + Args: + m (nn.Module): The module being checked. + name (str, optional): The module's name. + prefix (str, optional): A prefix to be used in constructing `full_name`. + + Returns: + Optional[Tuple[str, str]]: A tuple containing (matching_pattern, full_name) if a match + is found; otherwise, `None`. + + Matching Logic: + 1) If `canonical_mapping` is defined, it checks: + - Whether `name` exactly matches a pattern. + - Whether `full_name` matches any regex pattern in `canonical_mapping`. + 2) If `target_modules` is defined, it follows the same logic as `canonical_mapping`. + 3) If neither `canonical_mapping` nor `target_modules` are defined, it ensures: + - `name` is not in `exclude_modules`. + - `full_name` does not match any `target_modules` patterns. + - `m` is an instance of `nn.Linear`. + + Notes: + - `exclude_modules` should only be non-empty if neither `canonical_mapping` nor `target_modules` are set. + - The function asserts that `exclude_modules` is empty when using `canonical_mapping` or `target_modules`. + """ + + full_name = f"{prefix}.{name}" if prefix else name + if len(self.canonical_mapping or []) > 0: + """ + Find the element in canonical_mapping which + 1) matches the current `name` exactly, OR + 2) matches the current `full_name` with wildcard + match is None if current module name doesn't match the specified targets. + """ + assert len(self.exclude_modules) == 0, "exclude_modules should be empty when using canonical_mapping" + for pattern in self.canonical_mapping: + if name == pattern or wildcard_match(pattern, full_name): + return (pattern, full_name) + elif len(self.target_modules or []) > 0: + assert len(self.exclude_modules) == 0, "exclude_modules should be empty when using target_modules" + for pattern in self.target_modules: + if name == pattern or wildcard_match(pattern, full_name): + return (pattern, full_name) + else: + linear_types = [ColumnParallelLinear, RowParallelLinear, nn.Linear, TEColumnParallelLinear, TELayerNormColumnParallelLinear, TERowParallelLinear] + linear_types = tuple(linear_types) + + if ( + name not in self.exclude_modules + and not any(wildcard_match(pattern, full_name) for pattern in self.exclude_modules) + and isinstance(m, linear_types) + ): + return (name, full_name) + + return None diff --git a/mbridge/peft/recompute.py b/mbridge/peft/recompute.py new file mode 100644 index 0000000..0958f72 --- /dev/null +++ b/mbridge/peft/recompute.py @@ -0,0 +1,118 @@ +# Adapted from NVIDIA Megatron-Bridge + +"""Helpers for PEFT-specific activation recompute fixes.""" + +from __future__ import annotations + +import logging +from functools import wraps +from typing import Iterable, Set + +import torch +from megatron.core.utils import unwrap_model + +logger = logging.getLogger(__name__) + + +def _print_rank_0(msg): + """Print message only on rank 0 (or if distributed is not initialized).""" + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + logger.info(msg) + + +PEFT_RECOMPUTE_PATCHED: Set[int] = set() + + +def _iter_unwrapped_models(model) -> Iterable[torch.nn.Module]: + """Yield unwrapped Megatron modules regardless of list/list-like inputs.""" + unwrapped = unwrap_model(model) + if isinstance(unwrapped, list): + for module in unwrapped: + if module is not None: + yield module + else: + if unwrapped is not None: + yield unwrapped + + +def maybe_enable_recompute_inputs_grad(model, peft_recompute_patched: Set[int] | None = None) -> Set[int]: + """Enable grad on TransformerBlock inputs when only adapters are trainable. + + Root cause analysis: + + - Megatron's CheckpointFunction.backward() is only invoked by PyTorch autograd + when at least one input tensor requires grad. + - With PP>1, received tensors from other stages have requires_grad=True, so + checkpoint backward is always called. + - With PP=1 and frozen base model, embedding outputs have requires_grad=False. + This means CheckpointFunction.backward() is never called, and LoRA gradients + inside the checkpoint are never computed. + + Solution: Hook TransformerBlock.forward to ensure hidden_states.requires_grad=True + before it enters checkpointed computation. This doesn't unfreeze any parameters; + it just ensures the autograd machinery calls checkpoint's backward. + + Borrowed (with modifications) from + https://github.com/HollowMan6/verl/blob/4285f0601028aee7ddcb9ec5a15198ebfc69bba3/verl/utils/megatron_peft_utils.py + """ + + from megatron.core.transformer.transformer_block import TransformerBlock + + patched_registry = peft_recompute_patched or PEFT_RECOMPUTE_PATCHED + + try: + for unwrapped_model in _iter_unwrapped_models(model): + cfg = getattr(unwrapped_model, "config", None) + if cfg is None or getattr(cfg, "recompute_method", None) is None: + continue + + if id(unwrapped_model) in patched_registry: + continue + + params = list(unwrapped_model.named_parameters()) + trainable_adapter = any(p.requires_grad and ".adapter." in n.lower() for n, p in params) + trainable_base = any( + p.requires_grad and (".to_wrap." not in n.lower() and ".adapter." not in n.lower()) for n, p in params + ) + + if not (trainable_adapter and not trainable_base): + continue # Not adapter-only training, no fix needed + + def _patch_transformer_block(module: torch.nn.Module) -> bool: + if isinstance(module, TransformerBlock): + original_forward = module.forward + + @wraps(original_forward) + def patched_forward(hidden_states, *args, _original_forward=original_forward, **kwargs): + # Ensure hidden_states requires grad so checkpoint backward is called + if ( + torch.is_tensor(hidden_states) + and not hidden_states.requires_grad + and hidden_states.is_floating_point() + ): + hidden_states = hidden_states.detach().requires_grad_(True) + return _original_forward(hidden_states, *args, **kwargs) + + module.forward = patched_forward + return True + return False + + patched = False + for module in unwrapped_model.modules(): + if _patch_transformer_block(module): + patched = True + if patched: + patched_registry.add(id(unwrapped_model)) + _print_rank_0( + "[PEFT+Recompute] Patched TransformerBlock.forward to enable grad on " + "hidden_states input. This ensures checkpoint backward is called when " + "only adapters are trainable (PP=1 with frozen base model).", + ) + except Exception as exc: # pragma: no cover - best effort logging + # Log but don't fail - user will see grad_norm=0 and can debug + _print_rank_0(f"[PEFT+Recompute] Warning: Failed to patch TransformerBlock: {exc}") + + return patched_registry + + +__all__ = ["maybe_enable_recompute_inputs_grad", "PEFT_RECOMPUTE_PATCHED"] diff --git a/mbridge/peft/utils.py b/mbridge/peft/utils.py new file mode 100644 index 0000000..8d4d613 --- /dev/null +++ b/mbridge/peft/utils.py @@ -0,0 +1,692 @@ +# Adapted from NVIDIA Megatron-Bridge + +import math +import re +from dataclasses import dataclass +from importlib.metadata import version +from typing import Callable, Dict, Optional, Tuple + +import packaging +import torch +import torch.nn as nn +from megatron.core import ModelParallelConfig, parallel_state +from megatron.core.dist_checkpointing.mapping import (ShardedStateDict, + ShardedTensor) +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, TEColumnParallelLinear, + TELayerNormColumnParallelLinear, TELinear, TERowParallelGroupedLinear, + TERowParallelLinear) +from megatron.core.tensor_parallel import (ColumnParallelLinear, + RowParallelLinear) +from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) +from megatron.core.transformer.mlp import apply_swiglu_sharded_factory +from megatron.core.transformer.moe.router import TopKRouter + +TECL = (TEColumnParallelLinear, TELayerNormColumnParallelLinear, TEColumnParallelGroupedLinear) +TERL = (TERowParallelLinear, TERowParallelGroupedLinear) + + +@dataclass(frozen=True) +class AdapterAttributes: + """Container for base linear adapter attributes.""" + + input_is_parallel: bool + in_features: int + out_features: int + disable_tensor_parallel_comm: bool + disable_sequence_parallel_comm: bool + base_linear_is_parallel: bool + + +def get_adapter_attributes_from_linear(m: nn.Module, is_expert: bool = False) -> AdapterAttributes: + """Returns attributes from the base layer as an AdapterAttributes dataclass. + + input_is_parallel, in_features, out_features, disable_tensor_parallel_comm, + disable_sequence_parallel_comm, base_linear_is_parallel + + This function analyzes a linear module and extracts key attributes needed for adapter configuration, + particularly for PEFT adapters in distributed training scenarios. + + Args: + m: The linear module to analyze (should have a config attribute). + + Returns: + AdapterAttributes containing: + - input_is_parallel: Whether the input is already parallelized + - in_features: Input feature dimension + - out_features: Output feature dimension + - disable_tensor_parallel_comm: Whether to disable tensor parallel communication + - disable_sequence_parallel_comm: Whether to disable sequence parallel communication + - base_linear_is_parallel: Whether the base linear layer uses parallelization + + Raises: + NotImplementedError: If the layer type is not recognized for LoRA adaptation. + """ + disable_sequence_parallel_comm = not m.config.sequence_parallel + base_linear_is_parallel = True + + # In some modules (notably MoE shared_experts when moe_shared_expert_overlap is enabled), + # Megatron disables TP-related communications on the base linear layer by + # setting `parallel_mode=None` (TE) or `explicit_expert_comm=True` (legacy). + # https://github.com/NVIDIA/Megatron-LM/blob/5b1ef0703184299fbf71f6131bf2f9a5331e7238/megatron/core/transformer/moe/shared_experts.py#L95-L104 + # The weights are still TP-sharded though, so we must keep using the real TP size + disable_tensor_parallel_comm = getattr(m, "parallel_mode", "") is None or getattr(m, "explicit_expert_comm", False) + if disable_tensor_parallel_comm: + disable_sequence_parallel_comm = True + + if is_expert: + tp_size = parallel_state.get_expert_tensor_parallel_world_size() + else: + tp_size = parallel_state.get_tensor_model_parallel_world_size() + if isinstance(m, TopKRouter): + input_is_parallel = False + in_features = m.weight.shape[1] + out_features = m.weight.shape[0] + base_linear_is_parallel = False + disable_sequence_parallel_comm = True + elif any(isinstance(m, te_column_parallel) for te_column_parallel in TECL): + input_is_parallel = False + # m.in_features and m.out_features are divided by tp_size already, + # but in_features and out_features passed to ParallelLinearAdapter are not. + in_features = m.in_features + out_features = m.out_features * tp_size + + if isinstance(m, TELayerNormColumnParallelLinear): + # LoRA is applied after layernorm, so layernorm output must be returned + m.return_layernorm_output = True + # perf optimization for LoRA + SP + if hasattr(m, "ub_overlap_ag"): + ub_overlap_ag = m.ub_overlap_ag + elif hasattr(m, "ub_overlap_ag_fprop"): + ub_overlap_ag = m.ub_overlap_ag_fprop + else: + ub_overlap_ag = False + if hasattr(m, "config") and m.config.sequence_parallel and not ub_overlap_ag: + m.return_layernorm_output_gathered = True + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("1.5.0dev") and ( + not getattr(m.config, "tp_comm_overlap", False) + or getattr(m.config, "tp_comm_overlap_disable_qkv", False) + ): + # TE 1.5 introduces the option `return_layernorm_output_gathered`, so the all gather + # in the forward method is not needed, so disable sp communications + # unless TP communication overlap is used + disable_sequence_parallel_comm = True + elif any(isinstance(m, te_row_parallel) for te_row_parallel in TERL): + input_is_parallel = True + in_features = m.in_features * tp_size + out_features = m.out_features + elif isinstance(m, TELinear): # parallel_mode="duplicated" + input_is_parallel = False + in_features = m.in_features + out_features = m.out_features + base_linear_is_parallel = False + elif isinstance(m, ColumnParallelLinear): + input_is_parallel = False + in_features = m.input_size + out_features = m.output_size + elif isinstance(m, RowParallelLinear): + input_is_parallel = True + in_features = m.input_size + out_features = m.output_size + else: + raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}") + + return AdapterAttributes( + input_is_parallel=input_is_parallel, + in_features=in_features, + out_features=out_features, + disable_tensor_parallel_comm=disable_tensor_parallel_comm, + disable_sequence_parallel_comm=disable_sequence_parallel_comm, + base_linear_is_parallel=base_linear_is_parallel, + ) + + +def is_expert_linear(fqn: str) -> bool: + """Return whether the current base module is an expert linear module. + + This function checks if a fully qualified name (FQN) corresponds to an expert linear + module in a Mixture of Experts (MoE) architecture. + + Args: + fqn: Fully qualified name of the module. + + Returns: + True if the module is an expert linear module, False otherwise. + + Example: + >>> is_expert_linear("model.layers.0.mlp.experts.0.linear_fc1") + True + >>> is_expert_linear("model.layers.0.mlp.linear_fc1") + False + """ + return re.match(r".*mlp\..*experts.*\.linear_fc[1-2]$", fqn) is not None and not ".shared_experts." in fqn + + +def wildcard_match(pattern: str, key: Optional[str]) -> Optional[bool]: + """Return whether the pattern (target module to add LoRA) matches the key (model weight name). + + This function performs wildcard matching using '*' as a placeholder for any substring. + + Args: + pattern: Pattern string with wildcards (*) to match against. + key: Key string to test against the pattern. + + Returns: + True if the pattern matches the key, False if it doesn't, None if key is None. + + Example: + >>> wildcard_match("*.layers.0.*.linear_qkv", "decoder.layers.0.self_attention.linear_qkv") + True + >>> wildcard_match("*.layers.0.*.linear_qkv", "decoder.layers.1.self_attention.linear_qkv") + False + """ + if key is None: + return None + regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$") + match = regex_pattern.match(key) + return match is not None + + +def init_method_normal(sigma: float) -> Callable[[torch.Tensor], torch.Tensor]: + """Create an initialization method based on normal distribution N(0, sigma). + + Args: + sigma: Standard deviation for the normal distribution. + + Returns: + Initialization function that applies normal distribution to a tensor. + """ + + def init_(tensor: torch.Tensor) -> torch.Tensor: + return nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def init_method_kaiming_uniform(val: float) -> Callable[[torch.Tensor], torch.Tensor]: + """Create an initialization method based on Kaiming uniform distribution. + + Args: + val: The 'a' parameter for Kaiming uniform initialization. + + Returns: + Initialization function that applies Kaiming uniform distribution to a tensor. + """ + + def init_(tensor: torch.Tensor) -> torch.Tensor: + return nn.init.kaiming_uniform_(tensor, a=val) + + return init_ + + +def init_method_const(val: float) -> Callable[[torch.Tensor], torch.Tensor]: + """Create an initialization method that sets all values to a constant. + + Args: + val: Constant value to initialize the tensor with. + + Returns: + Initialization function that sets tensor to constant value. + """ + + def init_(tensor: torch.Tensor) -> torch.Tensor: + return nn.init.constant_(tensor, val) + + return init_ + + +def pad_seq_to_mult(x: torch.Tensor, mult: int) -> Tuple[torch.Tensor, int]: + """Pad sequence length to be a multiple of mult. + + This function pads the first dimension of the tensor to ensure it's divisible by mult. + Used primarily for MoE (Mixture of Experts) operations that require specific sequence lengths. + + Args: + x: Input tensor to pad. + mult: Multiple that the sequence length should be divisible by. + + Returns: + A tuple containing: + - Padded tensor + - Number of padding elements added + """ + if x.shape[0] % mult == 0: + return x, 0 + pad_len = mult - (x.shape[0] % mult) + with torch.no_grad(): + # pad at the tail + x = nn.functional.pad(x, (0, 0, 0, pad_len)) + return x, pad_len + + +def unpad_seq_to_mult(x: torch.Tensor, pad_len: int) -> torch.Tensor: + """Remove sequence padding that was added by pad_seq_to_mult. + + Args: + x: Padded tensor to unpad. + pad_len: Number of padding elements to remove from the end. + + Returns: + Unpadded tensor with pad_len elements removed from the first dimension. + """ + if pad_len <= 0: + return x + with torch.no_grad(): + # prune tail padding + return x[:-pad_len, :] + + +class _All2AllHp2Sp(torch.autograd.Function): + """All-2-All from Hidden Parallel to Sequence Parallel. + + This is a temporary workaround for distributed communication patterns and can be updated in the future. + It performs all-to-all communication to transform from hidden parallel to sequence parallel layout. + + TODO: Move the functionality to MCore + """ + + @staticmethod + def forward(ctx, input_: torch.Tensor) -> torch.Tensor: + """Forward pass: All-to-All from Hidden Parallel to Sequence Parallel. + + Args: + ctx: Autograd context (unused but required by Function interface). + input_: Input tensor in hidden parallel layout. + + Returns: + Output tensor in sequence parallel layout. + """ + world_size = parallel_state.get_tensor_model_parallel_world_size() + group = parallel_state.get_tensor_model_parallel_group() + send_list = list(input_.chunk(world_size, dim=0)) + send_list = [tensor.contiguous() for tensor in send_list] + receive_list = [torch.empty_like(send_list[0]) for _ in range(world_size)] + torch.distributed.all_to_all(receive_list, send_list, group=group) + x = torch.cat(receive_list, dim=-1) + + return x + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + """Backward pass: All-to-All from Sequence Parallel to Hidden Parallel. + + Args: + ctx: Autograd context (unused but required by Function interface). + grad_output: Gradient tensor in sequence parallel layout. + + Returns: + Gradient tensor in hidden parallel layout. + """ + world_size = parallel_state.get_tensor_model_parallel_world_size() + group = parallel_state.get_tensor_model_parallel_group() + send_list = list(grad_output.chunk(world_size, dim=-1)) + send_list = [tensor.contiguous() for tensor in send_list] + receive_list = [torch.empty_like(send_list[0]) for _ in range(world_size)] + torch.distributed.all_to_all(receive_list, send_list, group=group) + x = torch.cat(receive_list, dim=0) + + return x + + +def all2all_hp2sp(input_: torch.Tensor) -> torch.Tensor: + """Perform All-to-All communication from Hidden Parallel to Sequence Parallel. + + Args: + input_: Input tensor in hidden parallel layout. + + Returns: + Output tensor in sequence parallel layout. + """ + return _All2AllHp2Sp.apply(input_) + + +class ParallelLinearAdapter(nn.Module): + """Parallel Linear Adapter for Parameter-Efficient Fine-Tuning (PEFT) in distributed settings. + + This adapter implements a low-rank adaptation pattern using two linear layers with configurable + parallelization strategies. It supports both tensor and sequence parallelism patterns used in + large language model training. + + The adapter follows the pattern: input -> linear_in -> activation -> linear_out -> scaling + where linear_in and linear_out are parallelized according to the base layer configuration. + + Args: + in_features: Input feature dimension. + out_features: Output feature dimension. + dim: Adapter bottleneck dimension (rank). + base_linear_name: Name of the base linear layer being adapted. + activation: Activation function name (default: 'swish'). + column_init_method: Initialization method for column parallel layer (default: 'xavier'). + row_init_method: Initialization method for row parallel layer (default: 'zero'). + input_is_parallel: Whether input is already parallelized (default: False). + dropout: Dropout probability (default: 0.0). + model_parallel_config: Configuration for model parallelism (default: None). + alpha: Scaling factor for adapter output (default: None, uses dim). + dropout_position: Where to apply dropout ('pre' or 'post', default: 'pre'). + a2a_experimental: Whether to use experimental all-to-all communication (default: False). + is_expert: Whether this adapter is for expert layers in MoE (default: False). + disable_sequence_parallel_comm: Whether to disable sequence parallel communication (default: True). + base_linear_is_parallel: Whether the base linear layer uses parallelization (default: True). + """ + + def __init__( + self, + in_features: int, + out_features: int, + dim: int, + base_linear_name: str, + activation: str = "swish", + column_init_method: str = "xavier", + row_init_method: str = "zero", + input_is_parallel: bool = False, + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: Optional[float] = None, + dropout_position: str = "pre", + a2a_experimental: bool = False, + is_expert: bool = False, + disable_tensor_parallel_comm: bool = False, + disable_sequence_parallel_comm: bool = True, + base_linear_is_parallel: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.base_linear_name = base_linear_name + self.activation = self._get_activation_fn(activation) + self.dim = dim + self.in_features = in_features # stored for TP-independent init + self.alpha = alpha if alpha is not None else self.dim + self.input_is_parallel = input_is_parallel + self.dropout_position = dropout_position + self.use_a2a = a2a_experimental + self.is_expert = is_expert + + # megatron_gpt_peft_models will provide this arg, but deprecated ones do not. + # in case this arg is not provided, use the dummy default config. + if model_parallel_config is None: + model_parallel_config = ModelParallelConfig() + _sequence_parallel = model_parallel_config.sequence_parallel + model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer + self.config = model_parallel_config + + # Ensure adapter parameters are initialized when creating adapter layers. + # In some flows (e.g., after import), perform_initialization may be False to skip heavy init. + if hasattr(model_parallel_config, "perform_initialization"): + model_parallel_config.perform_initialization = True + + # Force CPU initialization so lora_A (xavier) is TP-independent: + # GPU init uses per-rank model-parallel seeds (seed + tp_rank), producing + # different full weights for different TP degrees. CPU init generates + # one master weight on all ranks with the same RNG and scatters the + # matching shard, making the result identical regardless of TP. + _use_cpu_init = getattr(model_parallel_config, 'use_cpu_initialization', False) + model_parallel_config.use_cpu_initialization = True + + if input_is_parallel: + self.linear_in = RowParallelLinear( + in_features, + dim, + config=model_parallel_config, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + init_method=self._get_init_fn(column_init_method), + is_expert=is_expert, + ) + else: + self.linear_in = ColumnParallelLinear( + in_features, + dim, + config=model_parallel_config, + bias=False, + gather_output=True, + init_method=self._get_init_fn(column_init_method), + disable_grad_reduce=_sequence_parallel, + is_expert=is_expert, + ) + + # (@adithyare) we use this option to mirror the behavior + # a column parallel layer with two low-rank column parallel layers + # if the original column parallel layer uses gather_output=False, + # then we will use the self.liner_out layer defined below. + lin_out_gather_output = True if input_is_parallel else False + if ( + self.use_a2a + and input_is_parallel + and _sequence_parallel + or (disable_tensor_parallel_comm and not input_is_parallel) + ): + lin_out_gather_output = False + + if not base_linear_is_parallel: + lin_out_gather_output = True + + self.linear_out = ColumnParallelLinear( + dim, + out_features, + config=model_parallel_config, + bias=False, + gather_output=lin_out_gather_output, + init_method=self._get_init_fn(row_init_method), + is_expert=is_expert, + ) + + if dropout > 0.0: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = nn.Identity() + + # cast all parameters when using amp O2 training + if model_parallel_config.bf16: + self.bfloat16() + elif model_parallel_config.fp16: + self.half() + + # revert config changes in case they are read elsewhere + model_parallel_config.use_cpu_initialization = _use_cpu_init + model_parallel_config.sequence_parallel = _sequence_parallel + self.disable_sequence_parallel_comm = disable_sequence_parallel_comm + if not _sequence_parallel: + self.disable_sequence_parallel_comm = True + + if not base_linear_is_parallel: + self.disable_sequence_parallel_comm = True + + def _get_activation_fn(self, activation: str) -> nn.Module: + """Get activation function by name. + + Args: + activation: Name of the activation function. + + Returns: + PyTorch activation module. + + Note: + Defaults to Identity if activation name is not recognized. + """ + activation_map = { + "identity": nn.Identity(), + "relu": nn.ReLU(), + "gelu": nn.GELU(), + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "tanh": nn.Tanh(), + "sigmoid": nn.Sigmoid(), + } + return activation_map.get(activation, nn.Identity()) + + def _get_init_fn(self, init_method: str) -> Callable[[torch.Tensor], torch.Tensor]: + """Get initialization function by method name. + + Args: + init_method: Name of the initialization method. + + Returns: + Initialization function. + + Raises: + NotImplementedError: If init_method is not supported. + + Note: + For ``xavier`` init on ``linear_in`` (lora_A), we use the full (unsharded) + ``dim`` and ``in_features`` to compute the xavier std, making initialization + independent of tensor_model_parallel_size. Otherwise, with TP>1 the per-rank + weight shape is ``(dim/TP, in_features)`` and xavier computes a different std + (because ``fan_out = dim/TP``), leading to diverging training trajectories. + """ + if init_method == "xavier": + # Use the full (unsharded) dim to make init TP-independent. + # xavier_normal_ uses std = gain * sqrt(2 / (fan_in + fan_out)). + # fan_in = self.in_features, fan_out = self.dim (full, not dim/TP). + std = math.sqrt(2.0 / (self.in_features + self.dim)) + init_fn = init_method_normal(std) + elif init_method == "normal": + init_fn = init_method_normal(0.2) + elif init_method == "kaiming": + init_fn = init_method_kaiming_uniform(math.sqrt(5)) + elif init_method == "zero": + init_fn = init_method_const(0.0) + else: + raise NotImplementedError("out_init_method should be zero, normal, kaiming or xavier") + return init_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the parallel linear adapter. + + Performs the adaptation computation with proper handling of parallel communication + patterns, dropout, and expert routing for MoE scenarios. + + Args: + x: Input tensor. + + Returns: + Adapted output tensor with scaling applied. + """ + if self.dropout_position == "pre": + x = self.dropout(x) + + pad_len = 0 + if self.is_expert: + x, pad_len = pad_seq_to_mult(x, self.config.expert_tensor_parallel_size) + + if not self.disable_sequence_parallel_comm and not self.input_is_parallel and not self.is_expert: + # for attention_qkv and linear_fc1 + # layernorm before lora is impacted by sequence parallel, + # hence seq dim need to be gathered right before lora linear layers + # this function also handles the backward pass correctly + x = gather_from_sequence_parallel_region(x) + + if self.config.cpu_offloading and self.config.cpu_offloading_activations: + x.activation_offloading = True + x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term. + + x = self.activation(x) + + if self.config.cpu_offloading and self.config.cpu_offloading_activations: + x.activation_offloading = True + x, _ = self.linear_out(x) + + if not self.disable_sequence_parallel_comm and self.input_is_parallel and not self.is_expert: + # for attention_dense and linear_fc2 + # layernorm after lora is impacted by sequence parallel, + # hence seq dim need to be scattered right after lora linear layers + # this function also handles the backward pass correctly + if self.use_a2a: + # all2all hidden_size / TP to seq_len / TP + x = all2all_hp2sp(x) + else: + x = scatter_to_sequence_parallel_region(x) + + # Add dropout if available + if self.dropout_position == "post": + x = self.dropout(x) + + x = x * (self.alpha / self.dim) + + if pad_len > 0: + # Remove MoE padding. + x = unpad_seq_to_mult(x, pad_len) + + return x + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Tuple = (), + metadata: Optional[Dict] = None, + mamba_dim_info: Optional[Dict] = None, + ) -> ShardedStateDict: + """Create sharded state dictionary for distributed checkpointing. + + Special treatment is given to the linear_fc1 adapter since tensor parallelism is + sharded separately for the two logical matrices (gate and up) in SwiGLU. + + Args: + prefix: Prefix for parameter names. + sharded_offsets: Offsets for sharded parameters. + metadata: Additional metadata for sharding. + + Returns: + Sharded state dictionary for distributed checkpointing. + """ + sharded_state_dict = {} + linear_in_sd = self.linear_in.sharded_state_dict(f"{prefix}linear_in.", sharded_offsets, metadata) + linear_out_sd = self.linear_out.sharded_state_dict(f"{prefix}linear_out.", sharded_offsets, metadata) + + # The experts.py code in Megatron-LM set replica_id = (PP, ETP, EDP), + # but it will cause errors as mentioned in https://github.com/volcengine/verl/issues/4303, + # since adapter weights are not EP sharded and it assumes that it will + # replicate along DP modulo EP (sharded by EP) + if self.is_expert: + from megatron.core import parallel_state + + ep_rank = parallel_state.get_expert_model_parallel_rank() + edp_rank = parallel_state.get_expert_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + # TODO: This modification logic is in question and needs further verification. + rank = (ep_rank + 1) * (edp_rank + 1) - 1 if dp_size == 1 else ep_rank + for sd in [linear_in_sd, linear_out_sd]: + for v in sd.values(): + if hasattr(v, "replica_id"): + old_rid = v.replica_id + v.replica_id = (old_rid[0], rank, old_rid[2]) + + if "linear_fc1" in self.base_linear_name: + for k, v in linear_out_sd.items(): + if k in (f"{prefix}linear_out.weight", f"{prefix}linear_out.bias"): + linear_out_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets) + + # Special handling for Mamba in_proj layer which needs to be split into 5 tensors + if mamba_dim_info is not None: + from megatron.core.ssm.mamba_mixer import _split_tensor_factory + + # Split linear_out.weight into 5 parts: z, x, B, C, dt + # The in_proj output dimension is: d_inner * 2 + 2 * ngroups * d_state + nheads + # After TP sharding: d_inner_local_tp * 2 + 2 * ngroups_local_tp * d_state + nheads_local_tp + for k, v in linear_out_sd.items(): + if k == f"{prefix}linear_out.weight" and isinstance(v, ShardedTensor): + in_proj_dim_local = ( + mamba_dim_info["d_inner_local_tp"] * 2 + + 2 * mamba_dim_info["ngroups_local_tp"] * mamba_dim_info["d_state"] + + mamba_dim_info["nheads_local_tp"] + ) + # Verify the dimension matches + if v.data.size(0) == in_proj_dim_local: + linear_out_sd[k] = _split_tensor_factory( + v, + [ + mamba_dim_info["d_inner_local_tp"], # z + mamba_dim_info["d_inner_local_tp"], # x + mamba_dim_info["ngroups_local_tp"] * mamba_dim_info["d_state"], # B + mamba_dim_info["ngroups_local_tp"] * mamba_dim_info["d_state"], # C + mamba_dim_info["nheads_local_tp"], # dt + ], + ["z", "x", "B", "C", "dt"], + 0, # split along dimension 0 + ) + + sharded_state_dict.update(linear_in_sd) + sharded_state_dict.update(linear_out_sd) + return sharded_state_dict diff --git a/mbridge/peft/walk_utils.py b/mbridge/peft/walk_utils.py new file mode 100644 index 0000000..d5e44d3 --- /dev/null +++ b/mbridge/peft/walk_utils.py @@ -0,0 +1,304 @@ +# Adapted from NVIDIA Megatron-Bridge + +""" +Walking utilities for PyTorch module transformation. + +This module provides utilities for recursively applying transformations to PyTorch modules, +handling complex hierarchies including lists, dictionaries, and nested structures. + +Examples: + Basic module transformation: + >>> def add_tag(module, name=None, **kwargs): + ... module.tag = f"transformed_{name}" + ... return module + >>> + >>> model = nn.Sequential(nn.Linear(10, 5), nn.ReLU()) + >>> transformed = walk(model, add_tag) + + Conditional transformation: + >>> def freeze_linear(module, **kwargs): + ... if isinstance(module, nn.Linear): + ... for param in module.parameters(): + ... param.requires_grad = False + ... return module + >>> + >>> frozen_model = walk(model, freeze_linear) +""" + +import inspect +from typing import (Callable, Iterable, Protocol, TypeVar, Union, + runtime_checkable) + +import torch.nn as nn + + +@runtime_checkable +class HasBool(Protocol): + """Protocol for objects that can be evaluated as boolean.""" + + def __bool__(self) -> bool: ... + + +_TModule = TypeVar("_TModule", bound=nn.Module) +ModuleFunc = Callable[[nn.Module], nn.Module] +ModulePredicate = Callable[[nn.Module], Union[bool, HasBool]] + + +def map( # noqa: A001 + module: _TModule, + func: ModuleFunc, + leaf_only: bool = False, + **kwargs, +) -> _TModule: + """Applies a function to a PyTorch module or a collection of modules. + + This function can be used to modify modules in place, such as changing their attributes, + applying normalization, or any other custom transformations. It supports individual modules, + lists of modules, and dictionaries of modules. The function can be applied selectively to + modules that do not have parameters if `leaf_only` is set to True. + + Args: + module: The module or collection of modules to which the function will be applied. + func: A callable that takes a module (and optionally additional keyword arguments) and + returns a transformed module. The signature should be `func(module, **kwargs)`. + leaf_only: If True, the function will only be applied to modules that + do not have any parameters. Defaults to False. + **kwargs: Additional keyword arguments that will be passed to `func`. + + Returns: + The transformed module or collection of modules. + + Examples: + >>> import torch.nn as nn + >>> from mbridge.peft.walk_utils import map + + # Example: Adding a custom attribute to all modules + >>> model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) + >>> def add_id(m, module_id=0): + ... m.custom_id = module_id + ... return m + >>> model = map(model, add_id, module_id=42) + """ + if module is None: + return module + + if not kwargs.pop("_skip_map", False) and hasattr(module, "map"): + return module.map(func, leaf_only=leaf_only, **kwargs) + + elif isinstance(module, Iterable): + if all(hasattr(module, key) for key in ["items", "values", "keys"]): + return _map_module_dict(module, func, leaf_only=leaf_only, **kwargs) + + return _map_module_list(module, func, leaf_only=leaf_only, **kwargs) + else: + return _map_module(module, func, leaf_only=leaf_only, **kwargs) + + +def walk( + module: _TModule, + func: ModuleFunc, + leaf_only: bool = False, + **kwargs, +) -> _TModule: + """Recursively apply a function to a module or collection. + + This function is similar to `map`, but it applies the function recursively to all child + modules as well. This is useful for applying transformations that need to consider the + module hierarchy. + + Args: + module: The module or collection to recursively apply to. + func: The function to apply. + leaf_only: If True, only apply to modules without parameters. Defaults to False. + **kwargs: Additional kwargs to pass to the function. + + Returns: + The transformed module or collection. + + Examples: + >>> import torch.nn as nn + >>> from mbridge.peft.walk_utils import walk + + # Example: Freezing all parameters in a model + >>> model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10)) + >>> def freeze_params(m): + ... for param in m.parameters(recurse=False): + ... param.requires_grad = False + ... return m + >>> frozen_model = walk(model, freeze_params) + """ + return map( + module, + func, + recurse=True, + leaf_only=leaf_only, + **kwargs, + ) + + +def _map_module( + module: _TModule, func: ModuleFunc, recurse=False, leaf_only=False, transformed_modules=None, **kwargs +) -> _TModule: + """ + Applies a transformation function to a module and optionally to its child modules. + + Parameters: + module : nn.Module + The module to which the function will be applied. + func : ModuleFunc + The function that will be applied to the module. + recurse : bool, optional + Whether to apply the function recursively to child modules. + leaf_only : bool, optional + Whether to apply the function only to modules without parameters. + transformed_modules : set, optional + A set to keep track of modules that have already been transformed. + **kwargs : dict + Additional keyword arguments that will be passed to the transformation function. + + Returns: + nn.Module + The transformed module. + """ + if transformed_modules is None: + transformed_modules = set() + + if id(module) in transformed_modules: + return module + + new_module = module + f_kwargs = _get_func_kwargs(func, **kwargs) + + if not leaf_only or list(module.parameters(recurse=False)): + new_module = func(new_module, **f_kwargs) + + prefix = kwargs.get("name", "") if not kwargs.get("prefix", "") else f"{kwargs['prefix']}.{kwargs['name']}" + kwargs.pop("i", None) + kwargs.pop("name", None) + kwargs.pop("prefix", None) + + for i, (name, child) in enumerate(module.named_children()): + setattr( + new_module, + name, + map( + child, + func, + recurse=recurse, + leaf_only=leaf_only, + transformed_modules=transformed_modules, + i=i, + name=name, + prefix=prefix, + **kwargs, + ), + ) + + transformed_modules.add(id(new_module)) + + return new_module + + +def _map_module_list( + module_list: _TModule, func: ModuleFunc, recurse=False, leaf_only=False, transformed_modules=None, **kwargs +) -> _TModule: + """Apply a transformation function to a list of modules.""" + if transformed_modules is None: + transformed_modules = set() + + f_kwargs = _get_func_kwargs(func, **kwargs) + if not leaf_only: + module_list = func(module_list, **f_kwargs) + + mapped_modules = [] + prefix = kwargs.get("name", "") if not kwargs.get("prefix", "") else f"{kwargs['prefix']}.{kwargs['name']}" + kwargs.pop("i", None) + kwargs.pop("name", None) + kwargs.pop("prefix", None) + for i, module in enumerate(module_list): + new_module = map( + module, + func, + recurse=recurse, + leaf_only=leaf_only, + transformed_modules=transformed_modules, + i=i, + name=str(i), + prefix=prefix, + **kwargs, + ) + mapped_modules.append(new_module) + + return _create_list_wrapper(module_list, mapped_modules) + + +def _map_module_dict( + module_dict: _TModule, + func: ModuleFunc, + recurse: bool = False, + leaf_only: bool = False, + transformed_modules=None, + **kwargs, +) -> _TModule: + """ + Applies a transformation function to a ModuleDict of modules. + + Parameters: + module_dict : nn.ModuleDict + The ModuleDict of modules to which the function will be applied. + func : ModuleFunc + The function that will be applied to the modules. + recurse : bool, optional + Whether to apply the function recursively to child modules. + leaf_only : bool, optional + Whether to apply the function only to modules without parameters. + **kwargs : dict + Additional keyword arguments that will be passed to the transformation function. + + Returns: + nn.ModuleDict + The ModuleDict of transformed modules. + """ + if transformed_modules is None: + transformed_modules = set() + + f_kwargs = _get_func_kwargs(func, **kwargs) + if not leaf_only: + module_dict = func(module_dict, **f_kwargs) + + mapped_modules = {} + prefix = kwargs.get("name", "") if not kwargs.get("prefix", "") else f"{kwargs['prefix']}.{kwargs['name']}" + kwargs.pop("i", None) + kwargs.pop("name", None) + kwargs.pop("prefix", None) + + for i, (name, module) in enumerate(module_dict.items()): + mapped_modules[name] = map( + module, + func, + recurse=recurse, + leaf_only=leaf_only, + transformed_modules=transformed_modules, + i=i, + name=name, + prefix=prefix, + **kwargs, + ) + + return type(module_dict)(mapped_modules) + + +def _create_list_wrapper(module_list, to_add): + """Create a wrapper for a list of modules, preserving the original type.""" + # Check the signature of the type constructor + sig = inspect.signature(type(module_list).__init__) + if "args" in sig.parameters: + return type(module_list)(*to_add) # Unpack new_modules + + return type(module_list)(to_add) # Don't unpack new_modules + + +def _get_func_kwargs(func, **kwargs): + """Extract kwargs that match the function signature.""" + sig = inspect.signature(func) + return {kwarg: value for kwarg, value in kwargs.items() if kwarg in sig.parameters}