diff --git a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index edfd690482..8db37eb9b2 100644 --- a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -26,9 +26,9 @@ def _patch_bridge_expert_cache_to_cpu(): _orig = GPTOSSBridge.maybe_modify_converted_hf_weight - def _patched(self, task, converted_weights_dict): + def _patched(self, task, converted_weights_dict, hf_state_dict=None): cpu_dict = {k: v.cpu() for k, v in converted_weights_dict.items()} - result = _orig(self, task, cpu_dict) + result = _orig(self, task, cpu_dict, hf_state_dict) # Move merged result back to GPU for CUDA IPC serialization return {k: v.cuda() for k, v in result.items()} if result else result