From 857f6ff7ce69eb692df42659808030ac5a5a0762 Mon Sep 17 00:00:00 2001 From: robbiebusinessacc <65429016+robbiebusinessacc@users.noreply.github.com> Date: Tue, 2 Jun 2026 15:48:17 +0300 Subject: [PATCH] ENH Add KaSA as a LoRA variant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework, following the SVD-based variants (CorDA/DoRA). KaSA changes vanilla LoRA in two ways: - A one-time, destructive SVD truncation of the frozen base weight that drops its r smallest singular components, leaving the rank-(k-r) approximation as the new frozen base (k = min(in_features, out_features)). - A learnable diagonal of singular values (lora_diag) inserted between the LoRA A and B factors, so the update is ΔW = scaling * B @ diag(lora_diag) @ A. - New KasaConfig sub-config (beta, gamma) and LoraConfig.kasa_config field; selection is driven by kasa_config being non-None via resolve_lora_variant, with explicit guards rejecting KaSA on embedding/conv/MHA/ParamWrapper and fan_in_fan_out layers. - KasaLinearVariant implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge. lora_diag is registered in adapter_layer_names so it is saved/loaded. - get_kasa_regularization_loss helper exposes the paper's two auxiliary terms (L2 singular-value penalty + L3 orthogonal regularization), since the variant forward has no channel to inject an extra loss into the training loop. - Tests in tests/test_kasa.py (SVD-truncation faithfulness, lora_diag shape, zero-init update, merge/unmerge round-trip, delta-weight formula, save/load, regularization closed-form checks) plus wiring in tests/test_lora_variants.py. Faithfulness notes: - The base-weight truncation is destructive; disabling/unloading does not restore the original weight and merge/unmerge round-trips to the truncated base. This is inherent to the method and documented. - The paper's L2/L3 regularizers are required for the SVD interpretation to hold but cannot be auto-injected; users must add get_kasa_regularization_loss to their loss. --- src/peft/__init__.py | 4 + src/peft/tuners/__init__.py | 4 + src/peft/tuners/lora/__init__.py | 4 + src/peft/tuners/lora/config.py | 76 ++++++ src/peft/tuners/lora/layer.py | 19 ++ src/peft/tuners/lora/variants.py | 276 ++++++++++++++++++++++ tests/test_kasa.py | 392 +++++++++++++++++++++++++++++++ tests/test_lora_variants.py | 24 +- 8 files changed, 798 insertions(+), 1 deletion(-) create mode 100644 tests/test_kasa.py diff --git a/src/peft/__init__.py b/src/peft/__init__.py index ec12d52583..f82bd093e4 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -78,6 +78,7 @@ HRAModel, IA3Config, IA3Model, + KasaConfig, LilyConfig, LilyModel, LNTuningConfig, @@ -138,6 +139,7 @@ convert_to_lora, create_arrow_model, get_eva_state_dict, + get_kasa_regularization_loss, initialize_lora_eva_weights, preprocess_loraga, save_as_lora, @@ -208,6 +210,7 @@ "HiraModel", "IA3Config", "IA3Model", + "KasaConfig", "LNTuningConfig", "LNTuningModel", "LilyConfig", @@ -286,6 +289,7 @@ "create_arrow_model", "find_kappa_target_modules", "get_eva_state_dict", + "get_kasa_regularization_loss", "get_layer_status", "get_model_status", "get_peft_config", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index 4900a71aa8..4dc5e2ca59 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -34,6 +34,7 @@ ArrowConfig, BdLoraConfig, EvaConfig, + KasaConfig, LoftQConfig, LoraConfig, LoraGAConfig, @@ -44,6 +45,7 @@ convert_to_lora, create_arrow_model, get_eva_state_dict, + get_kasa_regularization_loss, initialize_lora_eva_weights, preprocess_loraga, save_as_lora, @@ -103,6 +105,7 @@ "HiraModel", "IA3Config", "IA3Model", + "KasaConfig", "LNTuningConfig", "LNTuningModel", "LilyConfig", @@ -165,6 +168,7 @@ "convert_to_lora", "create_arrow_model", "get_eva_state_dict", + "get_kasa_regularization_loss", "initialize_lora_eva_weights", "preprocess_loraga", "save_as_lora", diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index 94d6e16c36..63ff86980d 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -20,6 +20,7 @@ ArrowConfig, BdLoraConfig, EvaConfig, + KasaConfig, LoftQConfig, LoraConfig, LoraGAConfig, @@ -33,6 +34,7 @@ from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer, ParamWrapper from .loraga import preprocess_loraga from .model import LoraModel +from .variants import get_kasa_regularization_loss __all__ = [ @@ -43,6 +45,7 @@ "Embedding", "EvaConfig", "GPTQLoraLinear", + "KasaConfig", "Linear", "LoftQConfig", "LoraConfig", @@ -56,6 +59,7 @@ "convert_to_lora", "create_arrow_model", "get_eva_state_dict", + "get_kasa_regularization_loss", "initialize_lora_eva_weights", "preprocess_loraga", "save_as_lora", diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index a376c0e065..6fdd679847 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -294,6 +294,61 @@ def __post_init__(self): raise ValueError("`tau` must be between 0.0 and 1.0.") +@dataclass +class KasaConfig: + """ + This is the sub-configuration class to store the configuration for KaSA (Knowledge-aware Singular-value Adaptation, + [arXiv:2412.06071](https://huggingface.co/papers/2412.06071)). + + KaSA is a LoRA variant that (1) refines the frozen base weight by discarding its `r` smallest ("noisy"/long-tail) + singular components via a one-time SVD truncation, and (2) parametrizes the trainable update in SVD form, inserting + a learnable diagonal of singular values `lora_diag` (`ΔΣ`) between the LoRA `A` and `B` factors, i.e. `ΔW = scaling + * B @ diag(ΔΣ) @ A`. The paper additionally trains with two auxiliary regularizers (a singular-value L2 penalty and + an orthogonal regularization on the adapter factors); these are exposed through + [`~peft.get_kasa_regularization_loss`] and must be added to the task loss by the user, since the PEFT variant API + has no channel to inject an extra loss term into the training loop. + + Args: + beta (`float`): + Coefficient `β` for the L2 (singular-value) regularization `||ΔΣ||_F^2 = sum(lora_diag ** 2)` (paper Eq. + 9-10). Only takes effect if the user adds [`~peft.get_kasa_regularization_loss`] to their loss. Defaults to + `1e-4` (the value used in the reference GLUE configs). + gamma (`float`): + Coefficient `γ` for the orthogonal regularization `||B^T B - I||_F + ||A A^T - I||_F` (paper Eq. 11), which + softly enforces the semi-orthogonality of `ΔU`/`ΔV` assumed by the SVD parametrization. Only takes effect + if the user adds [`~peft.get_kasa_regularization_loss`] to their loss. Defaults to `1e-3` (the reference + GLUE value). + """ + + beta: float = field( + default=1e-4, + metadata={ + "help": ( + "Coefficient `β` for the L2 (singular-value) regularization `sum(lora_diag ** 2)` (KaSA paper " + "Eq. 9-10). Only takes effect if the user adds `get_kasa_regularization_loss` to their training loss. " + "Defaults to 1e-4 (the reference GLUE value)." + ) + }, + ) + gamma: float = field( + default=1e-3, + metadata={ + "help": ( + "Coefficient `γ` for the orthogonal regularization `||B^T B - I||_F + ||A A^T - I||_F` (KaSA paper " + "Eq. 11), which softly enforces the semi-orthogonality of the adapter factors assumed by the SVD " + "parametrization. Only takes effect if the user adds `get_kasa_regularization_loss` to their training " + "loss. Defaults to 1e-3 (the reference GLUE value)." + ) + }, + ) + + def __post_init__(self): + if self.beta < 0: + raise ValueError(f"`beta` must be non-negative, got {self.beta}.") + if self.gamma < 0: + raise ValueError(f"`gamma` must be non-negative, got {self.gamma}.") + + @dataclass class CordaConfig: """ @@ -478,6 +533,10 @@ class LoraConfig(PeftConfig): velora_config (`Optional[VeloraConfig]`): Enable VeLoRA by providing a VeloraConfig. VeLoRA swaps in a custom backward pass for the LoRA A projection that stores compressed activations instead of the full input activations. + kasa_config (`Optional[KasaConfig]`): + Enable KaSA (Knowledge-aware Singular-value Adaptation) by providing a KasaConfig. KaSA truncates the `r` + smallest singular components of the frozen base weight via a one-time SVD and inserts a learnable diagonal + of singular values between the LoRA A and B factors. Currently only linear layers are supported. alora_invocation_tokens (`List[int]`): If not None, enable 'Activated LoRA' (aLoRA), with alora_invocation_tokens being the tokenized invocation string for the adapter (must be present in all model @@ -833,6 +892,18 @@ class LoraConfig(PeftConfig): arrow_config: Optional[ArrowConfig] = field( default=None, metadata={"help": "The necessary config to apply arrow routing on the model."} ) + kasa_config: Optional[KasaConfig] = field( + default=None, + metadata={ + "help": ( + "Enable KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant by providing a " + "`KasaConfig`. KaSA truncates the `r` smallest singular components of the frozen base weight via a " + "one-time SVD and parametrizes the trainable update with a learnable diagonal of singular values " + "(`lora_diag`) inserted between the LoRA A and B factors. Currently only linear layers are " + "supported. See `KasaConfig` for details on the individual hyperparameters." + ) + }, + ) ensure_weight_tying: bool = field( default=False, metadata={ @@ -874,6 +945,11 @@ def __post_init__(self): elif self.velora_config is not None and not isinstance(self.velora_config, VeloraConfig): raise TypeError("`velora_config` must be a `VeloraConfig`, a dict, or None.") + if isinstance(self.kasa_config, dict): + self.kasa_config = KasaConfig(**self.kasa_config) + elif self.kasa_config is not None and not isinstance(self.kasa_config, KasaConfig): + raise TypeError("`kasa_config` must be a `KasaConfig`, a dict, or None.") + if isinstance(self.target_parameters, str): raise TypeError("`target_parameters` must be a list of strings or None.") diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 1a07158308..cb87c5e205 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -796,6 +796,11 @@ def __init__( self.is_target_conv_1d_layer = is_target_conv_1d_layer def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: + if config.kasa_config is not None: + from .variants import KasaLinearVariant + + return KasaLinearVariant() + if config.velora_config is not None: from .variants import VeloraLinearVariant @@ -1064,6 +1069,8 @@ def __init__( def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: if config.velora_config is not None: raise ValueError("VeLoRA does not support adapting embedding layers.") + if config.kasa_config is not None: + raise ValueError("KaSA does not support adapting embedding layers.") if not config.use_dora: return None @@ -1379,6 +1386,8 @@ def __init__( raise ValueError("aLoRA does not support adapting conv layers.") if config.velora_config is not None: raise ValueError("VeLoRA does not support adapting conv layers.") + if config.kasa_config is not None: + raise ValueError("KaSA does not support adapting conv layers.") if base_layer.groups > 1: warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.") @@ -1673,6 +1682,8 @@ def __init__(self, *args, **kwargs): def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: if config.velora_config is not None: raise ValueError("VeLoRA does not support adapting conv layers.") + if config.kasa_config is not None: + raise ValueError("KaSA does not support adapting conv layers.") if not config.use_dora: return None @@ -1692,6 +1703,8 @@ def __init__(self, *args, **kwargs): def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: if config.velora_config is not None: raise ValueError("VeLoRA does not support adapting conv layers.") + if config.kasa_config is not None: + raise ValueError("KaSA does not support adapting conv layers.") if not config.use_dora: return None @@ -1711,6 +1724,8 @@ def __init__(self, *args, **kwargs): def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]: if config.velora_config is not None: raise ValueError("VeLoRA does not support adapting conv layers.") + if config.kasa_config is not None: + raise ValueError("KaSA does not support adapting conv layers.") if not config.use_dora: return None @@ -1756,6 +1771,8 @@ def __init__( raise ValueError(f"{self.__class__.__name__} does not support DoRA (yet), please set use_dora to False") if config.velora_config is not None: raise ValueError(f"{self.__class__.__name__} does not support VeLoRA, please set `velora_config=None`.") + if config.kasa_config is not None: + raise ValueError(f"{self.__class__.__name__} does not support KaSA, please set `kasa_config=None`.") if kwargs.get("use_alora", False): raise ValueError(f"{self.__class__.__name__} does not support aLoRA (yet), please set use_alora to False") super().__init__() @@ -2172,6 +2189,8 @@ def __init__( raise ValueError(f"lora.{self.__class__.__name__} does not work with use_dora=True.") if config.velora_config is not None: raise ValueError(f"lora.{self.__class__.__name__} does not work when `velora_config` is set.") + if config.kasa_config is not None: + raise ValueError(f"lora.{self.__class__.__name__} does not work when `kasa_config` is set.") if is_target_conv_1d_layer: raise ValueError(f"lora.{self.__class__.__name__} does not work with is_target_conv_1d_layer=True.") diff --git a/src/peft/tuners/lora/variants.py b/src/peft/tuners/lora/variants.py index 43e465e642..c7e05dcb16 100644 --- a/src/peft/tuners/lora/variants.py +++ b/src/peft/tuners/lora/variants.py @@ -24,6 +24,7 @@ from peft.tuners._buffer_dict import BufferDict from peft.tuners.lora.config import BdLoraConfig, MontecloraConfig +from peft.utils.integrations import gather_params_ctx from peft.utils.other import transpose from .arrow import ArrowLoraLinearLayer @@ -1216,3 +1217,278 @@ def forward( out_A = F.linear(x_dropped, current_weight_A) result = result + lora_B(out_A) * scaling return result + + +class KasaLinearVariant(LoraVariant): + """ + KaSA (Knowledge-aware Singular-value Adaptation) variant for linear layers. + + Reference: "KaSA: Knowledge-Aware Singular-Value Adaptation of Large Language Models" + ([arXiv:2412.06071](https://huggingface.co/papers/2412.06071)), reference implementation: + https://github.com/juyongjiang/KaSA. + + KaSA changes vanilla LoRA in two ways: + + 1. **Knowledge-based SVD truncation of the frozen base weight (one-time, non-trainable).** At init the base weight + ``W`` is SVD-factored ``W = U S V^T`` and its ``r`` smallest ("noisy"/long-tail) singular components are + discarded, leaving the rank-``(k - r)`` approximation (``k = min(in_features, out_features)``) as the new frozen + base. The trainable LoRA branch then re-learns in the discarded residual subspace. + + 2. **Knowledge-aware singular-value adaptation (trainable update).** The update is parametrized in SVD form with a + learnable diagonal of singular values ``lora_diag`` (``ΔΣ``) inserted between the LoRA ``A`` and ``B`` factors: + ``ΔW = scaling * B @ diag(ΔΣ) @ A``. ``lora_diag`` is a learnable ``r``-vector (the only new parameter per + layer); ``B`` is zero-initialized as in vanilla LoRA, so the update is zero at init. + + Important notes on faithfulness and behavior: + + - The base-weight truncation is **destructive**: adding a KaSA adapter permanently changes the layer's clean + forward output (the base is now the truncated ``W_world``, not the original ``W``). Disabling/unloading the + adapter does **not** restore the original weight, and ``merge`` followed by ``unmerge`` round-trips to the + truncated weight, not the original one. This is inherent to the method. + - The KaSA paper trains with two auxiliary regularizers (an L2 penalty on ``lora_diag`` and an orthogonal + regularization on ``A``/``B``). These cannot be injected through the variant ``forward`` (which has no + loss-return channel), so they are exposed via [`~peft.get_kasa_regularization_loss`], which the user must add to + their task loss. Without them, the SVD interpretation of the update is only approximate. + """ + + @staticmethod + def _truncate_base_weight(module: Linear, r: int) -> None: + """Replace the frozen base weight in-place with its rank-``(k - r)`` SVD approximation (drop the ``r`` smallest + singular components). ``k = min(in_features, out_features)``.""" + base_layer = module.get_base_layer() + weight = base_layer.weight + orig_dtype = weight.dtype + # ``nn.Linear.weight`` has shape (out_features, in_features). For Conv1D (transposed storage) the variant is not + # dispatched, so we always deal with a standard linear weight here. + out_features, in_features = weight.shape + k = min(in_features, out_features) + if r >= k: + raise ValueError( + f"KaSA requires `r` ({r}) to be smaller than min(in_features, out_features) ({k}) so that at least one " + "singular component of the base weight is preserved after truncation." + ) + svd_rank = k - r + # SVD must run on a dequantized fp32 weight for numerical stability and to support (b)float16 CPU weights. + weight_fp32 = weight.detach().to(torch.float32) + U, S, Vh = torch.linalg.svd(weight_fp32, full_matrices=False) + # Keep the principal (largest-sigma) ``svd_rank`` components; discard the ``r`` smallest. + U_p = U[:, :svd_rank] + S_p = S[:svd_rank] + Vh_p = Vh[:svd_rank, :] + truncated = (U_p * S_p) @ Vh_p + base_layer.weight.data = truncated.to(orig_dtype) + + @staticmethod + def init(module: Linear, adapter_name: str, config: LoraConfig, **kwargs: Any) -> None: + if getattr(module, "fan_in_fan_out", False): + # The reference implementation is for nn.Linear (out, in) weights only. fan_in_fan_out (e.g. Conv1D) would + # require transposing the weight before SVD; this is not supported to avoid a silently wrong truncation. + raise ValueError( + "KaSA does not support `fan_in_fan_out=True` layers (e.g. Conv1D). Please target nn.Linear layers." + ) + + if not hasattr(module, "lora_diag"): + # First KaSA layer being added: register `lora_diag` as a learnable adapter parameter so it is saved/loaded. + module.lora_diag = nn.ParameterDict({}) + module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",) + if not hasattr(module, "_lora_kasa_config"): + module._lora_kasa_config = {} + module._lora_kasa_config[adapter_name] = config.kasa_config + if not hasattr(module, "_lora_kasa_truncation_deferred"): + # Set of adapters whose destructive base-weight truncation could not run at init (meta device) and must + # therefore be applied lazily at the first forward, once a real base weight is available. + module._lora_kasa_truncation_deferred = set() + + r = module.r[adapter_name] + lora_A = module.lora_A[adapter_name].weight + device = lora_A.device + dtype = lora_A.dtype + + if device.type == "meta": + # With low_cpu_mem_usage=True adapters may be initialized on the meta device. We cannot SVD a meta tensor, + # so we only create the (meta) lora_diag parameter and defer the destructive truncation to the first + # forward (see KasaLinearVariant.forward), which mirrors the deferral pattern used by Monteclora. Without + # this re-trigger the SVD truncation would be silently skipped on the low_cpu_mem_usage path and the model + # would compute the wrong thing (full base + an adapter trained against the truncated base). + module.lora_diag[adapter_name] = nn.Parameter(torch.randn(r, device=device, dtype=dtype)) + module._lora_kasa_truncation_deferred.add(adapter_name) + return + + # The learnable diagonal of singular values (ΔΣ). randn is fine because lora_B is zero-init, so the update is + # zero at step 0 regardless of the diag values (see reference reset_lora_parameters). + module.lora_diag[adapter_name] = nn.Parameter(torch.randn(r, device=device, dtype=dtype)) + + # One-time destructive SVD truncation of the frozen base weight. gather_params_ctx is used so the weight is + # materialized when using DeepSpeed ZeRO-3, consistent with the other SVD-based inits (pissa/corda). + with gather_params_ctx(module.get_base_layer().weight): + KasaLinearVariant._truncate_base_weight(module, r) + + @staticmethod + def _get_delta_weight(module: Linear, adapter: str) -> torch.Tensor: + """Compute ΔW = scaling * B @ diag(lora_diag) @ A (transposed for fan_in_fan_out).""" + device = module.lora_B[adapter].weight.device + dtype = module.lora_B[adapter].weight.dtype + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = module.lora_A[adapter].weight + weight_B = module.lora_B[adapter].weight + diag = module.lora_diag[adapter] + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + diag = diag.float() + + # (out, r) * (r,) -> scale columns of B by the diagonal, then (out, r) @ (r, in) -> (out, in). + delta = (weight_B * diag) @ weight_A + output_tensor = transpose(delta, module.fan_in_fan_out) * module.scaling[adapter] + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + return output_tensor + + @staticmethod + def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype + delta_weight = KasaLinearVariant._get_delta_weight(module, active_adapter) + return orig_weight + delta_weight.to(orig_dtype) + + @staticmethod + def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: + orig_dtype = orig_weight.dtype + delta_weight = KasaLinearVariant._get_delta_weight(module, active_adapter) + orig_weight.data += delta_weight.to(orig_dtype) + + @staticmethod + def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: + orig_dtype = orig_weight.dtype + delta_weight = KasaLinearVariant._get_delta_weight(module, active_adapter) + return orig_weight - delta_weight.to(orig_dtype) + + @staticmethod + def forward( + module: Linear, + active_adapter: str, + x: torch.Tensor, + result: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + deferred = getattr(module, "_lora_kasa_truncation_deferred", None) + if deferred is not None and active_adapter in deferred: + # The destructive SVD truncation was deferred at init because the base weight was on the meta device + # (low_cpu_mem_usage=True). Now that a real base weight is materialized, apply it exactly once. We also + # re-create lora_diag on the real device/dtype (it was a meta tensor at init). + base_weight = module.get_base_layer().weight + if base_weight.device.type == "meta": + raise RuntimeError( + "KaSA could not apply its SVD base-weight truncation because the base weight is still on the " + "meta device at forward time. Materialize the base model weights before running a forward pass." + ) + r = module.r[active_adapter] + old_diag = module.lora_diag[active_adapter] + if old_diag.device.type == "meta": + module.lora_diag[active_adapter] = nn.Parameter( + torch.randn(r, device=base_weight.device, dtype=base_weight.dtype) + ) + with gather_params_ctx(base_weight): + KasaLinearVariant._truncate_base_weight(module, r) + deferred.discard(active_adapter) + # `result` was computed by the caller using the *un-truncated* base weight (the base forward runs before + # this variant forward). Now that the base has been truncated, recompute the base contribution so this + # first forward already reflects the truncated weight, matching every subsequent call. + result = module.get_base_layer()(x) + + lora_A = module.lora_A[active_adapter] + lora_B = module.lora_B[active_adapter] + dropout = module.lora_dropout[active_adapter] + scaling = module.scaling[active_adapter] + diag = module.lora_diag[active_adapter] + + # h = h_base + scaling * B( diag(ΔΣ) * A(dropout(x)) ). The diagonal multiply is an elementwise scaling of the + # r-dim intermediate, equivalent to (and cheaper than) constructing torch.diag(diag). + after_A = lora_A(dropout(x)) + result = result + lora_B(after_A * diag) * scaling + return result + + +def _kasa_layer_regularization_loss(module: Linear, adapter_name: str, beta: float, gamma: float) -> torch.Tensor: + """Per-layer KaSA regularization for a single adapter. + + Returns ``beta * L2 + gamma * L3`` where (paper Eq. 9-11): + + - ``L2 = ||lora_diag||_F^2 = sum(lora_diag ** 2)`` (singular-value penalty), and + - ``L3 = ||B^T B - I||_F + ||A A^T - I||_F`` (orthogonal regularization of the adapter factors). + """ + diag = module.lora_diag[adapter_name] + weight_A = module.lora_A[adapter_name].weight # (r, in) + weight_B = module.lora_B[adapter_name].weight # (out, r) + r = diag.shape[0] + + # Compute in fp32 for numerical stability, regardless of the adapter dtype. + diag_f = diag.float() + A = weight_A.float() + B = weight_B.float() + eye = torch.eye(r, device=diag.device, dtype=torch.float32) + + l2 = (diag_f**2).sum() + gram_B = B.transpose(-1, -2) @ B # (r, r) == ΔU^T ΔU + gram_A = A @ A.transpose(-1, -2) # (r, r) == ΔV^T ΔV + l3 = torch.linalg.norm(gram_B - eye) + torch.linalg.norm(gram_A - eye) + + return beta * l2 + gamma * l3 + + +def get_kasa_regularization_loss( + model, beta: Optional[float] = None, gamma: Optional[float] = None, adapter_name: Optional[str] = None +) -> torch.Tensor: + """Compute the KaSA auxiliary regularization loss summed over all KaSA-adapted linear layers of ``model``. + + The KaSA paper ([arXiv:2412.06071](https://huggingface.co/papers/2412.06071)) optimizes the task loss together with + two auxiliary terms (Eq. 9-12): + + - an L2 penalty on the learnable singular values ``sum(lora_diag ** 2)`` (weighted by ``β``), and + - an orthogonal regularization ``||B^T B - I||_F + ||A A^T - I||_F`` on the adapter factors (weighted by ``γ``), + which softly enforces the semi-orthogonality assumed by the SVD parametrization. + + Because the PEFT LoRA-variant API has no channel to inject an additional loss term into the training loop, this + helper must be called explicitly and its result added to the task loss, e.g.:: + + loss = task_loss + get_kasa_regularization_loss(model) + + Args: + model: A PEFT model whose LoRA layers were configured with a `KasaConfig`. + beta (`Optional[float]`): + Override for the L2 coefficient ``β``. If `None`, the per-layer value from each layer's `KasaConfig` is + used. + gamma (`Optional[float]`): + Override for the orthogonal-regularization coefficient ``γ``. If `None`, the per-layer value from each + layer's `KasaConfig` is used. + adapter_name (`Optional[str]`): + If given, only accumulate the loss for this adapter. If `None`, all KaSA adapters found on each layer are + included. + + Returns: + `torch.Tensor`: A scalar tensor (0-dim) with the accumulated regularization loss. Returns ``0.0`` if the model + contains no KaSA layers. + """ + total = None + for module in model.modules(): + if not isinstance(module, Linear): + continue + if not hasattr(module, "lora_diag"): + continue + kasa_configs = getattr(module, "_lora_kasa_config", {}) + adapters = [adapter_name] if adapter_name is not None else list(module.lora_diag.keys()) + for name in adapters: + if name not in module.lora_diag: + continue + cfg = kasa_configs.get(name) + b = beta if beta is not None else (cfg.beta if cfg is not None else 0.0) + g = gamma if gamma is not None else (cfg.gamma if cfg is not None else 0.0) + layer_loss = _kasa_layer_regularization_loss(module, name, b, g) + total = layer_loss if total is None else total + layer_loss + + if total is None: + return torch.zeros(()) + return total diff --git a/tests/test_kasa.py b/tests/test_kasa.py new file mode 100644 index 0000000000..a1c8075695 --- /dev/null +++ b/tests/test_kasa.py @@ -0,0 +1,392 @@ +# Copyright 2026-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the KaSA (Knowledge-aware Singular-value Adaptation) LoRA variant. + +KaSA changes vanilla LoRA in two ways (arXiv:2412.06071): + 1. a one-time, destructive SVD truncation of the frozen base weight (drop the r smallest singular components), and + 2. a learnable diagonal of singular values (lora_diag) inserted between the LoRA A and B factors. + +These tests run on tiny random nn.Linear models on CPU; no downloads. +""" + +import copy + +import pytest +import torch +from torch import nn + +from peft import KasaConfig, LoraConfig, PeftType, get_kasa_regularization_loss, get_peft_model +from peft.tuners.lora import KasaConfig as LoraKasaConfig +from peft.tuners.lora.layer import Linear as LoraLinear +from peft.tuners.lora.variants import KasaLinearVariant +from peft.utils import get_peft_model_state_dict + + +class MLP(nn.Module): + def __init__(self, in_features=16, hidden=12, out_features=10, bias=False): + super().__init__() + # in_features >= hidden >= out so min(in,out) - r stays positive at small r for both layers. + self.lin0 = nn.Linear(in_features, hidden, bias=bias) + self.lin1 = nn.Linear(hidden, out_features, bias=bias) + + def forward(self, x): + return self.lin1(torch.relu(self.lin0(x))) + + +def _make_kasa_config(target_modules=("lin0", "lin1"), r=4, lora_alpha=8, beta=1e-4, gamma=1e-3, **kwargs): + return LoraConfig( + target_modules=list(target_modules), + r=r, + lora_alpha=lora_alpha, + kasa_config=KasaConfig(beta=beta, gamma=gamma), + **kwargs, + ) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Config wiring +# ---------------------------------------------------------------------------------------------------------------------- + + +def test_kasa_config_object_triggers_variant(): + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config()) + found = 0 + for module in model.modules(): + if isinstance(module, LoraLinear): + assert isinstance(module.lora_variant["default"], KasaLinearVariant) + found += 1 + assert found == 2 + + +def test_kasa_config_dict_round_trip(): + # A dict passed as kasa_config should be coerced to a KasaConfig in __post_init__. + cfg = LoraConfig(target_modules=["lin0"], kasa_config={"beta": 0.5, "gamma": 0.25}) + assert isinstance(cfg.kasa_config, KasaConfig) + assert cfg.kasa_config.beta == 0.5 + assert cfg.kasa_config.gamma == 0.25 + assert cfg.peft_type == PeftType.LORA + + +def test_kasa_config_alias_matches_lora_module_config(): + # The class re-exported from peft.tuners.lora must be the same as the top-level one. + assert LoraKasaConfig is KasaConfig + + +def test_kasa_config_invalid_type_raises(): + with pytest.raises(TypeError, match="`kasa_config` must be a `KasaConfig`"): + LoraConfig(target_modules=["lin0"], kasa_config=123) + + +def test_kasa_config_negative_coeffs_raise(): + with pytest.raises(ValueError, match="`beta` must be non-negative"): + KasaConfig(beta=-1.0) + with pytest.raises(ValueError, match="`gamma` must be non-negative"): + KasaConfig(gamma=-1.0) + + +def test_kasa_rejects_too_large_rank(): + # r must be < min(in, out) for at least one base singular component to survive truncation. + # lin1 is (out=10, in=12) so min=10; r=10 must raise. + with pytest.raises(ValueError, match="KaSA requires `r`"): + get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=10)) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Faithfulness: SVD truncation of the base weight + the new lora_diag parameter +# ---------------------------------------------------------------------------------------------------------------------- + + +def test_kasa_lora_diag_shape_and_learnable(): + torch.manual_seed(0) + r = 4 + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=r)) + for module in model.modules(): + if isinstance(module, LoraLinear): + diag = module.lora_diag["default"] + assert diag.shape == (r,) + assert diag.requires_grad + + +def test_kasa_lora_b_is_zero_init(): + # The update must be zero at init (output == truncated base), which requires B == 0. + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config()) + for module in model.modules(): + if isinstance(module, LoraLinear): + assert torch.allclose(module.lora_B["default"].weight, torch.zeros_like(module.lora_B["default"].weight)) + + +def test_kasa_truncates_base_weight_rank(): + """After init, the frozen base weight must have its rank reduced by exactly r (its r smallest singular values are + dropped).""" + torch.manual_seed(0) + r = 3 + base = MLP(in_features=16, hidden=12, out_features=10) + # Snapshot the original singular values of each targeted weight before adapting. + orig_singulars = {} + for name in ["lin0", "lin1"]: + w = getattr(base, name).weight.detach().clone().float() + orig_singulars[name] = torch.linalg.svdvals(w) + + model = get_peft_model(copy.deepcopy(base), _make_kasa_config(r=r)) + + for name in ["lin0", "lin1"]: + lora_layer = getattr(model.base_model.model, name) + new_weight = lora_layer.get_base_layer().weight.detach().float() + sv = torch.linalg.svdvals(new_weight) + k = min(new_weight.shape) + # Exactly r singular values should be (numerically) zero -> rank dropped by r. + n_zero = int((sv < 1e-5).sum()) + assert n_zero == r, f"{name}: expected {r} zeroed singular values, got {n_zero} (sv={sv})" + # The surviving (largest) singular values should match the original principal ones. + kept_new = torch.sort(sv, descending=True).values[: k - r] + kept_orig = torch.sort(orig_singulars[name], descending=True).values[: k - r] + assert torch.allclose(kept_new, kept_orig, atol=1e-4) + + +def test_kasa_truncation_changes_base_forward(): + """Adding a KaSA adapter destructively edits the base weight, so the clean (adapter-disabled) forward differs from + the original model. This documents the (intentional) departure from the usual "disable == base" contract.""" + torch.manual_seed(0) + base = MLP() + x = torch.randn(5, 16) + with torch.no_grad(): + orig_out = base(x) + + model = get_peft_model(copy.deepcopy(base), _make_kasa_config(r=4)) + model.eval() + with torch.no_grad(): + with model.disable_adapter(): + disabled_out = model(x) + + # Because B == 0 at init, the *active* adapter output equals the truncated base output... + with torch.no_grad(): + active_out = model(x) + assert torch.allclose(active_out, disabled_out, atol=1e-6) + # ...but the truncated base is NOT the original weight, so the output differs from the original model. + assert not torch.allclose(disabled_out, orig_out, atol=1e-4) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Merge / unmerge round-trip and forward consistency +# ---------------------------------------------------------------------------------------------------------------------- + + +def _randomize_adapter(model): + """Give the adapter a non-trivial value so merge/forward differences are observable (B and diag both non-zero).""" + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LoraLinear): + nn.init.normal_(module.lora_B["default"].weight, std=0.1) + module.lora_diag["default"].copy_(torch.randn_like(module.lora_diag["default"])) + + +def test_kasa_merge_unmerge_round_trip(): + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=4)) + _randomize_adapter(model) + model.eval() + + x = torch.randn(7, 16) + with torch.no_grad(): + out_unmerged = model(x) + + # Capture the truncated base weights (merge/unmerge must round-trip to THESE, not the original weights). + truncated = { + name: getattr(model.base_model.model, name).get_base_layer().weight.detach().clone() + for name in ["lin0", "lin1"] + } + + model.merge_adapter() + with torch.no_grad(): + out_merged = model(x) + assert torch.allclose(out_unmerged, out_merged, atol=1e-5) + + model.unmerge_adapter() + for name in ["lin0", "lin1"]: + restored = getattr(model.base_model.model, name).get_base_layer().weight.detach() + assert torch.allclose(restored, truncated[name], atol=1e-5) + + with torch.no_grad(): + out_after = model(x) + assert torch.allclose(out_unmerged, out_after, atol=1e-5) + + +def test_kasa_delta_weight_matches_formula(): + """ΔW = scaling * B @ diag(lora_diag) @ A, and merging adds exactly this to the (truncated) base.""" + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=4, lora_alpha=8)) + _randomize_adapter(model) + + for name in ["lin0", "lin1"]: + layer = getattr(model.base_model.model, name) + A = layer.lora_A["default"].weight.detach() + B = layer.lora_B["default"].weight.detach() + diag = layer.lora_diag["default"].detach() + scaling = layer.scaling["default"] + expected = scaling * (B @ torch.diag(diag) @ A) + + before = layer.get_base_layer().weight.detach().clone() + layer.merge(safe_merge=True) + after = layer.get_base_layer().weight.detach() + assert torch.allclose(after - before, expected, atol=1e-5) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Save / load +# ---------------------------------------------------------------------------------------------------------------------- + + +def test_kasa_lora_diag_in_state_dict(tmp_path): + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=4)) + _randomize_adapter(model) + + sd = get_peft_model_state_dict(model) + diag_keys = [k for k in sd if "lora_diag" in k] + # One lora_diag entry per targeted linear layer. + assert len(diag_keys) == 2 + + model.eval() + x = torch.randn(3, 16) + with torch.no_grad(): + out_before = model(x) + + save_dir = tmp_path / "kasa_adapter" + model.save_pretrained(save_dir) + + # Reload onto a fresh base model that carries the (already truncated) base weights. Because KaSA mutates the base + # weight in-place at adapter-creation time, the user must persist that truncated base alongside the adapter; here we + # emulate that by copying the truncated weights into a clean MLP before loading the adapter. + from peft import PeftModel + + reloaded_base = MLP() + with torch.no_grad(): + for name in ["lin0", "lin1"]: + truncated_w = getattr(model.base_model.model, name).get_base_layer().weight.detach().clone() + getattr(reloaded_base, name).weight.copy_(truncated_w) + + reloaded = PeftModel.from_pretrained(reloaded_base, save_dir) + reloaded.eval() + with torch.no_grad(): + out_after = reloaded(x) + assert torch.allclose(out_before, out_after, atol=1e-5) + + +@pytest.mark.parametrize("low_cpu_mem_usage", [False, True]) +def test_kasa_reload_onto_original_base_retruncates(tmp_path, low_cpu_mem_usage): + """Reloading the adapter onto the *original* (un-truncated) base must reproduce the trained output, because the + deterministic SVD truncation is re-applied at load time. With low_cpu_mem_usage=True the truncation is deferred to + the first forward; this guards against it being silently skipped on that path.""" + from peft import PeftModel + + torch.manual_seed(0) + base = MLP() + original_state = copy.deepcopy(base.state_dict()) + model = get_peft_model(copy.deepcopy(base), _make_kasa_config(r=4)) + _randomize_adapter(model) + model.eval() + x = torch.randn(3, 16) + with torch.no_grad(): + out_before = model(x) + + save_dir = tmp_path / "kasa_adapter" + model.save_pretrained(save_dir) + + # Fresh base carrying the ORIGINAL (un-truncated) weights - the realistic reload scenario. + fresh_base = MLP() + fresh_base.load_state_dict(original_state) + reloaded = PeftModel.from_pretrained(fresh_base, save_dir, low_cpu_mem_usage=low_cpu_mem_usage) + reloaded.eval() + with torch.no_grad(): + out_after = reloaded(x) + # A second forward must be stable (truncation applied exactly once, no double-truncation). + out_after2 = reloaded(x) + assert torch.allclose(out_before, out_after, atol=1e-5) + assert torch.allclose(out_after, out_after2, atol=1e-6) + + +# ---------------------------------------------------------------------------------------------------------------------- +# Regularization helper (L2 singular-value penalty + L3 orthogonal regularization) +# ---------------------------------------------------------------------------------------------------------------------- + + +def test_kasa_regularization_zero_when_no_kasa_layers(): + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), LoraConfig(target_modules=["lin0"], r=4)) + loss = get_kasa_regularization_loss(model) + assert float(loss) == 0.0 + + +def test_kasa_regularization_l2_matches_closed_form(): + """With gamma=0 (and orthonormal A/B so L3=0 anyway), the loss reduces to beta * sum(lora_diag**2).""" + torch.manual_seed(0) + beta = 0.3 + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=4, beta=beta, gamma=0.0)) + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LoraLinear): + module.lora_diag["default"].copy_(torch.arange(1.0, 5.0)) # [1,2,3,4] + + expected_per_layer = beta * (1.0**2 + 2.0**2 + 3.0**2 + 4.0**2) # = beta * 30 + expected = 2 * expected_per_layer # two layers + loss = get_kasa_regularization_loss(model, gamma=0.0) + assert pytest.approx(loss.item(), rel=1e-5) == expected + + +def test_kasa_orthogonal_reg_zero_for_orthonormal_factors(): + """L3 = ||B^T B - I|| + ||A A^T - I|| must be ~0 when A and B have orthonormal rows/cols, and > 0 otherwise.""" + torch.manual_seed(0) + # Use square-ish factors so A (r x in) can have orthonormal rows and B (out x r) orthonormal columns. + model = get_peft_model(copy.deepcopy(MLP(in_features=16, hidden=12, out_features=12)), _make_kasa_config(r=4)) + + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LoraLinear): + A = module.lora_A["default"].weight # (r, in) + B = module.lora_B["default"].weight # (out, r) + # orthonormal rows of A + qa, _ = torch.linalg.qr(A.T) # (in, r) with orthonormal columns + module.lora_A["default"].weight.copy_(qa[:, : A.shape[0]].T) + # orthonormal columns of B + qb, _ = torch.linalg.qr(B) # (out, r) with orthonormal columns + module.lora_B["default"].weight.copy_(qb) + module.lora_diag["default"].zero_() # kill L2 so we isolate L3 + + loss_ortho = get_kasa_regularization_loss(model, beta=0.0, gamma=1.0) + assert loss_ortho.item() < 1e-4 + + # Now make B clearly non-orthonormal and confirm the penalty becomes strictly positive. + with torch.no_grad(): + for module in model.modules(): + if isinstance(module, LoraLinear): + module.lora_B["default"].weight.mul_(3.0) + loss_non_ortho = get_kasa_regularization_loss(model, beta=0.0, gamma=1.0) + assert loss_non_ortho.item() > 1e-3 + + +def test_kasa_regularization_has_gradients(): + """The regularization loss must be differentiable w.r.t. the KaSA parameters.""" + torch.manual_seed(0) + model = get_peft_model(copy.deepcopy(MLP()), _make_kasa_config(r=4)) + _randomize_adapter(model) + + loss = get_kasa_regularization_loss(model) + loss.backward() + for module in model.modules(): + if isinstance(module, LoraLinear): + assert module.lora_diag["default"].grad is not None + assert module.lora_A["default"].weight.grad is not None diff --git a/tests/test_lora_variants.py b/tests/test_lora_variants.py index 30f2fe53d5..fe6f94b309 100644 --- a/tests/test_lora_variants.py +++ b/tests/test_lora_variants.py @@ -17,7 +17,7 @@ from torch import nn from transformers import AutoModelForCausalLM -from peft import LoraConfig, TaskType, get_peft_model +from peft import KasaConfig, LoraConfig, TaskType, get_peft_model from peft.tuners.lora.layer import Conv1d as LoraConv1d from peft.tuners.lora.layer import Conv2d as LoraConv2d from peft.tuners.lora.layer import Embedding as LoraEmbedding @@ -28,6 +28,7 @@ DoraConv2dVariant, DoraEmbeddingVariant, DoraLinearVariant, + KasaLinearVariant, calculate_alora_offsets, get_alora_offsets_for_forward, get_alora_offsets_for_generate, @@ -112,6 +113,9 @@ def from_pretrained(cls): "alora": { LoraLinear: ALoraLinearVariant, }, + "kasa": { + LoraLinear: KasaLinearVariant, + }, } @@ -126,6 +130,11 @@ def from_pretrained(cls): LoraConfig, {"target_modules": ["linear1", "linear2"], "alora_invocation_tokens": [1]}, ), + ( + "kasa", + LoraConfig, + {"target_modules": ["linear1", "linear2"], "kasa_config": KasaConfig(), "r": 4}, + ), ] @@ -174,6 +183,19 @@ def test_dora_params_have_gradients(self): for layer in layer_names: assert getattr(peft_model.base_model.model, layer).lora_magnitude_vector["default"].weight.grad is not None + def test_kasa_params_have_gradients(self): + """Ensure that the lora_diag parameter added by the KaSA variant participates in the output computation.""" + layer_names = ["linear1", "linear2"] + peft_config = LoraConfig(target_modules=layer_names, kasa_config=KasaConfig(), r=4) + _, peft_model = self.custom_model_with_loss_backpropagated(peft_config) + + for layer in layer_names: + lora_diag = getattr(peft_model.base_model.model, layer).lora_diag["default"] + assert lora_diag.requires_grad + assert lora_diag.grad is not None + # lora_diag is the new KaSA parameter of shape (r,). + assert lora_diag.shape == (4,) + class TestActivatedLora: @pytest.mark.parametrize(