diff --git a/examples/geo3k_vlm_multi_turn/rollout.py b/examples/geo3k_vlm_multi_turn/rollout.py index 2bb1bef2fc..04eb97d9b1 100644 --- a/examples/geo3k_vlm_multi_turn/rollout.py +++ b/examples/geo3k_vlm_multi_turn/rollout.py @@ -167,6 +167,24 @@ def _prepare_initial_inputs(sample: Sample, processor, tokenizer): return prompt_ids, image_data, sample.multimodal_train_inputs +def _remaining_generation_budget( + sample: Sample, response_tokens: list[int], args: Any, sampling_params: dict +) -> int | None: + # SGLang max_new_tokens is a response-side budget; prompt tokens are already + # part of the request and should only count against the optional context cap. + budgets = [] + + context_len = getattr(args, "rollout_max_context_len", None) + if context_len is not None: + budgets.append(int(context_len) - len(sample.tokens)) + + max_new_tokens = sampling_params.get("max_new_tokens") + if max_new_tokens is not None: + budgets.append(int(max_new_tokens) - len(response_tokens)) + + return min(budgets) if budgets else None + + def _prepare_start_state(sample: Sample, state, args: Any, sampling_params: dict): prompt_ids, image_data, init_mm_train = _prepare_initial_inputs(sample, state.processor, state.tokenizer) current_image_data = image_data @@ -181,11 +199,7 @@ def _prepare_start_state(sample: Sample, state, args: Any, sampling_params: dict sample.rollout_log_probs = sample.rollout_log_probs or [] sample.response_length = len(response_tokens) - budget = None - if args.rollout_max_context_len is not None: - budget = args.rollout_max_context_len - len(sample.tokens) - elif sampling_params.get("max_new_tokens") is not None: - budget = sampling_params["max_new_tokens"] - len(sample.tokens) + budget = _remaining_generation_budget(sample, response_tokens, args, sampling_params) return current_image_data, response_tokens, budget, multimodal_train_inputs_buffer diff --git a/tests/test_geo3k_vlm_multi_turn_budget.py b/tests/test_geo3k_vlm_multi_turn_budget.py new file mode 100644 index 0000000000..0985145916 --- /dev/null +++ b/tests/test_geo3k_vlm_multi_turn_budget.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import importlib +import sys +import types +from argparse import Namespace +from pathlib import Path +from types import SimpleNamespace + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +MODULE_NAME = "examples.geo3k_vlm_multi_turn.rollout" +pytestmark = pytest.mark.unit + + +class _FakeTokenizer: + bos_token_id = None + + def encode(self, text, add_special_tokens=False): + del add_special_tokens + return [100 + idx for idx, _ in enumerate(str(text).split())] + + +class _FakeSample(SimpleNamespace): + def __init__(self, *, prompt="a b c", tokens=None, loss_mask=None, rollout_log_probs=None): + super().__init__( + prompt=prompt, + tokens=list(tokens or []), + multimodal_inputs=None, + multimodal_train_inputs=None, + loss_mask=loss_mask, + rollout_log_probs=rollout_log_probs, + response_length=0, + ) + + +@pytest.fixture() +def rollout(monkeypatch): + fake_torch = types.ModuleType("torch") + fake_torch.Tensor = object + fake_torch.cat = lambda values, dim=0: list(values) + monkeypatch.setitem(sys.modules, "torch", fake_torch) + + fake_sglang_rollout = types.ModuleType("slime.rollout.sglang_rollout") + fake_sglang_rollout.GenerateState = object + monkeypatch.setitem(sys.modules, "slime.rollout.sglang_rollout", fake_sglang_rollout) + + fake_http = types.ModuleType("slime.utils.http_utils") + fake_http.post = None + monkeypatch.setitem(sys.modules, "slime.utils.http_utils", fake_http) + + fake_processing = types.ModuleType("slime.utils.processing_utils") + fake_processing.encode_image_for_rollout_engine = lambda image: image + monkeypatch.setitem(sys.modules, "slime.utils.processing_utils", fake_processing) + + fake_types = types.ModuleType("slime.utils.types") + fake_types.Sample = object + monkeypatch.setitem(sys.modules, "slime.utils.types", fake_types) + + sys.modules.pop(MODULE_NAME, None) + try: + yield importlib.import_module(MODULE_NAME) + finally: + sys.modules.pop(MODULE_NAME, None) + + +def _start_state(rollout, sample, *, context_len=None, max_new_tokens=5): + state = SimpleNamespace(processor=None, tokenizer=_FakeTokenizer()) + return rollout._prepare_start_state( + sample, + state, + Namespace(rollout_max_context_len=context_len), + {"max_new_tokens": max_new_tokens}, + ) + + +def test_geo3k_multi_turn_response_budget_does_not_charge_prompt_tokens(rollout): + sample = _FakeSample(prompt="a b c") + + _image_data, response_tokens, budget, _mm_buffer = _start_state(rollout, sample) + + assert response_tokens == [] + assert sample.tokens == [100, 101, 102] + assert budget == 5 + + +def test_geo3k_multi_turn_response_budget_counts_existing_response_only(rollout): + sample = _FakeSample( + prompt="a b c", + tokens=[100, 101, 102, 201, 202], + loss_mask=[1, 1], + rollout_log_probs=[-0.1, -0.2], + ) + + _image_data, response_tokens, budget, _mm_buffer = _start_state(rollout, sample) + + assert response_tokens == [201, 202] + assert sample.response_length == 2 + assert budget == 3 + + +def test_geo3k_multi_turn_budget_respects_context_and_response_limits(rollout): + sample = _FakeSample( + prompt="a b c", + tokens=[100, 101, 102, 201, 202, 203, 204], + loss_mask=[1, 1, 1, 1], + rollout_log_probs=[-0.1, -0.2, -0.3, -0.4], + ) + + _image_data, response_tokens, budget, _mm_buffer = _start_state(rollout, sample, context_len=10) + + assert response_tokens == [201, 202, 203, 204] + assert budget == 1