From f2fc68110aa4e5c3c4991d40fe770c842f86d686 Mon Sep 17 00:00:00 2001 From: none0663 Date: Tue, 23 Jun 2026 12:10:47 +0800 Subject: [PATCH] fix partial-rollout: cap max_new_tokens by prior response length --- slime/rollout/sglang_rollout.py | 3 +++ slime/rollout/sglang_streaming_rollout.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 83bbf04d87..de51cc6038 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -163,6 +163,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A ), f"Sample status is {sample.status}" prompt_ids = _prepare_prompt_ids(sample, state.tokenizer, state.processor) + # partial-rollout reuses the sample; subtract prior progress so total <= rollout_max_response_len + if args.partial_rollout and sample.response_length > 0: + sampling_params["max_new_tokens"] = max(0, args.rollout_max_response_len - sample.response_length) assert ( sampling_params["max_new_tokens"] >= 0 diff --git a/slime/rollout/sglang_streaming_rollout.py b/slime/rollout/sglang_streaming_rollout.py index 12471148c9..bde9318132 100644 --- a/slime/rollout/sglang_streaming_rollout.py +++ b/slime/rollout/sglang_streaming_rollout.py @@ -58,6 +58,9 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d ), f"Sample status is {sample.status}" prompt_ids = _prepare_prompt_ids(sample, state.tokenizer, state.processor) + # partial-rollout reuses the sample; subtract prior progress so total <= rollout_max_response_len + if args.partial_rollout and sample.response_length > 0: + sampling_params["max_new_tokens"] = max(0, args.rollout_max_response_len - sample.response_length) assert ( sampling_params["max_new_tokens"] >= 0