From 4646352f61a188f150aabed985badf285de00061 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Sun, 28 Jun 2026 06:58:52 +0000 Subject: [PATCH] Add update_from_disk to megatron server --- .../megatron_utils/server/megatron_server.py | 85 ++++++++++++++----- 1 file changed, 62 insertions(+), 23 deletions(-) diff --git a/slime/backends/megatron_utils/server/megatron_server.py b/slime/backends/megatron_utils/server/megatron_server.py index d3224b828c..ef9abbeef6 100644 --- a/slime/backends/megatron_utils/server/megatron_server.py +++ b/slime/backends/megatron_utils/server/megatron_server.py @@ -421,7 +421,12 @@ def _args_to_dict(args) -> dict[str, Any]: def _build_http_app(sample_manager, args, update_from_disk_fn=None): app = web.Application(client_max_size=64 * 1024 * 1024) - update_state = {"in_progress": False} + update_state = { + "in_progress": False, + "updating_model_path": None, + "update_future": None, + } + update_lock = asyncio.Lock() async def detect(_request: web.Request) -> web.Response: return web.json_response({"server_type": "megatron_server"}) @@ -443,37 +448,71 @@ async def update_from_disk(request: web.Request) -> web.Response: model_path = _get_update_model_path(payload) if model_path is None: return _json_error("missing model_path", 400) - if update_state["in_progress"]: - return _json_error("update_from_disk is already in progress", 409) + + async with update_lock: + if getattr(args, "load", None) == model_path: + return web.json_response({"ok": True, "model_path": model_path, "skipped": True}) + + if update_state["in_progress"]: + if update_state["updating_model_path"] == model_path and update_state["update_future"] is not None: + update_future = update_state["update_future"] + coalesced = True + else: + updating_model_path = update_state["updating_model_path"] + return _json_error(f"update_from_disk is already in progress for {updating_model_path}", 409) + else: + update_future = asyncio.get_running_loop().create_future() + update_state["in_progress"] = True + update_state["updating_model_path"] = model_path + update_state["update_future"] = update_future + coalesced = False + + if coalesced: + result = await asyncio.shield(update_future) + if result.get("ok") is True: + result = dict(result) + result["coalesced"] = True + return web.json_response(result) + return _json_error(result.get("error", "update_from_disk failed"), int(result.get("status", 500))) timeout_s = _get_update_timeout_s(payload, args) - update_state["in_progress"] = True + result = None + error = None try: before_loads = await _wait_until_idle(sample_manager, timeout_s) update_result = await asyncio.to_thread(update_from_disk_fn, model_path) after_loads = await _ray_get(sample_manager.get_loads.remote()) except TimeoutError as e: - return _json_error(str(e), 503) + error = {"ok": False, "status": 503, "error": str(e)} except Exception as e: - return _json_error(f"update_from_disk failed: {e}", 500) + error = {"ok": False, "status": 500, "error": f"update_from_disk failed: {e}"} finally: - update_state["in_progress"] = False - - # Reflect the freshly loaded checkpoint in /info. The actors restore - # their own args after loading, so only the server-side copy needs to be - # kept in sync here. - args.load = model_path - args.ref_load = model_path - - return web.json_response( - { - "ok": True, - "model_path": model_path, - "before_loads": before_loads, - "after_loads": after_loads, - "update_result": update_result, - } - ) + if error is None: + result = { + "ok": True, + "model_path": model_path, + "before_loads": before_loads, + "after_loads": after_loads, + "update_result": update_result, + } + # Reflect the freshly loaded checkpoint in /info. The actors restore + # their own args after loading, so only the server-side copy needs to be + # kept in sync here. + args.load = model_path + args.ref_load = model_path + + async with update_lock: + if result is not None: + update_future.set_result(result) + else: + update_future.set_result(error) + update_state["in_progress"] = False + update_state["updating_model_path"] = None + update_state["update_future"] = None + + if result is not None: + return web.json_response(result) + return _json_error(error["error"], error["status"]) async def generate(request: web.Request) -> web.Response: if update_state["in_progress"]: