ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298
ENH Add KaSA (Knowledge-aware Singular-value Adaptation) as a LoRA variant#3298robbiebusinessacc wants to merge 1 commit into
Conversation
Implement KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework, following the SVD-based variants (CorDA/DoRA). KaSA changes vanilla LoRA in two ways: - A one-time, destructive SVD truncation of the frozen base weight that drops its r smallest singular components, leaving the rank-(k-r) approximation as the new frozen base (k = min(in_features, out_features)). - A learnable diagonal of singular values (lora_diag) inserted between the LoRA A and B factors, so the update is ΔW = scaling * B @ diag(lora_diag) @ A. - New KasaConfig sub-config (beta, gamma) and LoraConfig.kasa_config field; selection is driven by kasa_config being non-None via resolve_lora_variant, with explicit guards rejecting KaSA on embedding/conv/MHA/ParamWrapper and fan_in_fan_out layers. - KasaLinearVariant implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge. lora_diag is registered in adapter_layer_names so it is saved/loaded. - get_kasa_regularization_loss helper exposes the paper's two auxiliary terms (L2 singular-value penalty + L3 orthogonal regularization), since the variant forward has no channel to inject an extra loss into the training loop. - Tests in tests/test_kasa.py (SVD-truncation faithfulness, lora_diag shape, zero-init update, merge/unmerge round-trip, delta-weight formula, save/load, regularization closed-form checks) plus wiring in tests/test_lora_variants.py. Faithfulness notes: - The base-weight truncation is destructive; disabling/unloading does not restore the original weight and merge/unmerge round-trips to the truncated base. This is inherent to the method and documented. - The paper's L2/L3 regularizers are required for the SVD interpretation to hold but cannot be auto-injected; users must add get_kasa_regularization_loss to their loss.
|
Thanks for opening this PR to add KaSA to PEFT. As @iambogeumkim has already put a lot of time into implementing it, I would like to first give them the opportunity to pick up the work. So if you're reading this, please let me know if you plan to resume your work. Let's wait for ~2 weeks for a reply. If there is none, we can continue with this PR. Is that okay for you @robbiebusinessacc? |
|
Thanks @BenjaminBossan, that's completely fair and totally fine with me — let's wait for @iambogeumkim. They started this and should get the first chance to finish it. Happy to continue if they don't return, and either way they're welcome to use anything from this PR. |
|
Hi @BenjaminBossan, thank you so much for reaching out and waiting for me! I hadn't been able to work on this for several months due to personal reasons. Looking back at my work log, it seems I had stopped at the testing stage just before merge, where I ran into some errors. I'd like to pick this back up and aim to get it merged by the end of June. Let's discuss whether it's just a matter of fixing the test errors, or whether there are additional parts of the code that need to be updated. Also, a big thank you to @robbiebusinessacc for bringing this PR back up — I really appreciate it 🙌 |
|
@iambogeumkim Yes, it was also my impression that not much is missing. So if you plan to resume your work, I expect that it won't take too long to finish it. @robbiebusinessacc If @iambogeumkim resumes their work, feel free to provide your feedback there too based on what you learned. If you make substantial contributions and there are no objections from @iambogeumkim, I would add you as a co-author when merging the PR. |
|
Welcome back @iambogeumkim! To your question — beyond the test errors, the main missing piece in my attempt was the two paper regularizers (the L2 on the singular values + the orthogonal term on the A/B factors), since without them the SVD parametrization is only approximate. It's all in this PR, exposed via a get_kasa_regularization_loss helper, alongside a CPU test suite for the truncation / merge-unmerge / save-load / regularizers — so feel free to lift anything useful, and happy to help port it to yours. |
Closes part of #2516 (Call for contribution: KaSA).
Implements KaSA (Knowledge-aware Singular-value Adaptation, arXiv:2412.06071) using the LoRA-variant framework from #2443, following the SVD-based variants (CorDA/DoRA) and without adding if-branches to core LoRA logic. Reference
implementation: https://github.com/juyongjiang/KaSA.
Method
KaSA changes vanilla LoRA in two ways:
(k = min(in_features, out_features)) as the new frozen base. The trainable branch re-learns in the discarded residual subspace.
lora_diag(ΔΣ) is inserted between LoRA A and B: ΔW = scaling * B @ diag(ΔΣ) @ A.lora_diagis the only new per-layer parameter(an r-vector); B stays zero-init so the update is 0 at step 0.
The paper additionally trains with two auxiliary regularizers — an L2 penalty on the singular values (sum(lora_diag**2)) and an orthogonal regularization ||B^T B - I||_F + ||A A^T - I||_F that softly enforces the semi-orthogonality the SVD
parametrization assumes. The PEFT variant
forwardhas no channel to inject a scalar into the training loss, so these are exposed via aget_kasa_regularization_loss(model)helper that the user adds to their task loss.Integration
KasaConfigsub-config (beta,gamma, both validated non-negative; reference GLUE defaults 1e-4 / 1e-3) andLoraConfig.kasa_configfield. Selection is driven bykasa_config is not Noneinresolve_lora_variant(config-objectpattern, mirrors velora/monteclora/arrow), with dict-coercion + TypeError guard in
__post_init__.KasaLinearVariant(LoraVariant)implements init (SVD truncation + lora_diag), forward, merge_safe/merge_unsafe/unmerge, get_delta_weight.lora_diagis registered inadapter_layer_namesso it is saved/loaded.svd_rank = in_features - rtomin(in_features, out_features) - rso wide layers are handled correctly; raises ValueError if r >= min(in, out).KasaConfigandget_kasa_regularization_loss.Tests (tests/test_kasa.py, CPU-only, tiny random nn.Linear, no downloads)
Config dispatch/dict-round-trip/alias/error cases; lora_diag shape+learnable; B zero-init; SVD-truncation rank check (exactly r singular values zeroed, principal values preserved); truncation changes base forward; merge/unmerge round-trip to
the truncated base; delta-weight formula; lora_diag in state_dict + save/load forward equivalence; reload onto the original base re-truncates deterministically (parametrized over low_cpu_mem_usage=False/True, idempotent across forwards);
regularization closed-form (L2/L3), orthonormal vs non-orthonormal, gradients. Plus wiring in tests/test_lora_variants.py.
Open questions for maintainers (honest)
weight stashed to allow a true unload, or is this semantics acceptable as-is (documented)?
get_kasa_regularization_lossfor the user to add to their loss (the reference computes them in external training scripts). Is afree-function helper the API you want, or a method on the PEFT model? Without these terms the SVD interpretation is only approximate, so they are implemented and unit-tested rather than dropped — this was the correctness gap behind
Add KaSA implementation to layer.py #2543/[WIP] Update
LoraConfigfor KaSA implementation #2698.Want a docs note / explicit guard?
No docs page added yet; happy to add one if you'd like it in this PR.