diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index a376c0e065..f46e0dd11c 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -709,7 +709,8 @@ class LoraConfig(PeftConfig): "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger" "overhead than pure LoRA, so it is recommended to merge weights for inference." - ) + ), + "is_lora_variant": True, }, ) velora_config: Optional[Union[VeloraConfig, dict]] = field( @@ -718,7 +719,8 @@ class LoraConfig(PeftConfig): "help": ( "Enable VeLoRA as a LoRA variant by providing a VeloraConfig. VeLoRA swaps in a custom backward pass " "for the LoRA A projection that stores compressed activations instead of the full input activations." - ) + ), + "is_lora_variant": True, }, ) alora_invocation_tokens: Optional[list[int]] = field( @@ -735,7 +737,8 @@ class LoraConfig(PeftConfig): "operations. Overall adapter inference speedups of an order of magnitude or more can occur on vLLM, " "depending on the length of the shared context. Note that merging is not possible due to the selective " "application of the weights." - ) + ), + "is_lora_variant": True, }, ) use_qalora: bool = field( @@ -769,7 +772,8 @@ class LoraConfig(PeftConfig): "The configuration of Monteclora (Monte Carlo Low-Rank Adaptation). If passed, Monteclora will be " "used to add variational Monte Carlo sampling on top of the LoRA adapters. See `MontecloraConfig` " "for details on the individual hyperparameters." - ) + ), + "is_lora_variant": True, }, ) # Enables replicating layers in a model to expand it to a larger model. @@ -827,11 +831,16 @@ class LoraConfig(PeftConfig): "help": ( "Enable BD-LoRA (Block-Diagonal LoRA) by providing a BdLoraConfig. This technique uses block-diagonal matrices for LoRA-A or LoRA-B " "factors to enable faster multi-LoRA serving by eliminating communication overheads in distributed settings." - ) + ), + "is_lora_variant": True, }, ) arrow_config: Optional[ArrowConfig] = field( - default=None, metadata={"help": "The necessary config to apply arrow routing on the model."} + default=None, + metadata={ + "help": "The necessary config to apply arrow routing on the model.", + "is_lora_variant": True, + }, ) ensure_weight_tying: bool = field( default=False, diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 1a07158308..3ea6433db4 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -14,6 +14,7 @@ from __future__ import annotations import copy +import dataclasses import math import warnings from collections.abc import Callable @@ -139,19 +140,46 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]: return _get_in_out_features(module) + @property + def lora_variants(self): + """ + A dictionary mapping the active LoRA variants to their respective classes. + + To extend this, subclasses should override this property and return a dictionary where the keys are tuples of + variant field names (from LoraConfig) and the values are the specific LoraVariant subclasses. + + Tuples are used as keys because they are immutable and hashable, allowing us to safely map combinations of + active variants (e.g., DoRA + another variant) to a specific composed variant class. + """ + return {(): None} + def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: - """Return a matching LoRA variant for this layer type. - Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this - method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA. + # Fetch the dictionary of variants + layer_variants = self.lora_variants + lora_variants_configs = [f for f in dataclasses.fields(config) if f.metadata.get("is_lora_variant")] - If there is no fitting variant, return None. + # 1. Gather all valid variant field names from the config + tagged_fields = {f.name for f in lora_variants_configs} - Note: If this layer type does not support the LoRA variant at all, please raise an error during __init__ as is - convention, and not here. + # 2. SANITY CHECK: Ensure all keys in the layer's dictionary actually exist in the config + for variant_keys in layer_variants.keys(): + for variant_name in variant_keys: + if variant_name not in tagged_fields: + raise ValueError( + f"Variant '{variant_name}' found in lora_variants but it is not tagged with " + f"'is_lora_variant' in LoraConfig." + ) - """ - return None + # 3. Figure out which variants are currently active + active_variants = tuple(sorted(f.name for f in lora_variants_configs if getattr(config, f.name))) + + # 4. Route to the correct variant class + if active_variants not in layer_variants: + raise ValueError(f"Invalid or unsupported variant combination: {active_variants}") + + variant_class = layer_variants[active_variants] + return variant_class() if variant_class else None def update_layer( self, @@ -795,36 +823,19 @@ def __init__( ) self.is_target_conv_1d_layer = is_target_conv_1d_layer - def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: - if config.velora_config is not None: - from .variants import VeloraLinearVariant - - return VeloraLinearVariant() - - if config.arrow_config is not None: - from .variants import ArrowLinearVariant - - return ArrowLinearVariant() - - if config.monteclora_config is not None: - from .variants import MontecloraLinearVariant - - return MontecloraLinearVariant() - if config.use_bdlora is not None: - from .variants import BdLoraLinearVariant - - return BdLoraLinearVariant() - - use_alora = config.alora_invocation_tokens is not None - if not config.use_dora and not use_alora: - return None - - from .variants import ALoraLinearVariant, DoraLinearVariant - - if use_alora: - return ALoraLinearVariant() - else: - return DoraLinearVariant() + @property + def lora_variants(self): + from . import variants + + return { + (): None, + ("use_dora",): variants.DoraLinearVariant, + ("arrow_config",): variants.ArrowLinearVariant, + ("use_bdlora",): variants.BdLoraLinearVariant, + ("alora_invocation_tokens",): variants.ALoraLinearVariant, + ("velora_config",): variants.VeloraLinearVariant, + ("monteclora_config",): variants.MontecloraLinearVariant, + } def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ @@ -1061,15 +1072,14 @@ def __init__( config=config, ) - def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: - if config.velora_config is not None: - raise ValueError("VeLoRA does not support adapting embedding layers.") - if not config.use_dora: - return None - - from .variants import DoraEmbeddingVariant + @property + def lora_variants(self): + from . import variants - return DoraEmbeddingVariant() + return { + (): None, + ("use_dora",): variants.DoraEmbeddingVariant, + } def update_layer( self, @@ -1670,15 +1680,14 @@ def __init__(self, *args, **kwargs): raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}") self.conv_fn = F.conv2d - def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: - if config.velora_config is not None: - raise ValueError("VeLoRA does not support adapting conv layers.") - if not config.use_dora: - return None - - from .variants import DoraConv2dVariant + @property + def lora_variants(self): + from . import variants - return DoraConv2dVariant() + return { + (): None, + ("use_dora",): variants.DoraConv2dVariant, + } class Conv1d(_ConvNd): @@ -1689,15 +1698,14 @@ def __init__(self, *args, **kwargs): raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}") self.conv_fn = F.conv1d - def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: - if config.velora_config is not None: - raise ValueError("VeLoRA does not support adapting conv layers.") - if not config.use_dora: - return None - - from .variants import DoraConv1dVariant + @property + def lora_variants(self): + from . import variants - return DoraConv1dVariant() + return { + (): None, + ("use_dora",): variants.DoraConv1dVariant, + } class Conv3d(_ConvNd): @@ -1708,15 +1716,14 @@ def __init__(self, *args, **kwargs): raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}") self.conv_fn = F.conv3d - def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: - if config.velora_config is not None: - raise ValueError("VeLoRA does not support adapting conv layers.") - if not config.use_dora: - return None - - from .variants import DoraConv3dVariant + @property + def lora_variants(self): + from . import variants - return DoraConv3dVariant() + return { + (): None, + ("use_dora",): variants.DoraConv3dVariant, + } class MultiheadAttention(nn.Module, LoraLayer): diff --git a/tests/test_lora_variants.py b/tests/test_lora_variants.py index 30f2fe53d5..b5c354f8ea 100644 --- a/tests/test_lora_variants.py +++ b/tests/test_lora_variants.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import PropertyMock, patch + import pytest import torch from torch import nn @@ -174,6 +176,38 @@ def test_dora_params_have_gradients(self): for layer in layer_names: assert getattr(peft_model.base_model.model, layer).lora_magnitude_vector["default"].weight.grad is not None + def test_unregistered_variant_raises_error(self): + # 1. Create a config and dummy linear layer + config = LoraConfig() + base_layer = nn.Linear(10, 10) + layer = LoraLinear(base_layer, "default", config, r=8, lora_alpha=8) + + # 2. Monkey-patch the lora_variants property to include a fake variant + with patch("peft.tuners.lora.layer.Linear.lora_variants", new_callable=PropertyMock) as mock_variants: + mock_variants.return_value = {("fake_unregistered_variant",): None} + + # 3. Assert that the sanity check catches it and throws the right error + with pytest.raises( + ValueError, + match="Variant 'fake_unregistered_variant' found in lora_variants but it is not tagged with 'is_lora_variant' in LoraConfig.", + ): + layer.resolve_lora_variant(config=config) + + def test_invalid_variant_combination_raises_error(self): + # 1. Create a config with no variants active + config = LoraConfig() + base_layer = nn.Linear(10, 10) + layer = LoraLinear(base_layer, "default", config, r=8, lora_alpha=8) + + # 2. Monkey-patch lora_variants to include a valid tagged combo that isn't active + with patch("peft.tuners.lora.layer.Linear.lora_variants", new_callable=PropertyMock) as mock_variants: + mock_variants.return_value = { + ("use_dora",): None, # only use_dora is valid, empty combo not listed + } + # 3. Assert invalid combination error is raised + with pytest.raises(ValueError, match="Invalid or unsupported variant combination"): + layer.resolve_lora_variant(config=config) + class TestActivatedLora: @pytest.mark.parametrize( @@ -274,7 +308,7 @@ def test_num_beams_error(self): input_ids = torch.tensor([[0, 1, 2, 3]]) with pytest.raises(ValueError) as e: with torch.no_grad(): - lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3]) + lora_model(X=input_ids, num_beams=2, alora_offsets=[3]) assert "Beam search not yet supported for aLoRA." in str(e.value) def test_gradient_checkpointing_double_forward_raises(self):