Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,32 @@ def run_examples(

async def _post_subroutine(row: Dict) -> Tuple[Dict, Dict]:
async with semaphore:
res = await server_client.post(server_name=row["agent_ref"]["name"], url_path="/run", json=row)
try:
await raise_for_status(res)
except Exception:
if is_global_aiohttp_client_request_debug_enabled():
print(
"[rollout_collection] /run failed "
f"status={getattr(res, 'status', None)} "
f"row={json.dumps(_rollout_request_debug_summary(row), sort_keys=True)}",
flush=True,
)
raise
return row, await get_response_json(res)
# Retry a transient 5xx/connection failure a few times so one flaky
# /run (e.g. a momentarily overloaded code-exec server) only costs
# that rollout an attempt instead of aborting the whole batch. 4xx
# is deterministic, so re-raise it immediately.
attempts = 4
last_exc = None
for attempt in range(attempts):
res = await server_client.post(server_name=row["agent_ref"]["name"], url_path="/run", json=row)
try:
await raise_for_status(res)
except Exception as e:
last_exc = e
status = getattr(e, "status", None) or getattr(res, "status", None)
if is_global_aiohttp_client_request_debug_enabled():
print(
"[rollout_collection] /run failed "
f"status={status} attempt={attempt + 1}/{attempts} "
f"row={json.dumps(_rollout_request_debug_summary(row), sort_keys=True)}",
flush=True,
)
if isinstance(status, int) and 400 <= status < 500:
raise
await asyncio.sleep(1.0)
continue
return row, await get_response_json(res)
raise last_exc

return tqdm.as_completed(
map(_post_subroutine, examples),
Expand Down
9 changes: 9 additions & 0 deletions nemo_gym/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ async def raise_for_status(response: ClientResponse) -> None: # pragma: no cove
except ClientResponseError as e:
# Set the response content here so we have access to it down the line.
e.response_content = content
# request_info/history/headers are multidict.CIMultiDictProxy objects
# that don't pickle, which breaks Ray's cross-actor error propagation
# (rollout collection dies with "can't pickle CIMultiDictProxy" on any
# resource-server 5xx). Drop them so the error stays picklable; keep
# status/message/response_content.
e.request_info = None
e.history = ()
e.headers = None
e.args = (e.status, e.message)
raise e


Expand Down