TST Add tests for hotswapping targeted LoRA parameters#3304
TST Add tests for hotswapping targeted LoRA parameters#3304BenjaminBossan wants to merge 1 commit into
Conversation
The question recently came up if hotswapping works with target_parameters. Therefore, I added a test to check it. It turns out that it works indeed. The usefulness is, however, somewhat reduced because targeting parameters while using torch.compile (compiled models are a main use case for hotswapping) leads to re-compilation and/or graph breaks. This is a fundamental limitation of how targeting nn.Parameters is implemented, using nn.utils.parametrize to dynamically update the targeted nn.Parameter. We can't update it statically, since that would break all kinds of things (e.g. accessing the parameter with model.foo.bar would return the parameter *after* applying the LoRA delta weight). Therefore, we must undo the parametrization after the forward step, and this breaks compilation. This PR additionally documents the fundamental problem with torch.compile and target_parameters. It also removes an unused argument in a test and an incorrect comment.
BenjaminBossan
left a comment
There was a problem hiding this comment.
I marked the changes that are not strictly related to the PR.
| disabled. The main use case for this is when the LoRA weights were extracted from fully fine-tuned | ||
| parameters so the bias of those parameters can be taken into account. | ||
| target_parameters (`List[str]`, *optional*) | ||
| List of parameter names or regex expression of the parameter names to replace with LoRA. This argument |
There was a problem hiding this comment.
Unrelated change: Regex is not supported here.
| default=None, | ||
| metadata={ | ||
| "help": ( | ||
| "List of module names or regex expression of the module names to replace with LoRA. " |
There was a problem hiding this comment.
Unrelated change: Regex is not supported here.
| torch.manual_seed(0) | ||
| return ConvModel().to(self.torch_device) | ||
|
|
||
| # this works with all adapters except prompt learning, but we don't test all |
There was a problem hiding this comment.
Unrelated change: The comment is just not true.
| @pytest.mark.parametrize( | ||
| "config", | ||
| [ | ||
| LoraConfig(init_lora_weights=0, target_modules=["lin0"]), |
There was a problem hiding this comment.
Unrelated change: init_lora_weights=0 and init_lora_weights=False have the same effect, but since we expect a bool here, let's not pass 0.
| # Ensure that using LoRA directly with a v5 model still works | ||
| inputs = torch.arange(10).view(1, -1).to(device=self.torch_device) | ||
| model_id = "hf-internal-testing/Mixtral-tiny" | ||
| lora_id = "peft-internal-testing/mixtral-pre-v5-lora" |
There was a problem hiding this comment.
Unrelated change: This argument was unused.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The question recently came up if hotswapping works with
target_parameters. Therefore, I added a test to check it. It turns out that it works indeed.The usefulness is, however, somewhat reduced because targeting parameters while using
torch.compile(compiled models are a main use case for hotswapping) leads to re-compilation and/or graph breaks. This is a fundamental limitation of how targetingnn.Parameters is implemented, usingnn.utils.parametrizeto dynamically update the targetednn.Parameter. We can't update it statically, since that would break all kinds of things (e.g. accessing the parameter withmodel.foo.barwould return the parameter after applying the LoRA delta weight). Therefore, we must undo the parametrization after the forward step, and this breaks compilation.This PR additionally documents the fundamental problem with
torch.compileandtarget_parameters. It also changes a wrong argument type in a test, removes an unused argument in another test, as well as an incorrect comment.