Skip to content

ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298

Draft
robbiebusinessacc wants to merge 1 commit into
huggingface:mainfrom
robbiebusinessacc:contrib/add-kasa-lora-variant
Draft

ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298
robbiebusinessacc wants to merge 1 commit into
huggingface:mainfrom
robbiebusinessacc:contrib/add-kasa-lora-variant

Conversation

@robbiebusinessacc

Copy link
Copy Markdown

Closes part of #2516 (Call for contribution: KaSA).

Implements KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework from #2443, following the SVD-based variants (CorDA/DoRA) and without adding if-branches to core LoRA logic. Reference
implementation: https://github.com/juyongjiang/KaSA.

Method

KaSA changes vanilla LoRA in two ways:

  1. Knowledge-based SVD truncation of the frozen base weight (one-time, destructive). At init the base weight W is SVD-factored and its r smallest ("noisy"/long-tail) singular components are discarded, leaving the rank-(k-r) approximation
    (k = min(in_features, out_features)) as the new frozen base. The trainable branch re-learns in the discarded residual subspace.
  2. Knowledge-aware singular-value adaptation (trainable update). A learnable diagonal of singular values lora_diag (ΔΣ) is inserted between LoRA A and B: ΔW = scaling * B @ diag(ΔΣ) @ A. lora_diag is the only new per-layer parameter
    (an r-vector); B stays zero-init so the update is 0 at step 0.

The paper additionally trains with two auxiliary regularizers — an L2 penalty on the singular values (sum(lora_diag**2)) and an orthogonal regularization ||B^T B - I||_F + ||A A^T - I||_F that softly enforces the semi-orthogonality the SVD
parametrization assumes. The PEFT variant forward has no channel to inject a scalar into the training loss, so these are exposed via a get_kasa_regularization_loss(model) helper that the user adds to their task loss.

Integration

  • New KasaConfig sub-config (beta, gamma, both validated non-negative; reference GLUE defaults 1e-4 / 1e-3) and LoraConfig.kasa_config field. Selection is driven by kasa_config is not None in resolve_lora_variant (config-object
    pattern, mirrors velora/monteclora/arrow), with dict-coercion + TypeError guard in __post_init__.
  • KasaLinearVariant(LoraVariant) implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge, get_delta_weight. lora_diag is registered in adapter_layer_names so it is saved/loaded.
  • Explicit guards reject KaSA on Embedding / Conv / MultiheadAttention / ParamWrapper / fan_in_fan_out (Conv1D) layers, consistent with the VeLoRA guard pattern.
  • Generalizes the reference's svd_rank = in_features - r to min(in_features, out_features) - r so wide layers are handled correctly; raises ValueError if r >= min(in, out).
  • Top-level exports of KasaConfig and get_kasa_regularization_loss.

Tests (tests/test_kasa.py, CPU-only, tiny random nn.Linear, no downloads)

Config dispatch/dict-round-trip/alias/error cases; lora_diag shape+learnable; B zero-init; SVD-truncation rank check (exactly r singular values zeroed, principal values preserved); truncation changes base forward; merge/unmerge round-trip to
the truncated base; delta-weight formula; lora_diag in state_dict + save/load forward equivalence; reload onto the original base re-truncates deterministically (parametrized over low_cpu_mem_usage=False/True, idempotent across forwards);
regularization closed-form (L2/L3), orthonormal vs non-orthonormal, gradients. Plus wiring in tests/test_lora_variants.py.

Open questions for maintainers (honest)

  • Destructive base mutation breaks the usual "disable adapter == base" contract. Adding/disabling/unloading does not restore W0, and merge/unmerge round-trips to the truncated weight. This is inherent to KaSA. Do you want the original
    weight stashed to allow a true unload, or is this semantics acceptable as-is (documented)?
  • Regularization location. The LoRA-variant API has no loss-return channel, so L2/L3 are exposed as get_kasa_regularization_loss for the user to add to their loss (the reference computes them in external training scripts). Is a
    free-function helper the API you want, or a method on the PEFT model? Without these terms the SVD interpretation is only approximate, so they are implemented and unit-tested rather than dropped — this was the correctness gap behind
    Add KaSA implementation to layer.py #2543/[WIP] Update LoraConfig for KaSA implementation #2698.
  • Save/load story. Reloading the adapter onto the original base re-applies the deterministic truncation (verified for both the default and low_cpu_mem_usage paths). Reloading onto an already-truncated/merged base would double-truncate.
    Want a docs note / explicit guard?
  • lora_diag uses randn(r) per the reference; the paper says "randomly initialized without bias". Safe because B is zero-init. Confirm if ones is preferred.
  • Scope is nn.Linear only for now (quantized/Conv/MHA explicitly rejected). OK to land Linear-first?

No docs page added yet; happy to add one if you'd like it in this PR.

Implement KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071)
using the LoRA-variant framework, following the SVD-based variants (CorDA/DoRA).

KaSA changes vanilla LoRA in two ways:
- A one-time, destructive SVD truncation of the frozen base weight that drops
  its r smallest singular components, leaving the rank-(k-r) approximation as
  the new frozen base (k = min(in_features, out_features)).
- A learnable diagonal of singular values (lora_diag) inserted between the LoRA
  A and B factors, so the update is ΔW = scaling * B @ diag(lora_diag) @ A.

- New KasaConfig sub-config (beta, gamma) and LoraConfig.kasa_config field;
  selection is driven by kasa_config being non-None via resolve_lora_variant,
  with explicit guards rejecting KaSA on embedding/conv/MHA/ParamWrapper and
  fan_in_fan_out layers.
- KasaLinearVariant implements init (SVD truncation + lora_diag), forward,
  merge_safe/merge_unsafe/unmerge. lora_diag is registered in
  adapter_layer_names so it is saved/loaded.
- get_kasa_regularization_loss helper exposes the paper's two auxiliary terms
  (L2 singular-value penalty + L3 orthogonal regularization), since the variant
  forward has no channel to inject an extra loss into the training loop.
- Tests in tests/test_kasa.py (SVD-truncation faithfulness, lora_diag shape,
  zero-init update, merge/unmerge round-trip, delta-weight formula, save/load,
  regularization closed-form checks) plus wiring in tests/test_lora_variants.py.

Faithfulness notes:
- The base-weight truncation is destructive; disabling/unloading does not
  restore the original weight and merge/unmerge round-trips to the truncated
  base. This is inherent to the method and documented.
- The paper's L2/L3 regularizers are required for the SVD interpretation to
  hold but cannot be auto-injected; users must add get_kasa_regularization_loss
  to their loss.
@BenjaminBossan

Copy link
Copy Markdown
Member

Thanks for opening this PR to add KaSA to PEFT. As @iambogeumkim has already put a lot of time into implementing it, I would like to first give them the opportunity to pick up the work. So if you're reading this, please let me know if you plan to resume your work.

Let's wait for ~2 weeks for a reply. If there is none, we can continue with this PR. Is that okay for you @robbiebusinessacc?

@robbiebusinessacc

Copy link
Copy Markdown
Author

Thanks @BenjaminBossan, that's completely fair and totally fine with me — let's wait for @iambogeumkim. They started this and should get the first chance to finish it. Happy to continue if they don't return, and either way they're welcome to use anything from this PR.

@iambogeumkim

Copy link
Copy Markdown

Hi @BenjaminBossan, thank you so much for reaching out and waiting for me!

I hadn't been able to work on this for several months due to personal reasons. Looking back at my work log, it seems I had stopped at the testing stage just before merge, where I ran into some errors.

I'd like to pick this back up and aim to get it merged by the end of June. Let's discuss whether it's just a matter of fixing the test errors, or whether there are additional parts of the code that need to be updated.

Also, a big thank you to @robbiebusinessacc for bringing this PR back up — I really appreciate it 🙌

@BenjaminBossan

Copy link
Copy Markdown
Member

@iambogeumkim Yes, it was also my impression that not much is missing. So if you plan to resume your work, I expect that it won't take too long to finish it.

@robbiebusinessacc If @iambogeumkim resumes their work, feel free to provide your feedback there too based on what you learned. If you make substantial contributions and there are no objections from @iambogeumkim, I would add you as a co-author when merging the PR.

@robbiebusinessacc

robbiebusinessacc commented Jun 5, 2026

Copy link
Copy Markdown
Author

Welcome back @iambogeumkim! To your question — beyond the test errors, the main missing piece in my attempt was the two paper regularizers (the L2 on the singular values + the orthogonal term on the A/B factors), since without them the SVD parametrization is only approximate. It's all in this PR, exposed via a get_kasa_regularization_loss helper, alongside a CPU test suite for the truncation / merge-unmerge / save-load / regularizers — so feel free to lift anything useful, and happy to help port it to yours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants