Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/search-r1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ CUSTOM_ARGS=(

These are the `generate` and `reward_func` functions in `generate_with_search.py`.

### Partial Rollout

Search-R1 also supports slime's partial rollout path for long-tail search trajectories:

```bash
--partial-rollout
--mask-offpolicy-in-partial-rollout
```

When a rollout round aborts unfinished requests, Search-R1 keeps the generated response, search observations, loss mask, and rollout log probabilities on the `Sample`. The next rollout resumes from that context instead of starting the prompt over. With `--mask-offpolicy-in-partial-rollout`, tokens generated before the latest weight update are kept as context but masked out of training.

## Appendix: Setting up Local Retriever

This section provides detailed instructions for setting up the local dense retriever for use with the local search backend.
Expand Down
11 changes: 11 additions & 0 deletions examples/search-r1/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ CUSTOM_ARGS=(

也就是 `generate_with_search.py` 中的 `generate` 和 `reward_func` 两个函数。

### Partial Rollout

Search-R1 也支持 slime 的 partial rollout 路径,适合耗时较长、容易出现长尾的搜索轨迹:

```bash
--partial-rollout
--mask-offpolicy-in-partial-rollout
```

当某轮 rollout 中还有未完成请求被 abort 时,Search-R1 会把已经生成的 response、搜索 observation、loss mask 和 rollout log probabilities 保存在 `Sample` 上。下一轮 rollout 会从已有上下文继续生成,而不是从原始 prompt 重新开始。开启 `--mask-offpolicy-in-partial-rollout` 后,旧权重生成的 token 会继续作为上下文保留,但不会参与训练。

## 附录:配置本地检索器

本节提供详细的本地密集检索器设置说明,用于本地搜索后端。
Expand Down
153 changes: 107 additions & 46 deletions examples/search-r1/generate_with_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@


SEMAPHORE = asyncio.Semaphore(SEARCH_R1_CONFIGS["search_concurrency"])
_SEARCH_R1_TURN_COUNT_KEY = "search_r1_completed_turns"


def _passages2string(retrieval_result):
Expand Down Expand Up @@ -121,6 +122,11 @@ def postprocess_predictions(prediction: str):
return action, content


def _last_prediction_action(prediction: str) -> str | None:
matches = re.findall(r"<(search|answer)>.*?</\1>", prediction, re.DOTALL)
return matches[-1] if matches else None


async def execute_predictions(prediction: str) -> str:
action, content = postprocess_predictions(prediction)

Expand All @@ -142,24 +148,73 @@ async def execute_predictions(prediction: str) -> str:
return next_obs, done


async def generate(args, sample: Sample, sampling_params) -> Sample:
assert not args.partial_rollout, "Partial rollout is not supported for this function at the moment."
def _count_completed_model_turns(response: str) -> int:
return len(re.findall(r"</(?:search|answer)>", response))


def _get_completed_model_turns(sample: Sample) -> int:
turn_count = sample.metadata.get(_SEARCH_R1_TURN_COUNT_KEY)
if isinstance(turn_count, int):
return turn_count
return _count_completed_model_turns(sample.response or "")


def _mark_model_turn_completed(sample: Sample, completed_turns: int) -> None:
sample.metadata[_SEARCH_R1_TURN_COUNT_KEY] = completed_turns


def _append_trainable_response_tokens(
args,
sample: Sample,
*,
tokens: list[int],
log_probs: list[float] | None,
meta_info: dict,
text: str,
) -> None:
if log_probs is not None:
sample.append_response_tokens(
args,
tokens=tokens,
log_probs=log_probs,
trainable=True,
meta_info=meta_info,
text=text,
)
return

sample.response += text
sample.tokens += tokens
sample.response_length += len(tokens)
if sample.loss_mask is None:
sample.loss_mask = []
sample.loss_mask += [1] * len(tokens)
match meta_info["finish_reason"]["type"]:
case "length":
sample.status = Sample.Status.TRUNCATED
case "abort":
sample.status = Sample.Status.ABORTED
case "stop":
sample.status = Sample.Status.COMPLETED


async def generate(args, sample: Sample, sampling_params) -> Sample:
state = GenerateState(args)

url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

# Handle partial rollout samples: continue generation from existing response
prompt_text = sample.prompt
prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
sample.tokens = list(prompt_tokens_ids)
sample.loss_mask = []
response = ""
response_token_ids = []
loss_mask = []
rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None
sample.rollout_top_p_token_ids = None
sample.rollout_top_p_token_offsets = None
if not sample.tokens:
sample.tokens = list(prompt_tokens_ids)
if sample.loss_mask is None:
sample.loss_mask = [1] * sample.response_length
elif args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0:
sample.loss_mask = [0] * sample.response_length
if SEARCH_R1_CONFIGS["return_logprob"] and sample.rollout_log_probs is None and sample.response_length > 0:
sample.rollout_log_probs = [0.0] * sample.response_length

response = sample.response or ""

# BUGFIX: make the inference engine STOP at the tool/answer boundary.
# Without a stop, sglang keeps emitting tokens after </search> / </answer>
Expand All @@ -176,7 +231,17 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
_existing_stop = [_existing_stop]
sampling_params = {**sampling_params, "stop": list(dict.fromkeys([*_existing_stop, *_stop_tags]))}

for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
output = None
completed_turns = _get_completed_model_turns(sample) if args.partial_rollout else 0
if args.partial_rollout and _last_prediction_action(response) == "answer":
sample.status = Sample.Status.COMPLETED
return sample

for _turn_idx in range(completed_turns, SEARCH_R1_CONFIGS["max_turns"]):
if state.aborted:
sample.status = Sample.Status.ABORTED
return sample

payload = {
"text": prompt_text + response,
"sampling_params": sampling_params,
Expand All @@ -187,14 +252,14 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:

output = await post(url, payload)

# abort
if output["meta_info"]["finish_reason"]["type"] == "abort":
cur_response = output["text"]
finish_reason = output["meta_info"]["finish_reason"]["type"]
if finish_reason == "abort":
sample.status = Sample.Status.ABORTED
return sample

cur_response = output["text"]

# Extract tokens and log probs based on configuration
cur_response_log_probs = None
if SEARCH_R1_CONFIGS["return_logprob"]:
# Extract log probs from output - required for TIS metrics
if "output_token_logprobs" not in output["meta_info"]:
Expand All @@ -215,53 +280,49 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
cur_response_token_ids = state.tokenizer(cur_response, add_special_tokens=False)["input_ids"]

response += cur_response
response_token_ids += cur_response_token_ids
loss_mask += [1] * len(cur_response_token_ids)
sample.append_response_tokens(
_append_trainable_response_tokens(
args,
sample,
tokens=cur_response_token_ids,
log_probs=cur_response_log_probs if SEARCH_R1_CONFIGS["return_logprob"] else None,
trainable=True,
meta_info=output["meta_info"] if "output_token_logprobs" in output["meta_info"] else None,
meta_info=output["meta_info"],
text=cur_response,
)

# Add log probs if enabled
if SEARCH_R1_CONFIGS["return_logprob"]:
rollout_log_probs += cur_response_log_probs

if output["meta_info"]["finish_reason"]["type"] == "length":
if finish_reason == "length":
break

if finish_reason == "stop":
_mark_model_turn_completed(sample, _turn_idx + 1)

next_obs, done = await execute_predictions(cur_response)
if done:
break

assert next_obs != "", "Next observation should not be empty."
obs_tokens_ids = state.tokenizer(next_obs, add_special_tokens=False)["input_ids"]
response += next_obs
response_token_ids += obs_tokens_ids
loss_mask += [0] * len(obs_tokens_ids)
sample.append_response_tokens(args, tokens=obs_tokens_ids, trainable=False)
sample.append_response_tokens(args, tokens=obs_tokens_ids, trainable=False, text=next_obs)

# Add dummy log probs for observation tokens if enabled (they won't be used due to loss_mask=0)
# Verify alignment when collecting log probs. Observation tokens receive dummy
# log probs inside append_response_tokens because loss_mask marks them non-trainable.
if SEARCH_R1_CONFIGS["return_logprob"]:
rollout_log_probs += [0.0] * len(obs_tokens_ids)

# Verify alignment when collecting log probs
assert len(response_token_ids) == len(
rollout_log_probs
), f"Token/logp length mismatch: {len(response_token_ids)} tokens vs {len(rollout_log_probs)} logps"

# Store statistics for wandb logging
sample.tokens = prompt_tokens_ids + response_token_ids
sample.response_length = len(response_token_ids)
sample.response = response
sample.loss_mask = loss_mask
sample.prompt = prompt_text
assert sample.rollout_log_probs is not None
assert len(sample.rollout_log_probs) == sample.response_length, (
f"Token/logp length mismatch: {sample.response_length} tokens vs "
f"{len(sample.rollout_log_probs)} logps"
)

# Store log probs if enabled
if SEARCH_R1_CONFIGS["return_logprob"]:
sample.rollout_log_probs = rollout_log_probs if rollout_log_probs else None
sample.prompt = prompt_text
if output is None:
action = _last_prediction_action(response)
sample.status = Sample.Status.COMPLETED if action == "answer" else Sample.Status.TRUNCATED
return sample

action = _last_prediction_action(response)
if action != "answer" and output["meta_info"]["finish_reason"]["type"] == "stop":
sample.status = Sample.Status.TRUNCATED
return sample

match output["meta_info"]["finish_reason"]["type"]:
case "length":
Expand Down
Loading
Loading