diff --git a/examples/search-r1/README.md b/examples/search-r1/README.md
index 973030e9eb..c9f87337ab 100644
--- a/examples/search-r1/README.md
+++ b/examples/search-r1/README.md
@@ -176,6 +176,17 @@ CUSTOM_ARGS=(
These are the `generate` and `reward_func` functions in `generate_with_search.py`.
+### Partial Rollout
+
+Search-R1 also supports slime's partial rollout path for long-tail search trajectories:
+
+```bash
+--partial-rollout
+--mask-offpolicy-in-partial-rollout
+```
+
+When a rollout round aborts unfinished requests, Search-R1 keeps the generated response, search observations, loss mask, and rollout log probabilities on the `Sample`. The next rollout resumes from that context instead of starting the prompt over. With `--mask-offpolicy-in-partial-rollout`, tokens generated before the latest weight update are kept as context but masked out of training.
+
## Appendix: Setting up Local Retriever
This section provides detailed instructions for setting up the local dense retriever for use with the local search backend.
diff --git a/examples/search-r1/README_zh.md b/examples/search-r1/README_zh.md
index f7d9fbb0d9..891c0bf504 100644
--- a/examples/search-r1/README_zh.md
+++ b/examples/search-r1/README_zh.md
@@ -176,6 +176,17 @@ CUSTOM_ARGS=(
也就是 `generate_with_search.py` 中的 `generate` 和 `reward_func` 两个函数。
+### Partial Rollout
+
+Search-R1 也支持 slime 的 partial rollout 路径,适合耗时较长、容易出现长尾的搜索轨迹:
+
+```bash
+--partial-rollout
+--mask-offpolicy-in-partial-rollout
+```
+
+当某轮 rollout 中还有未完成请求被 abort 时,Search-R1 会把已经生成的 response、搜索 observation、loss mask 和 rollout log probabilities 保存在 `Sample` 上。下一轮 rollout 会从已有上下文继续生成,而不是从原始 prompt 重新开始。开启 `--mask-offpolicy-in-partial-rollout` 后,旧权重生成的 token 会继续作为上下文保留,但不会参与训练。
+
## 附录:配置本地检索器
本节提供详细的本地密集检索器设置说明,用于本地搜索后端。
diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py
index ae0f18dd15..aa36297f6c 100644
--- a/examples/search-r1/generate_with_search.py
+++ b/examples/search-r1/generate_with_search.py
@@ -39,6 +39,7 @@
SEMAPHORE = asyncio.Semaphore(SEARCH_R1_CONFIGS["search_concurrency"])
+_SEARCH_R1_TURN_COUNT_KEY = "search_r1_completed_turns"
def _passages2string(retrieval_result):
@@ -121,6 +122,11 @@ def postprocess_predictions(prediction: str):
return action, content
+def _last_prediction_action(prediction: str) -> str | None:
+ matches = re.findall(r"<(search|answer)>.*?\1>", prediction, re.DOTALL)
+ return matches[-1] if matches else None
+
+
async def execute_predictions(prediction: str) -> str:
action, content = postprocess_predictions(prediction)
@@ -142,24 +148,73 @@ async def execute_predictions(prediction: str) -> str:
return next_obs, done
-async def generate(args, sample: Sample, sampling_params) -> Sample:
- assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment."
+def _count_completed_model_turns(response: str) -> int:
+ return len(re.findall(r"(?:search|answer)>", response))
+
+
+def _get_completed_model_turns(sample: Sample) -> int:
+ turn_count = sample.metadata.get(_SEARCH_R1_TURN_COUNT_KEY)
+ if isinstance(turn_count, int):
+ return turn_count
+ return _count_completed_model_turns(sample.response or "")
+
+
+def _mark_model_turn_completed(sample: Sample, completed_turns: int) -> None:
+ sample.metadata[_SEARCH_R1_TURN_COUNT_KEY] = completed_turns
+
+
+def _append_trainable_response_tokens(
+ args,
+ sample: Sample,
+ *,
+ tokens: list[int],
+ log_probs: list[float] | None,
+ meta_info: dict,
+ text: str,
+) -> None:
+ if log_probs is not None:
+ sample.append_response_tokens(
+ args,
+ tokens=tokens,
+ log_probs=log_probs,
+ trainable=True,
+ meta_info=meta_info,
+ text=text,
+ )
+ return
+
+ sample.response += text
+ sample.tokens += tokens
+ sample.response_length += len(tokens)
+ if sample.loss_mask is None:
+ sample.loss_mask = []
+ sample.loss_mask += [1] * len(tokens)
+ match meta_info["finish_reason"]["type"]:
+ case "length":
+ sample.status = Sample.Status.TRUNCATED
+ case "abort":
+ sample.status = Sample.Status.ABORTED
+ case "stop":
+ sample.status = Sample.Status.COMPLETED
+
+async def generate(args, sample: Sample, sampling_params) -> Sample:
state = GenerateState(args)
url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
- # Handle partial rollout samples: continue generation from existing response
prompt_text = sample.prompt
prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
- sample.tokens = list(prompt_tokens_ids)
- sample.loss_mask = []
- response = ""
- response_token_ids = []
- loss_mask = []
- rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None
- sample.rollout_top_p_token_ids = None
- sample.rollout_top_p_token_offsets = None
+ if not sample.tokens:
+ sample.tokens = list(prompt_tokens_ids)
+ if sample.loss_mask is None:
+ sample.loss_mask = [1] * sample.response_length
+ elif args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0:
+ sample.loss_mask = [0] * sample.response_length
+ if SEARCH_R1_CONFIGS["return_logprob"] and sample.rollout_log_probs is None and sample.response_length > 0:
+ sample.rollout_log_probs = [0.0] * sample.response_length
+
+ response = sample.response or ""
# BUGFIX: make the inference engine STOP at the tool/answer boundary.
# Without a stop, sglang keeps emitting tokens after /
@@ -176,7 +231,17 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
_existing_stop = [_existing_stop]
sampling_params = {**sampling_params, "stop": list(dict.fromkeys([*_existing_stop, *_stop_tags]))}
- for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
+ output = None
+ completed_turns = _get_completed_model_turns(sample) if args.partial_rollout else 0
+ if args.partial_rollout and _last_prediction_action(response) == "answer":
+ sample.status = Sample.Status.COMPLETED
+ return sample
+
+ for _turn_idx in range(completed_turns, SEARCH_R1_CONFIGS["max_turns"]):
+ if state.aborted:
+ sample.status = Sample.Status.ABORTED
+ return sample
+
payload = {
"text": prompt_text + response,
"sampling_params": sampling_params,
@@ -187,14 +252,14 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
output = await post(url, payload)
- # abort
- if output["meta_info"]["finish_reason"]["type"] == "abort":
+ cur_response = output["text"]
+ finish_reason = output["meta_info"]["finish_reason"]["type"]
+ if finish_reason == "abort":
sample.status = Sample.Status.ABORTED
return sample
- cur_response = output["text"]
-
# Extract tokens and log probs based on configuration
+ cur_response_log_probs = None
if SEARCH_R1_CONFIGS["return_logprob"]:
# Extract log probs from output - required for TIS metrics
if "output_token_logprobs" not in output["meta_info"]:
@@ -215,23 +280,21 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]
response += cur_response
- response_token_ids += cur_response_token_ids
- loss_mask += [1] * len(cur_response_token_ids)
- sample.append_response_tokens(
+ _append_trainable_response_tokens(
args,
+ sample,
tokens=cur_response_token_ids,
log_probs=cur_response_log_probs if SEARCH_R1_CONFIGS["return_logprob"] else None,
- trainable=True,
- meta_info=output["meta_info"] if "output_token_logprobs" in output["meta_info"] else None,
+ meta_info=output["meta_info"],
+ text=cur_response,
)
- # Add log probs if enabled
- if SEARCH_R1_CONFIGS["return_logprob"]:
- rollout_log_probs += cur_response_log_probs
-
- if output["meta_info"]["finish_reason"]["type"] == "length":
+ if finish_reason == "length":
break
+ if finish_reason == "stop":
+ _mark_model_turn_completed(sample, _turn_idx + 1)
+
next_obs, done = await execute_predictions(cur_response)
if done:
break
@@ -239,29 +302,27 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
assert next_obs != "", "Next observation should not be empty."
obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]
response += next_obs
- response_token_ids += obs_tokens_ids
- loss_mask += [0] * len(obs_tokens_ids)
- sample.append_response_tokens(args, tokens=obs_tokens_ids, trainable=False)
+ sample.append_response_tokens(args, tokens=obs_tokens_ids, trainable=False, text=next_obs)
- # Add dummy log probs for observation tokens if enabled (they won't be used due to loss_mask=0)
+ # Verify alignment when collecting log probs. Observation tokens receive dummy
+ # log probs inside append_response_tokens because loss_mask marks them non-trainable.
if SEARCH_R1_CONFIGS["return_logprob"]:
- rollout_log_probs += [0.0] * len(obs_tokens_ids)
-
- # Verify alignment when collecting log probs
- assert len(response_token_ids) == len(
- rollout_log_probs
- ), f"Token/logp length mismatch: {len(response_token_ids)} tokens vs {len(rollout_log_probs)} logps"
-
- # Store statistics for wandb logging
- sample.tokens = prompt_tokens_ids + response_token_ids
- sample.response_length = len(response_token_ids)
- sample.response = response
- sample.loss_mask = loss_mask
- sample.prompt = prompt_text
+ assert sample.rollout_log_probs is not None
+ assert len(sample.rollout_log_probs) == sample.response_length, (
+ f"Token/logp length mismatch: {sample.response_length} tokens vs "
+ f"{len(sample.rollout_log_probs)} logps"
+ )
- # Store log probs if enabled
- if SEARCH_R1_CONFIGS["return_logprob"]:
- sample.rollout_log_probs = rollout_log_probs if rollout_log_probs else None
+ sample.prompt = prompt_text
+ if output is None:
+ action = _last_prediction_action(response)
+ sample.status = Sample.Status.COMPLETED if action == "answer" else Sample.Status.TRUNCATED
+ return sample
+
+ action = _last_prediction_action(response)
+ if action != "answer" and output["meta_info"]["finish_reason"]["type"] == "stop":
+ sample.status = Sample.Status.TRUNCATED
+ return sample
match output["meta_info"]["finish_reason"]["type"]:
case "length":
diff --git a/tests/test_search_r1_partial_rollout.py b/tests/test_search_r1_partial_rollout.py
new file mode 100644
index 0000000000..01b8f41835
--- /dev/null
+++ b/tests/test_search_r1_partial_rollout.py
@@ -0,0 +1,249 @@
+from __future__ import annotations
+
+import asyncio
+import sys
+import types
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+REPO_ROOT = Path(__file__).resolve().parents[1]
+SEARCH_R1_DIR = REPO_ROOT / "examples" / "search-r1"
+for path in (REPO_ROOT, SEARCH_R1_DIR):
+ if str(path) not in sys.path:
+ sys.path.insert(0, str(path))
+
+if "sglang_router" not in sys.modules:
+ sglang_router_stub = types.ModuleType("sglang_router")
+ sglang_router_stub.__version__ = "0.2.3"
+ sys.modules["sglang_router"] = sglang_router_stub
+
+if "ray" not in sys.modules:
+ ray_stub = types.ModuleType("ray")
+ ray_stub._private = types.SimpleNamespace(services=types.SimpleNamespace(get_node_ip_address=lambda: "127.0.0.1"))
+ sys.modules["ray"] = ray_stub
+
+if "transformers" not in sys.modules:
+ transformers_stub = types.ModuleType("transformers")
+ for name in ("AutoProcessor", "AutoTokenizer", "PreTrainedTokenizerBase", "ProcessorMixin"):
+ setattr(transformers_stub, name, type(name, (), {}))
+ sys.modules["transformers"] = transformers_stub
+
+sglang_rollout_stub = types.ModuleType("slime.rollout.sglang_rollout")
+
+
+class _StubGenerateState:
+ pass
+
+
+sglang_rollout_stub.GenerateState = _StubGenerateState
+previous_sglang_rollout = sys.modules.get("slime.rollout.sglang_rollout")
+sys.modules["slime.rollout.sglang_rollout"] = sglang_rollout_stub
+
+import generate_with_search as search_gen # noqa: E402
+from slime.utils.types import Sample # noqa: E402
+
+if previous_sglang_rollout is None:
+ sys.modules.pop("slime.rollout.sglang_rollout", None)
+else:
+ sys.modules["slime.rollout.sglang_rollout"] = previous_sglang_rollout
+
+NUM_GPUS = 0
+
+
+class FakeTokenizer:
+ def __call__(self, text: str, add_special_tokens: bool = False):
+ return {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)}
+
+ def encode(self, text: str, add_special_tokens: bool = False):
+ return [ord(ch) for ch in text]
+
+
+class FakeGenerateState:
+ def __init__(self, args) -> None:
+ self.args = args
+ self.tokenizer = FakeTokenizer()
+ self.aborted = False
+
+
+def _args(*, partial_rollout: bool = True, mask_offpolicy_in_partial_rollout: bool = False):
+ return SimpleNamespace(
+ partial_rollout=partial_rollout,
+ mask_offpolicy_in_partial_rollout=mask_offpolicy_in_partial_rollout,
+ sglang_router_ip="127.0.0.1",
+ sglang_router_port=1234,
+ sglang_speculative_algorithm=False,
+ )
+
+
+def _output(text: str, token_base: int, finish_reason: str = "stop"):
+ return {
+ "text": text,
+ "meta_info": {
+ "finish_reason": {"type": finish_reason},
+ "output_token_logprobs": [[-0.1 * (i + 1), token_base + i] for i in range(len(text))],
+ },
+ }
+
+
+@pytest.fixture(autouse=True)
+def patch_generate_state(monkeypatch):
+ monkeypatch.setattr(search_gen, "GenerateState", FakeGenerateState)
+ monkeypatch.setitem(search_gen.SEARCH_R1_CONFIGS, "return_logprob", True)
+ monkeypatch.setitem(search_gen.SEARCH_R1_CONFIGS, "max_turns", 2)
+ monkeypatch.setattr(search_gen, "execute_predictions", _fake_execute_predictions)
+
+
+async def _fake_execute_predictions(prediction: str):
+ action, content = search_gen.postprocess_predictions(prediction)
+ if action == "search":
+ return f"\n\nresult for {content}\n\n", False
+ if action == "answer":
+ return "", True
+ return "\ninvalid action\n", False
+
+
+def test_search_r1_fresh_rollout_records_turn_state_and_keeps_logprob_alignment(monkeypatch):
+ calls = []
+ outputs = [
+ _output("cats", 100),
+ _output("cats", 200),
+ ]
+
+ async def fake_post(_url, payload):
+ calls.append(payload)
+ return outputs.pop(0)
+
+ monkeypatch.setattr(search_gen, "post", fake_post)
+ sample = Sample(prompt="Question: cats\n", status=Sample.Status.PENDING)
+
+ result = asyncio.run(search_gen.generate(_args(), sample, {"stop": ""}))
+
+ assert result.status is Sample.Status.COMPLETED
+ assert result.metadata[search_gen._SEARCH_R1_TURN_COUNT_KEY] == 2
+ assert result.response == (
+ "cats\n\n" "result for cats\n\n" "cats"
+ )
+ assert len(result.loss_mask) == result.response_length
+ assert len(result.rollout_log_probs) == result.response_length
+ assert sum(result.loss_mask) == len("cats") + len("cats")
+ assert calls[0]["text"] == "Question: cats\n"
+ assert (
+ calls[1]["text"] == "Question: cats\ncats\n\nresult for cats\n\n"
+ )
+ assert calls[0]["sampling_params"]["stop"] == ["", "", ""]
+
+
+def test_search_r1_partial_rollout_resumes_without_clearing_existing_trajectory(monkeypatch):
+ old_response = "cats\n\nresult for cats\n\n"
+ old_prompt = "Question: cats\n"
+ old_prompt_tokens = FakeTokenizer().encode(old_prompt)
+ old_response_tokens = [10] * len(old_response)
+ old_log_probs = [-0.7] * len(old_response)
+ calls = []
+
+ async def fake_post(_url, payload):
+ calls.append(payload)
+ return _output("cats", 300)
+
+ monkeypatch.setattr(search_gen, "post", fake_post)
+ sample = Sample(
+ prompt=old_prompt,
+ tokens=old_prompt_tokens + old_response_tokens,
+ response=old_response,
+ response_length=len(old_response),
+ loss_mask=[0] * len(old_response),
+ rollout_log_probs=list(old_log_probs),
+ status=Sample.Status.ABORTED,
+ metadata={search_gen._SEARCH_R1_TURN_COUNT_KEY: 1},
+ )
+
+ result = asyncio.run(search_gen.generate(_args(mask_offpolicy_in_partial_rollout=True), sample, {}))
+
+ assert result.status is Sample.Status.COMPLETED
+ assert result.response == old_response + "cats"
+ assert (
+ result.tokens[: len(old_prompt_tokens) + len(old_response_tokens)] == old_prompt_tokens + old_response_tokens
+ )
+ assert result.rollout_log_probs[: len(old_log_probs)] == old_log_probs
+ assert result.loss_mask[: len(old_response)] == [0] * len(old_response)
+ assert result.loss_mask[len(old_response) :] == [1] * len("cats")
+ assert len(result.rollout_log_probs) == result.response_length
+ assert result.metadata[search_gen._SEARCH_R1_TURN_COUNT_KEY] == 2
+ assert calls == [
+ {
+ "text": old_prompt + old_response,
+ "sampling_params": {"stop": ["", ""]},
+ "return_logprob": True,
+ }
+ ]
+
+
+def test_search_r1_partial_rollout_falls_back_to_response_tags_for_older_samples(monkeypatch):
+ old_response = "cats\n\nresult for cats\n\n"
+ calls = []
+
+ async def fake_post(_url, payload):
+ calls.append(payload)
+ return _output("cats", 400)
+
+ monkeypatch.setattr(search_gen, "post", fake_post)
+ sample = Sample(
+ prompt="Question: cats\n",
+ response=old_response,
+ response_length=len(old_response),
+ tokens=FakeTokenizer().encode("Question: cats\n") + [1] * len(old_response),
+ loss_mask=[0] * len(old_response),
+ rollout_log_probs=[0.0] * len(old_response),
+ status=Sample.Status.ABORTED,
+ )
+
+ result = asyncio.run(search_gen.generate(_args(), sample, {}))
+
+ assert result.status is Sample.Status.COMPLETED
+ assert len(calls) == 1
+ assert result.metadata[search_gen._SEARCH_R1_TURN_COUNT_KEY] == 2
+
+
+def test_search_r1_partial_rollout_with_completed_answer_does_not_generate(monkeypatch):
+ async def fake_post(_url, _payload):
+ raise AssertionError("completed partial answer should not call SGLang")
+
+ monkeypatch.setattr(search_gen, "post", fake_post)
+ sample = Sample(
+ prompt="Question: cats\n",
+ response="cats",
+ response_length=len("cats"),
+ tokens=FakeTokenizer().encode("Question: cats\n") + [1] * len("cats"),
+ loss_mask=[1] * len("cats"),
+ rollout_log_probs=[-0.1] * len("cats"),
+ status=Sample.Status.ABORTED,
+ metadata={search_gen._SEARCH_R1_TURN_COUNT_KEY: 1},
+ )
+
+ result = asyncio.run(search_gen.generate(_args(), sample, {}))
+
+ assert result.status is Sample.Status.COMPLETED
+ assert result.response == "cats"
+
+
+def test_search_r1_partial_rollout_supports_non_logprob_mode(monkeypatch):
+ monkeypatch.setitem(search_gen.SEARCH_R1_CONFIGS, "return_logprob", False)
+
+ async def fake_post(_url, _payload):
+ return {
+ "text": "cats trailing junk",
+ "meta_info": {"finish_reason": {"type": "stop"}},
+ }
+
+ monkeypatch.setattr(search_gen, "post", fake_post)
+ sample = Sample(prompt="Question: cats\n", status=Sample.Status.PENDING)
+
+ result = asyncio.run(search_gen.generate(_args(), sample, {}))
+
+ assert result.status is Sample.Status.COMPLETED
+ assert result.response == "cats"
+ assert result.rollout_log_probs is None
+ assert len(result.loss_mask) == result.response_length
+ assert result.loss_mask == [1] * len("cats")