diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index 83bbf04d87..de51cc6038 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -163,6 +163,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A ), f"Sample status is {sample.status}" prompt_ids = _prepare_prompt_ids(sample, state.tokenizer, state.processor) + # partial-rollout reuses the sample; subtract prior progress so total <= rollout_max_response_len + if args.partial_rollout and sample.response_length > 0: + sampling_params["max_new_tokens"] = max(0, args.rollout_max_response_len - sample.response_length) assert ( sampling_params["max_new_tokens"] >= 0 diff --git a/slime/rollout/sglang_streaming_rollout.py b/slime/rollout/sglang_streaming_rollout.py index 12471148c9..bde9318132 100644 --- a/slime/rollout/sglang_streaming_rollout.py +++ b/slime/rollout/sglang_streaming_rollout.py @@ -58,6 +58,9 @@ async def generate_streaming(args: Namespace, sample: Sample, sampling_params: d ), f"Sample status is {sample.status}" prompt_ids = _prepare_prompt_ids(sample, state.tokenizer, state.processor) + # partial-rollout reuses the sample; subtract prior progress so total <= rollout_max_response_len + if args.partial_rollout and sample.response_length > 0: + sampling_params["max_new_tokens"] = max(0, args.rollout_max_response_len - sample.response_length) assert ( sampling_params["max_new_tokens"] >= 0