From 7117b7a930344bb69b98a98939ef5d7f12c42d5f Mon Sep 17 00:00:00 2001 From: EazyReal <8047065+EazyReal@users.noreply.github.com> Date: Fri, 26 Jun 2026 10:27:07 +0000 Subject: [PATCH] fix: handle empty colocated weight buckets --- .github/workflows/pr-test.yml | 2 +- .github/workflows/pr-test.yml.j2 | 1 + .../update_weight_from_tensor.py | 29 ++- tests/test_empty_colocated_weight_bucket.py | 204 ++++++++++++++++++ 4 files changed, 230 insertions(+), 6 deletions(-) create mode 100644 tests/test_empty_colocated_weight_bucket.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2d6e8ce5b2..3e332d7351 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -372,7 +372,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "test_empty_colocated_weight_bucket.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 30cf386421..107c7536cf 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -81,6 +81,7 @@ {'test_file': 'test_rollout_validation.py', 'num_gpus': 0}, {'test_file': 'test_placement_group.py', 'num_gpus': 0}, {'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0}, + {'test_file': 'test_empty_colocated_weight_bucket.py', 'num_gpus': 0}, {'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0}, diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 724e05355b..f8efd69c0a 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -231,7 +231,7 @@ def _send_to_colocated_engine( long_live_tensors = [] if getattr(FlattenedTensorBucket, "supports_multi_dtypes", False): - converted_named_tensors_by_dtypes = {"dtype": hf_named_tensors} + converted_named_tensors_by_dtypes = {"dtype": hf_named_tensors} if hf_named_tensors else {} else: converted_named_tensors_by_dtypes = {} for name, tensor in hf_named_tensors: @@ -263,14 +263,33 @@ def _send_to_colocated_engine( refs = [] if dist.get_rank() == ipc_gather_src: - # TODO: here we assume all ranks have the same number of dtypes, not sure if that is correct. - num_dtypes = len(serialized_named_tensors[0]) - for i in range(num_dtypes): + num_buckets = max(len(tensors) for tensors in serialized_named_tensors) + empty_serialized_tensor = None + for i in range(num_buckets): + serialized_tensors_for_dtype = [] + for tensors in serialized_named_tensors: + if i < len(tensors): + serialized_tensors_for_dtype.append(tensors[i]) + continue + + if empty_serialized_tensor is None: + empty_tensor_data = _empty_flattened_tensor_data() + long_live_tensors.append(empty_tensor_data) + empty_serialized_tensor = MultiprocessingSerializer.serialize(empty_tensor_data, output_str=True) + serialized_tensors_for_dtype.append(empty_serialized_tensor) + kwargs = { - "serialized_named_tensors": [tensors[i] for tensors in serialized_named_tensors], + "serialized_named_tensors": serialized_tensors_for_dtype, "load_format": "flattened_bucket", "weight_version": str(weight_version), } refs.append(ipc_engine.update_weights_from_tensor.remote(**kwargs)) return refs, long_live_tensors + + +def _empty_flattened_tensor_data(): + return { + "flattened_tensor": torch.empty(0, dtype=torch.uint8, device=torch.cuda.current_device()), + "metadata": [], + } diff --git a/tests/test_empty_colocated_weight_bucket.py b/tests/test_empty_colocated_weight_bucket.py new file mode 100644 index 0000000000..39afe74f23 --- /dev/null +++ b/tests/test_empty_colocated_weight_bucket.py @@ -0,0 +1,204 @@ +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + +NUM_GPUS = 0 + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +class _FakeFlattenedTensorBucket: + supports_multi_dtypes = True + + def __init__(self, *, named_tensors=None, flattened_tensor=None, metadata=None): + if named_tensors is not None: + if not named_tensors: + raise ValueError("Cannot create empty tensor bucket") + self._flattened_tensor = ("flattened", tuple(name for name, _ in named_tensors)) + self._metadata = tuple(name for name, _ in named_tensors) + return + + self._flattened_tensor = flattened_tensor + self._metadata = metadata + + def get_flattened_tensor(self): + return self._flattened_tensor + + def get_metadata(self): + return self._metadata + + +class _FakeMultiprocessingSerializer: + @staticmethod + def serialize(value, output_str): + assert output_str is True + return value + + +class _FakeRemoteMethod: + def __init__(self): + self.calls = [] + + def remote(self, **kwargs): + self.calls.append(kwargs) + return f"ref-{len(self.calls)}" + + +class _FakeEngine: + def __init__(self): + self.update_weights_from_tensor = _FakeRemoteMethod() + + +def _install_fake_deps(monkeypatch): + dist_state = types.SimpleNamespace(rank=0, world_size=2, gathered=None, local_object=None) + + slime_pkg = types.ModuleType("slime") + slime_pkg.__path__ = [str(REPO_ROOT / "slime")] + slime_backends_pkg = types.ModuleType("slime.backends") + slime_backends_pkg.__path__ = [str(REPO_ROOT / "slime" / "backends")] + megatron_utils_pkg = types.ModuleType("slime.backends.megatron_utils") + megatron_utils_pkg.__path__ = [str(REPO_ROOT / "slime" / "backends" / "megatron_utils")] + update_weight_pkg = types.ModuleType("slime.backends.megatron_utils.update_weight") + update_weight_pkg.__path__ = [str(REPO_ROOT / "slime" / "backends" / "megatron_utils" / "update_weight")] + slime_utils_pkg = types.ModuleType("slime.utils") + slime_utils_pkg.__path__ = [str(REPO_ROOT / "slime" / "utils")] + + dist_mod = types.ModuleType("torch.distributed") + + def gather_object(obj, object_gather_list, dst, group): + dist_state.local_object = obj + if object_gather_list is not None: + object_gather_list[:] = dist_state.gathered(obj) + + dist_mod.get_rank = lambda: dist_state.rank + dist_mod.get_world_size = lambda group=None: dist_state.world_size + dist_mod.gather_object = gather_object + + torch_mod = types.ModuleType("torch") + torch_mod.Tensor = object + torch_mod.uint8 = "uint8" + torch_mod.distributed = dist_mod + torch_mod.empty = lambda size, dtype, device: {"size": size, "dtype": dtype, "device": device} + torch_mod.no_grad = lambda: (lambda fn: fn) + torch_mod.cuda = types.SimpleNamespace(current_device=lambda: "cuda:0", ipc_collect=lambda: None) + torch_mod.nn = types.SimpleNamespace(Module=object) + + ray_mod = types.ModuleType("ray") + ray_mod.ObjectRef = object + ray_actor_mod = types.ModuleType("ray.actor") + ray_actor_mod.ActorHandle = object + + mpu_mod = types.ModuleType("megatron.core.mpu") + megatron_mod = types.ModuleType("megatron") + megatron_core_mod = types.ModuleType("megatron.core") + megatron_core_mod.mpu = mpu_mod + + sglang_mod = types.ModuleType("slime.backends.megatron_utils.sglang") + sglang_mod.FlattenedTensorBucket = _FakeFlattenedTensorBucket + sglang_mod.MultiprocessingSerializer = _FakeMultiprocessingSerializer + + distributed_utils_mod = types.ModuleType("slime.utils.distributed_utils") + distributed_utils_mod.get_gloo_group = lambda: object() + + update_from_distributed_mod = types.ModuleType( + "slime.backends.megatron_utils.update_weight.update_weight_from_distributed" + ) + update_from_distributed_mod.connect_rollout_engines_from_distributed = lambda *args, **kwargs: None + update_from_distributed_mod.disconnect_rollout_engines_from_distributed = lambda *args, **kwargs: None + update_from_distributed_mod.post_process_weights = lambda *args, **kwargs: None + update_from_distributed_mod.update_weights_from_distributed = lambda *args, **kwargs: [] + + monkeypatch.setitem(sys.modules, "slime", slime_pkg) + monkeypatch.setitem(sys.modules, "slime.backends", slime_backends_pkg) + monkeypatch.setitem(sys.modules, "slime.backends.megatron_utils", megatron_utils_pkg) + monkeypatch.setitem(sys.modules, "slime.backends.megatron_utils.update_weight", update_weight_pkg) + monkeypatch.setitem(sys.modules, "slime.utils", slime_utils_pkg) + monkeypatch.setitem(sys.modules, "torch", torch_mod) + monkeypatch.setitem(sys.modules, "torch.distributed", dist_mod) + monkeypatch.setitem(sys.modules, "ray", ray_mod) + monkeypatch.setitem(sys.modules, "ray.actor", ray_actor_mod) + monkeypatch.setitem(sys.modules, "megatron", megatron_mod) + monkeypatch.setitem(sys.modules, "megatron.core", megatron_core_mod) + monkeypatch.setitem(sys.modules, "megatron.core.mpu", mpu_mod) + monkeypatch.setitem(sys.modules, "slime.backends.megatron_utils.sglang", sglang_mod) + monkeypatch.setitem(sys.modules, "slime.utils.distributed_utils", distributed_utils_mod) + monkeypatch.setitem( + sys.modules, + "slime.backends.megatron_utils.update_weight.update_weight_from_distributed", + update_from_distributed_mod, + ) + + return dist_state + + +def _load_update_weight_module(monkeypatch): + dist_state = _install_fake_deps(monkeypatch) + + module_name = "slime.backends.megatron_utils.update_weight.update_weight_from_tensor" + sys.modules.pop(module_name, None) + module_path = ( + REPO_ROOT / "slime" / "backends" / "megatron_utils" / "update_weight" / "update_weight_from_tensor.py" + ) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, module_name, module) + assert spec.loader is not None + spec.loader.exec_module(module) + return module, dist_state + + +def test_empty_colocated_bucket_still_participates_in_gather(monkeypatch): + module, dist_state = _load_update_weight_module(monkeypatch) + dist_state.gathered = lambda local: [local, []] + engine = _FakeEngine() + + refs, long_lived_tensors = module._send_to_colocated_engine( + [], + ipc_engine=engine, + ipc_gather_src=0, + ipc_gather_group=object(), + weight_version=3, + ) + + assert dist_state.local_object == [] + assert refs == [] + assert long_lived_tensors == [] + assert engine.update_weights_from_tensor.calls == [] + + +def test_source_rank_pads_empty_colocated_bucket_entries(monkeypatch): + module, dist_state = _load_update_weight_module(monkeypatch) + remote_serialized_bucket = {"flattened_tensor": ("remote",), "metadata": ("remote_weight",)} + dist_state.gathered = lambda local: [local, [remote_serialized_bucket]] + engine = _FakeEngine() + + refs, long_lived_tensors = module._send_to_colocated_engine( + [], + ipc_engine=engine, + ipc_gather_src=0, + ipc_gather_group=object(), + weight_version=7, + ) + + assert refs == ["ref-1"] + assert len(long_lived_tensors) == 1 + empty_bucket = long_lived_tensors[0] + assert empty_bucket["metadata"] == [] + assert empty_bucket["flattened_tensor"] == {"size": 0, "dtype": "uint8", "device": "cuda:0"} + + assert engine.update_weights_from_tensor.calls == [ + { + "serialized_named_tensors": [empty_bucket, remote_serialized_bucket], + "load_format": "flattened_bucket", + "weight_version": "7", + } + ] + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))