Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,12 @@ class BdLoraConfig:
"Example: ['out_proj', 'down_proj']"
},
)
nblocks: int = (
field(
default=1,
metadata={
"help": "Number of blocks each block-diagonal matrix has. If using BD-LoRA to speed up inference, "
"set it to be equal to the desired sharding degree during serving."
},
),
nblocks: int = field(
default=1,
metadata={
"help": "Number of blocks each block-diagonal matrix has. If using BD-LoRA to speed up inference, "
"set it to be equal to the desired sharding degree during serving."
},
)
Comment on lines +219 to 225

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore this

match_strict: bool = field(
default=True,
Expand Down
19 changes: 18 additions & 1 deletion tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)
from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING
from peft.tuners.lokr.layer import LoKrLayer
from peft.tuners.lora.config import CordaConfig
from peft.tuners.lora.config import BdLoraConfig, CordaConfig
from peft.tuners.lora.corda import preprocess_corda
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device
Expand Down Expand Up @@ -1233,6 +1233,23 @@ def test_bdlora_feature_size_non_divisible_by_blocksize_raises(self):
with pytest.raises(ValueError, match="not divisible by"):
get_peft_model(model, config)

def test_bdlora_default_nblocks_is_int(self):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need dedicated tests for this. Our normal tests should already surface the error but they don't because they all override nblocks=2, so the default is never used:

(
"BD-LoRA A only",
"MLP",
LoraConfig,
{
"target_modules": ["lin0", "lin1"],
"use_bdlora": BdLoraConfig(target_modules_bd_a=["lin0"], nblocks=2, match_strict=False),
},
),
(
"BD-LoRA B only",
"MLP",
LoraConfig,
{
"target_modules": ["lin0", "lin1"],
"use_bdlora": BdLoraConfig(target_modules_bd_b=["lin1"], nblocks=2, match_strict=False),
},
),
(
"BD-LoRA both A and B",
"MLP",
LoraConfig,
{
"target_modules": ["lin0", "lin1"],
"use_bdlora": BdLoraConfig(target_modules_bd_a=["lin0"], target_modules_bd_b=["lin1"], nblocks=2),
},
),

So to trigger the error, we can just remove nblocks=2 from one of these settings.

# The default value of BdLoraConfig.nblocks must be the integer 1, not a wrapped field object. A misplaced
# trailing comma previously turned the default into a 1-tuple, which broke get_peft_model whenever the user
# did not pass nblocks explicitly.
assert BdLoraConfig().nblocks == 1

def test_bdlora_get_peft_model_with_default_nblocks(self):
# Building a BD-LoRA model without specifying nblocks must work and rely on the documented default of 1.
# This used to raise a TypeError because nblocks defaulted to a tuple instead of an int.
model = self.get_model()

bdlora_config = {"target_modules_bd_a": ["linear"], "target_modules_bd_b": []}
config = LoraConfig(target_modules=["linear"], use_bdlora=bdlora_config)
model = get_peft_model(model, config)

assert model.linear.lora_A["default"].nblocks == 1

@pytest.fixture
def mha_cls(self):
class ModelMha(nn.Module):
Expand Down
Loading