Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions slime/backends/megatron_utils/update_weight/_apply_result_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Lightweight helper for inspecting engine apply results from delta-weight sync.

Kept in its own module so it can be unit-tested without pulling in torch, ray,
or any other GPU/distributed dependency.
"""

from __future__ import annotations

import logging

logger = logging.getLogger(__name__)


def check_apply_results(results: list) -> None:
"""Log and raise on any failed delta-weight apply reported by a receiver engine.

The SGLang receiver wraps its apply logic in a try/except and always returns
``(success: bool, message: str)``. ``_finalize_sync`` previously called
``ray.get()`` and discarded these return values, so a failed apply was
silent: the sender snapshot advanced past what receivers actually hold, and
subsequent diffs were computed against a stale baseline (issue #2104).

This helper must be called immediately after ``ray.get()`` to surface
failures before the sync is declared complete.

Args:
results: the list returned by ``ray.get(object_refs)`` — one entry per
engine, expected to be ``(success, message)`` tuples. Entries that
are not 2-tuples are treated as successful (forward-compatible with
engines that do not yet return structured results).

Raises:
RuntimeError: if one or more engines reported ``success=False``.
"""
failures = [
(idx, result[1])
for idx, result in enumerate(results)
if isinstance(result, tuple) and len(result) == 2 and not result[0]
]
if not failures:
return
for idx, msg in failures:
logger.error("Engine[%d] failed to apply delta weights: %s", idx, msg)
raise RuntimeError(
f"Delta weight apply failed on {len(failures)}/{len(results)} engine(s). "
"The sender snapshot has advanced past what receivers actually hold; "
"subsequent diffs will be computed against a stale baseline. "
"See per-engine error messages logged above."
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from slime.utils.timer import Timer, timer

from ..sglang import DeltaEncoding, DeltaParam, DeltaSpec
from ._apply_result_check import check_apply_results
from .update_weight_from_distributed import UpdateWeightFromDistributed

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -808,7 +809,8 @@ def _finalize_sync(self) -> None:
# Futures unblocks the (commit-then-RPC) chain; ray.get waits for the
# receivers' apply to finish.
object_refs = [ref for fut in self._pending_publishes for ref in fut.result()]
ray.get(object_refs)
results = ray.get(object_refs)
check_apply_results(results)
self._pending_publishes.clear()
if not self._published_any:
# No delta files needed publishing this sync (e.g. all-zero diff).
Expand Down
69 changes: 69 additions & 0 deletions tests/test_delta_sync_check_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""CPU unit tests for ``check_apply_results`` (issue #2104).

``_finalize_sync`` previously called ``ray.get(object_refs)`` and discarded
the ``(success, msg)`` tuples returned by each SGLang receiver engine. A
failed delta-weight apply was therefore silent, leaving the sender snapshot
permanently ahead of the receiver.

``check_apply_results`` is the extracted helper that closes this gap.
These tests exercise it directly — no GPU, Ray, or distributed runtime needed.
"""

from __future__ import annotations

import importlib.util
from pathlib import Path
import pytest

# Load _apply_result_check.py directly so the test doesn't trigger the
# package-level __init__.py files that import torch / ray / megatron.
_spec = importlib.util.spec_from_file_location(
"_apply_result_check",
Path(__file__).parent.parent
/ "slime/backends/megatron_utils/update_weight/_apply_result_check.py",
)
_mod = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_mod)
_check_apply_results = _mod.check_apply_results


class TestCheckApplyResults:
"""Unit tests for the apply-result inspector added to fix issue #2104."""

def test_all_successful_results_do_not_raise(self):
"""When every engine reports success, no exception is raised."""
_check_apply_results([(True, ""), (True, "ok"), (True, "weight version 42")])

def test_empty_results_do_not_raise(self):
"""An empty result list (no engines) is also valid."""
_check_apply_results([])

def test_single_failed_engine_raises_runtime_error(self):
"""A single (False, msg) result raises RuntimeError."""
with pytest.raises(RuntimeError, match="1/2 engine"):
_check_apply_results([(True, ""), (False, "checksum mismatch")])

def test_error_message_contains_failed_engine_index(self, caplog):
"""The per-engine log line includes the engine index."""
import logging

with caplog.at_level(logging.ERROR), pytest.raises(RuntimeError):
_check_apply_results([(True, ""), (False, "io error on shard 3")])

assert "Engine[1]" in caplog.text
assert "io error on shard 3" in caplog.text

def test_all_failed_engines_raise_with_correct_count(self):
"""The RuntimeError message reflects the total failure count."""
with pytest.raises(RuntimeError, match="3/3"):
_check_apply_results(
[(False, "decode error"), (False, "nccl timeout"), (False, "oom")]
)

def test_non_tuple_results_treated_as_success(self):
"""
If an engine returns something other than a (bool, str) tuple (e.g. None
or a bare bool True from an older SGLang version), it is not treated as
a failure — the helper only flags explicit ``(False, msg)`` pairs.
"""
_check_apply_results([True, None, (True, "ok")])
Loading