[bugfix] fix grpo target_parameters & chord device#9525
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the device alignment logic in patch_lora_merge and patch_lora_unmerge by utilizing PEFT's type-agnostic _move_adapter_to_device_of_base_layer method instead of manual device handling. It also updates compute_chord_loss to move SFT inputs to the CPU before collation. The review feedback suggests adding fallback logic in both merge and unmerge patches to ensure backward compatibility with older PEFT versions (prior to 0.12.0) that do not implement _move_adapter_to_device_of_base_layer.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| for active_adapter in check_adapters_to_merge(self, adapter_names) or []: | ||
| # Align adapter sublayers (lora_A/B, DoRA magnitude, ...) to the base device. | ||
| # Type-agnostic: ParamWrapper overrides this to use get_param().device. | ||
| self._move_adapter_to_device_of_base_layer(active_adapter) |
There was a problem hiding this comment.
To ensure compatibility with older versions of PEFT (prior to 0.12.0), we should check if _move_adapter_to_device_of_base_layer is available on the LoraLayer instance before calling it. If it is not present, we can fall back to manually aligning the adapter sublayers to the base device.
| for active_adapter in check_adapters_to_merge(self, adapter_names) or []: | |
| # Align adapter sublayers (lora_A/B, DoRA magnitude, ...) to the base device. | |
| # Type-agnostic: ParamWrapper overrides this to use get_param().device. | |
| self._move_adapter_to_device_of_base_layer(active_adapter) | |
| for active_adapter in check_adapters_to_merge(self, adapter_names) or []: | |
| # Align adapter sublayers (lora_A/B, DoRA magnitude, ...) to the base device. | |
| # Type-agnostic: ParamWrapper overrides this to use get_param().device. | |
| if hasattr(self, '_move_adapter_to_device_of_base_layer'): | |
| self._move_adapter_to_device_of_base_layer(active_adapter) | |
| else: | |
| base_layer = self.get_base_layer() | |
| device = None | |
| if hasattr(base_layer, 'get_param'): | |
| device = base_layer.get_param().device | |
| elif hasattr(base_layer, 'weight'): | |
| device = base_layer.weight.device | |
| if device is not None: | |
| for attr in ['lora_A', 'lora_B', 'lora_embedding_A', 'lora_embedding_B', 'lora_magnitude_vector']: | |
| if hasattr(self, attr): | |
| module_dict = getattr(self, attr) | |
| if active_adapter in module_dict: | |
| module_dict[active_adapter].to(device) |
| for adapter in list(self.merged_adapters): | ||
| if self.use_dora.get(adapter, False): | ||
| self.lora_magnitude_vector[adapter].weight.data = \ | ||
| self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device) | ||
|
|
||
| self._move_adapter_to_device_of_base_layer(adapter) |
There was a problem hiding this comment.
Similarly, to ensure compatibility with older versions of PEFT, we should check if _move_adapter_to_device_of_base_layer is available before calling it in unmerge_patched.
| for adapter in list(self.merged_adapters): | |
| if self.use_dora.get(adapter, False): | |
| self.lora_magnitude_vector[adapter].weight.data = \ | |
| self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device) | |
| self._move_adapter_to_device_of_base_layer(adapter) | |
| for adapter in list(self.merged_adapters): | |
| if hasattr(self, '_move_adapter_to_device_of_base_layer'): | |
| self._move_adapter_to_device_of_base_layer(adapter) | |
| else: | |
| base_layer = self.get_base_layer() | |
| device = None | |
| if hasattr(base_layer, 'get_param'): | |
| device = base_layer.get_param().device | |
| elif hasattr(base_layer, 'weight'): | |
| device = base_layer.weight.device | |
| if device is not None: | |
| for attr in ['lora_A', 'lora_B', 'lora_embedding_A', 'lora_embedding_B', 'lora_magnitude_vector']: | |
| if hasattr(self, attr): | |
| module_dict = getattr(self, attr) | |
| if adapter in module_dict: | |
| module_dict[adapter].to(device) |
fix #9466 #9521