Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _check_lora_target_modules_mamba(peft_config: PeftConfig, model: nn.Module,

lora_like_types = {"LORA", "ADALORA", "XLORA", "RANDLORA"}
incompatible_modules = {"out_proj", "conv1d"}
mamba_model_types = {"falcon_h1", "mamba", "mamba2", "falcon_mamba"}
mamba_model_types = {"falcon_h1", "mamba", "mamba2", "falcon_mamba", "nemotron_h"}

if (
peft_config.peft_type in lora_like_types
Expand Down
4 changes: 4 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gemma4": r".*language_model\..*\.(q_proj|v_proj)",
"qwen2": ["q_proj", "v_proj"],
"qwen3": ["q_proj", "v_proj"],
"nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"],
"rwkv": ["key", "value", "receptance", "output"],
"rwkv7": ["r_proj", "k_proj", "v_proj", "o_proj", "key", "value"],
}
Expand Down Expand Up @@ -293,6 +294,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gemma4": r".*language_model\..*\.(q_proj|v_proj)",
"qwen2": ["q_proj", "v_proj"],
"qwen3": ["q_proj", "v_proj"],
"nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"],
}

TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING = {
Expand All @@ -318,6 +320,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gemma4": r".*language_model\..*\.(q_proj|v_proj)",
"qwen2": ["q_proj", "v_proj"],
"qwen3": ["q_proj", "v_proj"],
"nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"],
}

TRANSFORMERS_MODELS_TO_OSF_TARGET_MODULES_MAPPING = {
Expand Down Expand Up @@ -377,6 +380,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gemma4": r".*language_model\..*\.(q_proj|v_proj)",
"qwen2": ["q_proj", "v_proj"],
"qwen3": ["q_proj", "v_proj"],
"nemotron_h": ["q_proj", "k_proj", "v_proj", "o_proj"],
}

##################
Expand Down
46 changes: 46 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7008,3 +7008,49 @@ def test_default_target_modules_osf(self):
model = get_peft_model(model, config)
assert model.targeted_module_names == ["lin0", "lin1"]
assert model.peft_config["default"].target_modules == {"lin0", "lin1"}

def test_default_target_modules_nemotron_h(self):
# Nemotron-H is a hybrid Mamba + MoE + Attention architecture. Defaults
# target the attention projections only (q/k/v/o_proj) and must avoid
# `out_proj` / `conv1d` which belong to the Mamba mixer and are blocked
# by the Mamba compatibility check (PR #2562).
from peft.utils.constants import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING,
)

expected = ["q_proj", "k_proj", "v_proj", "o_proj"]
for mapping in (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING,
):
assert mapping.get("nemotron_h") == expected

forbidden = {"out_proj", "conv1d"}
assert set(expected).isdisjoint(forbidden)

def test_nemotron_h_blocks_mamba_modules(self):
# nemotron_h must be in the Mamba forbidden-module check so applying
# LoRA to `out_proj` or `conv1d` raises (those belong to the Mamba
# mixer, not attention).
class FakeNemotronH(nn.Module):
def __init__(self):
super().__init__()
self.out_proj = nn.Linear(8, 8)

def forward(self, x):
return self.out_proj(x)

model = FakeNemotronH()
mock_config = MagicMock()
mock_config.to_dict.return_value = {"model_type": "nemotron_h"}
mock_config.model_type = "nemotron_h"
model.config = mock_config

peft_config = LoraConfig(target_modules=["out_proj"])
with pytest.raises(ValueError, match="incompatible with Mamba-based"):
get_peft_model(model, peft_config)