Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_megatron_server_arguments.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_logprob_response_spans.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_cispo_loss.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "test_empty_colocated_weight_bucket.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
{'test_file': 'test_rollout_validation.py', 'num_gpus': 0},
{'test_file': 'test_placement_group.py', 'num_gpus': 0},
{'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0},
{'test_file': 'test_empty_colocated_weight_bucket.py', 'num_gpus': 0},
{'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0},
{'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _send_to_colocated_engine(
long_live_tensors = []

if getattr(FlattenedTensorBucket, "supports_multi_dtypes", False):
converted_named_tensors_by_dtypes = {"dtype": hf_named_tensors}
converted_named_tensors_by_dtypes = {"dtype": hf_named_tensors} if hf_named_tensors else {}
else:
converted_named_tensors_by_dtypes = {}
for name, tensor in hf_named_tensors:
Expand Down Expand Up @@ -263,14 +263,33 @@ def _send_to_colocated_engine(

refs = []
if dist.get_rank() == ipc_gather_src:
# TODO: here we assume all ranks have the same number of dtypes, not sure if that is correct.
num_dtypes = len(serialized_named_tensors[0])
for i in range(num_dtypes):
num_buckets = max(len(tensors) for tensors in serialized_named_tensors)
empty_serialized_tensor = None
for i in range(num_buckets):
serialized_tensors_for_dtype = []
for tensors in serialized_named_tensors:
if i < len(tensors):
serialized_tensors_for_dtype.append(tensors[i])
continue

if empty_serialized_tensor is None:
empty_tensor_data = _empty_flattened_tensor_data()
long_live_tensors.append(empty_tensor_data)
empty_serialized_tensor = MultiprocessingSerializer.serialize(empty_tensor_data, output_str=True)
serialized_tensors_for_dtype.append(empty_serialized_tensor)

kwargs = {
"serialized_named_tensors": [tensors[i] for tensors in serialized_named_tensors],
"serialized_named_tensors": serialized_tensors_for_dtype,
"load_format": "flattened_bucket",
"weight_version": str(weight_version),
}
refs.append(ipc_engine.update_weights_from_tensor.remote(**kwargs))

return refs, long_live_tensors


def _empty_flattened_tensor_data():
return {
"flattened_tensor": torch.empty(0, dtype=torch.uint8, device=torch.cuda.current_device()),
"metadata": [],
}
204 changes: 204 additions & 0 deletions tests/test_empty_colocated_weight_bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import importlib.util
import sys
import types
from pathlib import Path

import pytest

NUM_GPUS = 0

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))


class _FakeFlattenedTensorBucket:
supports_multi_dtypes = True

def __init__(self, *, named_tensors=None, flattened_tensor=None, metadata=None):
if named_tensors is not None:
if not named_tensors:
raise ValueError("Cannot create empty tensor bucket")
self._flattened_tensor = ("flattened", tuple(name for name, _ in named_tensors))
self._metadata = tuple(name for name, _ in named_tensors)
return

self._flattened_tensor = flattened_tensor
self._metadata = metadata

def get_flattened_tensor(self):
return self._flattened_tensor

def get_metadata(self):
return self._metadata


class _FakeMultiprocessingSerializer:
@staticmethod
def serialize(value, output_str):
assert output_str is True
return value


class _FakeRemoteMethod:
def __init__(self):
self.calls = []

def remote(self, **kwargs):
self.calls.append(kwargs)
return f"ref-{len(self.calls)}"


class _FakeEngine:
def __init__(self):
self.update_weights_from_tensor = _FakeRemoteMethod()


def _install_fake_deps(monkeypatch):
dist_state = types.SimpleNamespace(rank=0, world_size=2, gathered=None, local_object=None)

slime_pkg = types.ModuleType("slime")
slime_pkg.__path__ = [str(REPO_ROOT / "slime")]
slime_backends_pkg = types.ModuleType("slime.backends")
slime_backends_pkg.__path__ = [str(REPO_ROOT / "slime" / "backends")]
megatron_utils_pkg = types.ModuleType("slime.backends.megatron_utils")
megatron_utils_pkg.__path__ = [str(REPO_ROOT / "slime" / "backends" / "megatron_utils")]
update_weight_pkg = types.ModuleType("slime.backends.megatron_utils.update_weight")
update_weight_pkg.__path__ = [str(REPO_ROOT / "slime" / "backends" / "megatron_utils" / "update_weight")]
slime_utils_pkg = types.ModuleType("slime.utils")
slime_utils_pkg.__path__ = [str(REPO_ROOT / "slime" / "utils")]

dist_mod = types.ModuleType("torch.distributed")

def gather_object(obj, object_gather_list, dst, group):
dist_state.local_object = obj
if object_gather_list is not None:
object_gather_list[:] = dist_state.gathered(obj)

dist_mod.get_rank = lambda: dist_state.rank
dist_mod.get_world_size = lambda group=None: dist_state.world_size
dist_mod.gather_object = gather_object

torch_mod = types.ModuleType("torch")
torch_mod.Tensor = object
torch_mod.uint8 = "uint8"
torch_mod.distributed = dist_mod
torch_mod.empty = lambda size, dtype, device: {"size": size, "dtype": dtype, "device": device}
torch_mod.no_grad = lambda: (lambda fn: fn)
torch_mod.cuda = types.SimpleNamespace(current_device=lambda: "cuda:0", ipc_collect=lambda: None)
torch_mod.nn = types.SimpleNamespace(Module=object)

ray_mod = types.ModuleType("ray")
ray_mod.ObjectRef = object
ray_actor_mod = types.ModuleType("ray.actor")
ray_actor_mod.ActorHandle = object

mpu_mod = types.ModuleType("megatron.core.mpu")
megatron_mod = types.ModuleType("megatron")
megatron_core_mod = types.ModuleType("megatron.core")
megatron_core_mod.mpu = mpu_mod

sglang_mod = types.ModuleType("slime.backends.megatron_utils.sglang")
sglang_mod.FlattenedTensorBucket = _FakeFlattenedTensorBucket
sglang_mod.MultiprocessingSerializer = _FakeMultiprocessingSerializer

distributed_utils_mod = types.ModuleType("slime.utils.distributed_utils")
distributed_utils_mod.get_gloo_group = lambda: object()

update_from_distributed_mod = types.ModuleType(
"slime.backends.megatron_utils.update_weight.update_weight_from_distributed"
)
update_from_distributed_mod.connect_rollout_engines_from_distributed = lambda *args, **kwargs: None
update_from_distributed_mod.disconnect_rollout_engines_from_distributed = lambda *args, **kwargs: None
update_from_distributed_mod.post_process_weights = lambda *args, **kwargs: None
update_from_distributed_mod.update_weights_from_distributed = lambda *args, **kwargs: []

monkeypatch.setitem(sys.modules, "slime", slime_pkg)
monkeypatch.setitem(sys.modules, "slime.backends", slime_backends_pkg)
monkeypatch.setitem(sys.modules, "slime.backends.megatron_utils", megatron_utils_pkg)
monkeypatch.setitem(sys.modules, "slime.backends.megatron_utils.update_weight", update_weight_pkg)
monkeypatch.setitem(sys.modules, "slime.utils", slime_utils_pkg)
monkeypatch.setitem(sys.modules, "torch", torch_mod)
monkeypatch.setitem(sys.modules, "torch.distributed", dist_mod)
monkeypatch.setitem(sys.modules, "ray", ray_mod)
monkeypatch.setitem(sys.modules, "ray.actor", ray_actor_mod)
monkeypatch.setitem(sys.modules, "megatron", megatron_mod)
monkeypatch.setitem(sys.modules, "megatron.core", megatron_core_mod)
monkeypatch.setitem(sys.modules, "megatron.core.mpu", mpu_mod)
monkeypatch.setitem(sys.modules, "slime.backends.megatron_utils.sglang", sglang_mod)
monkeypatch.setitem(sys.modules, "slime.utils.distributed_utils", distributed_utils_mod)
monkeypatch.setitem(
sys.modules,
"slime.backends.megatron_utils.update_weight.update_weight_from_distributed",
update_from_distributed_mod,
)

return dist_state


def _load_update_weight_module(monkeypatch):
dist_state = _install_fake_deps(monkeypatch)

module_name = "slime.backends.megatron_utils.update_weight.update_weight_from_tensor"
sys.modules.pop(module_name, None)
module_path = (
REPO_ROOT / "slime" / "backends" / "megatron_utils" / "update_weight" / "update_weight_from_tensor.py"
)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, module_name, module)
assert spec.loader is not None
spec.loader.exec_module(module)
return module, dist_state


def test_empty_colocated_bucket_still_participates_in_gather(monkeypatch):
module, dist_state = _load_update_weight_module(monkeypatch)
dist_state.gathered = lambda local: [local, []]
engine = _FakeEngine()

refs, long_lived_tensors = module._send_to_colocated_engine(
[],
ipc_engine=engine,
ipc_gather_src=0,
ipc_gather_group=object(),
weight_version=3,
)

assert dist_state.local_object == []
assert refs == []
assert long_lived_tensors == []
assert engine.update_weights_from_tensor.calls == []


def test_source_rank_pads_empty_colocated_bucket_entries(monkeypatch):
module, dist_state = _load_update_weight_module(monkeypatch)
remote_serialized_bucket = {"flattened_tensor": ("remote",), "metadata": ("remote_weight",)}
dist_state.gathered = lambda local: [local, [remote_serialized_bucket]]
engine = _FakeEngine()

refs, long_lived_tensors = module._send_to_colocated_engine(
[],
ipc_engine=engine,
ipc_gather_src=0,
ipc_gather_group=object(),
weight_version=7,
)

assert refs == ["ref-1"]
assert len(long_lived_tensors) == 1
empty_bucket = long_lived_tensors[0]
assert empty_bucket["metadata"] == []
assert empty_bucket["flattened_tensor"] == {"size": 0, "dtype": "uint8", "device": "cuda:0"}

assert engine.update_weights_from_tensor.calls == [
{
"serialized_named_tensors": [empty_bucket, remote_serialized_bucket],
"load_format": "flattened_bucket",
"weight_version": "7",
}
]


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading