diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index bb87360639..c11ac2a9fc 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -326,17 +326,34 @@ async def generate_and_rm_group( if sample.session_id is None: sample.session_id = str(uuid.uuid4()) - tasks = [] + pairs = [] for idx, sample in enumerate(group): current_sampling_params = sampling_params.copy() if getattr(args, "sglang_enable_deterministic_inference", False): seed = state.group_sampling_seeds[idx] current_sampling_params["sampling_seed"] = seed - tasks.append( - asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation)) + pairs.append( + (sample, asyncio.create_task(generate_and_rm(args, sample, current_sampling_params, evaluation=evaluation))) ) - group = await asyncio.gather(*tasks) + results = await asyncio.gather(*[t for _, t in pairs], return_exceptions=True) + group = [] + for sample, res in zip([s for s, _ in pairs], results): + if isinstance(res, BaseException): + logger.error( + "[generate_and_rm_group] trajectory crashed, isolating idx=%s: %r", + getattr(sample, "index", "?"), res, exc_info=res, + ) + sample.tokens = [0, 0] + sample.response = "" + sample.response_length = 1 + sample.loss_mask = [0] + sample.rollout_log_probs = [0.0] + sample.reward = 0.0 + sample.status = Sample.Status.ABORTED + group.append([sample]) + else: + group.append(res) # for the rm that need the whole group, we will do the rm here if not state.aborted and args.group_rm: