Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
HRAModel,
IA3Config,
IA3Model,
KasaConfig,
LilyConfig,
LilyModel,
LNTuningConfig,
Expand Down Expand Up @@ -138,6 +139,7 @@
convert_to_lora,
create_arrow_model,
get_eva_state_dict,
get_kasa_regularization_loss,
initialize_lora_eva_weights,
preprocess_loraga,
save_as_lora,
Expand Down Expand Up @@ -208,6 +210,7 @@
"HiraModel",
"IA3Config",
"IA3Model",
"KasaConfig",
"LNTuningConfig",
"LNTuningModel",
"LilyConfig",
Expand Down Expand Up @@ -286,6 +289,7 @@
"create_arrow_model",
"find_kappa_target_modules",
"get_eva_state_dict",
"get_kasa_regularization_loss",
"get_layer_status",
"get_model_status",
"get_peft_config",
Expand Down
4 changes: 4 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ArrowConfig,
BdLoraConfig,
EvaConfig,
KasaConfig,
LoftQConfig,
LoraConfig,
LoraGAConfig,
Expand All @@ -44,6 +45,7 @@
convert_to_lora,
create_arrow_model,
get_eva_state_dict,
get_kasa_regularization_loss,
initialize_lora_eva_weights,
preprocess_loraga,
save_as_lora,
Expand Down Expand Up @@ -103,6 +105,7 @@
"HiraModel",
"IA3Config",
"IA3Model",
"KasaConfig",
"LNTuningConfig",
"LNTuningModel",
"LilyConfig",
Expand Down Expand Up @@ -165,6 +168,7 @@
"convert_to_lora",
"create_arrow_model",
"get_eva_state_dict",
"get_kasa_regularization_loss",
"initialize_lora_eva_weights",
"preprocess_loraga",
"save_as_lora",
Expand Down
4 changes: 4 additions & 0 deletions src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ArrowConfig,
BdLoraConfig,
EvaConfig,
KasaConfig,
LoftQConfig,
LoraConfig,
LoraGAConfig,
Expand All @@ -33,6 +34,7 @@
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer, ParamWrapper
from .loraga import preprocess_loraga
from .model import LoraModel
from .variants import get_kasa_regularization_loss


__all__ = [
Expand All @@ -43,6 +45,7 @@
"Embedding",
"EvaConfig",
"GPTQLoraLinear",
"KasaConfig",
"Linear",
"LoftQConfig",
"LoraConfig",
Expand All @@ -56,6 +59,7 @@
"convert_to_lora",
"create_arrow_model",
"get_eva_state_dict",
"get_kasa_regularization_loss",
"initialize_lora_eva_weights",
"preprocess_loraga",
"save_as_lora",
Expand Down
76 changes: 76 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,61 @@ def __post_init__(self):
raise ValueError("`tau` must be between 0.0 and 1.0.")


@dataclass
class KasaConfig:
"""
This is the sub-configuration class to store the configuration for KaSA (Knowledge-aware Singular-value Adaptation,
[arXiv:2412.06071](https://huggingface.co/papers/2412.06071)).

KaSA is a LoRA variant that (1) refines the frozen base weight by discarding its `r` smallest ("noisy"/long-tail)
singular components via a one-time SVD truncation, and (2) parametrizes the trainable update in SVD form, inserting
a learnable diagonal of singular values `lora_diag` (`ΔΣ`) between the LoRA `A` and `B` factors, i.e. `ΔW = scaling
* B @ diag(ΔΣ) @ A`. The paper additionally trains with two auxiliary regularizers (a singular-value L2 penalty and
an orthogonal regularization on the adapter factors); these are exposed through
[`~peft.get_kasa_regularization_loss`] and must be added to the task loss by the user, since the PEFT variant API
has no channel to inject an extra loss term into the training loop.

Args:
beta (`float`):
Coefficient `β` for the L2 (singular-value) regularization `||ΔΣ||_F^2 = sum(lora_diag ** 2)` (paper Eq.
9-10). Only takes effect if the user adds [`~peft.get_kasa_regularization_loss`] to their loss. Defaults to
`1e-4` (the value used in the reference GLUE configs).
gamma (`float`):
Coefficient `γ` for the orthogonal regularization `||B^T B - I||_F + ||A A^T - I||_F` (paper Eq. 11), which
softly enforces the semi-orthogonality of `ΔU`/`ΔV` assumed by the SVD parametrization. Only takes effect
if the user adds [`~peft.get_kasa_regularization_loss`] to their loss. Defaults to `1e-3` (the reference
GLUE value).
"""

beta: float = field(
default=1e-4,
metadata={
"help": (
"Coefficient `β` for the L2 (singular-value) regularization `sum(lora_diag ** 2)` (KaSA paper "
"Eq. 9-10). Only takes effect if the user adds `get_kasa_regularization_loss` to their training loss. "
"Defaults to 1e-4 (the reference GLUE value)."
)
},
)
gamma: float = field(
default=1e-3,
metadata={
"help": (
"Coefficient `γ` for the orthogonal regularization `||B^T B - I||_F + ||A A^T - I||_F` (KaSA paper "
"Eq. 11), which softly enforces the semi-orthogonality of the adapter factors assumed by the SVD "
"parametrization. Only takes effect if the user adds `get_kasa_regularization_loss` to their training "
"loss. Defaults to 1e-3 (the reference GLUE value)."
)
},
)

def __post_init__(self):
if self.beta < 0:
raise ValueError(f"`beta` must be non-negative, got {self.beta}.")
if self.gamma < 0:
raise ValueError(f"`gamma` must be non-negative, got {self.gamma}.")


@dataclass
class CordaConfig:
"""
Expand Down Expand Up @@ -478,6 +533,10 @@ class LoraConfig(PeftConfig):
velora_config (`Optional[VeloraConfig]`):
Enable VeLoRA by providing a VeloraConfig. VeLoRA swaps in a custom backward pass for the LoRA A projection
that stores compressed activations instead of the full input activations.
kasa_config (`Optional[KasaConfig]`):
Enable KaSA (Knowledge-aware Singular-value Adaptation) by providing a KasaConfig. KaSA truncates the `r`
smallest singular components of the frozen base weight via a one-time SVD and inserts a learnable diagonal
of singular values between the LoRA A and B factors. Currently only linear layers are supported.
alora_invocation_tokens (`List[int]`):
If not None, enable <a href='https://huggingface.co/papers/2504.12397'>'Activated LoRA' (aLoRA)</a>, with
alora_invocation_tokens being the tokenized invocation string for the adapter (must be present in all model
Expand Down Expand Up @@ -833,6 +892,18 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
kasa_config: Optional[KasaConfig] = field(
default=None,
metadata={
"help": (
"Enable KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant by providing a "
"`KasaConfig`. KaSA truncates the `r` smallest singular components of the frozen base weight via a "
"one-time SVD and parametrizes the trainable update with a learnable diagonal of singular values "
"(`lora_diag`) inserted between the LoRA A and B factors. Currently only linear layers are "
"supported. See `KasaConfig` for details on the individual hyperparameters."
)
},
)
ensure_weight_tying: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -874,6 +945,11 @@ def __post_init__(self):
elif self.velora_config is not None and not isinstance(self.velora_config, VeloraConfig):
raise TypeError("`velora_config` must be a `VeloraConfig`, a dict, or None.")

if isinstance(self.kasa_config, dict):
self.kasa_config = KasaConfig(**self.kasa_config)
elif self.kasa_config is not None and not isinstance(self.kasa_config, KasaConfig):
raise TypeError("`kasa_config` must be a `KasaConfig`, a dict, or None.")

if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")

Expand Down
19 changes: 19 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,11 @@ def __init__(
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.kasa_config is not None:
from .variants import KasaLinearVariant

return KasaLinearVariant()

if config.velora_config is not None:
from .variants import VeloraLinearVariant

Expand Down Expand Up @@ -1064,6 +1069,8 @@ def __init__(
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting embedding layers.")
if config.kasa_config is not None:
raise ValueError("KaSA does not support adapting embedding layers.")
if not config.use_dora:
return None

Expand Down Expand Up @@ -1379,6 +1386,8 @@ def __init__(
raise ValueError("aLoRA does not support adapting conv layers.")
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if config.kasa_config is not None:
raise ValueError("KaSA does not support adapting conv layers.")
if base_layer.groups > 1:
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")

Expand Down Expand Up @@ -1673,6 +1682,8 @@ def __init__(self, *args, **kwargs):
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if config.kasa_config is not None:
raise ValueError("KaSA does not support adapting conv layers.")
if not config.use_dora:
return None

Expand All @@ -1692,6 +1703,8 @@ def __init__(self, *args, **kwargs):
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if config.kasa_config is not None:
raise ValueError("KaSA does not support adapting conv layers.")
if not config.use_dora:
return None

Expand All @@ -1711,6 +1724,8 @@ def __init__(self, *args, **kwargs):
def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if config.kasa_config is not None:
raise ValueError("KaSA does not support adapting conv layers.")
if not config.use_dora:
return None

Expand Down Expand Up @@ -1756,6 +1771,8 @@ def __init__(
raise ValueError(f"{self.__class__.__name__} does not support DoRA (yet), please set use_dora to False")
if config.velora_config is not None:
raise ValueError(f"{self.__class__.__name__} does not support VeLoRA, please set `velora_config=None`.")
if config.kasa_config is not None:
raise ValueError(f"{self.__class__.__name__} does not support KaSA, please set `kasa_config=None`.")
if kwargs.get("use_alora", False):
raise ValueError(f"{self.__class__.__name__} does not support aLoRA (yet), please set use_alora to False")
super().__init__()
Expand Down Expand Up @@ -2172,6 +2189,8 @@ def __init__(
raise ValueError(f"lora.{self.__class__.__name__} does not work with use_dora=True.")
if config.velora_config is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not work when `velora_config` is set.")
if config.kasa_config is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not work when `kasa_config` is set.")
if is_target_conv_1d_layer:
raise ValueError(f"lora.{self.__class__.__name__} does not work with is_target_conv_1d_layer=True.")

Expand Down
Loading