diff --git a/slime/backends/megatron_utils/update_weight/_apply_result_check.py b/slime/backends/megatron_utils/update_weight/_apply_result_check.py new file mode 100644 index 0000000000..5c1da748c5 --- /dev/null +++ b/slime/backends/megatron_utils/update_weight/_apply_result_check.py @@ -0,0 +1,49 @@ +"""Lightweight helper for inspecting engine apply results from delta-weight sync. + +Kept in its own module so it can be unit-tested without pulling in torch, ray, +or any other GPU/distributed dependency. +""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + + +def check_apply_results(results: list) -> None: + """Log and raise on any failed delta-weight apply reported by a receiver engine. + + The SGLang receiver wraps its apply logic in a try/except and always returns + ``(success: bool, message: str)``. ``_finalize_sync`` previously called + ``ray.get()`` and discarded these return values, so a failed apply was + silent: the sender snapshot advanced past what receivers actually hold, and + subsequent diffs were computed against a stale baseline (issue #2104). + + This helper must be called immediately after ``ray.get()`` to surface + failures before the sync is declared complete. + + Args: + results: the list returned by ``ray.get(object_refs)`` — one entry per + engine, expected to be ``(success, message)`` tuples. Entries that + are not 2-tuples are treated as successful (forward-compatible with + engines that do not yet return structured results). + + Raises: + RuntimeError: if one or more engines reported ``success=False``. + """ + failures = [ + (idx, result[1]) + for idx, result in enumerate(results) + if isinstance(result, tuple) and len(result) == 2 and not result[0] + ] + if not failures: + return + for idx, msg in failures: + logger.error("Engine[%d] failed to apply delta weights: %s", idx, msg) + raise RuntimeError( + f"Delta weight apply failed on {len(failures)}/{len(results)} engine(s). " + "The sender snapshot has advanced past what receivers actually hold; " + "subsequent diffs will be computed against a stale baseline. " + "See per-engine error messages logged above." + ) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py index fbe24bbc1c..ac6ad37579 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py @@ -52,6 +52,7 @@ from slime.utils.timer import Timer, timer from ..sglang import DeltaEncoding, DeltaParam, DeltaSpec +from ._apply_result_check import check_apply_results from .update_weight_from_distributed import UpdateWeightFromDistributed logger = logging.getLogger(__name__) @@ -808,7 +809,8 @@ def _finalize_sync(self) -> None: # Futures unblocks the (commit-then-RPC) chain; ray.get waits for the # receivers' apply to finish. object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] - ray.get(object_refs) + results = ray.get(object_refs) + check_apply_results(results) self._pending_publishes.clear() if not self._published_any: # No delta files needed publishing this sync (e.g. all-zero diff). diff --git a/tests/test_delta_sync_check_apply.py b/tests/test_delta_sync_check_apply.py new file mode 100644 index 0000000000..b399fc0b09 --- /dev/null +++ b/tests/test_delta_sync_check_apply.py @@ -0,0 +1,69 @@ +"""CPU unit tests for ``check_apply_results`` (issue #2104). + +``_finalize_sync`` previously called ``ray.get(object_refs)`` and discarded +the ``(success, msg)`` tuples returned by each SGLang receiver engine. A +failed delta-weight apply was therefore silent, leaving the sender snapshot +permanently ahead of the receiver. + +``check_apply_results`` is the extracted helper that closes this gap. +These tests exercise it directly — no GPU, Ray, or distributed runtime needed. +""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path +import pytest + +# Load _apply_result_check.py directly so the test doesn't trigger the +# package-level __init__.py files that import torch / ray / megatron. +_spec = importlib.util.spec_from_file_location( + "_apply_result_check", + Path(__file__).parent.parent + / "slime/backends/megatron_utils/update_weight/_apply_result_check.py", +) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +_check_apply_results = _mod.check_apply_results + + +class TestCheckApplyResults: + """Unit tests for the apply-result inspector added to fix issue #2104.""" + + def test_all_successful_results_do_not_raise(self): + """When every engine reports success, no exception is raised.""" + _check_apply_results([(True, ""), (True, "ok"), (True, "weight version 42")]) + + def test_empty_results_do_not_raise(self): + """An empty result list (no engines) is also valid.""" + _check_apply_results([]) + + def test_single_failed_engine_raises_runtime_error(self): + """A single (False, msg) result raises RuntimeError.""" + with pytest.raises(RuntimeError, match="1/2 engine"): + _check_apply_results([(True, ""), (False, "checksum mismatch")]) + + def test_error_message_contains_failed_engine_index(self, caplog): + """The per-engine log line includes the engine index.""" + import logging + + with caplog.at_level(logging.ERROR), pytest.raises(RuntimeError): + _check_apply_results([(True, ""), (False, "io error on shard 3")]) + + assert "Engine[1]" in caplog.text + assert "io error on shard 3" in caplog.text + + def test_all_failed_engines_raise_with_correct_count(self): + """The RuntimeError message reflects the total failure count.""" + with pytest.raises(RuntimeError, match="3/3"): + _check_apply_results( + [(False, "decode error"), (False, "nccl timeout"), (False, "oom")] + ) + + def test_non_tuple_results_treated_as_success(self): + """ + If an engine returns something other than a (bool, str) tuple (e.g. None + or a bare bool True from an older SGLang version), it is not treated as + a failure — the helper only flags explicit ``(False, msg)`` pairs. + """ + _check_apply_results([True, None, (True, "ok")])