[checkpoint_engine][sglang] feat: delta weight sync for disaggregated rollout#6794
[checkpoint_engine][sglang] feat: delta weight sync for disaggregated rollout#6794ChangyiYang wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an experimental delta weight synchronization mechanism for SGLang rollouts, enabling the transmission of only changed weight elements via NCCL or disk transports. The implementation includes pinned-CPU snapshotting, gap-delta encoding, and integration with SGLang's update APIs. The review feedback identifies critical issues that must be addressed: a missing record_stream call during async copies on side streams which risks silent data corruption, race conditions and incomplete file dispatching when using disk transport with TP > 1, high memory overhead from concatenating unmasked parameter values during encoding, and potential directory creation race conditions across distributed ranks.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if self.d2h_stream is None: | ||
| self.d2h_stream = torch.cuda.Stream() | ||
| event = torch.cuda.current_stream().record_event() | ||
| with torch.cuda.stream(self.d2h_stream): | ||
| self.d2h_stream.wait_event(event) | ||
| for name, tensor in named_tensors: | ||
| self._allocate(name, tensor) | ||
| self.snapshot[name].copy_(tensor.detach(), non_blocking=True) | ||
| self._snapshot_dirty = True |
There was a problem hiding this comment.
When copying tensors asynchronously on a non-default CUDA stream (self.d2h_stream), you must call tensor.record_stream(self.d2h_stream) to prevent PyTorch's caching allocator from reclaiming or reusing the tensor's memory before the copy operation completes. Since the generator may yield new tensors or FSDP may free/reuse the gathered parameter buffers in subsequent steps, omitting record_stream can lead to silent data corruption in the snapshot.
| if self.d2h_stream is None: | |
| self.d2h_stream = torch.cuda.Stream() | |
| event = torch.cuda.current_stream().record_event() | |
| with torch.cuda.stream(self.d2h_stream): | |
| self.d2h_stream.wait_event(event) | |
| for name, tensor in named_tensors: | |
| self._allocate(name, tensor) | |
| self.snapshot[name].copy_(tensor.detach(), non_blocking=True) | |
| self._snapshot_dirty = True | |
| if self.d2h_stream is None: | |
| self.d2h_stream = torch.cuda.Stream() | |
| event = torch.cuda.current_stream().record_event() | |
| with torch.cuda.stream(self.d2h_stream): | |
| self.d2h_stream.wait_event(event) | |
| for name, tensor in named_tensors: | |
| self._allocate(name, tensor) | |
| self.snapshot[name].copy_(tensor.detach(), non_blocking=True) | |
| tensor.record_stream(self.d2h_stream) | |
| self._snapshot_dirty = True |
| if transport == "disk" and pending_files and self._is_server_tp_leader(): | ||
| await dispatch_disk_files( | ||
| self._engine, | ||
| out_dir=version_dir, | ||
| files=pending_files, | ||
| weight_version=self._delta_version, | ||
| ) | ||
| if not getattr(ce, "delta_keep_files", False): | ||
| import shutil | ||
| shutil.rmtree(version_dir, ignore_errors=True) | ||
|
|
There was a problem hiding this comment.
Using transport='disk' introduces two critical issues on TP > 1:
- Race Condition: The leader rank might call
dispatch_disk_filesand delete theversion_dirbefore other ranks have finished writing their files, or before SGLang has finished reading them. Atorch.distributed.barrier()is required to synchronize all ranks. - Incomplete File List:
pending_filesis local to each rank, so the leader only dispatches its own files. SGLang ranks > 0 will fail to load their corresponding delta files. The leader should reconstruct the complete list of files for all ranks deterministically based on the TP size andflushes_emitted.
if transport == "disk":
if torch.distributed.is_initialized():
torch.distributed.barrier()
if self._is_server_tp_leader() and flushes_emitted > 0:
tp_size = (
self.device_mesh["infer_tp"].size()
if "infer_tp" in self.device_mesh.mesh_dim_names
else 1
)
all_files = [
f"rank{r:04d}_flush{f:06d}.safetensors"
for r in range(tp_size)
for f in range(flushes_emitted)
]
await dispatch_disk_files(
self._engine,
out_dir=version_dir,
files=all_files,
weight_version=self._delta_version,
)
if not getattr(ce, "delta_keep_files", False):
import shutil
shutil.rmtree(version_dir, ignore_errors=True)| shape=list(d.values.shape), | ||
| pos_start=pos_byte_off, | ||
| pos_end=pos_byte_off + nnz * 4, | ||
| pos_width=4, | ||
| val_start=val_off, | ||
| val_end=val_off + nnz, | ||
| ) | ||
| ) | ||
| pos_byte_off += nnz * 4 | ||
| val_off += nnz | ||
| prev_b = b | ||
| prev_param_start = cum[i] | ||
| if not params: | ||
| return EncodedChunk.empty() | ||
| positions = torch.cat(pos_pieces, dim=0) | ||
| values = torch.cat(val_pieces, dim=0) | ||
| return EncodedChunk( | ||
| pos_bytes=positions.cpu().numpy().tobytes(), | ||
| val_tensor=values, | ||
| params=params, | ||
| nnz=val_off, | ||
| ) | ||
|
|
||
|
|
||
| def _encode_deltas(diffs: list[ParamDiff]) -> EncodedChunk: | ||
| """Gap-encode sorted positions. | ||
|
|
||
| Store ``idx[k] - idx[k-1] - 1`` with ``idx[-1] := -1`` so the first delta | ||
| equals the first index. Each parameter downcasts to uint16 if the max gap | ||
| fits, else uint32. At ~2% density on bf16 weights the typical max gap is | ||
| ~300, so uint16 normally suffices; the uint32 fallback covers pathological | ||
| inputs without correctness risk. The receiver inverts via | ||
| ``idx = cumsum(delta + 1) - 1``. | ||
| """ | ||
| if not diffs: | ||
| return EncodedChunk.empty() | ||
| big_val, bounds, big_idx, cum = _sparse_boundaries(diffs) | ||
|
|
||
| kept: list[tuple[ParamDiff, int]] = [] | ||
| per_param_deltas: list[torch.Tensor] = [] | ||
| val_pieces: list[torch.Tensor] = [] | ||
| prev_b = 0 | ||
| prev_param_start = 0 | ||
| for i, d in enumerate(diffs): | ||
| b = bounds[i] | ||
| nnz = b - prev_b | ||
| if nnz > 0: | ||
| local_idx = big_idx[prev_b:b] - prev_param_start # int64, sorted | ||
| prev = torch.cat( | ||
| [ | ||
| torch.tensor( | ||
| [-1], dtype=local_idx.dtype, device=local_idx.device | ||
| ), | ||
| local_idx[:-1], | ||
| ] | ||
| ) | ||
| per_param_deltas.append(local_idx - prev - 1) | ||
| val_pieces.append(big_val[prev_b:b]) | ||
| kept.append((d, nnz)) | ||
| prev_b = b | ||
| prev_param_start = cum[i] | ||
|
|
||
| if not kept: | ||
| return EncodedChunk.empty() | ||
|
|
||
| max_per_param = ( | ||
| torch.stack([d.max() for d in per_param_deltas]).cpu().tolist() | ||
| ) | ||
| pos_byte_pieces: list[bytes] = [] | ||
| pos_byte_off = val_off = 0 | ||
| params: list[DeltaParam] = [] | ||
| for (d, nnz), deltas, max_d in zip( | ||
| kept, per_param_deltas, max_per_param, strict=True | ||
| ): | ||
| width = 2 if int(max_d) <= 65535 else 4 | ||
| np_dtype = np.uint16 if width == 2 else np.uint32 | ||
| b_chunk = deltas.cpu().numpy().astype(np_dtype, copy=False).tobytes() | ||
| pos_byte_pieces.append(b_chunk) | ||
| params.append( | ||
| DeltaParam( | ||
| name=d.name, | ||
| dtype=str(d.values.dtype).replace("torch.", ""), | ||
| shape=list(d.values.shape), | ||
| pos_start=pos_byte_off, | ||
| pos_end=pos_byte_off + len(b_chunk), | ||
| pos_width=width, | ||
| val_start=val_off, | ||
| val_end=val_off + nnz, | ||
| ) | ||
| ) | ||
| pos_byte_off += len(b_chunk) | ||
| val_off += nnz | ||
|
|
||
| values = torch.cat(val_pieces, dim=0) | ||
| return EncodedChunk( | ||
| pos_bytes=b"".join(pos_byte_pieces), | ||
| val_tensor=values, | ||
| params=params, | ||
| nnz=val_off, | ||
| ) |
There was a problem hiding this comment.
Concatenating all parameter values in _sparse_boundaries via big_values = torch.cat(...) allocates a massive temporary tensor containing all parameter elements (even unchanged ones) for every chunk. This introduces significant GPU memory overhead and compute latency. By indexing into each d.values individually using the computed local_idx on the GPU, we can completely avoid allocating big_values and only allocate memory for the sparse set of changed values.
def _sparse_boundaries(
diffs: list[ParamDiff],
) -> tuple[list[int], torch.Tensor, list[int]]:
"""One concat -> one nonzero -> one searchsorted -> one tolist().
Collapses per-parameter host syncs to a single one per chunk.
"""
device = diffs[0].values.device
sizes = [d.values.numel() for d in diffs]
cum = list(itertools.accumulate(sizes))
cum_t = torch.tensor(cum, dtype=torch.int64, device=device)
big_mask = torch.cat([d.mask.contiguous().view(-1) for d in diffs], dim=0)
big_idx = big_mask.nonzero(as_tuple=False).view(-1)
bounds = torch.searchsorted(big_idx, cum_t).tolist()
return bounds, big_idx, cum
def _encode_indices(diffs: list[ParamDiff]) -> EncodedChunk:
if not diffs:
return EncodedChunk.empty()
bounds, big_idx, cum = _sparse_boundaries(diffs)
pos_pieces: list[torch.Tensor] = []
val_pieces: list[torch.Tensor] = []
params: list[DeltaParam] = []
pos_byte_off = val_off = 0
prev_b = 0
prev_param_start = 0
for i, d in enumerate(diffs):
b = bounds[i]
nnz = b - prev_b
if nnz > 0:
local_idx = big_idx[prev_b:b] - prev_param_start
pos_pieces.append(local_idx.to(torch.int32))
val_pieces.append(d.values.contiguous().view(-1)[local_idx])
params.append(
DeltaParam(
name=d.name,
dtype=str(d.values.dtype).replace("torch.", ""),
shape=list(d.values.shape),
pos_start=pos_byte_off,
pos_end=pos_byte_off + nnz * 4,
pos_width=4,
val_start=val_off,
val_end=val_off + nnz,
)
)
pos_byte_off += nnz * 4
val_off += nnz
prev_b = b
prev_param_start = cum[i]
if not params:
return EncodedChunk.empty()
positions = torch.cat(pos_pieces, dim=0)
values = torch.cat(val_pieces, dim=0)
return EncodedChunk(
pos_bytes=positions.cpu().numpy().tobytes(),
val_tensor=values,
params=params,
nnz=val_off,
)
def _encode_deltas(diffs: list[ParamDiff]) -> EncodedChunk:
"""Gap-encode sorted positions.
Store ``idx[k] - idx[k-1] - 1`` with ``idx[-1] := -1`` so the first delta
equals the first index. Each parameter downcasts to uint16 if the max gap
fits, else uint32. At ~2% density on bf16 weights the typical max gap is
~300, so uint16 normally suffices; the uint32 fallback covers pathological
inputs without correctness risk. The receiver inverts via
``idx = cumsum(delta + 1) - 1``.
"""
if not diffs:
return EncodedChunk.empty()
bounds, big_idx, cum = _sparse_boundaries(diffs)
kept: list[tuple[ParamDiff, int, torch.Tensor]] = []
per_param_deltas: list[torch.Tensor] = []
val_pieces: list[torch.Tensor] = []
prev_b = 0
prev_param_start = 0
for i, d in enumerate(diffs):
b = bounds[i]
nnz = b - prev_b
if nnz > 0:
local_idx = big_idx[prev_b:b] - prev_param_start # int64, sorted
prev = torch.cat(
[
torch.tensor(
[-1], dtype=local_idx.dtype, device=local_idx.device
),
local_idx[:-1],
]
)
per_param_deltas.append(local_idx - prev - 1)
val_pieces.append(d.values.contiguous().view(-1)[local_idx])
kept.append((d, nnz, per_param_deltas[-1]))
prev_b = b
prev_param_start = cum[i]
if not kept:
return EncodedChunk.empty()
max_per_param = (
torch.stack([item[2].max() for item in kept]).cpu().tolist()
)
pos_byte_pieces: list[bytes] = []
pos_byte_off = val_off = 0
params: list[DeltaParam] = []
for (d, nnz, deltas), max_d in zip(
kept, max_per_param, strict=True
):
width = 2 if int(max_d) <= 65535 else 4
np_dtype = np.uint16 if width == 2 else np.uint32
b_chunk = deltas.cpu().numpy().astype(np_dtype, copy=False).tobytes()
pos_byte_pieces.append(b_chunk)
params.append(
DeltaParam(
name=d.name,
dtype=str(d.values.dtype).replace("torch.", ""),
shape=list(d.values.shape),
pos_start=pos_byte_off,
pos_end=pos_byte_off + len(b_chunk),
pos_width=width,
val_start=val_off,
val_end=val_off + nnz,
)
)
pos_byte_off += len(b_chunk)
val_off += nnz
values = torch.cat(val_pieces, dim=0)
return EncodedChunk(
pos_bytes=b"".join(pos_byte_pieces),
val_tensor=values,
params=params,
nnz=val_off,
)| from safetensors.torch import save as st_save_bytes # local import | ||
|
|
||
| metadata = { | ||
| "encoding": flush.encoding, |
There was a problem hiding this comment.
In a distributed environment, non-leader ranks might attempt to write their flush files to out_dir before the leader rank has finished creating it, leading to a FileNotFoundError. Adding os.makedirs(out_dir, exist_ok=True) inside write_flush_to_disk ensures that every rank safely creates the directory if it does not exist yet.
| from safetensors.torch import save as st_save_bytes # local import | |
| metadata = { | |
| "encoding": flush.encoding, | |
| from safetensors.torch import save as st_save_bytes # local import | |
| os.makedirs(out_dir, exist_ok=True) | |
| metadata = { | |
| "encoding": flush.encoding, |
|
An e2e example is WIP |
4b49c96 to
d753768
Compare
…saggregated rollout Adds a "delta" checkpoint engine that puts only the changed weights on the trainer->rollout wire, mirroring THUDM/slime's NCCL delta transport. Instead of broadcasting every parameter each sync, the trainer byte-diffs against a pinned-CPU snapshot and broadcasts only the changed (position, value) pairs over the same ray.util.collective group the full-weight NCCLCheckpointEngine uses. Design (follows verl's existing "A" topology: actor rank0 -> rollout workers): - verl/workers/rollout/delta_sync/: framework-agnostic core -- DeltaState (pinned snapshot + bytewise diff, side-stream H2D/D2H pipelining), encode/ decode (indices / gap-deltas, per-param manifest, checksum), and a wrapper that turns verl's (name, tensor) generator into bucketed DeltaFlush objects. - DeltaCheckpointEngine(NCCLCheckpointEngine): send_weights diffs + broadcasts per-flush positions/values (master uses cupy buffers, expandable_segments safe); the rollout worker reconstructs full tensors from the delta into a local mirror and hands them to the standard server_adapter.update_weights. So the trainer->worker wire is sparse, while the worker->engine push is an ordinary full-tensor load -- no SGLang-side delta receiver required. First sync forces a full delta so a dummy-initialized rollout gets a correct base; a per-flush checksum is verified on receipt. Enable with rollout.checkpoint_engine.backend=delta (+ engine_kwargs.delta. encoding). Validated on a 4+4 single-node disaggregated one_step_off GRPO run; unit tests cover encode/decode bit-identity. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… workers In disaggregated (one_step_off) training the detached actor worker that produces weights has no local rollout handle, but ActorRolloutRefWorker.update_weights falls through to the colocated resume/sleep weight-sync path which assumes one. Return early when self.rollout is None (the disaggregated sync is driven by the checkpoint engine instead), and tolerate the colocated-only LoRA attributes being unset. Lets the naive full-sync baseline run in disaggregated mode. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
146dc0c to
d642348
Compare
Adds test_delta_result_equals_full_sync: seeds the trainer snapshot and the rollout mirror from the same W0, then over several steps diffs W_new, applies only the changed positions onto the mirror (reproducing the rollout worker's receive_weights mirror-combine), and asserts the mirror is byte-equal to W_new -- i.e. the weights a rollout ends up with via delta == what the old full path delivers. CPU-only, no GPU/NCCL/SGLang: the transport only moves bytes, so the lossless guarantee lives entirely in encode/decode/combine. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ngine The delta_sync core (DeltaState / encode / wrapper) was placed under workers/rollout/ back when delta lived in the SGLang rollout ServerAdapter. Its only consumer now is DeltaCheckpointEngine, so move it to verl/checkpoint_engine/delta_sync/ (and the test to tests/checkpoint_engine/) and switch to relative imports. Fixes the awkward checkpoint_engine -> workers.rollout dependency direction; no behavior change. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The delta transport is NCCL-only; remove the leftover disk framing. Drops the `deltas_zstd` encoding (it only existed to zstd-wrap the gap stream at safetensors-write time on the disk path -- a no-op alias for `deltas` without disk) and rewrites the docstrings that still referenced disk safetensors / a DeltaSpec-style SGLang receiver. The receiver is now the rollout worker's local decode-into-mirror; no behavior change for the `indices` / `deltas` encodings. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Summary
Adds delta weight sync for the SGLang rollout and puts it on the
disaggregated trainer→rollout wire: instead of broadcasting every parameter
each sync, ship only the
(position, value)pairs that changed since the lastbroadcast. In RL post-training >99% of BF16 weight bytes are unchanged
step-over-step, so this cuts the trainer→rollout transfer by ~100–1000×.
Design follows THUDM/slime's NCCL delta transport and reuses verl's existing
checkpoint-engine topology (actor rank 0 → rollout workers over
ray.util.collective).What's in this PR
Engine-agnostic core —
verl/workers/rollout/delta_sync/delta_state.py— pinned-CPU snapshot + bytewise diff, d2h/h2d side streams (CUDA-optional)encode.py—indices/deltas(uint16 gap) encodings, per-param manifest, checksum, decoderwrapper.py—iter_delta_flushes():(name, tensor)generator → stream of bucketedDeltaFlushDisaggregated NCCL transport —
verl/checkpoint_engine/delta_checkpoint_engine.py(backend="delta")NCCLCheckpointEngine, so it reuses the same collective group, zmq side-channeland
build_topologyas the full-weight engine.send_weights(trainer rank 0): byte-diff vs snapshot → publish a per-sync manifest over zmq, thencollective.broadcastonly the changed positions/values per flush (master uses cupy buffers so NCCLregistration is safe under
expandable_segments). The first sync forces a full delta so adummy-initialized rollout gets a correct base; subsequent syncs are sparse.
receive_weights(rollout worker): broadcast-recv the flush, verify the per-flush checksum, decode +NaN-mask into a local full-weight mirror, and yield reconstructed full tensors to the standard
server_adapter.update_weights. The trainer→worker wire is sparse; the worker→engine push is anordinary full-tensor load — no SGLang-side delta receiver is required.
Default behavior is unchanged:
deltais opt-in viarollout.checkpoint_engine.backend=delta(+engine_kwargs.delta.encoding).Also includes a small one_step_off fix: skip the colocated resume/sleep weight-sync path on detached
actor workers (which have no local rollout handle) so the disaggregated full-sync baseline runs too.
Validation
Single-node 4+4 disaggregated (one_step_off, FSDP2 + SGLang, Qwen2.5-0.5B, GSM8K GRPO):
backend=delta; per-flush checksum verified on receipt, no corruption(
ppo_klsmall and stable, rollout generates coherently).Notes / scope
deltas_zstdencoding is not yet wired (gap stream only); vLLM/TRTLLM rollouts out of scope.