diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 7227e869dd..bb6133c868 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -147,7 +147,7 @@ def _check_lora_target_modules_mamba(peft_config: PeftConfig, model: nn.Module, lora_like_types = {"LORA", "ADALORA", "XLORA", "RANDLORA"} incompatible_modules = {"out_proj", "conv1d"} - mamba_model_types = {"falcon_h1", "mamba", "mamba2", "falcon_mamba"} + mamba_model_types = {"falcon_h1", "mamba", "mamba2", "falcon_mamba", "nemotron_h"} if ( peft_config.peft_type in lora_like_types diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 94aa475f0d..f31970a9a4 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -100,6 +100,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "gemma4": r".*language_model\..*\.(q_proj|v_proj)", "qwen2": ["q_proj", "v_proj"], "qwen3": ["q_proj", "v_proj"], + "nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"], "rwkv": ["key", "value", "receptance", "output"], "rwkv7": ["r_proj", "k_proj", "v_proj", "o_proj", "key", "value"], } @@ -293,6 +294,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "gemma4": r".*language_model\..*\.(q_proj|v_proj)", "qwen2": ["q_proj", "v_proj"], "qwen3": ["q_proj", "v_proj"], + "nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"], } TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING = { @@ -318,6 +320,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "gemma4": r".*language_model\..*\.(q_proj|v_proj)", "qwen2": ["q_proj", "v_proj"], "qwen3": ["q_proj", "v_proj"], + "nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"], } TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING = { @@ -377,6 +380,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "gemma4": r".*language_model\..*\.(q_proj|v_proj)", "qwen2": ["q_proj", "v_proj"], "qwen3": ["q_proj", "v_proj"], + "nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"], } ################## diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 7a809a6b4c..e52d4e033f 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -7008,3 +7008,49 @@ def test_default_target_modules_osf(self): model = get_peft_model(model, config) assert model.targeted_module_names == ["lin0", "lin1"] assert model.peft_config["default"].target_modules == {"lin0", "lin1"} + + def test_default_target_modules_nemotron_h(self): + # Nemotron-H is a hybrid Mamba + MoE + Attention architecture. Defaults + # target the attention projections only (q/k/v/o_proj) and must avoid + # `out_proj` / `conv1d` which belong to the Mamba mixer and are blocked + # by the Mamba compatibility check (PR #2562). + from peft.utils.constants import ( + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING, + ) + + expected = ["q_proj", "k_proj", "v_proj", "o_proj"] + for mapping in ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING, + ): + assert mapping.get("nemotron_h") == expected + + forbidden = {"out_proj", "conv1d"} + assert set(expected).isdisjoint(forbidden) + + def test_nemotron_h_blocks_mamba_modules(self): + # nemotron_h must be in the Mamba forbidden-module check so applying + # LoRA to `out_proj` or `conv1d` raises (those belong to the Mamba + # mixer, not attention). + class FakeNemotronH(nn.Module): + def __init__(self): + super().__init__() + self.out_proj = nn.Linear(8, 8) + + def forward(self, x): + return self.out_proj(x) + + model = FakeNemotronH() + mock_config = MagicMock() + mock_config.to_dict.return_value = {"model_type": "nemotron_h"} + mock_config.model_type = "nemotron_h" + model.config = mock_config + + peft_config = LoraConfig(target_modules=["out_proj"]) + with pytest.raises(ValueError, match="incompatible with Mamba-based"): + get_peft_model(model, peft_config)