Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.integrations import init_empty_weights
from peft.utils.other import TrainableTokensWrapper, create_attention_mask, set_additional_trainable_modules
from peft.utils.save_and_load import _validate_lora_adapter_state_dict

from . import __version__
from .config import PeftConfig
Expand Down Expand Up @@ -319,6 +320,14 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
adapter_name=adapter_name,
save_embedding_layers=save_embedding_layers,
)

# Refuse to write LoRA adapters whose lora_A / lora_B tensors look unsharded (1-D or zero-sized). This is
# the canonical signature of an export that ran without gathering DeepSpeed ZeRO-3 / FSDP shards; the
# artifact otherwise looks valid on disk but breaks downstream loaders such as vLLM hot-swap with a
# confusing IndexError. Surface the failure here, where the caller can act on it. See
# vllm-project/vllm#28640 and huggingface/transformers#45313 for the historical failure mode.
_validate_lora_adapter_state_dict(output_state_dict, adapter_name=adapter_name)

output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
os.makedirs(output_dir, exist_ok=True)

Expand Down
52 changes: 52 additions & 0 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,58 @@
from .peft_types import PeftType


def _validate_lora_adapter_state_dict(
state_dict: dict[str, torch.Tensor],
adapter_name: str = "default",
) -> None:
"""Refuse to save a LoRA adapter state dict whose ``lora_A`` / ``lora_B`` tensors look unsharded.

A correctly-saved LoRA adapter has ``lora_A`` of shape ``(rank, in_dim)`` and ``lora_B`` of shape ``(out_dim,
rank)`` — both 2-D and non-empty. If any such tensor is 1-D or has a zero-sized dimension, the upstream caller
almost always has a partitioned-parameter export under DeepSpeed ZeRO-3 / FSDP that did not gather model shards
before ``save_pretrained``. The artifact looks structurally valid (filenames + ``adapter_config.json`` present),
but downstream loaders fail with confusing index errors at the first attempted use, e.g.::

IndexError: too many indices for tensor of dimension 1

in vLLM's ``slice_lora_b`` during hot-swap (vllm-project/vllm#28640).

Raise here so the failure surfaces at write time, with an actionable hint, instead of corrupting the artifact and
deferring the crash to whatever later loads it. Only the ``lora_A`` / ``lora_B`` keys are inspected; legitimately
1-D parameters such as DoRA's ``lora_magnitude_vector`` and AdaLoRA's ``lora_E`` are not affected.

Args:
state_dict: LoRA-only state dict that will be written to disk.
adapter_name: Adapter name being saved, used in the error message only.

Raises:
ValueError: When at least one ``lora_A`` / ``lora_B`` tensor is empty or has fewer than 2 dimensions.
"""
bad: list[tuple[str, tuple[int, ...]]] = []
for name, tensor in state_dict.items():
if not isinstance(tensor, torch.Tensor):
continue
if ".lora_A" not in name and ".lora_B" not in name:
continue
shape = tuple(tensor.shape)
if len(shape) < 2 or any(d == 0 for d in shape):
bad.append((name, shape))

if not bad:
return

preview = ", ".join(f"{n}:{s}" for n, s in bad[:3])
suffix = f" (+{len(bad) - 3} more)" if len(bad) > 3 else ""
raise ValueError(
f"Adapter {adapter_name!r}: {len(bad)} LoRA tensor(s) have invalid shape: "
f"{preview}{suffix}. A 1-D or zero-sized lora_A / lora_B tensor indicates that DeepSpeed ZeRO-3 / FSDP shards "
"were not gathered before save_pretrained() — downstream loaders such as vLLM hot-swap will then fail with "
"IndexError at first use. Wrap the save in deepspeed.zero.GatheredParameters([...], modifier_rank=None) for "
"ZeRO-3, or torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params(model) for FSDP, before "
"calling save_pretrained()."
)


def has_valid_embedding_base_layer(layer):
"""Check if the layer has an embedding base layer"""
return hasattr(layer, "base_layer") and isinstance(layer.base_layer, (torch.nn.Linear, torch.nn.Embedding))
Expand Down
85 changes: 85 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,91 @@ def test_no_infinite_recursion(self, cls, model, wrap_init):
cls.__init__ = original_init


class TestSaveValidatesLoraShapes:
# Regression for the silent-corruption failure mode described in
# vllm-project/vllm#28640 and huggingface/transformers#45313: when PEFT runs save_pretrained() under
# DeepSpeed ZeRO-3 / FSDP without first gathering shards, lora_A / lora_B end up 1-D or shape (0,) on disk.
# The artifact looks valid but breaks downstream loaders such as vLLM hot-swap with a confusing IndexError.
# _validate_lora_adapter_state_dict raises at write time with an actionable hint instead.

def test_validator_accepts_well_formed_state_dict(self):
from peft.utils.save_and_load import _validate_lora_adapter_state_dict

good = {
"base_model.model.q_proj.lora_A.weight": torch.zeros(8, 64),
"base_model.model.q_proj.lora_B.weight": torch.zeros(64, 8),
# DoRA's magnitude vector is legitimately 1-D and should NOT be flagged.
"base_model.model.q_proj.lora_magnitude_vector": torch.zeros(64),
# Non-LoRA keys are ignored.
"base_model.model.q_proj.bias": torch.zeros(64),
}
# Must not raise.
_validate_lora_adapter_state_dict(good, adapter_name="default")

@pytest.mark.parametrize(
"bad_shape",
[
(0,), # safetensors-emitted empty tensor (Z3 rank not owning the slice)
(8,), # 1-D rank-local shard
],
)
def test_validator_rejects_unsharded_lora_a(self, bad_shape):
from peft.utils.save_and_load import _validate_lora_adapter_state_dict

bad = {
"base_model.model.q_proj.lora_A.weight": torch.zeros(*bad_shape),
"base_model.model.q_proj.lora_B.weight": torch.zeros(64, 8),
}
with pytest.raises(ValueError, match=r"DeepSpeed ZeRO-3 / FSDP shards"):
_validate_lora_adapter_state_dict(bad, adapter_name="default")

def test_validator_rejects_zero_sized_lora_b(self):
from peft.utils.save_and_load import _validate_lora_adapter_state_dict

bad = {
"base_model.model.q_proj.lora_A.weight": torch.zeros(8, 64),
"base_model.model.q_proj.lora_B.weight": torch.zeros(0, 8),
}
with pytest.raises(ValueError, match=r"lora_B"):
_validate_lora_adapter_state_dict(bad, adapter_name="default")

def test_validator_error_includes_adapter_name(self):
from peft.utils.save_and_load import _validate_lora_adapter_state_dict

bad = {"base_model.model.q_proj.lora_A.weight": torch.zeros(8)}
with pytest.raises(ValueError, match=r"'my-adapter'"):
_validate_lora_adapter_state_dict(bad, adapter_name="my-adapter")

def test_save_pretrained_raises_for_unsharded_state_dict(self, tmp_path):
# End-to-end: feeding a deliberately-broken state_dict into save_pretrained's kwargs path triggers the
# validator at the same call site real Z3 / FSDP saves would.
class _Tiny(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)

peft_model = get_peft_model(_Tiny(), LoraConfig(target_modules=["linear"], r=4))

# Include both legitimate (kept) and broken (lora_A 1-D) keys; the broken one should trip the validator.
broken_state_dict = {
"base_model.model.linear.lora_A.default.weight": torch.zeros(0),
"base_model.model.linear.lora_B.default.weight": torch.zeros(8, 4),
}
with pytest.raises(ValueError, match=r"DeepSpeed ZeRO-3 / FSDP shards"):
peft_model.save_pretrained(tmp_path / "broken-adapter", state_dict=broken_state_dict)

def test_save_pretrained_succeeds_for_normal_lora(self, tmp_path):
# Sanity: the validator does NOT regress the happy path.
class _Tiny(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)

peft_model = get_peft_model(_Tiny(), LoraConfig(target_modules=["linear"], r=4))
peft_model.save_pretrained(tmp_path / "ok-adapter")
assert (tmp_path / "ok-adapter" / "adapter_config.json").exists()


class TestLoadAdapterOfflineMode:
base_model = "peft-internal-testing/tiny-random-OPTForCausalLM"
peft_model_id = "peft-internal-testing/tiny-OPTForCausalLM-lora"
Expand Down