From 4ee00782744d29447963700c1c1e70f3c68dcd7d Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Mon, 11 May 2026 23:17:55 +0800 Subject: [PATCH 1/4] feat: add Mooncake DataProto rollout transfer Add an optional mooncake_dataproto transfer backend that publishes rollout tensor fields through Mooncake while preserving slime's existing rollout data layout and Ray default path. Co-Authored-By: Claude Opus 4.6 --- .../advanced/mooncake-dataproto-transfer.md | 30 ++ docs/en/index.rst | 1 + slime/ray/rollout.py | 6 + slime/utils/arguments.py | 32 ++ slime/utils/data.py | 10 +- slime/utils/remote_batch.py | 322 +++++++++++++++++ slime/utils/rollout_dataproto.py | 335 ++++++++++++++++++ tests/utils/test_dataproto_transfer.py | 163 +++++++++ train.py | 21 +- train_async.py | 24 +- 10 files changed, 927 insertions(+), 17 deletions(-) create mode 100644 docs/en/advanced/mooncake-dataproto-transfer.md create mode 100644 slime/utils/remote_batch.py create mode 100644 slime/utils/rollout_dataproto.py create mode 100644 tests/utils/test_dataproto_transfer.py diff --git a/docs/en/advanced/mooncake-dataproto-transfer.md b/docs/en/advanced/mooncake-dataproto-transfer.md new file mode 100644 index 0000000000..572146e43b --- /dev/null +++ b/docs/en/advanced/mooncake-dataproto-transfer.md @@ -0,0 +1,30 @@ +# Mooncake DataProto Rollout Transfer + +slime can transfer rollout data through Mooncake instead of Ray object references. This is useful when the rollout producer and actor consumer run on different nodes and Mooncake Store is configured for the cluster transport. + +The default transfer backend remains Ray. Enable Mooncake DataProto transfer explicitly: + +```bash +python3 train.py \ + --transfer-backend mooncake_dataproto \ + --mooncake-dataproto-store-init-kwargs '{"setup_method":"setup"}' +``` + +## What is transferred + +The Mooncake path keeps slime's rollout data layout unchanged: + +- per-rank rollout partitions are still selected by slime before actor consumption; +- tensor fields such as `tokens` and `loss_masks` are stored as Mooncake remote tensor batches; +- non-tensor rollout fields and metadata stay in the `DataProto` wrapper; +- cleanup keys are tracked in metadata and removed after actor-side materialization. + +## Options + +| Option | Default | Meaning | +| --- | --- | --- | +| `--transfer-backend` | `ray` | Set to `mooncake_dataproto` to enable Mooncake rollout transfer. | +| `--mooncake-dataproto-store-init-kwargs` | `null` | JSON arguments used to initialize the Mooncake store. Use `{"setup_method":"setup"}` for real Mooncake Store setup and `{"setup_method":"setup_dummy"}` for local unit tests. | +| `--mooncake-dataproto-hard-pin` | `true` | Hard-pin remote tensor data to the producer segment when publishing tensor batches. | + +For performance runs, configure Mooncake Store with the production transport, for example RDMA, and keep buffer registration or prewarm costs separate from online transfer latency. diff --git a/docs/en/index.rst b/docs/en/index.rst index f3401ce3a7..4eff4a4764 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -84,6 +84,7 @@ Start by Use Case advanced/pd-disaggregation.md advanced/external-rollout-engines.md advanced/delta-weight-sync.md + advanced/mooncake-dataproto-transfer.md advanced/sglang-config.md advanced/megatron-config.md advanced/arch-support-beyond-megatron.md diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ff571101af..078e5a15ca 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -850,6 +850,12 @@ def _split_train_data_by_dp(self, data): rollout_indices=data["rollout_ids"], ) + if getattr(self.args, "transfer_backend", "ray") == "mooncake_dataproto": + from slime.utils.rollout_dataproto import split_rollout_data_by_dp_dataproto + + dynamic_global_batch_size = getattr(self, "_dynamic_global_batch_size", None) + return split_rollout_data_by_dp_dataproto(self.args, data, dp_size, partitions, dynamic_global_batch_size) + # Package per-rank rollout_data rollout_data_refs = [] for r in range(dp_size): diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d5cac9d44b..8ff28eded1 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -402,6 +402,24 @@ def add_rollout_arguments(parser): "This is used to shuffle the prompts and also for the random sampling of the prompts." ), ) + parser.add_argument( + "--transfer-backend", + choices=["ray", "mooncake_dataproto"], + default="ray", + help="Rollout data transfer backend. Keep ray as the default; mooncake_dataproto is experimental.", + ) + parser.add_argument( + "--mooncake-dataproto-hard-pin", + action=argparse.BooleanOptionalAction, + default=True, + help="Hard-pin Mooncake rollout tensors to the producer segment for mooncake_dataproto transfer.", + ) + parser.add_argument( + "--mooncake-dataproto-store-init-kwargs", + type=json.loads, + default=None, + help="JSON kwargs used to initialize MooncakeDistributedStore for mooncake_dataproto transfer.", + ) # sampling parser.add_argument( @@ -1748,6 +1766,20 @@ def _validate_update_weight_args(args) -> None: def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto": + from slime.utils.remote_batch import normalize_store_init_kwargs + + args.mooncake_dataproto_store_init_kwargs = normalize_store_init_kwargs( + args.mooncake_dataproto_store_init_kwargs + ) + + if args.use_slime_router: + logger.warning( + "--use-slime-router is deprecated and ignored. slime now always uses sglang_router " + "built from https://github.com/zhuzilin/sgl-router." + ) + args.use_slime_router = False + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") diff --git a/slime/utils/data.py b/slime/utils/data.py index 0d26b6dda5..8cfb622401 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -4,7 +4,6 @@ import os import random import re - import numpy as np import ray @@ -291,7 +290,14 @@ def __len__(self): def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): assert len(rollout_data_ref) == dp_size - rollout_data = ray.get(rollout_data_ref[dp_rank].inner) + if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto": + from slime.utils.rollout_dataproto import DataProto, dataproto_to_rollout_data + + proto = rollout_data_ref[dp_rank] + assert isinstance(proto, DataProto), f"expected DataProto, got {type(proto)}" + rollout_data = dataproto_to_rollout_data(proto, preserve_remote_tensors=True) + else: + rollout_data = ray.get(rollout_data_ref[dp_rank].inner) partition = rollout_data.pop("partition") total_lengths = rollout_data["total_lengths"] diff --git a/slime/utils/remote_batch.py b/slime/utils/remote_batch.py new file mode 100644 index 0000000000..77d70a05ed --- /dev/null +++ b/slime/utils/remote_batch.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import ctypes +import os +import re +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +import torch +from tensordict._td import TensorDict + +ALLOWED_SETUP_METHODS = {"setup", "setup_dummy"} +_STORE_CACHE: dict[tuple[tuple[str, str], ...], Any] = {} +_FIELD_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]{1,128}$") + + +def normalize_store_init_kwargs(store_init_kwargs: dict[str, Any] | None) -> dict[str, Any]: + if store_init_kwargs is None: + raise ValueError("mooncake_dataproto requires --mooncake-dataproto-store-init-kwargs") + if not store_init_kwargs: + return {"setup_method": "setup"} + setup_method = store_init_kwargs.get("setup_method", "setup") + if setup_method not in ALLOWED_SETUP_METHODS: + raise ValueError(f"unsupported Mooncake store setup_method {setup_method!r}; allowed: {sorted(ALLOWED_SETUP_METHODS)}") + return dict(store_init_kwargs) + + +def create_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> Any: + kwargs = normalize_store_init_kwargs(store_init_kwargs or {}) + setup_method = kwargs.get("setup_method", "setup") + if setup_method == "setup_dummy": + try: + from mooncake.structured_object_store import InMemoryMooncakeStore + except ImportError: + pass + else: + return InMemoryMooncakeStore() + + from mooncake.store import MooncakeDistributedStore # type: ignore + + store = MooncakeDistributedStore() + setup_kwargs = {key: val for key, val in kwargs.items() if key != "setup_method"} + setup = getattr(store, setup_method) + try: + ret = setup(**setup_kwargs) + except TypeError: + if setup_method != "setup": + raise + ret = setup(_env_store_config() | setup_kwargs) + if ret != 0: + raise RuntimeError(f"Mooncake store {setup_method} failed with retcode {ret}") + return store + + +def get_cached_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> Any: + kwargs = normalize_store_init_kwargs(store_init_kwargs) + cache_key = tuple(sorted((key, repr(val)) for key, val in kwargs.items())) + if cache_key not in _STORE_CACHE: + _STORE_CACHE[cache_key] = create_mooncake_store(kwargs) + return _STORE_CACHE[cache_key] + + +def remove_mooncake_keys(store: Any, keys: list[str]) -> None: + errors = [] + for key in sorted(set(keys)): + ret = store.remove(key, True) + if ret != 0: + errors.append((key, ret)) + if errors: + raise RuntimeError(f"Mooncake key cleanup failed: {errors}") + + +def _env_store_config() -> dict[str, Any]: + return { + "local_hostname": os.getenv("MOONCAKE_LOCAL_HOSTNAME", "localhost"), + "metadata_server": os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"), + "global_segment_size": int(os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", str(16 * 1024 * 1024 * 1024))), + "local_buffer_size": int(os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", str(16 * 1024 * 1024 * 1024))), + "protocol": os.getenv("MOONCAKE_PROTOCOL", "rdma"), + "rdma_devices": os.getenv("MOONCAKE_DEVICE", ""), + "master_server_addr": os.getenv("MOONCAKE_MASTER", "127.0.0.1:50051"), + } + + +def _import_mooncake_helpers(): + try: + from mooncake.structured_object_store import ( # type: ignore + RemoteTensorBatch, + TensorFieldRef, + normalize_dtype_name, + payload_to_buffer, + ) + except ImportError as exc: + raise ImportError("Mooncake structured object helpers are required for mooncake_dataproto transfer") from exc + return RemoteTensorBatch, TensorFieldRef, normalize_dtype_name, payload_to_buffer + + +@dataclass +class MooncakeRemoteBatch: + remote: Any + store_init_kwargs: dict[str, Any] = field(default_factory=dict) + keys_to_cleanup: tuple[str, ...] = () + use_reusable_buffer: bool = True + + @classmethod + def from_tensors( + cls, + tensors: dict[str, torch.Tensor], + store: Any, + prefix: str, + store_init_kwargs: dict[str, Any] | None = None, + use_hard_pin: bool = True, + use_reusable_buffer: bool = True, + ) -> MooncakeRemoteBatch: + _validate_prefix(prefix) + for name in tensors: + _validate_field_name(name) + RemoteTensorBatch, TensorFieldRef, normalize_dtype_name, payload_to_buffer = _import_mooncake_helpers() + fields = {} + batch_size = None + config = _hard_pin_config(store) if use_hard_pin else None + written_keys = [] + try: + for name, tensor in tensors.items(): + cpu_tensor = tensor.detach().contiguous().cpu() + if batch_size is None: + batch_size = int(cpu_tensor.shape[0]) + elif int(cpu_tensor.shape[0]) != batch_size: + raise ValueError(f"tensor {name} batch size {cpu_tensor.shape[0]} != {batch_size}") + key = f"{prefix}/{name}" + buffer, owner, _ = payload_to_buffer(cpu_tensor) + ret = _pub_tensor_from(store, key, buffer, config) + if ret != 0: + raise RuntimeError(f"Mooncake put failed for {key} with retcode {ret}") + written_keys.append(key) + fields[name] = TensorFieldRef( + key=key, + shape=tuple(cpu_tensor.shape), + dtype=normalize_dtype_name(cpu_tensor.dtype), + data_offset=0, + ) + del owner + except Exception: + remove_mooncake_keys(store, written_keys) + raise + if batch_size is None: + raise ValueError("MooncakeRemoteBatch.from_tensors requires non-empty tensors") + return cls( + remote=RemoteTensorBatch(fields=fields, batch_size=batch_size), + store_init_kwargs=store_init_kwargs or {}, + keys_to_cleanup=tuple(written_keys), + use_reusable_buffer=use_reusable_buffer, + ) + + def __len__(self) -> int: + return len(self.remote) + + def keys(self) -> list[str]: + return self.remote.keys() + + def materialize(self, fields: list[str] | None = None) -> TensorDict: + store = get_cached_mooncake_store(self.store_init_kwargs) + try: + tensors = _materialize_remote_tensors(store, self.remote, fields, self.use_reusable_buffer) + return TensorDict(source=tensors, batch_size=(len(self),)) + except Exception as exc: + requested = self.remote.keys() if fields is None else fields + raise RuntimeError(f"MooncakeRemoteBatch materialize failed for fields={list(requested)}") from exc + + def cleanup(self) -> None: + if not self.keys_to_cleanup: + return + store = get_cached_mooncake_store(self.store_init_kwargs) + remove_mooncake_keys(store, list(self.keys_to_cleanup)) + + +def _materialize_remote_tensors( + store: Any, + remote: Any, + fields: list[str] | None, + use_reusable_buffer: bool, +) -> dict[str, torch.Tensor]: + if use_reusable_buffer: + return _materialize_remote_tensors_with_pool(store, remote, fields) + return _materialize_remote_tensors_without_pool(store, remote, fields) + + +def _materialize_remote_tensors_without_pool(store: Any, remote: Any, fields: list[str] | None) -> dict[str, torch.Tensor]: + requests = remote.read_requests(fields) + regions = {request.name: _WritableRegion(bytearray(request.output_nbytes())) for request in requests} + try: + for request in requests: + _materialize_request_into_region(store, request, regions[request.name], register=True) + return {request.name: _region_to_tensor(regions[request.name], request) for request in requests} + finally: + for region in regions.values(): + region.close() + + +def _materialize_remote_tensors_with_pool(store: Any, remote: Any, fields: list[str] | None) -> dict[str, torch.Tensor]: + from mooncake.structured_object_store import get_registered_buffer_pool + + pool = get_registered_buffer_pool(store) + requests = remote.read_requests(fields) + leases = {request.name: pool.buffer(request.output_nbytes()) for request in requests} + try: + for request in requests: + _materialize_request_into_region(store, request, leases[request.name], register=False) + return {request.name: _region_to_tensor(_lease_region(leases[request.name]), request) for request in requests} + finally: + for lease in leases.values(): + lease.release() + + +def _materialize_request_into_region(store: Any, request: Any, region: Any, register: bool) -> None: + from mooncake.structured_object_store import normalize_dtype_name + + required_size = request.output_nbytes() + if required_size == 0: + return + if register: + register_ret = store.register_buffer(region.ptr, region.size) + if register_ret != 0: + raise RuntimeError(f"register_buffer failed with retcode {register_ret}") + try: + if hasattr(store, "get_tensor_dim_selection_into"): + ret = store.get_tensor_dim_selection_into( + request.ref.key, + region.ptr, + required_size, + list(request.ref.shape), + normalize_dtype_name(request.ref.dtype), + request.dim, + request.store_selections(), + request.ref.data_offset, + ) + if ret < 0: + raise RuntimeError(f"get_tensor_dim_selection_into failed with retcode {ret}") + elif request.store_selections() or request.ref.data_offset: + raise RuntimeError("store.get_tensor_dim_selection_into is required for selected remote tensors") + else: + ret = store.get_into(request.ref.key, region.ptr, required_size) + if ret != required_size: + raise RuntimeError(f"get_into failed for {request.ref.key}: expected {required_size}, got {ret}") + finally: + if register: + unregister_ret = store.unregister_buffer(region.ptr) + if unregister_ret != 0: + raise RuntimeError(f"unregister_buffer failed with retcode {unregister_ret}") + + +def _lease_region(lease: Any) -> Any: + return lease.view_region() if hasattr(lease, "view_region") else lease + + +def _region_to_tensor(region: Any, request: Any) -> torch.Tensor: + dtype_name = str(request.ref.dtype).removeprefix("torch.").lower() + if not hasattr(torch, dtype_name): + raise ValueError(f"unsupported Mooncake tensor dtype: {request.ref.dtype!r}") + torch_dtype = getattr(torch, dtype_name) + shape = request.output_shape() + count = int(np.prod(shape, dtype=np.int64)) + return torch.frombuffer(region.buffer, dtype=torch_dtype, count=count).reshape(shape).clone() + + +def _validate_prefix(prefix: str) -> None: + if not prefix or len(prefix) > 256 or ".." in prefix or any(ord(ch) < 32 for ch in prefix): + raise ValueError(f"invalid Mooncake key prefix: {prefix!r}") + + +def _validate_field_name(name: str) -> None: + if _FIELD_NAME_RE.fullmatch(name) is None: + raise ValueError(f"invalid Mooncake tensor field name: {name!r}") + + +def _pub_tensor_from(store: Any, key: str, buffer: memoryview, config: Any) -> int: + region = _WritableRegion(buffer) + try: + register_ret = store.register_buffer(region.ptr, region.size) + if register_ret != 0: + raise RuntimeError(f"register_buffer failed for Mooncake put_from key={key} retcode={register_ret}") + try: + return store.put_from(key=key, buffer_ptr=region.ptr, size=region.size, config=config) + finally: + unregister_ret = store.unregister_buffer(region.ptr) + if unregister_ret != 0: + raise RuntimeError(f"unregister_buffer failed for Mooncake put_from key={key} retcode={unregister_ret}") + finally: + region.close() + + +class _WritableRegion: + def __init__(self, buffer: Any) -> None: + self.buffer = buffer + self.view = memoryview(buffer) + if self.view.readonly: + self.view.release() + raise ValueError("buffer must be writable") + if self.view.format != "B": + cast_view = self.view.cast("B") + self.view.release() + self.view = cast_view + self.c_buffer = (ctypes.c_ubyte * self.view.nbytes).from_buffer(self.view) + self.ptr = ctypes.addressof(self.c_buffer) + self.size = self.view.nbytes + + def close(self) -> None: + self.c_buffer = None + self.view.release() + + +def _hard_pin_config(store: Any) -> Any: + try: + from mooncake.store import ReplicateConfig # type: ignore + except ImportError as exc: + raise ImportError("Mooncake ReplicateConfig is required for hard-pin transfer") from exc + config = ReplicateConfig() + config.preferred_segments = [store.get_hostname()] + config.with_hard_pin = True + return config diff --git a/slime/utils/rollout_dataproto.py b/slime/utils/rollout_dataproto.py new file mode 100644 index 0000000000..28721569b5 --- /dev/null +++ b/slime/utils/rollout_dataproto.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import io +import uuid +from dataclasses import dataclass, field +from typing import Any, Protocol + +import numpy as np +import torch +from tensordict._td import TensorDict +from torch.nn.utils.rnn import pad_sequence + +from slime.utils.remote_batch import ( + MooncakeRemoteBatch, + get_cached_mooncake_store, + normalize_store_init_kwargs, + remove_mooncake_keys, +) + +REMOTE_TENSOR_KEYS = ("tokens", "loss_masks") +PARTITIONED_KEYS = ( + "tokens", + "multimodal_train_inputs", + "response_lengths", + "rewards", + "truncated", + "loss_masks", + "round_number", + "sample_indices", + "rollout_log_probs", + "rollout_routed_experts", + "prompt", + "teacher_log_probs", +) +GLOBAL_KEYS = ("raw_reward", "total_lengths") + + +class RemoteBatch(Protocol): + def __len__(self) -> int: ... + + def keys(self) -> list[str]: ... + + def materialize(self, fields: list[str] | None = None) -> TensorDict: ... + + +@dataclass +class DataProto: + batch: TensorDict | None = None + non_tensor_batch: dict[str, np.ndarray] = field(default_factory=dict) + meta_info: dict = field(default_factory=dict) + remote_batch: RemoteBatch | None = None + + def __post_init__(self): + self.check_consistency() + + def __len__(self): + if self.batch is not None: + return self.batch.batch_size[0] + if self.non_tensor_batch: + return len(next(iter(self.non_tensor_batch.values()))) + if self.remote_batch is not None: + return len(self.remote_batch) + return 0 + + def __getstate__(self): + buffer = io.BytesIO() + batch = self.batch.contiguous().consolidate() if self.batch is not None else None + torch.save(batch, buffer) + return buffer.getvalue(), self.non_tensor_batch, self.meta_info, self.remote_batch + + def __setstate__(self, data): + batch_bytes, self.non_tensor_batch, self.meta_info, self.remote_batch = data + self.batch = torch.load(io.BytesIO(batch_bytes), weights_only=False, map_location="cpu") + + def check_consistency(self): + if self.batch is not None: + assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" + if self.non_tensor_batch: + batch_size = len(self) + for key, val in self.non_tensor_batch.items(): + assert isinstance(val, np.ndarray), f"non_tensor_batch[{key}] must be np.ndarray, got {type(val)}" + assert val.shape[0] == batch_size, f"key {key} length {val.shape[0]} != batch size {batch_size}" + if self.batch is not None and self.remote_batch is not None: + assert len(self.batch) == len(self.remote_batch), "local and remote batch sizes must match" + if self.non_tensor_batch and self.remote_batch is not None: + assert len(next(iter(self.non_tensor_batch.values()))) == len(self.remote_batch) + + @classmethod + def from_dict( + cls, + tensors: dict[str, torch.Tensor] | None = None, + non_tensors: dict[str, Any] | None = None, + meta_info: dict | None = None, + num_batch_dims: int = 1, + ): + assert num_batch_dims > 0, "num_batch_dims must be greater than zero" + if non_tensors is not None: + assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None" + tensors = tensors or {} + non_tensors = non_tensors or {} + meta_info = meta_info or {} + + batch_size = None + pivot_key = None + for key, tensor in tensors.items(): + current_batch = tuple(tensor.shape[:num_batch_dims]) + if batch_size is None: + batch_size = current_batch + pivot_key = key + else: + assert current_batch == batch_size, ( + f"Not all tensors have the same batch size. {pivot_key} has {batch_size}, " + f"{key} has {current_batch}" + ) + + normalized_non_tensors = {} + for key, val in non_tensors.items(): + if not isinstance(val, np.ndarray): + val = np.array(val, dtype=object) + normalized_non_tensors[key] = val + if batch_size is None: + batch_size = (val.shape[0],) + else: + assert val.shape[0] == batch_size[0], ( + f"non_tensor {key} length {val.shape[0]} != batch size {batch_size[0]}" + ) + + tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None + return cls(batch=tensor_dict, non_tensor_batch=normalized_non_tensors, meta_info=meta_info) + + @classmethod + def from_remote( + cls, + remote_batch: RemoteBatch, + batch: TensorDict | None = None, + non_tensors: dict[str, np.ndarray] | None = None, + meta_info: dict | None = None, + ): + return cls( + batch=batch, + non_tensor_batch=non_tensors or {}, + meta_info=meta_info or {}, + remote_batch=remote_batch, + ) + + def materialize_remote_batch(self): + if self.remote_batch is None: + return self + fetched = self.remote_batch.materialize() + if self.batch is not None: + assert self.batch.batch_size == fetched.batch_size, ( + f"TensorDict batch size mismatch: {self.batch.batch_size} != {fetched.batch_size}" + ) + for key, val in fetched.items(): + assert key not in self.batch.keys() or self.batch[key].equal(val), ( + f"{key} exists in both TensorDicts with different values" + ) + self.batch[key] = val + else: + self.batch = fetched + self.remote_batch = None + self.check_consistency() + return self + + +def split_rollout_data_by_dp_dataproto( + args: Any, + data: dict, + dp_size: int, + partitions: list, + dynamic_global_batch_size: int | None = None, +) -> list[DataProto]: + if len(partitions) != dp_size: + raise ValueError(f"expected {dp_size} partitions, got {len(partitions)}") + return _split_rollout_data_by_dp_remote_batch(args, data, partitions, dynamic_global_batch_size) + + +def _split_rollout_data_by_dp_remote_batch( + args: Any, + data: dict, + partitions: list, + dynamic_global_batch_size: int | None = None, +) -> list[DataProto]: + store_init_kwargs = _store_init_kwargs(args) + store = get_cached_mooncake_store(store_init_kwargs) + transfer_id = uuid.uuid4().hex + refs = [] + try: + for dp_rank, partition in enumerate(partitions): + indices = [int(idx) for idx in partition] + shard = _slice_partitioned_data(data, indices) + shard["partition"] = np.asarray(indices, dtype=np.int64) + meta_info = {key: data[key] for key in GLOBAL_KEYS if key in data} + if dynamic_global_batch_size is not None: + meta_info["dynamic_global_batch_size"] = dynamic_global_batch_size + + remote_tensors, remote_lengths = _extract_remote_tensors(shard) + meta_info.update(remote_lengths) + remote_batch = None + if remote_tensors: + remote_batch = MooncakeRemoteBatch.from_tensors( + remote_tensors, + store, + prefix=f"slime-rollout/{transfer_id}/dp{dp_rank}", + store_init_kwargs=store_init_kwargs, + use_hard_pin=getattr(args, "mooncake_dataproto_hard_pin", True), + ) + _attach_cleanup_info(meta_info, remote_batch, store_init_kwargs) + + try: + proto = ( + DataProto.from_remote(remote_batch, non_tensors=_dict_to_non_tensors(shard), meta_info=meta_info) + if remote_batch is not None + else DataProto.from_dict(non_tensors=shard, meta_info=meta_info) + ) + except Exception: + if remote_batch is not None: + remote_batch.cleanup() + raise + refs.append(proto) + except Exception: + cleanup_dataproto_refs(refs) + raise + return refs + + +def maybe_cleanup_dataproto_refs(args: Any, refs: list[DataProto], suppress_errors: bool = False) -> None: + if getattr(args, "transfer_backend", "ray") != "mooncake_dataproto": + return + if not suppress_errors: + cleanup_dataproto_refs(refs) + return + try: + cleanup_dataproto_refs(refs) + except Exception: + return + + +def cleanup_dataproto_refs(refs: list[DataProto]) -> None: + keys = set() + store_init_kwargs = None + for proto in refs: + keys.update(proto.meta_info.get("mooncake_cleanup_keys", [])) + if store_init_kwargs is None and "mooncake_cleanup_store_kwargs" in proto.meta_info: + store_init_kwargs = dict(proto.meta_info["mooncake_cleanup_store_kwargs"]) + if not keys or store_init_kwargs is None: + return + store = get_cached_mooncake_store(store_init_kwargs) + remove_mooncake_keys(store, sorted(keys)) + + +def _attach_cleanup_info( + meta_info: dict, + remote_batch: MooncakeRemoteBatch, + store_init_kwargs: dict[str, Any], +) -> None: + meta_info["mooncake_cleanup_keys"] = list(remote_batch.keys_to_cleanup) + meta_info["mooncake_cleanup_store_kwargs"] = dict(store_init_kwargs) + + +def dataproto_to_rollout_data(proto: DataProto, preserve_remote_tensors: bool = True) -> dict: + """Materialize a transfer DataProto into slime's legacy rollout dict.""" + if proto.remote_batch is not None: + proto.materialize_remote_batch() + rollout_data = {key: val.tolist() for key, val in proto.non_tensor_batch.items()} + rollout_data.update({key: val for key, val in proto.meta_info.items() if not key.startswith("mooncake_cleanup_")}) + if proto.batch is not None: + for key, tensor in proto.batch.items(): + lengths = proto.meta_info.get(f"{key}_lengths") + if preserve_remote_tensors and key in REMOTE_TENSOR_KEYS: + rollout_data[key] = _tensor_to_row_tensors(tensor, lengths) + else: + rollout_data[key] = _tensor_to_list(tensor, lengths) + return rollout_data + + +def _slice_partitioned_data(data: dict, indices: list[int]) -> dict: + shard = {} + for key in PARTITIONED_KEYS: + if key in data: + shard[key] = [data[key][idx] for idx in indices] + return shard + + +def _extract_remote_tensors(shard: dict) -> tuple[dict[str, torch.Tensor], dict[str, list[int]]]: + tensors = {} + lengths = {} + for key in REMOTE_TENSOR_KEYS: + if key not in shard: + continue + values = shard.pop(key) + tensor, field_lengths = _list_to_padded_tensor(values, torch.long if key == "tokens" else torch.int) + tensors[key] = tensor + lengths[f"{key}_lengths"] = field_lengths + return tensors, lengths + + +def _list_to_padded_tensor(values: list, dtype: torch.dtype) -> tuple[torch.Tensor, list[int]]: + if not values: + return torch.empty((0, 0), dtype=dtype), [] + tensors = [torch.as_tensor(value, dtype=dtype).reshape(-1) for value in values] + lengths = [int(tensor.numel()) for tensor in tensors] + return pad_sequence(tensors, batch_first=True, padding_value=0), lengths + + +def _tensor_to_row_tensors(tensor: torch.Tensor, lengths: list[int] | None = None) -> list[torch.Tensor]: + if tensor.ndim == 2: + if lengths is not None: + return [tensor[idx, : int(length)] for idx, length in zip(range(tensor.shape[0]), lengths, strict=True)] + return [tensor[idx] for idx in range(tensor.shape[0])] + return [tensor[idx] for idx in range(tensor.shape[0])] + + +def _tensor_to_list(tensor: torch.Tensor, lengths: list[int] | None = None) -> list: + if tensor.ndim == 2: + rows = tensor.cpu().tolist() + if lengths is not None: + return [row[: int(length)] for row, length in zip(rows, lengths, strict=True)] + return rows + return tensor.cpu().numpy().tolist() + + +def _dict_to_non_tensors(data: dict) -> dict[str, np.ndarray]: + return {key: val if isinstance(val, np.ndarray) else np.asarray(val, dtype=_infer_numpy_dtype(val)) for key, val in data.items()} + + +def _infer_numpy_dtype(val: Any) -> Any: + if isinstance(val, list) and all(isinstance(item, (bool, int, float, np.number)) for item in val): + return None + return object + + +def _store_init_kwargs(args: Any) -> dict[str, Any]: + kwargs = getattr(args, "mooncake_dataproto_store_init_kwargs", None) + return normalize_store_init_kwargs(kwargs) diff --git a/tests/utils/test_dataproto_transfer.py b/tests/utils/test_dataproto_transfer.py new file mode 100644 index 0000000000..6a8914a2da --- /dev/null +++ b/tests/utils/test_dataproto_transfer.py @@ -0,0 +1,163 @@ +import sys +import types + +import numpy as np +import torch + + +class TensorDict(dict): + def __init__(self, source=None, batch_size=None, device=None): + super().__init__(source or {}) + self.batch_size = torch.Size(batch_size or []) + self.device = device + + def __len__(self): + return self.batch_size[0] if len(self.batch_size) > 0 else dict.__len__(self) + + def clone(self): + return TensorDict({key: val.clone() for key, val in self.items()}, self.batch_size, self.device) + + def select(self, *keys): + return TensorDict({key: self[key] for key in keys}, self.batch_size, self.device) + + +tensordict_module = types.ModuleType("tensordict") +tensordict_td_module = types.ModuleType("tensordict._td") +tensordict_module.TensorDict = TensorDict +tensordict_td_module.TensorDict = TensorDict +sys.modules.setdefault("tensordict", tensordict_module) +sys.modules.setdefault("tensordict._td", tensordict_td_module) + +from slime.utils.remote_batch import MooncakeRemoteBatch, create_mooncake_store, normalize_store_init_kwargs +from slime.utils.rollout_dataproto import DataProto, dataproto_to_rollout_data, split_rollout_data_by_dp_dataproto + + +class FakeRemoteBatch: + def __init__(self, tensors, indices=None): + self.tensors = tensors + self.indices = list(range(len(next(iter(tensors.values()))))) if indices is None else indices + + def __len__(self): + return len(self.indices) + + @property + def batch_size(self): + return torch.Size([len(self.indices)]) + + def keys(self): + return list(self.tensors.keys()) + + def materialize(self, fields=None): + selected = self.keys() if fields is None else fields + return TensorDict({key: self.tensors[key][self.indices] for key in selected}, batch_size=(len(self),)) + + +def test_dataproto_remote_materialize(): + remote = FakeRemoteBatch({"tokens": torch.arange(12).reshape(4, 3), "loss_masks": torch.ones(4, 3, dtype=torch.int32)}) + proto = DataProto.from_remote(remote, non_tensors={"response_lengths": np.asarray([1, 2, 3, 4])}) + + proto.materialize_remote_batch() + + assert proto.remote_batch is None + assert proto.batch["tokens"].tolist() == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + assert proto.non_tensor_batch["response_lengths"].tolist() == [1, 2, 3, 4] + + +def test_dataproto_to_rollout_data_preserves_remote_tensor_rows(): + remote = FakeRemoteBatch( + { + "tokens": torch.tensor([[1, 2, 0], [3, 4, 5]]), + "loss_masks": torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int), + } + ) + proto = DataProto.from_remote( + remote, + non_tensors={"partition": np.asarray([0, 1])}, + meta_info={ + "total_lengths": [2, 3], + "tokens_lengths": [2, 3], + "loss_masks_lengths": [2, 3], + }, + ) + + rollout_data = dataproto_to_rollout_data(proto) + + assert all(isinstance(row, torch.Tensor) for row in rollout_data["tokens"]) + assert all(isinstance(row, torch.Tensor) for row in rollout_data["loss_masks"]) + assert [row.tolist() for row in rollout_data["tokens"]] == [[1, 2], [3, 4, 5]] + assert [row.tolist() for row in rollout_data["loss_masks"]] == [[1, 1], [1, 1, 1]] + assert "_remote_tensor_owners" not in rollout_data + assert rollout_data["partition"] == [0, 1] + assert rollout_data["total_lengths"] == [2, 3] + + +def test_dataproto_to_rollout_data_legacy_tensor_list_fallback(): + remote = FakeRemoteBatch({"tokens": torch.tensor([[1, 2, 0], [3, 4, 5]])}) + proto = DataProto.from_remote( + remote, + non_tensors={"partition": np.asarray([0, 1])}, + meta_info={"total_lengths": [2, 3], "tokens_lengths": [2, 3]}, + ) + + rollout_data = dataproto_to_rollout_data(proto, preserve_remote_tensors=False) + + assert rollout_data["tokens"] == [[1, 2], [3, 4, 5]] + assert "_remote_tensor_owners" not in rollout_data + + +def test_dataproto_to_rollout_data_keeps_non_remote_tensors_legacy(): + proto = DataProto.from_dict(tensors={"other": torch.tensor([[1, 2], [3, 4]])}) + + rollout_data = dataproto_to_rollout_data(proto, preserve_remote_tensors=True) + + assert rollout_data["other"] == [[1, 2], [3, 4]] + + +def test_rollout_transfer_rejects_partition_mismatch(): + args = types.SimpleNamespace(mooncake_dataproto_store_init_kwargs={"setup_method": "setup_dummy"}) + try: + split_rollout_data_by_dp_dataproto(args, {}, 2, [[]]) + except ValueError as exc: + assert "expected 2 partitions" in str(exc) + else: + raise AssertionError("partition mismatch should be rejected") + + +def test_normalizes_empty_mooncake_setup_kwargs_to_setup(): + assert normalize_store_init_kwargs({}) == {"setup_method": "setup"} + + +def test_rejects_unsafe_mooncake_setup_method(): + try: + normalize_store_init_kwargs({"setup_method": "remove"}) + except ValueError as exc: + assert "unsupported Mooncake store setup_method" in str(exc) + else: + raise AssertionError("unsafe setup_method should be rejected") + + +def test_create_store_normalizes_none_for_default_call(monkeypatch): + class Store: + def setup(self): + return 0 + + mooncake_module = types.ModuleType("mooncake") + store_module = types.ModuleType("mooncake.store") + store_module.MooncakeDistributedStore = Store + monkeypatch.setitem(sys.modules, "mooncake", mooncake_module) + monkeypatch.setitem(sys.modules, "mooncake.store", store_module) + + assert isinstance(create_mooncake_store(), Store) + + +def test_rejects_invalid_remote_field_name(): + class Store: + def get_hostname(self): + return "localhost" + + try: + MooncakeRemoteBatch.from_tensors({"../tokens": torch.ones(1, 1, dtype=torch.int64)}, Store(), "prefix") + except ValueError as exc: + assert "invalid Mooncake tensor field name" in str(exc) + else: + raise AssertionError("invalid tensor field name should be rejected") diff --git a/train.py b/train.py index 620f7e8d70..02392d0fe5 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,7 @@ import ray +from slime.utils.rollout_dataproto import maybe_cleanup_dataproto_refs + from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from slime.utils.arguments import parse_args from slime.utils.logging_utils import configure_logger, finish_tracking, init_tracking @@ -71,14 +73,19 @@ def save(rollout_id): actor_trains_this_step = (not args.use_critic) or rollout_id >= args.num_critic_only_steps - if args.use_critic: - value_refs = critic_model.async_train(rollout_id, rollout_data_ref) - if actor_trains_this_step: - ray.get(actor_model.async_train(rollout_id, rollout_data_ref, external_data=value_refs)) + train_succeeded = False + try: + if args.use_critic: + value_refs = critic_model.async_train(rollout_id, rollout_data_ref) + if actor_trains_this_step: + ray.get(actor_model.async_train(rollout_id, rollout_data_ref, external_data=value_refs)) + else: + ray.get(value_refs) else: - ray.get(value_refs) - else: - ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) + ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) + train_succeeded = True + finally: + maybe_cleanup_dataproto_refs(args, rollout_data_ref, suppress_errors=not train_succeeded) if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): save(rollout_id) diff --git a/train_async.py b/train_async.py index 9d4c9b6473..b29a7c1741 100644 --- a/train_async.py +++ b/train_async.py @@ -1,5 +1,7 @@ import ray +from slime.utils.rollout_dataproto import maybe_cleanup_dataproto_refs + from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from slime.utils.arguments import parse_args from slime.utils.logging_utils import configure_logger, finish_tracking, init_tracking @@ -38,15 +40,21 @@ def train(args): if rollout_id + 1 < args.num_rollout: rollout_data_next_future = rollout_manager.generate.remote(rollout_id + 1) - if args.use_critic: - actor_trains_this_step = rollout_id >= args.num_critic_only_steps - value_refs = critic_model.async_train(rollout_id, rollout_data_curr_ref) - if actor_trains_this_step: - ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref, external_data=value_refs)) + train_succeeded = False + try: + if args.use_critic: + actor_trains_this_step = rollout_id >= args.num_critic_only_steps + value_refs = critic_model.async_train(rollout_id, rollout_data_curr_ref) + if actor_trains_this_step: + ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref, external_data=value_refs)) + else: + ray.get(value_refs) else: - ray.get(value_refs) - else: - ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) + actor_trains_this_step = True + ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) + train_succeeded = True + finally: + maybe_cleanup_dataproto_refs(args, rollout_data_curr_ref, suppress_errors=not train_succeeded) if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): if (not args.use_critic) or rollout_id >= args.num_critic_only_steps: From 5e4072f166ea61d661f06c47b6f5aafb59ad0f2e Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Thu, 25 Jun 2026 16:57:10 +0800 Subject: [PATCH 2/4] Align Mooncake rollout transfer with structured DataProto API Route remote rollout batches through MooncakeBundleTransfer put/get/cleanup DataProto helpers so slime matches the refactored PR2050 interface. Co-Authored-By: Claude Opus 4.6 --- slime/utils/remote_batch.py | 269 +++++++++--------------------------- 1 file changed, 69 insertions(+), 200 deletions(-) diff --git a/slime/utils/remote_batch.py b/slime/utils/remote_batch.py index 77d70a05ed..9a230dc41e 100644 --- a/slime/utils/remote_batch.py +++ b/slime/utils/remote_batch.py @@ -1,13 +1,10 @@ from __future__ import annotations -import ctypes import os import re from dataclasses import dataclass, field from typing import Any -import numpy as np - import torch from tensordict._td import TensorDict @@ -31,12 +28,7 @@ def create_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> An kwargs = normalize_store_init_kwargs(store_init_kwargs or {}) setup_method = kwargs.get("setup_method", "setup") if setup_method == "setup_dummy": - try: - from mooncake.structured_object_store import InMemoryMooncakeStore - except ImportError: - pass - else: - return InMemoryMooncakeStore() + return InMemoryMooncakeStore() from mooncake.store import MooncakeDistributedStore # type: ignore @@ -84,185 +76,109 @@ def _env_store_config() -> dict[str, Any]: } -def _import_mooncake_helpers(): +class InMemoryMooncakeStore: + def __init__(self) -> None: + self.objects: dict[str, bytes] = {} + self.tensors: dict[str, torch.Tensor] = {} + + def put(self, key: str, value: Any) -> int: + self.objects[key] = bytes(value) + return 0 + + def get(self, key: str) -> bytes: + return self.objects[key] + + def remove(self, key: str, force: bool = False) -> int: + self.objects.pop(key, None) + self.tensors.pop(key, None) + return 0 + + def put_tensor(self, key: str, tensor: torch.Tensor) -> int: + self.tensors[key] = tensor.detach().cpu().clone() + return 0 + + def get_tensor(self, key: str) -> torch.Tensor: + return self.tensors[key].clone() + + +def _import_mooncake_transfer(): try: - from mooncake.structured_object_store import ( # type: ignore - RemoteTensorBatch, - TensorFieldRef, - normalize_dtype_name, - payload_to_buffer, - ) + from mooncake.structured_object_store import MooncakeBundleTransfer except ImportError as exc: - raise ImportError("Mooncake structured object helpers are required for mooncake_dataproto transfer") from exc - return RemoteTensorBatch, TensorFieldRef, normalize_dtype_name, payload_to_buffer + raise ImportError("Mooncake structured object DataProto helpers are required for mooncake_dataproto transfer") from exc + return MooncakeBundleTransfer @dataclass class MooncakeRemoteBatch: - remote: Any + ref: Any store_init_kwargs: dict[str, Any] = field(default_factory=dict) - keys_to_cleanup: tuple[str, ...] = () - use_reusable_buffer: bool = True + key_prefix: str = "slime-rollout" @classmethod - def from_tensors( + def from_dataproto( cls, - tensors: dict[str, torch.Tensor], + proto: Any, store: Any, prefix: str, store_init_kwargs: dict[str, Any] | None = None, use_hard_pin: bool = True, use_reusable_buffer: bool = True, ) -> MooncakeRemoteBatch: + del use_hard_pin, use_reusable_buffer _validate_prefix(prefix) - for name in tensors: + for name in (proto.batch or {}).keys(): _validate_field_name(name) - RemoteTensorBatch, TensorFieldRef, normalize_dtype_name, payload_to_buffer = _import_mooncake_helpers() - fields = {} - batch_size = None - config = _hard_pin_config(store) if use_hard_pin else None - written_keys = [] - try: - for name, tensor in tensors.items(): - cpu_tensor = tensor.detach().contiguous().cpu() - if batch_size is None: - batch_size = int(cpu_tensor.shape[0]) - elif int(cpu_tensor.shape[0]) != batch_size: - raise ValueError(f"tensor {name} batch size {cpu_tensor.shape[0]} != {batch_size}") - key = f"{prefix}/{name}" - buffer, owner, _ = payload_to_buffer(cpu_tensor) - ret = _pub_tensor_from(store, key, buffer, config) - if ret != 0: - raise RuntimeError(f"Mooncake put failed for {key} with retcode {ret}") - written_keys.append(key) - fields[name] = TensorFieldRef( - key=key, - shape=tuple(cpu_tensor.shape), - dtype=normalize_dtype_name(cpu_tensor.dtype), - data_offset=0, - ) - del owner - except Exception: - remove_mooncake_keys(store, written_keys) - raise - if batch_size is None: - raise ValueError("MooncakeRemoteBatch.from_tensors requires non-empty tensors") - return cls( - remote=RemoteTensorBatch(fields=fields, batch_size=batch_size), - store_init_kwargs=store_init_kwargs or {}, - keys_to_cleanup=tuple(written_keys), + transfer = _import_mooncake_transfer()(store, key_prefix=prefix) + ref = transfer.put_dataproto(proto, namespace="slime", partition="rollout", stage="batch") + return cls(ref=ref, store_init_kwargs=store_init_kwargs or {}, key_prefix=prefix) + + @classmethod + def from_tensors( + cls, + tensors: dict[str, torch.Tensor], + store: Any, + prefix: str, + store_init_kwargs: dict[str, Any] | None = None, + use_hard_pin: bool = True, + use_reusable_buffer: bool = True, + ) -> MooncakeRemoteBatch: + from slime.utils.rollout_dataproto import DataProto + + proto = DataProto.from_dict(tensors={key: tensor.detach().cpu().contiguous() for key, tensor in tensors.items()}) + return cls.from_dataproto( + proto, + store, + prefix, + store_init_kwargs=store_init_kwargs, + use_hard_pin=use_hard_pin, use_reusable_buffer=use_reusable_buffer, ) def __len__(self) -> int: - return len(self.remote) + return int(self.ref.batch_size) def keys(self) -> list[str]: - return self.remote.keys() + return [name for name, location in self.ref.field_index.items() if location.section == "batch"] def materialize(self, fields: list[str] | None = None) -> TensorDict: store = get_cached_mooncake_store(self.store_init_kwargs) + transfer = _import_mooncake_transfer()(store, key_prefix=self.key_prefix) try: - tensors = _materialize_remote_tensors(store, self.remote, fields, self.use_reusable_buffer) - return TensorDict(source=tensors, batch_size=(len(self),)) + result = transfer.get_dataproto(self.ref, batch_fields=fields, non_tensor_fields=[]) + return TensorDict(source=result["batch"], batch_size=(len(self),)) except Exception as exc: - requested = self.remote.keys() if fields is None else fields + requested = self.keys() if fields is None else fields raise RuntimeError(f"MooncakeRemoteBatch materialize failed for fields={list(requested)}") from exc def cleanup(self) -> None: - if not self.keys_to_cleanup: - return store = get_cached_mooncake_store(self.store_init_kwargs) - remove_mooncake_keys(store, list(self.keys_to_cleanup)) + transfer = _import_mooncake_transfer()(store, key_prefix=self.key_prefix) + transfer.cleanup_dataproto(self.ref) -def _materialize_remote_tensors( - store: Any, - remote: Any, - fields: list[str] | None, - use_reusable_buffer: bool, -) -> dict[str, torch.Tensor]: - if use_reusable_buffer: - return _materialize_remote_tensors_with_pool(store, remote, fields) - return _materialize_remote_tensors_without_pool(store, remote, fields) - - -def _materialize_remote_tensors_without_pool(store: Any, remote: Any, fields: list[str] | None) -> dict[str, torch.Tensor]: - requests = remote.read_requests(fields) - regions = {request.name: _WritableRegion(bytearray(request.output_nbytes())) for request in requests} - try: - for request in requests: - _materialize_request_into_region(store, request, regions[request.name], register=True) - return {request.name: _region_to_tensor(regions[request.name], request) for request in requests} - finally: - for region in regions.values(): - region.close() - - -def _materialize_remote_tensors_with_pool(store: Any, remote: Any, fields: list[str] | None) -> dict[str, torch.Tensor]: - from mooncake.structured_object_store import get_registered_buffer_pool - - pool = get_registered_buffer_pool(store) - requests = remote.read_requests(fields) - leases = {request.name: pool.buffer(request.output_nbytes()) for request in requests} - try: - for request in requests: - _materialize_request_into_region(store, request, leases[request.name], register=False) - return {request.name: _region_to_tensor(_lease_region(leases[request.name]), request) for request in requests} - finally: - for lease in leases.values(): - lease.release() - - -def _materialize_request_into_region(store: Any, request: Any, region: Any, register: bool) -> None: - from mooncake.structured_object_store import normalize_dtype_name - - required_size = request.output_nbytes() - if required_size == 0: - return - if register: - register_ret = store.register_buffer(region.ptr, region.size) - if register_ret != 0: - raise RuntimeError(f"register_buffer failed with retcode {register_ret}") - try: - if hasattr(store, "get_tensor_dim_selection_into"): - ret = store.get_tensor_dim_selection_into( - request.ref.key, - region.ptr, - required_size, - list(request.ref.shape), - normalize_dtype_name(request.ref.dtype), - request.dim, - request.store_selections(), - request.ref.data_offset, - ) - if ret < 0: - raise RuntimeError(f"get_tensor_dim_selection_into failed with retcode {ret}") - elif request.store_selections() or request.ref.data_offset: - raise RuntimeError("store.get_tensor_dim_selection_into is required for selected remote tensors") - else: - ret = store.get_into(request.ref.key, region.ptr, required_size) - if ret != required_size: - raise RuntimeError(f"get_into failed for {request.ref.key}: expected {required_size}, got {ret}") - finally: - if register: - unregister_ret = store.unregister_buffer(region.ptr) - if unregister_ret != 0: - raise RuntimeError(f"unregister_buffer failed with retcode {unregister_ret}") - - -def _lease_region(lease: Any) -> Any: - return lease.view_region() if hasattr(lease, "view_region") else lease - - -def _region_to_tensor(region: Any, request: Any) -> torch.Tensor: - dtype_name = str(request.ref.dtype).removeprefix("torch.").lower() - if not hasattr(torch, dtype_name): - raise ValueError(f"unsupported Mooncake tensor dtype: {request.ref.dtype!r}") - torch_dtype = getattr(torch, dtype_name) - shape = request.output_shape() - count = int(np.prod(shape, dtype=np.int64)) - return torch.frombuffer(region.buffer, dtype=torch_dtype, count=count).reshape(shape).clone() +# Backward-compatible name for callers that conceptually store a DataProto ref. +MooncakeRemoteDataProto = MooncakeRemoteBatch def _validate_prefix(prefix: str) -> None: @@ -273,50 +189,3 @@ def _validate_prefix(prefix: str) -> None: def _validate_field_name(name: str) -> None: if _FIELD_NAME_RE.fullmatch(name) is None: raise ValueError(f"invalid Mooncake tensor field name: {name!r}") - - -def _pub_tensor_from(store: Any, key: str, buffer: memoryview, config: Any) -> int: - region = _WritableRegion(buffer) - try: - register_ret = store.register_buffer(region.ptr, region.size) - if register_ret != 0: - raise RuntimeError(f"register_buffer failed for Mooncake put_from key={key} retcode={register_ret}") - try: - return store.put_from(key=key, buffer_ptr=region.ptr, size=region.size, config=config) - finally: - unregister_ret = store.unregister_buffer(region.ptr) - if unregister_ret != 0: - raise RuntimeError(f"unregister_buffer failed for Mooncake put_from key={key} retcode={unregister_ret}") - finally: - region.close() - - -class _WritableRegion: - def __init__(self, buffer: Any) -> None: - self.buffer = buffer - self.view = memoryview(buffer) - if self.view.readonly: - self.view.release() - raise ValueError("buffer must be writable") - if self.view.format != "B": - cast_view = self.view.cast("B") - self.view.release() - self.view = cast_view - self.c_buffer = (ctypes.c_ubyte * self.view.nbytes).from_buffer(self.view) - self.ptr = ctypes.addressof(self.c_buffer) - self.size = self.view.nbytes - - def close(self) -> None: - self.c_buffer = None - self.view.release() - - -def _hard_pin_config(store: Any) -> Any: - try: - from mooncake.store import ReplicateConfig # type: ignore - except ImportError as exc: - raise ImportError("Mooncake ReplicateConfig is required for hard-pin transfer") from exc - config = ReplicateConfig() - config.preferred_segments = [store.get_hostname()] - config.with_hard_pin = True - return config From a86be28bc15bbf22297edff3d2ea94a2d970b6bf Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Thu, 25 Jun 2026 19:39:14 +0800 Subject: [PATCH 3/4] Align Mooncake rollout transfer with DataProto handles Use Mooncake structured DataProto handles directly for rollout dict transport so slime no longer carries a local DataProto/RemoteBatch wrapper. Co-Authored-By: Claude Opus 4.6 --- slime/ray/rollout.py | 11 +- slime/utils/data.py | 6 +- slime/utils/remote_batch.py | 137 +++----- slime/utils/rollout_dataproto.py | 446 ++++++++++--------------- tests/utils/test_dataproto_transfer.py | 188 +++++------ 5 files changed, 319 insertions(+), 469 deletions(-) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 078e5a15ca..0ea0979882 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -854,7 +854,16 @@ def _split_train_data_by_dp(self, data): from slime.utils.rollout_dataproto import split_rollout_data_by_dp_dataproto dynamic_global_batch_size = getattr(self, "_dynamic_global_batch_size", None) - return split_rollout_data_by_dp_dataproto(self.args, data, dp_size, partitions, dynamic_global_batch_size) + return split_rollout_data_by_dp_dataproto( + self.args, + data, + dp_size, + partitions, + dynamic_global_batch_size, + micro_batch_indices, + num_microbatches, + global_batch_sizes, + ) # Package per-rank rollout_data rollout_data_refs = [] diff --git a/slime/utils/data.py b/slime/utils/data.py index 8cfb622401..248070049a 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -291,11 +291,9 @@ def __len__(self): def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): assert len(rollout_data_ref) == dp_size if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto": - from slime.utils.rollout_dataproto import DataProto, dataproto_to_rollout_data + from slime.utils.rollout_dataproto import materialize_dataproto_rollout_data - proto = rollout_data_ref[dp_rank] - assert isinstance(proto, DataProto), f"expected DataProto, got {type(proto)}" - rollout_data = dataproto_to_rollout_data(proto, preserve_remote_tensors=True) + rollout_data = materialize_dataproto_rollout_data(args, rollout_data_ref[dp_rank]) else: rollout_data = ray.get(rollout_data_ref[dp_rank].inner) diff --git a/slime/utils/remote_batch.py b/slime/utils/remote_batch.py index 9a230dc41e..bafdc7cc2c 100644 --- a/slime/utils/remote_batch.py +++ b/slime/utils/remote_batch.py @@ -1,16 +1,12 @@ from __future__ import annotations import os -import re -from dataclasses import dataclass, field -from typing import Any +from typing import Any, Mapping import torch -from tensordict._td import TensorDict ALLOWED_SETUP_METHODS = {"setup", "setup_dummy"} _STORE_CACHE: dict[tuple[tuple[str, str], ...], Any] = {} -_FIELD_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]{1,128}$") def normalize_store_init_kwargs(store_init_kwargs: dict[str, Any] | None) -> dict[str, Any]: @@ -54,14 +50,37 @@ def get_cached_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) - return _STORE_CACHE[cache_key] -def remove_mooncake_keys(store: Any, keys: list[str]) -> None: - errors = [] - for key in sorted(set(keys)): - ret = store.remove(key, True) - if ret != 0: - errors.append((key, ret)) - if errors: - raise RuntimeError(f"Mooncake key cleanup failed: {errors}") +def put_mooncake_dataproto( + data: Mapping[str, Any], + store: Any, + *, + key_prefix: str, + namespace: str = "slime", + partition: str = "rollout", +) -> dict[str, Any]: + transfer_cls, export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers() + transfer = transfer_cls(store, key_prefix=key_prefix) + ref = transfer.put_dataproto(data, namespace=namespace, partition=partition, stage="rollout") + handle = export_ref(ref) + handle["slime_key_prefix"] = key_prefix + return handle + + +def get_mooncake_dataproto(handle: Mapping[str, Any], store: Any) -> dict[str, Any]: + transfer_cls, _export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers() + transfer = transfer_cls(store, key_prefix=handle.get("slime_key_prefix", "")) + return transfer.get_dataproto(handle) + + +def cleanup_mooncake_dataproto(handle: Mapping[str, Any], store: Any) -> None: + transfer_cls, _export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers() + transfer = transfer_cls(store, key_prefix=handle.get("slime_key_prefix", "")) + transfer.cleanup_dataproto(handle) + + +def is_mooncake_dataproto_handle(value: Any) -> bool: + _transfer_cls, _export_ref, is_ref_handle = _import_mooncake_dataproto_helpers() + return is_ref_handle(value) def _env_store_config() -> dict[str, Any]: @@ -101,91 +120,13 @@ def get_tensor(self, key: str) -> torch.Tensor: return self.tensors[key].clone() -def _import_mooncake_transfer(): +def _import_mooncake_dataproto_helpers(): try: - from mooncake.structured_object_store import MooncakeBundleTransfer + from mooncake.structured_object_store import ( + MooncakeBundleTransfer, + export_dataproto_ref, + is_dataproto_ref_handle, + ) except ImportError as exc: raise ImportError("Mooncake structured object DataProto helpers are required for mooncake_dataproto transfer") from exc - return MooncakeBundleTransfer - - -@dataclass -class MooncakeRemoteBatch: - ref: Any - store_init_kwargs: dict[str, Any] = field(default_factory=dict) - key_prefix: str = "slime-rollout" - - @classmethod - def from_dataproto( - cls, - proto: Any, - store: Any, - prefix: str, - store_init_kwargs: dict[str, Any] | None = None, - use_hard_pin: bool = True, - use_reusable_buffer: bool = True, - ) -> MooncakeRemoteBatch: - del use_hard_pin, use_reusable_buffer - _validate_prefix(prefix) - for name in (proto.batch or {}).keys(): - _validate_field_name(name) - transfer = _import_mooncake_transfer()(store, key_prefix=prefix) - ref = transfer.put_dataproto(proto, namespace="slime", partition="rollout", stage="batch") - return cls(ref=ref, store_init_kwargs=store_init_kwargs or {}, key_prefix=prefix) - - @classmethod - def from_tensors( - cls, - tensors: dict[str, torch.Tensor], - store: Any, - prefix: str, - store_init_kwargs: dict[str, Any] | None = None, - use_hard_pin: bool = True, - use_reusable_buffer: bool = True, - ) -> MooncakeRemoteBatch: - from slime.utils.rollout_dataproto import DataProto - - proto = DataProto.from_dict(tensors={key: tensor.detach().cpu().contiguous() for key, tensor in tensors.items()}) - return cls.from_dataproto( - proto, - store, - prefix, - store_init_kwargs=store_init_kwargs, - use_hard_pin=use_hard_pin, - use_reusable_buffer=use_reusable_buffer, - ) - - def __len__(self) -> int: - return int(self.ref.batch_size) - - def keys(self) -> list[str]: - return [name for name, location in self.ref.field_index.items() if location.section == "batch"] - - def materialize(self, fields: list[str] | None = None) -> TensorDict: - store = get_cached_mooncake_store(self.store_init_kwargs) - transfer = _import_mooncake_transfer()(store, key_prefix=self.key_prefix) - try: - result = transfer.get_dataproto(self.ref, batch_fields=fields, non_tensor_fields=[]) - return TensorDict(source=result["batch"], batch_size=(len(self),)) - except Exception as exc: - requested = self.keys() if fields is None else fields - raise RuntimeError(f"MooncakeRemoteBatch materialize failed for fields={list(requested)}") from exc - - def cleanup(self) -> None: - store = get_cached_mooncake_store(self.store_init_kwargs) - transfer = _import_mooncake_transfer()(store, key_prefix=self.key_prefix) - transfer.cleanup_dataproto(self.ref) - - -# Backward-compatible name for callers that conceptually store a DataProto ref. -MooncakeRemoteDataProto = MooncakeRemoteBatch - - -def _validate_prefix(prefix: str) -> None: - if not prefix or len(prefix) > 256 or ".." in prefix or any(ord(ch) < 32 for ch in prefix): - raise ValueError(f"invalid Mooncake key prefix: {prefix!r}") - - -def _validate_field_name(name: str) -> None: - if _FIELD_NAME_RE.fullmatch(name) is None: - raise ValueError(f"invalid Mooncake tensor field name: {name!r}") + return MooncakeBundleTransfer, export_dataproto_ref, is_dataproto_ref_handle diff --git a/slime/utils/rollout_dataproto.py b/slime/utils/rollout_dataproto.py index 28721569b5..c5c258afe7 100644 --- a/slime/utils/rollout_dataproto.py +++ b/slime/utils/rollout_dataproto.py @@ -1,23 +1,20 @@ from __future__ import annotations -import io import uuid -from dataclasses import dataclass, field -from typing import Any, Protocol +from typing import Any, Mapping import numpy as np import torch -from tensordict._td import TensorDict -from torch.nn.utils.rnn import pad_sequence from slime.utils.remote_batch import ( - MooncakeRemoteBatch, + cleanup_mooncake_dataproto, get_cached_mooncake_store, + get_mooncake_dataproto, + is_mooncake_dataproto_handle, normalize_store_init_kwargs, - remove_mooncake_keys, + put_mooncake_dataproto, ) -REMOTE_TENSOR_KEYS = ("tokens", "loss_masks") PARTITIONED_KEYS = ( "tokens", "multimodal_train_inputs", @@ -27,140 +24,32 @@ "loss_masks", "round_number", "sample_indices", + "rollout_ids", + "rollout_mask_sums", "rollout_log_probs", + "rollout_top_p_token_ids", + "rollout_top_p_token_offsets", "rollout_routed_experts", "prompt", "teacher_log_probs", ) -GLOBAL_KEYS = ("raw_reward", "total_lengths") - - -class RemoteBatch(Protocol): - def __len__(self) -> int: ... - - def keys(self) -> list[str]: ... - - def materialize(self, fields: list[str] | None = None) -> TensorDict: ... - - -@dataclass -class DataProto: - batch: TensorDict | None = None - non_tensor_batch: dict[str, np.ndarray] = field(default_factory=dict) - meta_info: dict = field(default_factory=dict) - remote_batch: RemoteBatch | None = None - - def __post_init__(self): - self.check_consistency() - - def __len__(self): - if self.batch is not None: - return self.batch.batch_size[0] - if self.non_tensor_batch: - return len(next(iter(self.non_tensor_batch.values()))) - if self.remote_batch is not None: - return len(self.remote_batch) - return 0 - - def __getstate__(self): - buffer = io.BytesIO() - batch = self.batch.contiguous().consolidate() if self.batch is not None else None - torch.save(batch, buffer) - return buffer.getvalue(), self.non_tensor_batch, self.meta_info, self.remote_batch - - def __setstate__(self, data): - batch_bytes, self.non_tensor_batch, self.meta_info, self.remote_batch = data - self.batch = torch.load(io.BytesIO(batch_bytes), weights_only=False, map_location="cpu") - - def check_consistency(self): - if self.batch is not None: - assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" - if self.non_tensor_batch: - batch_size = len(self) - for key, val in self.non_tensor_batch.items(): - assert isinstance(val, np.ndarray), f"non_tensor_batch[{key}] must be np.ndarray, got {type(val)}" - assert val.shape[0] == batch_size, f"key {key} length {val.shape[0]} != batch size {batch_size}" - if self.batch is not None and self.remote_batch is not None: - assert len(self.batch) == len(self.remote_batch), "local and remote batch sizes must match" - if self.non_tensor_batch and self.remote_batch is not None: - assert len(next(iter(self.non_tensor_batch.values()))) == len(self.remote_batch) - - @classmethod - def from_dict( - cls, - tensors: dict[str, torch.Tensor] | None = None, - non_tensors: dict[str, Any] | None = None, - meta_info: dict | None = None, - num_batch_dims: int = 1, - ): - assert num_batch_dims > 0, "num_batch_dims must be greater than zero" - if non_tensors is not None: - assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None" - tensors = tensors or {} - non_tensors = non_tensors or {} - meta_info = meta_info or {} - - batch_size = None - pivot_key = None - for key, tensor in tensors.items(): - current_batch = tuple(tensor.shape[:num_batch_dims]) - if batch_size is None: - batch_size = current_batch - pivot_key = key - else: - assert current_batch == batch_size, ( - f"Not all tensors have the same batch size. {pivot_key} has {batch_size}, " - f"{key} has {current_batch}" - ) - - normalized_non_tensors = {} - for key, val in non_tensors.items(): - if not isinstance(val, np.ndarray): - val = np.array(val, dtype=object) - normalized_non_tensors[key] = val - if batch_size is None: - batch_size = (val.shape[0],) - else: - assert val.shape[0] == batch_size[0], ( - f"non_tensor {key} length {val.shape[0]} != batch size {batch_size[0]}" - ) - - tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None - return cls(batch=tensor_dict, non_tensor_batch=normalized_non_tensors, meta_info=meta_info) - - @classmethod - def from_remote( - cls, - remote_batch: RemoteBatch, - batch: TensorDict | None = None, - non_tensors: dict[str, np.ndarray] | None = None, - meta_info: dict | None = None, - ): - return cls( - batch=batch, - non_tensor_batch=non_tensors or {}, - meta_info=meta_info or {}, - remote_batch=remote_batch, - ) - - def materialize_remote_batch(self): - if self.remote_batch is None: - return self - fetched = self.remote_batch.materialize() - if self.batch is not None: - assert self.batch.batch_size == fetched.batch_size, ( - f"TensorDict batch size mismatch: {self.batch.batch_size} != {fetched.batch_size}" - ) - for key, val in fetched.items(): - assert key not in self.batch.keys() or self.batch[key].equal(val), ( - f"{key} exists in both TensorDicts with different values" - ) - self.batch[key] = val - else: - self.batch = fetched - self.remote_batch = None - self.check_consistency() - return self +GLOBAL_KEYS = ( + "raw_reward", + "total_lengths", + "global_batch_sizes", + "num_microbatches", + "micro_batch_indices", + "dynamic_global_batch_size", +) +_ROLLOUT_DATA_TENSOR_DTYPES = { + "tokens": torch.long, + "loss_masks": torch.int, + "rollout_log_probs": torch.float32, + "rollout_top_p_token_ids": torch.int32, + "rollout_top_p_token_offsets": torch.int32, + "teacher_log_probs": torch.float32, + "rollout_routed_experts": None, +} def split_rollout_data_by_dp_dataproto( @@ -169,165 +58,194 @@ def split_rollout_data_by_dp_dataproto( dp_size: int, partitions: list, dynamic_global_batch_size: int | None = None, -) -> list[DataProto]: + micro_batch_indices: list | None = None, + num_microbatches: list | None = None, + global_batch_sizes: list | None = None, +) -> list[dict[str, Any]]: if len(partitions) != dp_size: raise ValueError(f"expected {dp_size} partitions, got {len(partitions)}") - return _split_rollout_data_by_dp_remote_batch(args, data, partitions, dynamic_global_batch_size) - - -def _split_rollout_data_by_dp_remote_batch( - args: Any, - data: dict, - partitions: list, - dynamic_global_batch_size: int | None = None, -) -> list[DataProto]: store_init_kwargs = _store_init_kwargs(args) store = get_cached_mooncake_store(store_init_kwargs) transfer_id = uuid.uuid4().hex refs = [] try: for dp_rank, partition in enumerate(partitions): - indices = [int(idx) for idx in partition] - shard = _slice_partitioned_data(data, indices) - shard["partition"] = np.asarray(indices, dtype=np.int64) - meta_info = {key: data[key] for key in GLOBAL_KEYS if key in data} - if dynamic_global_batch_size is not None: - meta_info["dynamic_global_batch_size"] = dynamic_global_batch_size - - remote_tensors, remote_lengths = _extract_remote_tensors(shard) - meta_info.update(remote_lengths) - remote_batch = None - if remote_tensors: - remote_batch = MooncakeRemoteBatch.from_tensors( - remote_tensors, - store, - prefix=f"slime-rollout/{transfer_id}/dp{dp_rank}", - store_init_kwargs=store_init_kwargs, - use_hard_pin=getattr(args, "mooncake_dataproto_hard_pin", True), - ) - _attach_cleanup_info(meta_info, remote_batch, store_init_kwargs) - - try: - proto = ( - DataProto.from_remote(remote_batch, non_tensors=_dict_to_non_tensors(shard), meta_info=meta_info) - if remote_batch is not None - else DataProto.from_dict(non_tensors=shard, meta_info=meta_info) - ) - except Exception: - if remote_batch is not None: - remote_batch.cleanup() - raise - refs.append(proto) + rollout_data = _build_rank_rollout_data( + data, + [int(idx) for idx in partition], + micro_batch_indices[dp_rank] if micro_batch_indices is not None else None, + num_microbatches, + global_batch_sizes, + dynamic_global_batch_size, + ) + _tensorize_rollout_data_for_training(rollout_data) + ref = put_mooncake_dataproto( + _rollout_data_to_dataproto_envelope(rollout_data), + store, + key_prefix=f"slime-rollout/{transfer_id}/dp{dp_rank}", + ) + refs.append(ref) except Exception: - cleanup_dataproto_refs(refs) + cleanup_dataproto_refs(refs, store_init_kwargs=store_init_kwargs) raise return refs -def maybe_cleanup_dataproto_refs(args: Any, refs: list[DataProto], suppress_errors: bool = False) -> None: +def materialize_dataproto_rollout_data(args: Any, ref: Mapping[str, Any]) -> dict: + store_init_kwargs = _store_init_kwargs(args) + store = get_cached_mooncake_store(store_init_kwargs) + if not is_mooncake_dataproto_handle(ref): + raise TypeError(f"expected Mooncake DataProto handle, got {type(ref).__name__}") + envelope = get_mooncake_dataproto(ref, store) + rollout_data = dict(envelope.get("batch", {})) + rollout_data.update( + { + key: _non_tensor_value_to_legacy(value) + for key, value in envelope.get("non_tensor_batch", {}).items() + } + ) + rollout_data.update(envelope.get("meta_info", {})) + _tensorize_rollout_data_for_training(rollout_data) + return rollout_data + + +def maybe_cleanup_dataproto_refs(args: Any, refs: list[Mapping[str, Any]], suppress_errors: bool = False) -> None: if getattr(args, "transfer_backend", "ray") != "mooncake_dataproto": return + store_init_kwargs = _store_init_kwargs(args) if not suppress_errors: - cleanup_dataproto_refs(refs) + cleanup_dataproto_refs(refs, store_init_kwargs=store_init_kwargs) return try: - cleanup_dataproto_refs(refs) + cleanup_dataproto_refs(refs, store_init_kwargs=store_init_kwargs) except Exception: return -def cleanup_dataproto_refs(refs: list[DataProto]) -> None: - keys = set() - store_init_kwargs = None - for proto in refs: - keys.update(proto.meta_info.get("mooncake_cleanup_keys", [])) - if store_init_kwargs is None and "mooncake_cleanup_store_kwargs" in proto.meta_info: - store_init_kwargs = dict(proto.meta_info["mooncake_cleanup_store_kwargs"]) - if not keys or store_init_kwargs is None: - return - store = get_cached_mooncake_store(store_init_kwargs) - remove_mooncake_keys(store, sorted(keys)) - - -def _attach_cleanup_info( - meta_info: dict, - remote_batch: MooncakeRemoteBatch, - store_init_kwargs: dict[str, Any], +def cleanup_dataproto_refs( + refs: list[Mapping[str, Any]], + store_init_kwargs: dict[str, Any] | None = None, ) -> None: - meta_info["mooncake_cleanup_keys"] = list(remote_batch.keys_to_cleanup) - meta_info["mooncake_cleanup_store_kwargs"] = dict(store_init_kwargs) - - -def dataproto_to_rollout_data(proto: DataProto, preserve_remote_tensors: bool = True) -> dict: - """Materialize a transfer DataProto into slime's legacy rollout dict.""" - if proto.remote_batch is not None: - proto.materialize_remote_batch() - rollout_data = {key: val.tolist() for key, val in proto.non_tensor_batch.items()} - rollout_data.update({key: val for key, val in proto.meta_info.items() if not key.startswith("mooncake_cleanup_")}) - if proto.batch is not None: - for key, tensor in proto.batch.items(): - lengths = proto.meta_info.get(f"{key}_lengths") - if preserve_remote_tensors and key in REMOTE_TENSOR_KEYS: - rollout_data[key] = _tensor_to_row_tensors(tensor, lengths) - else: - rollout_data[key] = _tensor_to_list(tensor, lengths) - return rollout_data + if not refs: + return + store = get_cached_mooncake_store(store_init_kwargs or {"setup_method": "setup"}) + for ref in refs: + cleanup_mooncake_dataproto(ref, store) -def _slice_partitioned_data(data: dict, indices: list[int]) -> dict: - shard = {} +def _build_rank_rollout_data( + data: dict, + partition: list[int], + micro_batch_indices: list | None, + num_microbatches: list | None, + global_batch_sizes: list | None, + dynamic_global_batch_size: int | None, +) -> dict: + rollout_data = {"partition": partition} for key in PARTITIONED_KEYS: if key in data: - shard[key] = [data[key][idx] for idx in indices] - return shard - - -def _extract_remote_tensors(shard: dict) -> tuple[dict[str, torch.Tensor], dict[str, list[int]]]: - tensors = {} - lengths = {} - for key in REMOTE_TENSOR_KEYS: - if key not in shard: - continue - values = shard.pop(key) - tensor, field_lengths = _list_to_padded_tensor(values, torch.long if key == "tokens" else torch.int) - tensors[key] = tensor - lengths[f"{key}_lengths"] = field_lengths - return tensors, lengths - - -def _list_to_padded_tensor(values: list, dtype: torch.dtype) -> tuple[torch.Tensor, list[int]]: - if not values: - return torch.empty((0, 0), dtype=dtype), [] - tensors = [torch.as_tensor(value, dtype=dtype).reshape(-1) for value in values] - lengths = [int(tensor.numel()) for tensor in tensors] - return pad_sequence(tensors, batch_first=True, padding_value=0), lengths - - -def _tensor_to_row_tensors(tensor: torch.Tensor, lengths: list[int] | None = None) -> list[torch.Tensor]: - if tensor.ndim == 2: - if lengths is not None: - return [tensor[idx, : int(length)] for idx, length in zip(range(tensor.shape[0]), lengths, strict=True)] - return [tensor[idx] for idx in range(tensor.shape[0])] - return [tensor[idx] for idx in range(tensor.shape[0])] - + rollout_data[key] = [data[key][idx] for idx in partition] + for key in ("raw_reward", "total_lengths"): + if key in data: + rollout_data[key] = data[key] + if global_batch_sizes is not None: + rollout_data["global_batch_sizes"] = global_batch_sizes + if num_microbatches is not None: + rollout_data["num_microbatches"] = num_microbatches + if micro_batch_indices is not None: + rollout_data["micro_batch_indices"] = micro_batch_indices + if dynamic_global_batch_size is not None: + rollout_data["dynamic_global_batch_size"] = dynamic_global_batch_size + return rollout_data -def _tensor_to_list(tensor: torch.Tensor, lengths: list[int] | None = None) -> list: - if tensor.ndim == 2: - rows = tensor.cpu().tolist() - if lengths is not None: - return [row[: int(length)] for row, length in zip(rows, lengths, strict=True)] - return rows - return tensor.cpu().numpy().tolist() +def _rollout_data_to_dataproto_envelope(rollout_data: dict) -> dict[str, dict[str, Any]]: + batch_size = len(rollout_data["partition"]) + batch = {} + non_tensor_batch = {} + meta_info = {} + for key, value in rollout_data.items(): + if key in GLOBAL_KEYS: + meta_info[key] = _json_safe_metadata(value) + elif isinstance(value, torch.Tensor) and _is_row_aligned(value, batch_size): + batch[key] = value + elif _is_row_aligned(value, batch_size): + non_tensor_batch[key] = _to_object_array(value) + else: + meta_info[key] = _json_safe_metadata(value) + return {"batch": batch, "non_tensor_batch": non_tensor_batch, "meta_info": meta_info} + + +def _is_row_aligned(value: Any, batch_size: int) -> bool: + if isinstance(value, torch.Tensor): + return value.ndim > 0 and value.shape[0] == batch_size + if isinstance(value, np.ndarray): + return value.ndim > 0 and value.shape[0] == batch_size + if isinstance(value, (list, tuple)): + return len(value) == batch_size + return False + + +def _to_object_array(value: Any) -> np.ndarray: + if isinstance(value, np.ndarray): + return value + if isinstance(value, torch.Tensor): + return np.asarray([item.detach().cpu() for item in value], dtype=object) + return np.asarray(value, dtype=object) + + +def _non_tensor_value_to_legacy(value: Any) -> Any: + if isinstance(value, np.ndarray): + return value.tolist() + return value + + +def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None: + for key, dtype in _ROLLOUT_DATA_TENSOR_DTYPES.items(): + if key in rollout_data: + rollout_data[key] = [_cpu_tensor(value, dtype=dtype) for value in rollout_data[key]] + + if "multimodal_train_inputs" in rollout_data: + rollout_data["multimodal_train_inputs"] = [ + ( + { + key: _cpu_tensor(value) if isinstance(value, (np.ndarray, torch.Tensor)) else value + for key, value in mm_dict.items() + } + if mm_dict is not None + else None + ) + for mm_dict in rollout_data["multimodal_train_inputs"] + ] -def _dict_to_non_tensors(data: dict) -> dict[str, np.ndarray]: - return {key: val if isinstance(val, np.ndarray) else np.asarray(val, dtype=_infer_numpy_dtype(val)) for key, val in data.items()} + if "rollout_mask_sums" in rollout_data: + rollout_data["rollout_mask_sums"] = _cpu_tensor( + rollout_data["rollout_mask_sums"], + dtype=torch.float32, + ) -def _infer_numpy_dtype(val: Any) -> Any: - if isinstance(val, list) and all(isinstance(item, (bool, int, float, np.number)) for item in val): - return None - return object +def _cpu_tensor(value: Any, dtype: torch.dtype | None = None) -> torch.Tensor: + if isinstance(value, np.ndarray) and not value.flags.writeable: + value = value.copy() + tensor = torch.as_tensor(value, dtype=dtype) if dtype is not None else torch.as_tensor(value) + return tensor.detach().cpu().contiguous() + + +def _json_safe_metadata(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, np.generic): + return value.item() + if isinstance(value, torch.Tensor): + return value.detach().cpu().tolist() + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, Mapping): + return {str(key): _json_safe_metadata(val) for key, val in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe_metadata(item) for item in value] + return value def _store_init_kwargs(args: Any) -> dict[str, Any]: diff --git a/tests/utils/test_dataproto_transfer.py b/tests/utils/test_dataproto_transfer.py index 6a8914a2da..ce6fba823c 100644 --- a/tests/utils/test_dataproto_transfer.py +++ b/tests/utils/test_dataproto_transfer.py @@ -1,119 +1,116 @@ +import pickle import sys import types import numpy as np import torch +from slime.utils import remote_batch +from slime.utils.remote_batch import create_mooncake_store, normalize_store_init_kwargs +from slime.utils.rollout_dataproto import ( + cleanup_dataproto_refs, + materialize_dataproto_rollout_data, + split_rollout_data_by_dp_dataproto, +) -class TensorDict(dict): - def __init__(self, source=None, batch_size=None, device=None): - super().__init__(source or {}) - self.batch_size = torch.Size(batch_size or []) - self.device = device - def __len__(self): - return self.batch_size[0] if len(self.batch_size) > 0 else dict.__len__(self) +class FakeRef: + def __init__(self, manifest_key, batch_size): + self.manifest_key = manifest_key + self.batch_size = batch_size - def clone(self): - return TensorDict({key: val.clone() for key, val in self.items()}, self.batch_size, self.device) - def select(self, *keys): - return TensorDict({key: self[key] for key in keys}, self.batch_size, self.device) +class FakeMooncakeBundleTransfer: + def __init__(self, store, key_prefix=""): + self.store = store + self.key_prefix = key_prefix or "default" + def put_dataproto(self, data, namespace="default", partition="default", stage="default"): + key = f"{self.key_prefix}/{namespace}/{partition}/{stage}/manifest" + self.store.put(key, pickle.dumps(data)) + if data["non_tensor_batch"]: + batch_size = len(next(iter(data["non_tensor_batch"].values()))) + else: + batch_size = len(next(iter(data["batch"].values()))) + return FakeRef(key, batch_size) -tensordict_module = types.ModuleType("tensordict") -tensordict_td_module = types.ModuleType("tensordict._td") -tensordict_module.TensorDict = TensorDict -tensordict_td_module.TensorDict = TensorDict -sys.modules.setdefault("tensordict", tensordict_module) -sys.modules.setdefault("tensordict._td", tensordict_td_module) + def get_dataproto(self, handle): + return pickle.loads(self.store.get(handle["manifest_key"])) -from slime.utils.remote_batch import MooncakeRemoteBatch, create_mooncake_store, normalize_store_init_kwargs -from slime.utils.rollout_dataproto import DataProto, dataproto_to_rollout_data, split_rollout_data_by_dp_dataproto + def cleanup_dataproto(self, handle): + self.store.remove(handle["manifest_key"], True) -class FakeRemoteBatch: - def __init__(self, tensors, indices=None): - self.tensors = tensors - self.indices = list(range(len(next(iter(tensors.values()))))) if indices is None else indices +def fake_export_dataproto_ref(ref): + return { + "type": "mooncake_dataproto_ref", + "version": 1, + "kind": "bundle_stages", + "manifest_key": ref.manifest_key, + "batch_size": ref.batch_size, + } - def __len__(self): - return len(self.indices) - @property - def batch_size(self): - return torch.Size([len(self.indices)]) +def fake_is_dataproto_ref_handle(value): + return isinstance(value, dict) and value.get("type") == "mooncake_dataproto_ref" - def keys(self): - return list(self.tensors.keys()) - def materialize(self, fields=None): - selected = self.keys() if fields is None else fields - return TensorDict({key: self.tensors[key][self.indices] for key in selected}, batch_size=(len(self),)) - - -def test_dataproto_remote_materialize(): - remote = FakeRemoteBatch({"tokens": torch.arange(12).reshape(4, 3), "loss_masks": torch.ones(4, 3, dtype=torch.int32)}) - proto = DataProto.from_remote(remote, non_tensors={"response_lengths": np.asarray([1, 2, 3, 4])}) - - proto.materialize_remote_batch() - - assert proto.remote_batch is None - assert proto.batch["tokens"].tolist() == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] - assert proto.non_tensor_batch["response_lengths"].tolist() == [1, 2, 3, 4] - - -def test_dataproto_to_rollout_data_preserves_remote_tensor_rows(): - remote = FakeRemoteBatch( - { - "tokens": torch.tensor([[1, 2, 0], [3, 4, 5]]), - "loss_masks": torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.int), - } - ) - proto = DataProto.from_remote( - remote, - non_tensors={"partition": np.asarray([0, 1])}, - meta_info={ - "total_lengths": [2, 3], - "tokens_lengths": [2, 3], - "loss_masks_lengths": [2, 3], - }, - ) - - rollout_data = dataproto_to_rollout_data(proto) - - assert all(isinstance(row, torch.Tensor) for row in rollout_data["tokens"]) - assert all(isinstance(row, torch.Tensor) for row in rollout_data["loss_masks"]) - assert [row.tolist() for row in rollout_data["tokens"]] == [[1, 2], [3, 4, 5]] - assert [row.tolist() for row in rollout_data["loss_masks"]] == [[1, 1], [1, 1, 1]] - assert "_remote_tensor_owners" not in rollout_data - assert rollout_data["partition"] == [0, 1] - assert rollout_data["total_lengths"] == [2, 3] +def install_fake_mooncake(monkeypatch): + mooncake_module = types.ModuleType("mooncake") + structured_module = types.ModuleType("mooncake.structured_object_store") + structured_module.MooncakeBundleTransfer = FakeMooncakeBundleTransfer + structured_module.export_dataproto_ref = fake_export_dataproto_ref + structured_module.is_dataproto_ref_handle = fake_is_dataproto_ref_handle + monkeypatch.setitem(sys.modules, "mooncake", mooncake_module) + monkeypatch.setitem(sys.modules, "mooncake.structured_object_store", structured_module) + remote_batch._STORE_CACHE.clear() -def test_dataproto_to_rollout_data_legacy_tensor_list_fallback(): - remote = FakeRemoteBatch({"tokens": torch.tensor([[1, 2, 0], [3, 4, 5]])}) - proto = DataProto.from_remote( - remote, - non_tensors={"partition": np.asarray([0, 1])}, - meta_info={"total_lengths": [2, 3], "tokens_lengths": [2, 3]}, +def test_rollout_dict_roundtrips_through_mooncake_handle(monkeypatch): + install_fake_mooncake(monkeypatch) + args = types.SimpleNamespace(mooncake_dataproto_store_init_kwargs={"setup_method": "setup_dummy"}) + data = { + "tokens": [[1, 2], [3, 4, 5], [6]], + "loss_masks": [[1, 1], [1, 1, 1], [1]], + "response_lengths": [2, 3, 1], + "rewards": [1.0, 2.0, 3.0], + "rollout_ids": [10, 11, 12], + "rollout_mask_sums": [2.0, 3.0, 1.0], + "total_lengths": [2, 3, 1], + "raw_reward": [1.0, 2.0, 3.0], + } + + refs = split_rollout_data_by_dp_dataproto( + args, + data, + 2, + [[0, 2], [1]], + micro_batch_indices=[[[0], [1]], [[0]]], + num_microbatches=[2, 1], + global_batch_sizes=[2, 1], ) - rollout_data = dataproto_to_rollout_data(proto, preserve_remote_tensors=False) - - assert rollout_data["tokens"] == [[1, 2], [3, 4, 5]] - assert "_remote_tensor_owners" not in rollout_data - + assert all(ref["type"] == "mooncake_dataproto_ref" for ref in refs) + rollout_data = materialize_dataproto_rollout_data(args, refs[0]) -def test_dataproto_to_rollout_data_keeps_non_remote_tensors_legacy(): - proto = DataProto.from_dict(tensors={"other": torch.tensor([[1, 2], [3, 4]])}) + assert rollout_data["partition"] == [0, 2] + assert [row.tolist() for row in rollout_data["tokens"]] == [[1, 2], [6]] + assert [row.tolist() for row in rollout_data["loss_masks"]] == [[1, 1], [1]] + assert rollout_data["response_lengths"] == [2, 1] + assert rollout_data["rollout_ids"] == [10, 12] + assert rollout_data["rollout_mask_sums"].tolist() == [2.0, 1.0] + assert rollout_data["total_lengths"] == [2, 3, 1] + assert rollout_data["micro_batch_indices"] == [[0], [1]] + assert rollout_data["num_microbatches"] == [2, 1] + assert rollout_data["global_batch_sizes"] == [2, 1] - rollout_data = dataproto_to_rollout_data(proto, preserve_remote_tensors=True) + store = remote_batch.get_cached_mooncake_store({"setup_method": "setup_dummy"}) + cleanup_dataproto_refs(refs, store_init_kwargs={"setup_method": "setup_dummy"}) + assert store.objects == {} - assert rollout_data["other"] == [[1, 2], [3, 4]] - -def test_rollout_transfer_rejects_partition_mismatch(): +def test_rollout_transfer_rejects_partition_mismatch(monkeypatch): + install_fake_mooncake(monkeypatch) args = types.SimpleNamespace(mooncake_dataproto_store_init_kwargs={"setup_method": "setup_dummy"}) try: split_rollout_data_by_dp_dataproto(args, {}, 2, [[]]) @@ -148,16 +145,3 @@ def setup(self): monkeypatch.setitem(sys.modules, "mooncake.store", store_module) assert isinstance(create_mooncake_store(), Store) - - -def test_rejects_invalid_remote_field_name(): - class Store: - def get_hostname(self): - return "localhost" - - try: - MooncakeRemoteBatch.from_tensors({"../tokens": torch.ones(1, 1, dtype=torch.int64)}, Store(), "prefix") - except ValueError as exc: - assert "invalid Mooncake tensor field name" in str(exc) - else: - raise AssertionError("invalid tensor field name should be rejected") From 175dad6dba261c16d52079cb03eec473dbc6636f Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Thu, 25 Jun 2026 20:04:16 +0800 Subject: [PATCH 4/4] Rename Mooncake rollout transfer backend Expose the rollout transfer backend as mooncake while keeping mooncake_dataproto as a compatibility alias for existing scripts. Co-Authored-By: Claude Opus 4.6 --- slime/ray/rollout.py | 2 +- slime/utils/arguments.py | 11 +++++++---- slime/utils/data.py | 2 +- slime/utils/remote_batch.py | 2 +- slime/utils/rollout_dataproto.py | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 0ea0979882..225ec69428 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -850,7 +850,7 @@ def _split_train_data_by_dp(self, data): rollout_indices=data["rollout_ids"], ) - if getattr(self.args, "transfer_backend", "ray") == "mooncake_dataproto": + if getattr(self.args, "transfer_backend", "ray") in {"mooncake", "mooncake_dataproto"}: from slime.utils.rollout_dataproto import split_rollout_data_by_dp_dataproto dynamic_global_batch_size = getattr(self, "_dynamic_global_batch_size", None) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 8ff28eded1..b8867dc021 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -404,21 +404,21 @@ def add_rollout_arguments(parser): ) parser.add_argument( "--transfer-backend", - choices=["ray", "mooncake_dataproto"], + choices=["ray", "mooncake", "mooncake_dataproto"], default="ray", - help="Rollout data transfer backend. Keep ray as the default; mooncake_dataproto is experimental.", + help="Rollout data transfer backend. Keep ray as the default; mooncake is experimental.", ) parser.add_argument( "--mooncake-dataproto-hard-pin", action=argparse.BooleanOptionalAction, default=True, - help="Hard-pin Mooncake rollout tensors to the producer segment for mooncake_dataproto transfer.", + help="Hard-pin Mooncake rollout tensors to the producer segment for Mooncake transfer.", ) parser.add_argument( "--mooncake-dataproto-store-init-kwargs", type=json.loads, default=None, - help="JSON kwargs used to initialize MooncakeDistributedStore for mooncake_dataproto transfer.", + help="JSON kwargs used to initialize MooncakeDistributedStore for Mooncake transfer.", ) # sampling @@ -1767,6 +1767,9 @@ def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto": + args.transfer_backend = "mooncake" + + if getattr(args, "transfer_backend", "ray") == "mooncake": from slime.utils.remote_batch import normalize_store_init_kwargs args.mooncake_dataproto_store_init_kwargs = normalize_store_init_kwargs( diff --git a/slime/utils/data.py b/slime/utils/data.py index 248070049a..1d2b64a339 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -290,7 +290,7 @@ def __len__(self): def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): assert len(rollout_data_ref) == dp_size - if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto": + if getattr(args, "transfer_backend", "ray") in {"mooncake", "mooncake_dataproto"}: from slime.utils.rollout_dataproto import materialize_dataproto_rollout_data rollout_data = materialize_dataproto_rollout_data(args, rollout_data_ref[dp_rank]) diff --git a/slime/utils/remote_batch.py b/slime/utils/remote_batch.py index bafdc7cc2c..2ab36ca784 100644 --- a/slime/utils/remote_batch.py +++ b/slime/utils/remote_batch.py @@ -11,7 +11,7 @@ def normalize_store_init_kwargs(store_init_kwargs: dict[str, Any] | None) -> dict[str, Any]: if store_init_kwargs is None: - raise ValueError("mooncake_dataproto requires --mooncake-dataproto-store-init-kwargs") + raise ValueError("mooncake transfer requires --mooncake-dataproto-store-init-kwargs") if not store_init_kwargs: return {"setup_method": "setup"} setup_method = store_init_kwargs.get("setup_method", "setup") diff --git a/slime/utils/rollout_dataproto.py b/slime/utils/rollout_dataproto.py index c5c258afe7..a527a4b731 100644 --- a/slime/utils/rollout_dataproto.py +++ b/slime/utils/rollout_dataproto.py @@ -110,7 +110,7 @@ def materialize_dataproto_rollout_data(args: Any, ref: Mapping[str, Any]) -> dic def maybe_cleanup_dataproto_refs(args: Any, refs: list[Mapping[str, Any]], suppress_errors: bool = False) -> None: - if getattr(args, "transfer_backend", "ray") != "mooncake_dataproto": + if getattr(args, "transfer_backend", "ray") not in {"mooncake", "mooncake_dataproto"}: return store_init_kwargs = _store_init_kwargs(args) if not suppress_errors: