diff --git a/examples/search-r1/README.md b/examples/search-r1/README.md index 973030e9eb..c9f87337ab 100644 --- a/examples/search-r1/README.md +++ b/examples/search-r1/README.md @@ -176,6 +176,17 @@ CUSTOM_ARGS=( These are the `generate` and `reward_func` functions in `generate_with_search.py`. +### Partial Rollout + +Search-R1 also supports slime's partial rollout path for long-tail search trajectories: + +```bash +--partial-rollout +--mask-offpolicy-in-partial-rollout +``` + +When a rollout round aborts unfinished requests, Search-R1 keeps the generated response, search observations, loss mask, and rollout log probabilities on the `Sample`. The next rollout resumes from that context instead of starting the prompt over. With `--mask-offpolicy-in-partial-rollout`, tokens generated before the latest weight update are kept as context but masked out of training. + ## Appendix: Setting up Local Retriever This section provides detailed instructions for setting up the local dense retriever for use with the local search backend. diff --git a/examples/search-r1/README_zh.md b/examples/search-r1/README_zh.md index f7d9fbb0d9..891c0bf504 100644 --- a/examples/search-r1/README_zh.md +++ b/examples/search-r1/README_zh.md @@ -176,6 +176,17 @@ CUSTOM_ARGS=( 也就是 `generate_with_search.py` 中的 `generate` 和 `reward_func` 两个函数。 +### Partial Rollout + +Search-R1 也支持 slime 的 partial rollout 路径,适合耗时较长、容易出现长尾的搜索轨迹: + +```bash +--partial-rollout +--mask-offpolicy-in-partial-rollout +``` + +当某轮 rollout 中还有未完成请求被 abort 时,Search-R1 会把已经生成的 response、搜索 observation、loss mask 和 rollout log probabilities 保存在 `Sample` 上。下一轮 rollout 会从已有上下文继续生成,而不是从原始 prompt 重新开始。开启 `--mask-offpolicy-in-partial-rollout` 后,旧权重生成的 token 会继续作为上下文保留,但不会参与训练。 + ## 附录:配置本地检索器 本节提供详细的本地密集检索器设置说明,用于本地搜索后端。 diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py index ae0f18dd15..aa36297f6c 100644 --- a/examples/search-r1/generate_with_search.py +++ b/examples/search-r1/generate_with_search.py @@ -39,6 +39,7 @@ SEMAPHORE = asyncio.Semaphore(SEARCH_R1_CONFIGS["search_concurrency"]) +_SEARCH_R1_TURN_COUNT_KEY = "search_r1_completed_turns" def _passages2string(retrieval_result): @@ -121,6 +122,11 @@ def postprocess_predictions(prediction: str): return action, content +def _last_prediction_action(prediction: str) -> str | None: + matches = re.findall(r"<(search|answer)>.*?", prediction, re.DOTALL) + return matches[-1] if matches else None + + async def execute_predictions(prediction: str) -> str: action, content = postprocess_predictions(prediction) @@ -142,24 +148,73 @@ async def execute_predictions(prediction: str) -> str: return next_obs, done -async def generate(args, sample: Sample, sampling_params) -> Sample: - assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment." +def _count_completed_model_turns(response: str) -> int: + return len(re.findall(r"", response)) + + +def _get_completed_model_turns(sample: Sample) -> int: + turn_count = sample.metadata.get(_SEARCH_R1_TURN_COUNT_KEY) + if isinstance(turn_count, int): + return turn_count + return _count_completed_model_turns(sample.response or "") + + +def _mark_model_turn_completed(sample: Sample, completed_turns: int) -> None: + sample.metadata[_SEARCH_R1_TURN_COUNT_KEY] = completed_turns + + +def _append_trainable_response_tokens( + args, + sample: Sample, + *, + tokens: list[int], + log_probs: list[float] | None, + meta_info: dict, + text: str, +) -> None: + if log_probs is not None: + sample.append_response_tokens( + args, + tokens=tokens, + log_probs=log_probs, + trainable=True, + meta_info=meta_info, + text=text, + ) + return + + sample.response += text + sample.tokens += tokens + sample.response_length += len(tokens) + if sample.loss_mask is None: + sample.loss_mask = [] + sample.loss_mask += [1] * len(tokens) + match meta_info["finish_reason"]["type"]: + case "length": + sample.status = Sample.Status.TRUNCATED + case "abort": + sample.status = Sample.Status.ABORTED + case "stop": + sample.status = Sample.Status.COMPLETED + +async def generate(args, sample: Sample, sampling_params) -> Sample: state = GenerateState(args) url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - # Handle partial rollout samples: continue generation from existing response prompt_text = sample.prompt prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"] - sample.tokens = list(prompt_tokens_ids) - sample.loss_mask = [] - response = "" - response_token_ids = [] - loss_mask = [] - rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None - sample.rollout_top_p_token_ids = None - sample.rollout_top_p_token_offsets = None + if not sample.tokens: + sample.tokens = list(prompt_tokens_ids) + if sample.loss_mask is None: + sample.loss_mask = [1] * sample.response_length + elif args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0: + sample.loss_mask = [0] * sample.response_length + if SEARCH_R1_CONFIGS["return_logprob"] and sample.rollout_log_probs is None and sample.response_length > 0: + sample.rollout_log_probs = [0.0] * sample.response_length + + response = sample.response or "" # BUGFIX: make the inference engine STOP at the tool/answer boundary. # Without a stop, sglang keeps emitting tokens after / @@ -176,7 +231,17 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: _existing_stop = [_existing_stop] sampling_params = {**sampling_params, "stop": list(dict.fromkeys([*_existing_stop, *_stop_tags]))} - for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]): + output = None + completed_turns = _get_completed_model_turns(sample) if args.partial_rollout else 0 + if args.partial_rollout and _last_prediction_action(response) == "answer": + sample.status = Sample.Status.COMPLETED + return sample + + for _turn_idx in range(completed_turns, SEARCH_R1_CONFIGS["max_turns"]): + if state.aborted: + sample.status = Sample.Status.ABORTED + return sample + payload = { "text": prompt_text + response, "sampling_params": sampling_params, @@ -187,14 +252,14 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: output = await post(url, payload) - # abort - if output["meta_info"]["finish_reason"]["type"] == "abort": + cur_response = output["text"] + finish_reason = output["meta_info"]["finish_reason"]["type"] + if finish_reason == "abort": sample.status = Sample.Status.ABORTED return sample - cur_response = output["text"] - # Extract tokens and log probs based on configuration + cur_response_log_probs = None if SEARCH_R1_CONFIGS["return_logprob"]: # Extract log probs from output - required for TIS metrics if "output_token_logprobs" not in output["meta_info"]: @@ -215,23 +280,21 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"] response += cur_response - response_token_ids += cur_response_token_ids - loss_mask += [1] * len(cur_response_token_ids) - sample.append_response_tokens( + _append_trainable_response_tokens( args, + sample, tokens=cur_response_token_ids, log_probs=cur_response_log_probs if SEARCH_R1_CONFIGS["return_logprob"] else None, - trainable=True, - meta_info=output["meta_info"] if "output_token_logprobs" in output["meta_info"] else None, + meta_info=output["meta_info"], + text=cur_response, ) - # Add log probs if enabled - if SEARCH_R1_CONFIGS["return_logprob"]: - rollout_log_probs += cur_response_log_probs - - if output["meta_info"]["finish_reason"]["type"] == "length": + if finish_reason == "length": break + if finish_reason == "stop": + _mark_model_turn_completed(sample, _turn_idx + 1) + next_obs, done = await execute_predictions(cur_response) if done: break @@ -239,29 +302,27 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: assert next_obs != "", "Next observation should not be empty." obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"] response += next_obs - response_token_ids += obs_tokens_ids - loss_mask += [0] * len(obs_tokens_ids) - sample.append_response_tokens(args, tokens=obs_tokens_ids, trainable=False) + sample.append_response_tokens(args, tokens=obs_tokens_ids, trainable=False, text=next_obs) - # Add dummy log probs for observation tokens if enabled (they won't be used due to loss_mask=0) + # Verify alignment when collecting log probs. Observation tokens receive dummy + # log probs inside append_response_tokens because loss_mask marks them non-trainable. if SEARCH_R1_CONFIGS["return_logprob"]: - rollout_log_probs += [0.0] * len(obs_tokens_ids) - - # Verify alignment when collecting log probs - assert len(response_token_ids) == len( - rollout_log_probs - ), f"Token/logp length mismatch: {len(response_token_ids)} tokens vs {len(rollout_log_probs)} logps" - - # Store statistics for wandb logging - sample.tokens = prompt_tokens_ids + response_token_ids - sample.response_length = len(response_token_ids) - sample.response = response - sample.loss_mask = loss_mask - sample.prompt = prompt_text + assert sample.rollout_log_probs is not None + assert len(sample.rollout_log_probs) == sample.response_length, ( + f"Token/logp length mismatch: {sample.response_length} tokens vs " + f"{len(sample.rollout_log_probs)} logps" + ) - # Store log probs if enabled - if SEARCH_R1_CONFIGS["return_logprob"]: - sample.rollout_log_probs = rollout_log_probs if rollout_log_probs else None + sample.prompt = prompt_text + if output is None: + action = _last_prediction_action(response) + sample.status = Sample.Status.COMPLETED if action == "answer" else Sample.Status.TRUNCATED + return sample + + action = _last_prediction_action(response) + if action != "answer" and output["meta_info"]["finish_reason"]["type"] == "stop": + sample.status = Sample.Status.TRUNCATED + return sample match output["meta_info"]["finish_reason"]["type"]: case "length": diff --git a/tests/test_search_r1_partial_rollout.py b/tests/test_search_r1_partial_rollout.py new file mode 100644 index 0000000000..01b8f41835 --- /dev/null +++ b/tests/test_search_r1_partial_rollout.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import asyncio +import sys +import types +from pathlib import Path +from types import SimpleNamespace + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +SEARCH_R1_DIR = REPO_ROOT / "examples" / "search-r1" +for path in (REPO_ROOT, SEARCH_R1_DIR): + if str(path) not in sys.path: + sys.path.insert(0, str(path)) + +if "sglang_router" not in sys.modules: + sglang_router_stub = types.ModuleType("sglang_router") + sglang_router_stub.__version__ = "0.2.3" + sys.modules["sglang_router"] = sglang_router_stub + +if "ray" not in sys.modules: + ray_stub = types.ModuleType("ray") + ray_stub._private = types.SimpleNamespace(services=types.SimpleNamespace(get_node_ip_address=lambda: "127.0.0.1")) + sys.modules["ray"] = ray_stub + +if "transformers" not in sys.modules: + transformers_stub = types.ModuleType("transformers") + for name in ("AutoProcessor", "AutoTokenizer", "PreTrainedTokenizerBase", "ProcessorMixin"): + setattr(transformers_stub, name, type(name, (), {})) + sys.modules["transformers"] = transformers_stub + +sglang_rollout_stub = types.ModuleType("slime.rollout.sglang_rollout") + + +class _StubGenerateState: + pass + + +sglang_rollout_stub.GenerateState = _StubGenerateState +previous_sglang_rollout = sys.modules.get("slime.rollout.sglang_rollout") +sys.modules["slime.rollout.sglang_rollout"] = sglang_rollout_stub + +import generate_with_search as search_gen # noqa: E402 +from slime.utils.types import Sample # noqa: E402 + +if previous_sglang_rollout is None: + sys.modules.pop("slime.rollout.sglang_rollout", None) +else: + sys.modules["slime.rollout.sglang_rollout"] = previous_sglang_rollout + +NUM_GPUS = 0 + + +class FakeTokenizer: + def __call__(self, text: str, add_special_tokens: bool = False): + return {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)} + + def encode(self, text: str, add_special_tokens: bool = False): + return [ord(ch) for ch in text] + + +class FakeGenerateState: + def __init__(self, args) -> None: + self.args = args + self.tokenizer = FakeTokenizer() + self.aborted = False + + +def _args(*, partial_rollout: bool = True, mask_offpolicy_in_partial_rollout: bool = False): + return SimpleNamespace( + partial_rollout=partial_rollout, + mask_offpolicy_in_partial_rollout=mask_offpolicy_in_partial_rollout, + sglang_router_ip="127.0.0.1", + sglang_router_port=1234, + sglang_speculative_algorithm=False, + ) + + +def _output(text: str, token_base: int, finish_reason: str = "stop"): + return { + "text": text, + "meta_info": { + "finish_reason": {"type": finish_reason}, + "output_token_logprobs": [[-0.1 * (i + 1), token_base + i] for i in range(len(text))], + }, + } + + +@pytest.fixture(autouse=True) +def patch_generate_state(monkeypatch): + monkeypatch.setattr(search_gen, "GenerateState", FakeGenerateState) + monkeypatch.setitem(search_gen.SEARCH_R1_CONFIGS, "return_logprob", True) + monkeypatch.setitem(search_gen.SEARCH_R1_CONFIGS, "max_turns", 2) + monkeypatch.setattr(search_gen, "execute_predictions", _fake_execute_predictions) + + +async def _fake_execute_predictions(prediction: str): + action, content = search_gen.postprocess_predictions(prediction) + if action == "search": + return f"\n\nresult for {content}\n\n", False + if action == "answer": + return "", True + return "\ninvalid action\n", False + + +def test_search_r1_fresh_rollout_records_turn_state_and_keeps_logprob_alignment(monkeypatch): + calls = [] + outputs = [ + _output("cats", 100), + _output("cats", 200), + ] + + async def fake_post(_url, payload): + calls.append(payload) + return outputs.pop(0) + + monkeypatch.setattr(search_gen, "post", fake_post) + sample = Sample(prompt="Question: cats\n", status=Sample.Status.PENDING) + + result = asyncio.run(search_gen.generate(_args(), sample, {"stop": ""})) + + assert result.status is Sample.Status.COMPLETED + assert result.metadata[search_gen._SEARCH_R1_TURN_COUNT_KEY] == 2 + assert result.response == ( + "cats\n\n" "result for cats\n\n" "cats" + ) + assert len(result.loss_mask) == result.response_length + assert len(result.rollout_log_probs) == result.response_length + assert sum(result.loss_mask) == len("cats") + len("cats") + assert calls[0]["text"] == "Question: cats\n" + assert ( + calls[1]["text"] == "Question: cats\ncats\n\nresult for cats\n\n" + ) + assert calls[0]["sampling_params"]["stop"] == ["", "", ""] + + +def test_search_r1_partial_rollout_resumes_without_clearing_existing_trajectory(monkeypatch): + old_response = "cats\n\nresult for cats\n\n" + old_prompt = "Question: cats\n" + old_prompt_tokens = FakeTokenizer().encode(old_prompt) + old_response_tokens = [10] * len(old_response) + old_log_probs = [-0.7] * len(old_response) + calls = [] + + async def fake_post(_url, payload): + calls.append(payload) + return _output("cats", 300) + + monkeypatch.setattr(search_gen, "post", fake_post) + sample = Sample( + prompt=old_prompt, + tokens=old_prompt_tokens + old_response_tokens, + response=old_response, + response_length=len(old_response), + loss_mask=[0] * len(old_response), + rollout_log_probs=list(old_log_probs), + status=Sample.Status.ABORTED, + metadata={search_gen._SEARCH_R1_TURN_COUNT_KEY: 1}, + ) + + result = asyncio.run(search_gen.generate(_args(mask_offpolicy_in_partial_rollout=True), sample, {})) + + assert result.status is Sample.Status.COMPLETED + assert result.response == old_response + "cats" + assert ( + result.tokens[: len(old_prompt_tokens) + len(old_response_tokens)] == old_prompt_tokens + old_response_tokens + ) + assert result.rollout_log_probs[: len(old_log_probs)] == old_log_probs + assert result.loss_mask[: len(old_response)] == [0] * len(old_response) + assert result.loss_mask[len(old_response) :] == [1] * len("cats") + assert len(result.rollout_log_probs) == result.response_length + assert result.metadata[search_gen._SEARCH_R1_TURN_COUNT_KEY] == 2 + assert calls == [ + { + "text": old_prompt + old_response, + "sampling_params": {"stop": ["", ""]}, + "return_logprob": True, + } + ] + + +def test_search_r1_partial_rollout_falls_back_to_response_tags_for_older_samples(monkeypatch): + old_response = "cats\n\nresult for cats\n\n" + calls = [] + + async def fake_post(_url, payload): + calls.append(payload) + return _output("cats", 400) + + monkeypatch.setattr(search_gen, "post", fake_post) + sample = Sample( + prompt="Question: cats\n", + response=old_response, + response_length=len(old_response), + tokens=FakeTokenizer().encode("Question: cats\n") + [1] * len(old_response), + loss_mask=[0] * len(old_response), + rollout_log_probs=[0.0] * len(old_response), + status=Sample.Status.ABORTED, + ) + + result = asyncio.run(search_gen.generate(_args(), sample, {})) + + assert result.status is Sample.Status.COMPLETED + assert len(calls) == 1 + assert result.metadata[search_gen._SEARCH_R1_TURN_COUNT_KEY] == 2 + + +def test_search_r1_partial_rollout_with_completed_answer_does_not_generate(monkeypatch): + async def fake_post(_url, _payload): + raise AssertionError("completed partial answer should not call SGLang") + + monkeypatch.setattr(search_gen, "post", fake_post) + sample = Sample( + prompt="Question: cats\n", + response="cats", + response_length=len("cats"), + tokens=FakeTokenizer().encode("Question: cats\n") + [1] * len("cats"), + loss_mask=[1] * len("cats"), + rollout_log_probs=[-0.1] * len("cats"), + status=Sample.Status.ABORTED, + metadata={search_gen._SEARCH_R1_TURN_COUNT_KEY: 1}, + ) + + result = asyncio.run(search_gen.generate(_args(), sample, {})) + + assert result.status is Sample.Status.COMPLETED + assert result.response == "cats" + + +def test_search_r1_partial_rollout_supports_non_logprob_mode(monkeypatch): + monkeypatch.setitem(search_gen.SEARCH_R1_CONFIGS, "return_logprob", False) + + async def fake_post(_url, _payload): + return { + "text": "cats trailing junk", + "meta_info": {"finish_reason": {"type": "stop"}}, + } + + monkeypatch.setattr(search_gen, "post", fake_post) + sample = Sample(prompt="Question: cats\n", status=Sample.Status.PENDING) + + result = asyncio.run(search_gen.generate(_args(), sample, {})) + + assert result.status is Sample.Status.COMPLETED + assert result.response == "cats" + assert result.rollout_log_probs is None + assert len(result.loss_mask) == result.response_length + assert result.loss_mask == [1] * len("cats")