Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mbridge/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Adapted from NVIDIA Megatron-Bridge

from mbridge.peft.base import PEFT
from mbridge.peft.canonical_lora import CanonicalLoRA
from mbridge.peft.lora import (LoRA, LoRAMerge, gather_lora_state_dict,
infer_hf_target_modules, lora_merged,
mcore_adapter_name_to_hf)
203 changes: 203 additions & 0 deletions mbridge/peft/adapter_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Adapted from NVIDIA Megatron-Bridge

from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

import torch
import torch.nn as nn
from mbridge.peft.utils import ParallelLinearAdapter

if TYPE_CHECKING:
from megatron.core.dist_checkpointing.mapping import ShardedStateDict


def _compute_mamba_dim_info(wrapped_module: nn.Module) -> Dict[str, int]:
"""Compute Mamba dimension information from a wrapped module's config.

This follows the same logic as mamba_mixer.py to derive local tensor parallel
dimensions from the TransformerConfig.

Args:
wrapped_module: The wrapped module (typically a linear projection in a Mamba layer).

Returns:
Dictionary containing d_inner_local_tp, ngroups_local_tp, d_state, and nheads_local_tp.
"""
config = wrapped_module.config

# Get base dimensions from config
d_state = config.mamba_state_dim
headdim = config.mamba_head_dim
ngroups = config.mamba_num_groups

# Compute nheads and d_inner
if config.mamba_num_heads is not None:
nheads = config.mamba_num_heads
d_inner = nheads * headdim
else:
d_inner = wrapped_module.d_inner
nheads = d_inner // headdim

# Get tensor parallel size and compute local dimensions
tp_size = wrapped_module.tp_size

return {
"d_inner_local_tp": d_inner // tp_size,
"ngroups_local_tp": ngroups // tp_size,
"d_state": d_state,
"nheads_local_tp": nheads // tp_size,
}


class AdapterWrapper(nn.Module):
"""Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT).

This class wraps a module and its associated adapter, providing methods for
managing the state dictionaries of both the main module and the adapter. It does not
implement the forward method, which must be implemented by concrete subclasses.

Attributes:
to_wrap (nn.Module): The main module to be wrapped.
adapter (nn.Module): The adapter module to be applied.

Note:
This class is abstract and cannot be instantiated directly. Subclasses must
implement the forward method.

Example:
class LoRALinear(AdapterWrapper):
def __init__(self, to_wrap, adapter):
super().__init__(to_wrap, adapter)

def forward(self, x):
return self.to_wrap(x) + self.adapter(x)

main_module = nn.Linear(100, 100)
adapter = nn.Linear(100, 100)
parallel_adapter = LoRALinear(main_module, adapter)
"""

def __init__(self, to_wrap: nn.Module, adapter: nn.Module) -> None:
"""Initialize the AdapterWrapper with a main module and adapter.

Args:
to_wrap: The main module to be wrapped.
adapter: The adapter module to be applied.
"""
super(AdapterWrapper, self).__init__()
self.to_wrap = to_wrap
self.adapter = adapter
self._adapter_enabled = True

def enable_adapter_layers(self) -> None:
"""Enable the adapter layers, allowing them to contribute to the forward pass output."""
self._adapter_enabled = True

def disable_adapter_layers(self) -> None:
"""Disable the adapter layers, making the forward pass return only the base module output."""
self._adapter_enabled = False

def base_linear_forward(
self, x: torch.Tensor, *args: Any, **kwargs: Any
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""Run the forward method of the linear module `to_wrap`.

This method handles the complex return patterns of Megatron's linear layers,
which can return different combinations of outputs, biases, and layernorm outputs.

The flow is: x -> [layernorm/identity] -> layernorm_output -> [linear] -> linear_output, bias

Args:
x: Input tensor.
*args: Additional positional arguments for the wrapped module.
**kwargs: Additional keyword arguments for the wrapped module.

Returns:
A tuple containing:
- linear_output: The output from the linear layer
- bias: The bias term (if present, otherwise None)
- layernorm_output: The output from layernorm (differs from x only for
LayerNormColumnParallelLinear, otherwise equals x)

Note:
The wrapped module can return values in four different patterns:
1. nothing: (out, None)
2. return_bias: (out, bias)
3. return_layernorm_output: ((out, ln_out), None)
4. both: (out, bias, ln_out)
"""
linear_output = self.to_wrap(x, *args, **kwargs)
assert isinstance(linear_output, tuple), (
f"{self.to_wrap} should return a tuple but instead returns {linear_output}"
)

bias = None
layernorm_output = x

if len(linear_output) == 2:
linear_output, bias = linear_output
if isinstance(linear_output, tuple) and len(linear_output) == 2:
linear_output, layernorm_output = linear_output
elif len(linear_output) == 3:
linear_output, bias, layernorm_output = linear_output

return linear_output, bias, layernorm_output

def state_dict(
self, destination: Optional[Dict[str, Any]] = None, prefix: str = "", keep_vars: bool = False
) -> Dict[str, Any]:
"""Retrieve the state dictionary of the wrapped module and adapter.

This method overrides the default state_dict behavior to include both
the main module's state and the adapter's state under a special 'adapter' prefix.

Args:
destination: A dictionary to store the state. If None, a new
dictionary is created. Defaults to None.
prefix: A prefix added to parameter and buffer names. Defaults to ''.
keep_vars: If True, returns variables instead of tensor values.
Defaults to False.

Returns:
The state dictionary containing both the main module and adapter states.
"""
if destination is None:
destination = {}

# Get state dict of the main module
self.to_wrap.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

# Store adapter state dict under the "adapter" prefix in the destination dict
self.adapter.state_dict(destination=destination, prefix=f"{prefix}adapter.", keep_vars=keep_vars)
return destination

def sharded_state_dict(
self,
prefix: str = "",
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[Dict[str, Any]] = None,
) -> "ShardedStateDict":
"""Retrieve the sharded state dictionary of the wrapped module and adapter.

This method is used for distributed checkpointing, combining the sharded states
of both the main module and the adapter.

Args:
prefix: A prefix added to parameter and buffer names. Defaults to ''.
sharded_offsets: Offsets for sharded parameters. Defaults to an empty tuple.
metadata: Additional metadata for the sharded state. Defaults to None.

Returns:
The combined sharded state dictionary.
"""
adapter_sharded_state_dict_kwargs = {}
if isinstance(self.adapter, ParallelLinearAdapter) and "mixer.in_proj" in self.adapter.base_linear_name:
adapter_sharded_state_dict_kwargs["mamba_dim_info"] = _compute_mamba_dim_info(self.to_wrap)

sharded_state_dict = {}
sharded_state_dict.update(self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata))
sharded_state_dict.update(
self.adapter.sharded_state_dict(
f"{prefix}adapter.", sharded_offsets, metadata, **adapter_sharded_state_dict_kwargs
)
)
return sharded_state_dict
159 changes: 159 additions & 0 deletions mbridge/peft/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Adapted from NVIDIA Megatron-Bridge

import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Optional, TypeVar, Union

import torch
import torch.nn as nn
from mbridge.peft.recompute import maybe_enable_recompute_inputs_grad
from mbridge.peft.walk_utils import walk
from megatron.core.transformer.module import MegatronModule

logger: logging.Logger = logging.getLogger(__name__)

ModelType = TypeVar("ModelType", bound=Union[nn.Module, list[MegatronModule]])


@dataclass
class PEFT(ABC):
"""Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods.

This class defines the interface for PEFT methods, which are used to fine-tune
large language models efficiently by modifying only a small subset of the model's
parameters.
"""

# Runtime state that should not be serialized in checkpoints
params_to_save: set[str] = field(default_factory=set, init=False, repr=False)

@abstractmethod
def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module:
"""Transform a single module according to the PEFT method.

This method is called for each module in the model during the PEFT application process.
It should be implemented by subclasses to define how individual modules are transformed
for the specific PEFT technique.

Args:
module (nn.Module): The individual module to be transformed.
name (Optional[str]): The name of the module within the model structure. Defaults to None.
prefix (Optional[str]): A prefix to be added to the module name, typically used for
nested modules. Defaults to None.

Returns:
nn.Module: The transformed module. This can be the original module with modifications,
a new module replacing the original, or the original module if no
transformation is needed for this specific module.

Note:
This method is automatically called for each module in the model when the PEFT
instance is applied to the model using the __call__ method.
"""
raise NotImplementedError("The transform method should be implemented by subclasses.")

def __call__(self, model: ModelType, training: bool = True) -> ModelType:
"""Apply the PEFT method to the entire model.

This method freezes the model parameters and walks through the model
structure, applying the transform method to each module.

Args:
model: The model to be fine-tuned. Can be a single model or a list of model chunks
(for pipeline parallelism).
training (bool): Whether the model will be used for training. If False,
additional freezing may be applied. Defaults to True.

Returns:
The same type as the input model, transformed with PEFT applied.
"""
self.freeze_model(model, training=training)

self._walk_model(model, self.transform)

if training:
maybe_enable_recompute_inputs_grad(model)

if not training:
self.freeze_model(model, training=training)

# Set model training mode appropriately
if isinstance(model, list):
for model_chunk in model:
model_chunk.train(mode=training)
else:
model.train(mode=training)

return model

def _walk_model(self, model: ModelType, func) -> None:
if isinstance(model, list):
for model_chunk in model:
walk(model_chunk, func)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
walk(model.module, func)
else:
walk(model, func)

def enable_adapter_layers(self, model: ModelType) -> None:
"""Enable adapter layers for all PEFT-wrapped modules in the model."""

def enable(module: nn.Module) -> nn.Module:
method = getattr(module, "enable_adapter_layers", None)
if callable(method):
method()
return module

self._walk_model(model, enable)

def disable_adapter_layers(self, model: ModelType) -> None:
"""Disable adapter layers for all PEFT-wrapped modules in the model."""

def disable(module: nn.Module) -> nn.Module:
method = getattr(module, "disable_adapter_layers", None)
if callable(method):
method()
return module

self._walk_model(model, disable)

@contextmanager
def disable_adapter(self, model: ModelType):
"""
Disables the adapter module.
"""
try:
self.disable_adapter_layers(model)
yield
finally:
self.enable_adapter_layers(model)

def freeze_model(self, model: ModelType, training: bool = True) -> None:
"""Apply a default freeze method to the model.

This method freezes all the model parameters. This method can be overridden by subclasses to
implement custom freeze strategies (e.g. freeze only parts of the model)

Args:
model: The model to be fine-tuned.
training (bool): Whether the model is being used for training. Affects training mode handling.
"""

def freeze_parameters(module):
"""Freeze all parameters in a module."""
for param in module.parameters(recurse=False):
param.requires_grad = False
return module

self._walk_model(model, freeze_parameters)

if training:
if isinstance(model, list):
for model_chunk in model:
model_chunk.train(mode=True)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
model.module.train(mode=True)
else:
model.train(mode=True)
Loading
Loading