diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 85345b6781..6af166e52c 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1083,6 +1083,15 @@ def strip_base_layer_from_name(module_name): def create_and_replace_param(module_name, key, param_name): # helper function to avoid duplication + if module_name == "": + # nn.Parameters that are registered directly on the top-level module (i.e. the module passed to + # get_peft_model) cannot be targeted. Wrapping the parameter would require replacing the module that + # holds it with lora.ParamWrapper, but that module is its own parent, so the wrapper ends up registered + # as a submodule of the very module it wraps. This creates a cyclic module graph, resulting in an error. + raise ValueError( + f"Targeting an nn.Parameter on the top-level module is not supported (parameter '{param_name}'). " + ) + parent, target, target_name = _get_submodules(model, module_name) unwrapped_module_name = strip_base_layer_from_name(module_name) unwrapped_module = model.get_submodule(unwrapped_module_name) diff --git a/tests/test_target_parameters.py b/tests/test_target_parameters.py index 6a31e8aa78..af97074f98 100644 --- a/tests/test_target_parameters.py +++ b/tests/test_target_parameters.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + import pytest import torch from torch import nn @@ -598,3 +600,20 @@ def test_target_parameter_init_does_not_warn_about_unknown_layer_type(self, recw warn_messages = (w.message.args[0] for w in recwarn.list) msg_start = "Unsupported layer type" assert not any(msg.startswith(msg_start) for msg in warn_messages) + + def test_target_parameter_on_top_level_module_raises(self): + # nn.Parameters that are registered directly on the top-level module (i.e. the module passed to get_peft_model) + # cannot be targeted. Wrapping the parameter would require replacing the module that holds it with + # lora.ParamWrapper, but that module is its own parent, so the wrapper ends up registered as a submodule of the + # very module it wraps. This creates a cyclic module graph, resulting in an error. + + class MyModule(nn.Module): + # module with a 2d and a 3d nn.Parameter registered directly on the top-level module + def __init__(self): + super().__init__() + self.param = nn.Parameter(torch.zeros(10, 10)) + + config = LoraConfig(target_parameters=["param"]) + msg = re.escape("Targeting an nn.Parameter on the top-level module is not supported (parameter 'param')") + with pytest.raises(ValueError, match=msg): + get_peft_model(MyModule(), config)