Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions slime/backends/megatron_utils/server/megatron_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,12 @@ def _args_to_dict(args) -> dict[str, Any]:

def _build_http_app(sample_manager, args, update_from_disk_fn=None):
app = web.Application(client_max_size=64 * 1024 * 1024)
update_state = {"in_progress": False}
update_state = {
"in_progress": False,
"updating_model_path": None,
"update_future": None,
}
update_lock = asyncio.Lock()

async def detect(_request: web.Request) -> web.Response:
return web.json_response({"server_type": "megatron_server"})
Expand All @@ -443,37 +448,71 @@ async def update_from_disk(request: web.Request) -> web.Response:
model_path = _get_update_model_path(payload)
if model_path is None:
return _json_error("missing model_path", 400)
if update_state["in_progress"]:
return _json_error("update_from_disk is already in progress", 409)

async with update_lock:
if getattr(args, "load", None) == model_path:
return web.json_response({"ok": True, "model_path": model_path, "skipped": True})

if update_state["in_progress"]:
if update_state["updating_model_path"] == model_path and update_state["update_future"] is not None:
update_future = update_state["update_future"]
coalesced = True
else:
updating_model_path = update_state["updating_model_path"]
return _json_error(f"update_from_disk is already in progress for {updating_model_path}", 409)
else:
update_future = asyncio.get_running_loop().create_future()
update_state["in_progress"] = True
update_state["updating_model_path"] = model_path
update_state["update_future"] = update_future
coalesced = False

if coalesced:
result = await asyncio.shield(update_future)
if result.get("ok") is True:
result = dict(result)
result["coalesced"] = True
return web.json_response(result)
return _json_error(result.get("error", "update_from_disk failed"), int(result.get("status", 500)))

timeout_s = _get_update_timeout_s(payload, args)
update_state["in_progress"] = True
result = None
error = None
try:
before_loads = await _wait_until_idle(sample_manager, timeout_s)
update_result = await asyncio.to_thread(update_from_disk_fn, model_path)
after_loads = await _ray_get(sample_manager.get_loads.remote())
except TimeoutError as e:
return _json_error(str(e), 503)
error = {"ok": False, "status": 503, "error": str(e)}
except Exception as e:
return _json_error(f"update_from_disk failed: {e}", 500)
error = {"ok": False, "status": 500, "error": f"update_from_disk failed: {e}"}
finally:
update_state["in_progress"] = False

# Reflect the freshly loaded checkpoint in /info. The actors restore
# their own args after loading, so only the server-side copy needs to be
# kept in sync here.
args.load = model_path
args.ref_load = model_path

return web.json_response(
{
"ok": True,
"model_path": model_path,
"before_loads": before_loads,
"after_loads": after_loads,
"update_result": update_result,
}
)
if error is None:
result = {
"ok": True,
"model_path": model_path,
"before_loads": before_loads,
"after_loads": after_loads,
"update_result": update_result,
}
# Reflect the freshly loaded checkpoint in /info. The actors restore
# their own args after loading, so only the server-side copy needs to be
# kept in sync here.
args.load = model_path
args.ref_load = model_path

async with update_lock:
if result is not None:
update_future.set_result(result)
else:
update_future.set_result(error)
update_state["in_progress"] = False
update_state["updating_model_path"] = None
update_state["update_future"] = None

if result is not None:
return web.json_response(result)
return _json_error(error["error"], error["status"])

async def generate(request: web.Request) -> web.Response:
if update_state["in_progress"]:
Expand Down
Loading