From ed33a9951303ffef3848ea787ad946de76bf83ce Mon Sep 17 00:00:00 2001 From: Tai An Date: Sun, 7 Jun 2026 06:07:08 -0700 Subject: [PATCH] fix(save_and_load): include base_layer.bias for bias=lora_only/boft_only When a tuner targets a layer, the original module is wrapped and its bias lives at .base_layer.bias. The previous key reconstruction produced .bias, so get_peft_model_state_dict and save_pretrained silently dropped the trained bias for bias="lora_only" (and bias="boft_only"), breaking adapter round-trips. Check the base_layer.bias name as well, keeping the legacy .bias name for backward compatibility. Fixes #3306 --- src/peft/utils/save_and_load.py | 17 +++++++++------ tests/test_initialization.py | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index a366e7c4d2..1f323622fe 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -145,9 +145,13 @@ def get_peft_model_state_dict( for k in state_dict: if "lora_" in k: to_return[k] = state_dict[k] - bias_name = k.split("lora_")[0] + "bias" - if bias_name in state_dict: - to_return[bias_name] = state_dict[bias_name] + prefix = k.split("lora_")[0] + # The trainable bias of a targeted layer lives under the wrapped `base_layer` + # (current tuner-layer structure); older state dicts may store it directly on + # the module, so check both names. + for bias_name in (prefix + "base_layer.bias", prefix + "bias"): + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] else: raise NotImplementedError to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} @@ -182,9 +186,10 @@ def renamed_dora_weights(k): for k in state_dict: if "boft_" in k: to_return[k] = state_dict[k] - bias_name = k.split("boft_")[0] + "bias" - if bias_name in state_dict: - to_return[bias_name] = state_dict[bias_name] + prefix = k.split("boft_")[0] + for bias_name in (prefix + "base_layer.bias", prefix + "bias"): + if bias_name in state_dict: + to_return[bias_name] = state_dict[bias_name] else: raise NotImplementedError diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 3be97bfab8..fa0344338e 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1371,6 +1371,44 @@ def test_lora_with_bias_argument(self, bias): # only layers targeted with target_modules assert param.requires_grad is ("linear" in name) or ("conv2d" in name) + def test_lora_only_bias_is_saved_and_reloaded(self, tmp_path): + # See https://github.com/huggingface/peft/issues/3306: with bias="lora_only" the trained + # base_layer bias of targeted modules lives at ".base_layer.bias" in the current + # tuner-layer structure. It must be included in the saved adapter, otherwise reloading does + # not reproduce the trained outputs. + from peft.utils import get_peft_model_state_dict + + bias_key = "base_model.model.linear.base_layer.bias" + + torch.manual_seed(0) + model = self.get_model() + config = LoraConfig(target_modules=["linear"], bias="lora_only") + model = get_peft_model(model, config) + + # simulate training by perturbing every trainable parameter, including the base_layer bias + with torch.no_grad(): + for _, param in model.named_parameters(): + if param.requires_grad: + param.add_(torch.randn_like(param) * 0.1) + + state_dict = get_peft_model_state_dict(model) + assert bias_key in state_dict + + data = torch.rand(10, 1000).to(self.torch_device) + with torch.no_grad(): + output_before = model(data)[0] + + model.save_pretrained(tmp_path) + saved = load_file(tmp_path / "adapter_model.safetensors") + assert bias_key in saved + + # reload into a freshly initialized base model with identical base weights (same seed) + torch.manual_seed(0) + reloaded = PeftModel.from_pretrained(self.get_model(), tmp_path) + with torch.no_grad(): + output_after = reloaded(data)[0] + assert torch.allclose(output_before, output_after) + def test_lora_with_bias_extra_params(self): # lora with lora_bias=True model = self.get_model()