Skip to content

[checkpoint_engine][sglang] feat: delta weight sync for disaggregated rollout#6794

Draft
ChangyiYang wants to merge 5 commits into
verl-project:mainfrom
ChangyiYang:feat/delta-weight-sync-sglang
Draft

[checkpoint_engine][sglang] feat: delta weight sync for disaggregated rollout#6794
ChangyiYang wants to merge 5 commits into
verl-project:mainfrom
ChangyiYang:feat/delta-weight-sync-sglang

Conversation

@ChangyiYang

@ChangyiYang ChangyiYang commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds delta weight sync for the SGLang rollout and puts it on the
disaggregated trainer→rollout wire: instead of broadcasting every parameter
each sync, ship only the (position, value) pairs that changed since the last
broadcast. In RL post-training >99% of BF16 weight bytes are unchanged
step-over-step, so this cuts the trainer→rollout transfer by ~100–1000×.

Design follows THUDM/slime's NCCL delta transport and reuses verl's existing
checkpoint-engine topology (actor rank 0 → rollout workers over
ray.util.collective).

What's in this PR

Engine-agnostic coreverl/workers/rollout/delta_sync/

  • delta_state.py — pinned-CPU snapshot + bytewise diff, d2h/h2d side streams (CUDA-optional)
  • encode.pyindices / deltas (uint16 gap) encodings, per-param manifest, checksum, decoder
  • wrapper.pyiter_delta_flushes(): (name, tensor) generator → stream of bucketed DeltaFlush

Disaggregated NCCL transportverl/checkpoint_engine/delta_checkpoint_engine.py (backend="delta")

  • Subclasses NCCLCheckpointEngine, so it reuses the same collective group, zmq side-channel
    and build_topology as the full-weight engine.
  • send_weights (trainer rank 0): byte-diff vs snapshot → publish a per-sync manifest over zmq, then
    collective.broadcast only the changed positions/values per flush (master uses cupy buffers so NCCL
    registration is safe under expandable_segments). The first sync forces a full delta so a
    dummy-initialized rollout gets a correct base; subsequent syncs are sparse.
  • receive_weights (rollout worker): broadcast-recv the flush, verify the per-flush checksum, decode +
    NaN-mask into a local full-weight mirror, and yield reconstructed full tensors to the standard
    server_adapter.update_weights. The trainer→worker wire is sparse; the worker→engine push is an
    ordinary full-tensor load — no SGLang-side delta receiver is required.

Default behavior is unchanged: delta is opt-in via
rollout.checkpoint_engine.backend=delta (+ engine_kwargs.delta.encoding).

Also includes a small one_step_off fix: skip the colocated resume/sleep weight-sync path on detached
actor workers (which have no local rollout handle) so the disaggregated full-sync baseline runs too.

Validation

Single-node 4+4 disaggregated (one_step_off, FSDP2 + SGLang, Qwen2.5-0.5B, GSM8K GRPO):

  • runs end-to-end with backend=delta; per-flush checksum verified on receipt, no corruption
    (ppo_kl small and stable, rollout generates coherently).
  • unit tests cover encode/decode bit-identity for both encodings.

Notes / scope

  • deltas_zstd encoding is not yet wired (gap stream only); vLLM/TRTLLM rollouts out of scope.

@CLAassistant

CLAassistant commented Jun 18, 2026

Copy link
Copy Markdown

CLA assistant check
All committers have signed the CLA.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an experimental delta weight synchronization mechanism for SGLang rollouts, enabling the transmission of only changed weight elements via NCCL or disk transports. The implementation includes pinned-CPU snapshotting, gap-delta encoding, and integration with SGLang's update APIs. The review feedback identifies critical issues that must be addressed: a missing record_stream call during async copies on side streams which risks silent data corruption, race conditions and incomplete file dispatching when using disk transport with TP > 1, high memory overhead from concatenating unmasked parameter values during encoding, and potential directory creation race conditions across distributed ranks.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +163 to +171
if self.d2h_stream is None:
self.d2h_stream = torch.cuda.Stream()
event = torch.cuda.current_stream().record_event()
with torch.cuda.stream(self.d2h_stream):
self.d2h_stream.wait_event(event)
for name, tensor in named_tensors:
self._allocate(name, tensor)
self.snapshot[name].copy_(tensor.detach(), non_blocking=True)
self._snapshot_dirty = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When copying tensors asynchronously on a non-default CUDA stream (self.d2h_stream), you must call tensor.record_stream(self.d2h_stream) to prevent PyTorch's caching allocator from reclaiming or reusing the tensor's memory before the copy operation completes. Since the generator may yield new tensors or FSDP may free/reuse the gathered parameter buffers in subsequent steps, omitting record_stream can lead to silent data corruption in the snapshot.

Suggested change
if self.d2h_stream is None:
self.d2h_stream = torch.cuda.Stream()
event = torch.cuda.current_stream().record_event()
with torch.cuda.stream(self.d2h_stream):
self.d2h_stream.wait_event(event)
for name, tensor in named_tensors:
self._allocate(name, tensor)
self.snapshot[name].copy_(tensor.detach(), non_blocking=True)
self._snapshot_dirty = True
if self.d2h_stream is None:
self.d2h_stream = torch.cuda.Stream()
event = torch.cuda.current_stream().record_event()
with torch.cuda.stream(self.d2h_stream):
self.d2h_stream.wait_event(event)
for name, tensor in named_tensors:
self._allocate(name, tensor)
self.snapshot[name].copy_(tensor.detach(), non_blocking=True)
tensor.record_stream(self.d2h_stream)
self._snapshot_dirty = True

Comment on lines +442 to +452
if transport == "disk" and pending_files and self._is_server_tp_leader():
await dispatch_disk_files(
self._engine,
out_dir=version_dir,
files=pending_files,
weight_version=self._delta_version,
)
if not getattr(ce, "delta_keep_files", False):
import shutil
shutil.rmtree(version_dir, ignore_errors=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using transport='disk' introduces two critical issues on TP > 1:

  1. Race Condition: The leader rank might call dispatch_disk_files and delete the version_dir before other ranks have finished writing their files, or before SGLang has finished reading them. A torch.distributed.barrier() is required to synchronize all ranks.
  2. Incomplete File List: pending_files is local to each rank, so the leader only dispatches its own files. SGLang ranks > 0 will fail to load their corresponding delta files. The leader should reconstruct the complete list of files for all ranks deterministically based on the TP size and flushes_emitted.
        if transport == "disk":
            if torch.distributed.is_initialized():
                torch.distributed.barrier()
            if self._is_server_tp_leader() and flushes_emitted > 0:
                tp_size = (
                    self.device_mesh["infer_tp"].size()
                    if "infer_tp" in self.device_mesh.mesh_dim_names
                    else 1
                )
                all_files = [
                    f"rank{r:04d}_flush{f:06d}.safetensors"
                    for r in range(tp_size)
                    for f in range(flushes_emitted)
                ]
                await dispatch_disk_files(
                    self._engine,
                    out_dir=version_dir,
                    files=all_files,
                    weight_version=self._delta_version,
                )
                if not getattr(ce, "delta_keep_files", False):
                    import shutil
                    shutil.rmtree(version_dir, ignore_errors=True)

Comment on lines +115 to +255
shape=list(d.values.shape),
pos_start=pos_byte_off,
pos_end=pos_byte_off + nnz * 4,
pos_width=4,
val_start=val_off,
val_end=val_off + nnz,
)
)
pos_byte_off += nnz * 4
val_off += nnz
prev_b = b
prev_param_start = cum[i]
if not params:
return EncodedChunk.empty()
positions = torch.cat(pos_pieces, dim=0)
values = torch.cat(val_pieces, dim=0)
return EncodedChunk(
pos_bytes=positions.cpu().numpy().tobytes(),
val_tensor=values,
params=params,
nnz=val_off,
)


def _encode_deltas(diffs: list[ParamDiff]) -> EncodedChunk:
"""Gap-encode sorted positions.

Store ``idx[k] - idx[k-1] - 1`` with ``idx[-1] := -1`` so the first delta
equals the first index. Each parameter downcasts to uint16 if the max gap
fits, else uint32. At ~2% density on bf16 weights the typical max gap is
~300, so uint16 normally suffices; the uint32 fallback covers pathological
inputs without correctness risk. The receiver inverts via
``idx = cumsum(delta + 1) - 1``.
"""
if not diffs:
return EncodedChunk.empty()
big_val, bounds, big_idx, cum = _sparse_boundaries(diffs)

kept: list[tuple[ParamDiff, int]] = []
per_param_deltas: list[torch.Tensor] = []
val_pieces: list[torch.Tensor] = []
prev_b = 0
prev_param_start = 0
for i, d in enumerate(diffs):
b = bounds[i]
nnz = b - prev_b
if nnz > 0:
local_idx = big_idx[prev_b:b] - prev_param_start # int64, sorted
prev = torch.cat(
[
torch.tensor(
[-1], dtype=local_idx.dtype, device=local_idx.device
),
local_idx[:-1],
]
)
per_param_deltas.append(local_idx - prev - 1)
val_pieces.append(big_val[prev_b:b])
kept.append((d, nnz))
prev_b = b
prev_param_start = cum[i]

if not kept:
return EncodedChunk.empty()

max_per_param = (
torch.stack([d.max() for d in per_param_deltas]).cpu().tolist()
)
pos_byte_pieces: list[bytes] = []
pos_byte_off = val_off = 0
params: list[DeltaParam] = []
for (d, nnz), deltas, max_d in zip(
kept, per_param_deltas, max_per_param, strict=True
):
width = 2 if int(max_d) <= 65535 else 4
np_dtype = np.uint16 if width == 2 else np.uint32
b_chunk = deltas.cpu().numpy().astype(np_dtype, copy=False).tobytes()
pos_byte_pieces.append(b_chunk)
params.append(
DeltaParam(
name=d.name,
dtype=str(d.values.dtype).replace("torch.", ""),
shape=list(d.values.shape),
pos_start=pos_byte_off,
pos_end=pos_byte_off + len(b_chunk),
pos_width=width,
val_start=val_off,
val_end=val_off + nnz,
)
)
pos_byte_off += len(b_chunk)
val_off += nnz

values = torch.cat(val_pieces, dim=0)
return EncodedChunk(
pos_bytes=b"".join(pos_byte_pieces),
val_tensor=values,
params=params,
nnz=val_off,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Concatenating all parameter values in _sparse_boundaries via big_values = torch.cat(...) allocates a massive temporary tensor containing all parameter elements (even unchanged ones) for every chunk. This introduces significant GPU memory overhead and compute latency. By indexing into each d.values individually using the computed local_idx on the GPU, we can completely avoid allocating big_values and only allocate memory for the sparse set of changed values.

def _sparse_boundaries(
    diffs: list[ParamDiff],
) -> tuple[list[int], torch.Tensor, list[int]]:
    """One concat -> one nonzero -> one searchsorted -> one tolist().

    Collapses per-parameter host syncs to a single one per chunk.
    """
    device = diffs[0].values.device
    sizes = [d.values.numel() for d in diffs]
    cum = list(itertools.accumulate(sizes))
    cum_t = torch.tensor(cum, dtype=torch.int64, device=device)

    big_mask = torch.cat([d.mask.contiguous().view(-1) for d in diffs], dim=0)
    big_idx = big_mask.nonzero(as_tuple=False).view(-1)
    bounds = torch.searchsorted(big_idx, cum_t).tolist()
    return bounds, big_idx, cum


def _encode_indices(diffs: list[ParamDiff]) -> EncodedChunk:
    if not diffs:
        return EncodedChunk.empty()
    bounds, big_idx, cum = _sparse_boundaries(diffs)
    pos_pieces: list[torch.Tensor] = []
    val_pieces: list[torch.Tensor] = []
    params: list[DeltaParam] = []
    pos_byte_off = val_off = 0
    prev_b = 0
    prev_param_start = 0
    for i, d in enumerate(diffs):
        b = bounds[i]
        nnz = b - prev_b
        if nnz > 0:
            local_idx = big_idx[prev_b:b] - prev_param_start
            pos_pieces.append(local_idx.to(torch.int32))
            val_pieces.append(d.values.contiguous().view(-1)[local_idx])
            params.append(
                DeltaParam(
                    name=d.name,
                    dtype=str(d.values.dtype).replace("torch.", ""),
                    shape=list(d.values.shape),
                    pos_start=pos_byte_off,
                    pos_end=pos_byte_off + nnz * 4,
                    pos_width=4,
                    val_start=val_off,
                    val_end=val_off + nnz,
                )
            )
            pos_byte_off += nnz * 4
            val_off += nnz
        prev_b = b
        prev_param_start = cum[i]
    if not params:
        return EncodedChunk.empty()
    positions = torch.cat(pos_pieces, dim=0)
    values = torch.cat(val_pieces, dim=0)
    return EncodedChunk(
        pos_bytes=positions.cpu().numpy().tobytes(),
        val_tensor=values,
        params=params,
        nnz=val_off,
    )


def _encode_deltas(diffs: list[ParamDiff]) -> EncodedChunk:
    """Gap-encode sorted positions.

    Store ``idx[k] - idx[k-1] - 1`` with ``idx[-1] := -1`` so the first delta
    equals the first index. Each parameter downcasts to uint16 if the max gap
    fits, else uint32. At ~2% density on bf16 weights the typical max gap is
    ~300, so uint16 normally suffices; the uint32 fallback covers pathological
    inputs without correctness risk. The receiver inverts via
    ``idx = cumsum(delta + 1) - 1``.
    """
    if not diffs:
        return EncodedChunk.empty()
    bounds, big_idx, cum = _sparse_boundaries(diffs)

    kept: list[tuple[ParamDiff, int, torch.Tensor]] = []
    per_param_deltas: list[torch.Tensor] = []
    val_pieces: list[torch.Tensor] = []
    prev_b = 0
    prev_param_start = 0
    for i, d in enumerate(diffs):
        b = bounds[i]
        nnz = b - prev_b
        if nnz > 0:
            local_idx = big_idx[prev_b:b] - prev_param_start  # int64, sorted
            prev = torch.cat(
                [
                    torch.tensor(
                        [-1], dtype=local_idx.dtype, device=local_idx.device
                    ),
                    local_idx[:-1],
                ]
            )
            per_param_deltas.append(local_idx - prev - 1)
            val_pieces.append(d.values.contiguous().view(-1)[local_idx])
            kept.append((d, nnz, per_param_deltas[-1]))
        prev_b = b
        prev_param_start = cum[i]

    if not kept:
        return EncodedChunk.empty()

    max_per_param = (
        torch.stack([item[2].max() for item in kept]).cpu().tolist()
    )
    pos_byte_pieces: list[bytes] = []
    pos_byte_off = val_off = 0
    params: list[DeltaParam] = []
    for (d, nnz, deltas), max_d in zip(
        kept, max_per_param, strict=True
    ):
        width = 2 if int(max_d) <= 65535 else 4
        np_dtype = np.uint16 if width == 2 else np.uint32
        b_chunk = deltas.cpu().numpy().astype(np_dtype, copy=False).tobytes()
        pos_byte_pieces.append(b_chunk)
        params.append(
            DeltaParam(
                name=d.name,
                dtype=str(d.values.dtype).replace("torch.", ""),
                shape=list(d.values.shape),
                pos_start=pos_byte_off,
                pos_end=pos_byte_off + len(b_chunk),
                pos_width=width,
                val_start=val_off,
                val_end=val_off + nnz,
            )
        )
        pos_byte_off += len(b_chunk)
        val_off += nnz

    values = torch.cat(val_pieces, dim=0)
    return EncodedChunk(
        pos_bytes=b"".join(pos_byte_pieces),
        val_tensor=values,
        params=params,
        nnz=val_off,
    )

Comment on lines +127 to +130
from safetensors.torch import save as st_save_bytes # local import

metadata = {
"encoding": flush.encoding,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In a distributed environment, non-leader ranks might attempt to write their flush files to out_dir before the leader rank has finished creating it, leading to a FileNotFoundError. Adding os.makedirs(out_dir, exist_ok=True) inside write_flush_to_disk ensures that every rank safely creates the directory if it does not exist yet.

Suggested change
from safetensors.torch import save as st_save_bytes # local import
metadata = {
"encoding": flush.encoding,
from safetensors.torch import save as st_save_bytes # local import
os.makedirs(out_dir, exist_ok=True)
metadata = {
"encoding": flush.encoding,

@gxlvera

gxlvera commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

An e2e example is WIP

@ChangyiYang ChangyiYang force-pushed the feat/delta-weight-sync-sglang branch from 4b49c96 to d753768 Compare June 18, 2026 20:56
@ChangyiYang ChangyiYang changed the title [rollout][sglang] feat: delta weight sync (sparse trainer->rollout updates) [checkpoint_engine][sglang] feat: delta weight sync for disaggregated rollout Jun 24, 2026
ChangyiYang and others added 2 commits June 24, 2026 23:51
…saggregated rollout

Adds a "delta" checkpoint engine that puts only the changed weights on the
trainer->rollout wire, mirroring THUDM/slime's NCCL delta transport. Instead of
broadcasting every parameter each sync, the trainer byte-diffs against a
pinned-CPU snapshot and broadcasts only the changed (position, value) pairs over
the same ray.util.collective group the full-weight NCCLCheckpointEngine uses.

Design (follows verl's existing "A" topology: actor rank0 -> rollout workers):
- verl/workers/rollout/delta_sync/: framework-agnostic core -- DeltaState
  (pinned snapshot + bytewise diff, side-stream H2D/D2H pipelining), encode/
  decode (indices / gap-deltas, per-param manifest, checksum), and a wrapper
  that turns verl's (name, tensor) generator into bucketed DeltaFlush objects.
- DeltaCheckpointEngine(NCCLCheckpointEngine): send_weights diffs + broadcasts
  per-flush positions/values (master uses cupy buffers, expandable_segments
  safe); the rollout worker reconstructs full tensors from the delta into a
  local mirror and hands them to the standard server_adapter.update_weights.
  So the trainer->worker wire is sparse, while the worker->engine push is an
  ordinary full-tensor load -- no SGLang-side delta receiver required. First
  sync forces a full delta so a dummy-initialized rollout gets a correct base;
  a per-flush checksum is verified on receipt.

Enable with rollout.checkpoint_engine.backend=delta (+ engine_kwargs.delta.
encoding). Validated on a 4+4 single-node disaggregated one_step_off GRPO run;
unit tests cover encode/decode bit-identity.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
… workers

In disaggregated (one_step_off) training the detached actor worker that produces
weights has no local rollout handle, but ActorRolloutRefWorker.update_weights
falls through to the colocated resume/sleep weight-sync path which assumes one.
Return early when self.rollout is None (the disaggregated sync is driven by the
checkpoint engine instead), and tolerate the colocated-only LoRA attributes
being unset. Lets the naive full-sync baseline run in disaggregated mode.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ChangyiYang ChangyiYang force-pushed the feat/delta-weight-sync-sglang branch from 146dc0c to d642348 Compare June 24, 2026 23:52
ChangyiYang and others added 3 commits June 25, 2026 03:29
Adds test_delta_result_equals_full_sync: seeds the trainer snapshot and the
rollout mirror from the same W0, then over several steps diffs W_new, applies
only the changed positions onto the mirror (reproducing the rollout worker's
receive_weights mirror-combine), and asserts the mirror is byte-equal to W_new
-- i.e. the weights a rollout ends up with via delta == what the old full path
delivers. CPU-only, no GPU/NCCL/SGLang: the transport only moves bytes, so the
lossless guarantee lives entirely in encode/decode/combine.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ngine

The delta_sync core (DeltaState / encode / wrapper) was placed under
workers/rollout/ back when delta lived in the SGLang rollout ServerAdapter.
Its only consumer now is DeltaCheckpointEngine, so move it to
verl/checkpoint_engine/delta_sync/ (and the test to tests/checkpoint_engine/)
and switch to relative imports. Fixes the awkward checkpoint_engine ->
workers.rollout dependency direction; no behavior change.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The delta transport is NCCL-only; remove the leftover disk framing. Drops the
`deltas_zstd` encoding (it only existed to zstd-wrap the gap stream at
safetensors-write time on the disk path -- a no-op alias for `deltas` without
disk) and rewrites the docstrings that still referenced disk safetensors / a
DeltaSpec-style SGLang receiver. The receiver is now the rollout worker's local
decode-into-mirror; no behavior change for the `indices` / `deltas` encodings.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants