diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index 58a4132506..0fb251ed38 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -54,6 +54,19 @@ lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...) ``` For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning). +### MiCA + +[MiCA](https://arxiv.org/abs/2604.01694) (Minor Component Adaptation) is a complement to PiSSA: instead of initializing from the *principal* singular components, MiCA uses the *minor* ones. Concretely, with `W = U Σ V^T`, MiCA sets `B = U[:, -r:]` (the `r` left singular vectors associated with the smallest singular values) and `A = 0`. During training, only `A` is updated; `B` is frozen. The intuition is that the minor singular directions are largely unused by the pretrained task and therefore offer a more "plastic" subspace for injecting new knowledge while preserving pretrained capabilities. + +Because `A == 0` at init, the adapter contribution `B · A == 0` and the model output is preserved exactly at step 0 — no residual subtraction on the base weight is needed (unlike PiSSA). Since only `A` is trainable, the trainable parameter count for matching `r` is roughly half that of LoRA. + +```python +from peft import LoraConfig +config = LoraConfig(init_lora_weights="mica", r=16, target_modules=["q_proj", "v_proj"], ...) +``` + +MiCA currently supports `nn.Linear` and `nn.Embedding` target modules. The chosen rank must satisfy `r <= min(in_features, out_features)` for linear layers and `r <= min(num_embeddings, embedding_dim)` for embedding layers. For detailed usage, see [these instructions](https://github.com/huggingface/peft/tree/main/examples/mica_finetuning). + ### CorDA [CorDA](https://huggingface.co/papers/2406.05223) builds task-aware LoRA adapters from weight decomposition oriented by the context of downstream task to learn (instruction-previewed mode, IPM) or world knowledge to maintain (knowledge-preserved mode, KPM). diff --git a/examples/mica_finetuning/README.md b/examples/mica_finetuning/README.md new file mode 100644 index 0000000000..205c9b2799 --- /dev/null +++ b/examples/mica_finetuning/README.md @@ -0,0 +1,80 @@ +# MiCA: Minor Component Adaptation + +## Introduction ([Paper](https://arxiv.org/abs/2604.01694)) + +Minor Component Adaptation (MiCA) is a parameter-efficient fine-tuning method closely related to LoRA. Like LoRA, MiCA inserts a low-rank update `ΔW = (α/r) · B · A` into a pretrained weight `W ∈ R^{out×in}`. Unlike LoRA, MiCA initializes the matrices from the singular value decomposition of `W` and trains only one of them: + +- Compute the SVD `W = U Σ V^T`. +- Initialize `B = U[:, -r:]` — the `r` left singular vectors associated with the **smallest** singular values. +- Initialize `A = 0`. +- During training, optimize only `A`; `W` and `B` remain frozen. + +The motivation is that the *minor* singular directions of a pretrained weight encode subspaces that are largely unused by the original task. Restricting adaptation to these directions provides a more "plastic" subspace for knowledge injection, with less risk of overwriting capabilities encoded in the dominant subspace. Empirically MiCA improves knowledge acquisition while reducing the trainable parameter footprint compared with LoRA at the same rank (because only `A` is trained, the parameter count is roughly halved for matching `r`). + +Because `A == 0` at initialization, the adapter contribution `B · A == 0` and the model's forward output is preserved exactly at step 0 — no residual subtraction is needed on the base weight. + +## Quick Start + +```python +import torch +from peft import LoraConfig, get_peft_model +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer +from datasets import load_dataset + +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") +tokenizer.pad_token_id = tokenizer.eos_token_id + +lora_config = LoraConfig( + init_lora_weights="mica", + r=16, + lora_alpha=16, + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", +) +peft_model = get_peft_model(model, lora_config) +peft_model.print_trainable_parameters() + +dataset = load_dataset("imdb", split="train[:1%]") +training_args = SFTConfig(dataset_text_field="text", max_length=128) +trainer = SFTTrainer( + model=peft_model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, +) +trainer.train() +peft_model.save_pretrained("mica-llama-2-7b") +``` + +To reload the trained adapter: + +```python +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", dtype=torch.bfloat16, device_map="auto" +) +peft_model = PeftModel.from_pretrained(model, "mica-llama-2-7b") +``` + +## Notes and limitations + +- MiCA currently supports `nn.Linear` and `nn.Embedding` target modules. +- The chosen rank must satisfy `r <= min(in_features, out_features)` for linear layers and `r <= min(num_embeddings, embedding_dim)` for embedding layers; otherwise initialization raises `ValueError`. +- MiCA performs a full SVD per target weight at initialization. For 7B-scale models this is a one-time cost of seconds; for substantially larger weight matrices (e.g. 70B-scale) the cost grows. +- Combining MiCA with `use_dora=True` or other LoRA variants is not supported in this initial integration. + +## Citation + +``` +@article{rudiger2026mica, + title={MiCA Learns More Knowledge Than LoRA and Full Fine-Tuning}, + author={R{\"u}diger, Sten and Raschka, Sebastian}, + journal={arXiv preprint arXiv:2604.01694}, + year={2026} +} +``` diff --git a/examples/mica_finetuning/mica_finetuning.py b/examples/mica_finetuning/mica_finetuning.py new file mode 100644 index 0000000000..39e11d3917 --- /dev/null +++ b/examples/mica_finetuning/mica_finetuning.py @@ -0,0 +1,80 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Minimal MiCA fine-tuning example. + +Mirrors `examples/pissa_finetuning/pissa_finetuning.py` in spirit but with the MiCA-specific knobs only. MiCA +initializes `B` from the bottom-r left singular vectors of the base weight and freezes it during training; only +`A` is updated. Because `A == 0` at init, the adapter is a no-op on initialization and no residual subtraction +on the base weight is needed. +""" + +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from trl import SFTConfig, SFTTrainer + +from peft import LoraConfig, get_peft_model + + +@dataclass +class ScriptArguments(SFTConfig): + base_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "Name or path of the base model."}) + lora_r: int = field(default=16) + lora_alpha: int = field(default=16) + lora_dropout: float = field(default=0.0) + target_modules: Optional[str] = field( + default="q_proj,v_proj", + metadata={"help": "Comma-separated module names to adapt with MiCA."}, + ) + data_path: str = field(default="imdb", metadata={"help": "HF dataset path."}) + dataset_split: str = field(default="train[:1%]") + dataset_text_field: str = field(default="text") + + +def train(): + parser = HfArgumentParser(ScriptArguments) + args = parser.parse_args_into_dataclasses()[0] + + model = AutoModelForCausalLM.from_pretrained(args.base_model_name_or_path, dtype=torch.bfloat16, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + lora_config = LoraConfig( + init_lora_weights="mica", + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=[m.strip() for m in args.target_modules.split(",")], + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(model, lora_config) + peft_model.print_trainable_parameters() + + dataset = load_dataset(args.data_path, split=args.dataset_split) + trainer = SFTTrainer( + model=peft_model, + args=args, + train_dataset=dataset, + processing_class=tokenizer, + ) + trainer.train() + peft_model.save_pretrained(args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/method_comparison/MetaMathQA/experiments/lora/llama-3.2-3B-rank32-mica/adapter_config.json b/method_comparison/MetaMathQA/experiments/lora/llama-3.2-3B-rank32-mica/adapter_config.json new file mode 100644 index 0000000000..e62097f635 --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/lora/llama-3.2-3B-rank32-mica/adapter_config.json @@ -0,0 +1,30 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": null, + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": false, + "init_lora_weights": "mica", + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": null, + "task_type": "CAUSAL_LM", + "use_dora": false, + "use_rslora": false +} diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index a376c0e065..d057a904b8 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -408,7 +408,7 @@ class LoraConfig(PeftConfig): use the original default value of `lora_alpha/r`. modules_to_save (`List[str]`): List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. - init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"]`): + init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal", "mica"]`): How to initialize the weights of the adapter layers. Passing True (default) results in the default initialization from the reference implementation from Microsoft, with the LoRA B weight being set to 0. This means that without further training, the LoRA adapter will be a no-op. Setting the initialization to @@ -430,7 +430,10 @@ class LoraConfig(PeftConfig): converges even more rapidly than PiSSA in Instruction-Previewed Mode, and preserves world knowledge better than LoRA in Knowledge-Preserved Mode. Passing `"orthogonal"` results in LoRA A and B being intialized orthogonally; in this, it resembles `"olora"`, but the base weights are left untouched (requires `r` to be - even, only supported for linear layers for now). + even, only supported for linear layers for now). Passing `"mica"` results in the initialization of Minor Component Adaptation (MiCA), which initializes B from + the r left singular vectors of the base weight associated with the smallest singular values, sets A to + zero, and freezes B during training; only A is updated. Currently supported for linear and embedding layers. layers_to_transform (`Union[List[int], int]`): The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices that are specified in this list. If a single integer is passed, it will apply the transformations on the @@ -566,7 +569,17 @@ class LoraConfig(PeftConfig): ) init_lora_weights: ( bool - | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"] + | Literal[ + "gaussian", + "eva", + "olora", + "pissa", + "pissa_niter_[number of iters]", + "corda", + "loftq", + "orthogonal", + "mica", + ] ) = field( default=True, metadata={ @@ -586,7 +599,10 @@ class LoraConfig(PeftConfig): "nonnegative integer. " "Passing `'corda'` results in CorDA initialization. " "Pass `'loftq'` to use LoftQ initialization. " - "Pass `'orthogonal'` for orthogonal initialization of LoRA A and B." + "Pass `'orthogonal'` for orthogonal initialization of LoRA A and B. " + "Pass `'mica'` to use MiCA initialization, where B is set to the r left singular vectors of the " + "base weight associated with the smallest singular values, A is set to zero, and B is frozen during " + "training (only A is updated)." ), }, ) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 1a07158308..9a2e6f6c02 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -136,6 +136,10 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * self.in_features = in_features self.out_features = out_features + def delete_adapter(self, adapter_name: str) -> None: + super().delete_adapter(adapter_name) + self.lora_variant.pop(adapter_name, None) + def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]: return _get_in_out_features(module) @@ -231,6 +235,9 @@ def update_layer( elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": with gather_params_ctx(self.get_base_layer().weight): self.olora_init(adapter_name) + elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "mica": + with gather_params_ctx(self.get_base_layer().weight): + self.mica_init(adapter_name) elif init_lora_weights == "loftq": with gather_params_ctx(self.get_base_layer().weight): self.loftq_init(adapter_name, config) @@ -395,6 +402,41 @@ def pissa_init(self, adapter_name, init_lora_weights): weight = transpose(weight.to(dtype), self.fan_in_fan_out) self.get_base_layer().weight.data = weight + def mica_init(self, adapter_name): + """Minor Component Adaptation (MiCA) initialization (https://arxiv.org/abs/2604.01694). + + Initializes `lora_B` from the `r` left singular vectors of the base weight associated with the smallest + singular values, and sets `lora_A` to zero. The `lora_B` matrix is frozen during training (see + `MiCALinearVariant.init`); only `lora_A` is updated. Because `lora_A == 0` at init, the adapter + contribution `B @ A == 0` and the base weight does not need to be modified to preserve the forward output. + """ + # When the adapter is being created under `init_empty_weights` (e.g. low_cpu_mem_usage=True), its parameters + # live on the meta device and will be filled in from a checkpoint after creation. Skip the SVD in that case. + if self.lora_B[adapter_name].weight.device.type == "meta": + return + + weight = self.get_base_layer().weight + dtype = weight.dtype + if dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError("Please initialize MiCA under float32, float16, or bfloat16.") + + weight = transpose(weight.to(torch.float32), self.fan_in_fan_out) + # weight has shape (out_features, in_features) once transposed for fan_in_fan_out, matching nn.Linear.weight. + # SVD: weight = U @ diag(S) @ Vh, with U: (out, k), Vh: (k, in), S sorted descending. + # MiCA selects the LAST r left singular vectors (smallest singular values) for B and zeroes A. + r = self.r[adapter_name] + max_r = min(weight.shape) + if r > max_r: + raise ValueError( + f"MiCA requires `r` <= min(in_features, out_features) but got r={r} for a layer with " + f"weight shape {tuple(weight.shape)} (max usable r is {max_r})." + ) + U, _, _ = torch.linalg.svd(weight.data, full_matrices=False) + lora_B = U[:, -r:].contiguous() + lora_A = torch.zeros(r, weight.shape[1], device=weight.device) + self.lora_B[adapter_name].weight.data = lora_B.to(dtype) + self.lora_A[adapter_name].weight.data = lora_A.to(dtype) + def corda_init(self, adapter_name, init_lora_weights): linear = self.get_base_layer() weight = linear.weight @@ -815,6 +857,11 @@ def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVar return BdLoraLinearVariant() + if isinstance(config.init_lora_weights, str) and config.init_lora_weights.lower() == "mica": + from .variants import MiCALinearVariant + + return MiCALinearVariant() + use_alora = config.alora_invocation_tokens is not None if not config.use_dora and not use_alora: return None @@ -1064,6 +1111,10 @@ def __init__( def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: if config.velora_config is not None: raise ValueError("VeLoRA does not support adapting embedding layers.") + if isinstance(config.init_lora_weights, str) and config.init_lora_weights.lower() == "mica": + from .variants import MiCAEmbeddingVariant + + return MiCAEmbeddingVariant() if not config.use_dora: return None @@ -1116,7 +1167,10 @@ def update_layer( self.use_dora[adapter_name] = config.use_dora - if init_lora_weights == "loftq": + if isinstance(init_lora_weights, str) and init_lora_weights.lower() == "mica": + with gather_params_ctx(self.get_base_layer().weight): + self.mica_init(adapter_name) + elif init_lora_weights == "loftq": self.loftq_init(adapter_name) elif init_lora_weights == "lora_ga": # Embedding layers don't support LoRA-GA, fall back to standard initialization @@ -1145,6 +1199,36 @@ def output_fn(outputs): self.input_fns[adapter_name] = input_fn self.output_fns[adapter_name] = output_fn + def mica_init(self, adapter_name): + """Minor Component Adaptation (MiCA) initialization for embedding layers. + + The effective embedding projection has shape `(embedding_dim, num_embeddings)`, so MiCA initializes + `lora_embedding_B` from the minor left singular vectors of `base_layer.weight.T` and sets + `lora_embedding_A` to zero. + """ + if self.lora_embedding_B[adapter_name].device.type == "meta": + return + + weight = self.get_base_layer().weight + dtype = weight.dtype + if dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError("Please initialize MiCA under float32, float16, or bfloat16.") + + weight = weight.to(torch.float32).T + r = self.r[adapter_name] + max_r = min(weight.shape) + if r > max_r: + raise ValueError( + f"MiCA requires `r` <= min(num_embeddings, embedding_dim) but got r={r} for an embedding layer with " + f"weight shape {tuple(self.get_base_layer().weight.shape)} (max usable r is {max_r})." + ) + + U, _, _ = torch.linalg.svd(weight.data, full_matrices=False) + lora_embedding_B = U[:, -r:].contiguous() + lora_embedding_A = torch.zeros(r, weight.shape[1], device=weight.device) + self.lora_embedding_B[adapter_name].data = lora_embedding_B.to(dtype) + self.lora_embedding_A[adapter_name].data = lora_embedding_A.to(dtype) + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ Merge the active adapter weights into the base weights diff --git a/src/peft/tuners/lora/variants.py b/src/peft/tuners/lora/variants.py index 43e465e642..18763c5a98 100644 --- a/src/peft/tuners/lora/variants.py +++ b/src/peft/tuners/lora/variants.py @@ -1216,3 +1216,93 @@ def forward( out_A = F.linear(x_dropped, current_weight_A) result = result + lora_B(out_A) * scaling return result + + +def _register_frozen_peft_weight(module: Linear | Embedding, adapter_name: str, weight_name: str) -> None: + frozen_peft_weight_names = module.frozen_peft_weight_names.copy() + frozen_names = frozen_peft_weight_names.get(adapter_name, ()) + frozen_peft_weight_names[adapter_name] = tuple(dict.fromkeys((*frozen_names, weight_name))) + module.frozen_peft_weight_names = frozen_peft_weight_names + module._freeze_declared_peft_weights(adapter_name) + + +class MiCALinearVariant(LoraVariant): + """Variant for Minor Component Adaptation (MiCA), https://arxiv.org/abs/2604.01694. + + The actual SVD-based initialization is performed in `LoraLayer.mica_init` (called from `update_layer`); this + variant declares `lora_B` as frozen. Forward and merge semantics are identical to vanilla LoRA, since + `delta_W = scaling * B @ A` and only `A` is updated. + """ + + @staticmethod + def init(module: Linear, adapter_name: str, **kwargs: Any) -> None: + _register_frozen_peft_weight(module, adapter_name, "lora_B") + + @staticmethod + def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + return orig_weight + module.get_delta_weight(active_adapter).to(orig_weight.dtype) + + @staticmethod + def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: + orig_weight.data += module.get_delta_weight(active_adapter).to(orig_weight.dtype) + + @staticmethod + def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + return orig_weight - module.get_delta_weight(active_adapter).to(orig_weight.dtype) + + @staticmethod + def forward( + module: Linear, + active_adapter: str, + x: torch.Tensor, + result: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + lora_A = module.lora_A[active_adapter] + lora_B = module.lora_B[active_adapter] + dropout = module.lora_dropout[active_adapter] + scaling = module.scaling[active_adapter] + return result + lora_B(lora_A(dropout(x))) * scaling + + +class MiCAEmbeddingVariant(LoraVariant): + """Embedding variant for Minor Component Adaptation (MiCA), https://arxiv.org/abs/2604.01694.""" + + @staticmethod + def init(module: Embedding, adapter_name: str, **kwargs: Any) -> None: + _register_frozen_peft_weight(module, adapter_name, "lora_embedding_B") + + @staticmethod + def merge_safe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + return orig_weight + module.get_delta_weight(active_adapter).to(orig_weight.dtype) + + @staticmethod + def merge_unsafe(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> None: + orig_weight.data += module.get_delta_weight(active_adapter).to(orig_weight.dtype) + + @staticmethod + def unmerge(module: Embedding, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + return orig_weight - module.get_delta_weight(active_adapter).to(orig_weight.dtype) + + @staticmethod + def forward( + module: Embedding, + active_adapter: str, + x: torch.Tensor, + result: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + embedding_A = module.lora_embedding_A[active_adapter].T + embedding_B = module.lora_embedding_B[active_adapter].T + scaling = module.scaling[active_adapter] + input_fn = module.input_fns.get(active_adapter, None) + output_fn = module.output_fns.get(active_adapter, None) + + after_A = module._embed(x, embedding_A, input_fn=input_fn, output_fn=output_fn) + adapter_output = (after_A @ embedding_B) * scaling + + embed_scale = module._get_embed_scale() + if embed_scale is not None: + adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) + + return result + adapter_output diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 7227e869dd..84ac173118 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -495,6 +495,10 @@ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: else: raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + for module in model.modules(): + if isinstance(module, self.tuner_layer_cls): + module._freeze_declared_peft_weights() + def _enable_adapter_layers(self, enabled: bool = True) -> None: for module in self.model.modules(): if isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)): @@ -1385,6 +1389,8 @@ class BaseTunerLayer(ABC): adapter_layer_names: tuple[str, ...] = () # All names of other parameters that may contain adapter-related parameters other_param_names: tuple[str, ...] = () + # Mapping from adapter name to adapter layer names that should always stay frozen + frozen_peft_weight_names: dict[str, tuple[str, ...]] = {} # indicates whether all adapters should be disabled _disable_adapters: bool = False @@ -1521,6 +1527,23 @@ def enable_adapters(self, enabled: bool) -> None: _set_layer_requires_grad(layer, False) self._disable_adapters = True + def _freeze_declared_peft_weights(self, adapter_names: str | Sequence[str] | None = None) -> None: + if adapter_names is None: + adapter_names = self.frozen_peft_weight_names.keys() + elif isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for adapter_name in adapter_names: + for layer_name in self.frozen_peft_weight_names.get(adapter_name, ()): + if layer_name not in self.adapter_layer_names: + continue + + module_dict = getattr(self, layer_name) + if adapter_name not in module_dict: + continue + + _set_layer_requires_grad(module_dict[adapter_name], False) + def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). @@ -1543,6 +1566,7 @@ def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = Fal should_require_grad = (key in adapter_names) and (not inference_mode) _set_layer_requires_grad(layer, should_require_grad) + self._freeze_declared_peft_weights(adapter_names) self._active_adapter = adapter_names def _all_available_adapter_names(self) -> list[str]: @@ -1573,6 +1597,11 @@ def delete_adapter(self, adapter_name: str) -> None: if adapter_name in getattr(self, attr): del getattr(self, attr)[adapter_name] + if adapter_name in self.frozen_peft_weight_names: + frozen_peft_weight_names = self.frozen_peft_weight_names.copy() + del frozen_peft_weight_names[adapter_name] + self.frozen_peft_weight_names = frozen_peft_weight_names + if adapter_name in self.active_adapters: # choose a new active adapter active_adapters = self.active_adapters[:] @@ -1614,6 +1643,9 @@ def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: b if key in adapter_names_set: _set_layer_requires_grad(layer, requires_grad) + if requires_grad: + self._freeze_declared_peft_weights(adapter_names_set) + def _get_base_layer_device_and_dtype(self, base_layer): """ Helper function to determine the device and dtype of the base layer. If not possible to determine, return None. diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 7a809a6b4c..4cb8f1fc89 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -101,6 +101,12 @@ ), ("Vanilla MLP 7 LoRA with DoRA", "MLP", LoraConfig, {"target_modules": ["lin0"], "use_dora": True}), ("Vanilla MLP 8 LoRA with DoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_dora": True}), + ( + "Vanilla MLP 1 LoRA with MiCA", + "MLP", + LoraConfig, + {"target_modules": ["lin0"], "init_lora_weights": "mica", "r": 4}, + ), ( "Vanilla MLP 9 LoRA with DoRA", "MLP", @@ -2542,8 +2548,12 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k # via Monte Carlo sampling), so we don't include them in the strict allclose check below. continue if (model.prefix in name) or ("modules_to_save" in name) or ("token_adapter.trainable_tokens" in name): - # target_modules, modules_to_save and modules of `NewTokensWrapper` _are_ updated - assert not torch.allclose(param_before, param_after, atol=tol, rtol=tol) + # target_modules, modules_to_save and modules of `NewTokensWrapper` _are_ updated, except for adapter + # parameters that the variant intentionally freezes (e.g. MiCA freezes lora_B). + if not param_after.requires_grad: + assert torch.equal(param_before, param_after) + else: + assert not torch.allclose(param_before, param_after, atol=tol, rtol=tol) else: assert torch.allclose(param_before, param_after, atol=tol, rtol=tol) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 27d9ba16e2..379e976f58 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -389,6 +389,135 @@ def test_lora_pissa_linear_init_default(self, data): peft_model = get_peft_model(deepcopy(model), config) assert torch.allclose(output, peft_model(data)[0], atol=1e-06) + def test_lora_mica_linear_init_default(self, data): + # MiCA initializes A=0 and B = bottom-r left singular vectors of W. Because A=0, the adapter contribution + # B @ A is zero at init, so the forward output must equal the base model's output exactly. + model = self.get_model() + output = model(data)[0] + + config = LoraConfig(init_lora_weights="mica", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + + weight_A = peft_model.base_model.linear.lora_A["default"].weight + weight_B = peft_model.base_model.linear.lora_B["default"].weight + + # A must be zero + assert torch.all(weight_A == 0) + # B columns must be orthonormal (since they are left singular vectors) + eye = torch.eye(weight_B.shape[1], device=weight_B.device, dtype=weight_B.dtype) + assert torch.allclose(weight_B.t() @ weight_B, eye, atol=1e-4) + # Output at init equals the base output + assert torch.allclose(output, peft_model(data)[0], atol=1e-06) + + def test_lora_mica_embedding_init_default(self): + class EmbeddingModel(nn.Module): + def __init__(self): + super().__init__() + self.embed = nn.Embedding(7, 5) + + def forward(self, x): + return self.embed(x) + + model = EmbeddingModel() + data = torch.arange(7).unsqueeze(0) + output = model(data) + + config = LoraConfig(init_lora_weights="mica", target_modules=["embed"], r=3) + peft_model = get_peft_model(deepcopy(model), config) + + weight_A = peft_model.base_model.embed.lora_embedding_A["default"] + weight_B = peft_model.base_model.embed.lora_embedding_B["default"] + + assert torch.all(weight_A == 0) + eye = torch.eye(weight_B.shape[1], device=weight_B.device, dtype=weight_B.dtype) + assert torch.allclose(weight_B.t() @ weight_B, eye, atol=1e-4) + assert weight_A.requires_grad + assert not weight_B.requires_grad + assert torch.allclose(output, peft_model(data), atol=1e-06) + + def test_lora_mica_uses_minor_components(self): + # Verify B equals the *minor* (smallest singular value) left singular vectors, not the major ones. + torch.manual_seed(0) + model = self.get_model() + r = 8 + + config = LoraConfig(init_lora_weights="mica", target_modules=["linear"], r=r) + peft_model = get_peft_model(deepcopy(model), config) + weight_B = peft_model.base_model.linear.lora_B["default"].weight.detach().cpu() + + # Reference SVD of the original weight + W = model.linear.weight.detach().cpu().to(torch.float32) + U, _S, _ = torch.linalg.svd(W, full_matrices=False) + minor_U = U[:, -r:] + major_U = U[:, :r] + + # B should span the same subspace as `minor_U` (column spans match up to sign/orthogonal mixing within + # equal-singular-value groups). Equality of projectors is the right invariant. + proj_B = weight_B @ weight_B.t() + proj_minor = minor_U @ minor_U.t() + proj_major = major_U @ major_U.t() + assert torch.allclose(proj_B, proj_minor, atol=1e-4) + assert not torch.allclose(proj_B, proj_major, atol=1e-2) + + def test_lora_mica_freezes_B(self): + model = self.get_model() + config = LoraConfig(init_lora_weights="mica", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + + assert peft_model.base_model.linear.lora_A["default"].weight.requires_grad + assert not peft_model.base_model.linear.lora_B["default"].weight.requires_grad + + def test_lora_mica_freezes_B_when_switching_adapters(self): + class SimpleMlp(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 10) + self.fc2 = nn.Linear(10, 10) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + return self.fc2(x) + + def trainable_parameters(model): + return [name for name, param in model.named_parameters() if param.requires_grad] + + config0 = LoraConfig(target_modules=["fc1"], init_lora_weights="mica", r=4) + model = get_peft_model(SimpleMlp(), config0) + assert trainable_parameters(model) == ["base_model.model.fc1.lora_A.default.weight"] + + config1 = LoraConfig(target_modules=["fc1", "fc2"], init_lora_weights="mica", r=4) + model.add_adapter("other", config1) + model.set_adapter("other") + assert trainable_parameters(model) == [ + "base_model.model.fc1.lora_A.other.weight", + "base_model.model.fc2.lora_A.other.weight", + ] + + model.set_adapter("default") + assert trainable_parameters(model) == ["base_model.model.fc1.lora_A.default.weight"] + + model.delete_adapter("other") + config2 = LoraConfig(target_modules=["fc1"], r=4) + model.add_adapter("other", config2) + model.set_adapter("other") + assert trainable_parameters(model) == [ + "base_model.model.fc1.lora_A.other.weight", + "base_model.model.fc1.lora_B.other.weight", + ] + + def test_lora_mica_rank_too_large_raises(self): + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 3) + + def forward(self, x): + return self.linear(x) + + config = LoraConfig(init_lora_weights="mica", target_modules=["linear"], r=3) + with pytest.raises(ValueError, match="MiCA requires `r` <= min"): + get_peft_model(SimpleModel(), config) + def test_lora_olora_linear_init_default(self, data): model = self.get_model() output = model(data)[0] diff --git a/tests/testing_common.py b/tests/testing_common.py index 16d56d5783..c3330124cb 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1036,7 +1036,11 @@ def _test_training(self, model_id, config_cls, config_kwargs): for n, param in model.named_parameters(): if (model.prefix in n) or ("modules_to_save" in n) or ("token_adapter.trainable_tokens" in n): - assert param.grad is not None + # variants like MiCA intentionally freeze a subset of adapter params, which won't have a grad + if param.requires_grad: + assert param.grad is not None + else: + assert param.grad is None else: assert param.grad is None