From 57289727ed74bc8e8fa85076b506798ca6287208 Mon Sep 17 00:00:00 2001 From: Rod Boev Date: Thu, 25 Jun 2026 19:29:21 -0400 Subject: [PATCH] fix(config): seed no-model refs for prepare-data (#997) Signed-off-by: Rod Boev --- nemo_gym/global_config.py | 49 +++++++++++++++++++ tests/unit_tests/test_global_config.py | 59 +++++++++++++++++++++++ tests/unit_tests/test_train_data_utils.py | 49 +++++++++++++++++++ 3 files changed, 157 insertions(+) diff --git a/nemo_gym/global_config.py b/nemo_gym/global_config.py index c9d7e25aeb..7b3dc5a2fd 100644 --- a/nemo_gym/global_config.py +++ b/nemo_gym/global_config.py @@ -275,6 +275,52 @@ def filter_for_server_instance_configs(self, global_config_dict: DictConfig) -> return server_instance_configs + def _is_no_model_initial_config(self, initial_global_config_dict: DictConfig) -> bool: + policy_model_config = initial_global_config_dict.get(POLICY_MODEL_KEY_NAME) + if not isinstance(policy_model_config, (DictConfig, dict)): + return False + + responses_api_models = policy_model_config.get("responses_api_models") + if not isinstance(responses_api_models, (DictConfig, dict)): + return False + + dummy_model_config = responses_api_models.get("dummy_model") + return ( + isinstance(dummy_model_config, (DictConfig, dict)) + and dummy_model_config.get(ENTRYPOINT_KEY_NAME) == "app.py" + ) + + def _seed_no_model_response_model_refs(self, global_config_dict: DictConfig) -> None: + server_instance_configs = self.filter_for_server_instance_configs(global_config_dict) + existing_response_model_names = { + server_config.name + for server_config in server_instance_configs + if server_config.get_server_ref().type == "responses_api_models" + } + refs_to_seed = [] + + for server_instance_config in server_instance_configs: + run_server_config_dict = server_instance_config.get_inner_run_server_config_dict() + for maybe_ref_config in run_server_config_dict.values(): + maybe_server_ref = is_server_ref(maybe_ref_config) + if ( + maybe_server_ref + and maybe_server_ref.type == "responses_api_models" + and maybe_server_ref.name not in existing_response_model_names + ): + refs_to_seed.append(maybe_server_ref.name) + existing_response_model_names.add(maybe_server_ref.name) + + dummy_model_config = GlobalConfigDictParserConfig.NO_MODEL_GLOBAL_CONFIG_DICT[POLICY_MODEL_KEY_NAME][ + "responses_api_models" + ]["dummy_model"] + + with open_dict(global_config_dict): + for ref_name in refs_to_seed: + global_config_dict[ref_name] = { + "responses_api_models": {"dummy_model": OmegaConf.create(deepcopy(dummy_model_config))} + } + def raise_on_no_server_instances(self, global_config_dict: DictConfig) -> None: """Fail fast if a run has no server instances to start. @@ -579,6 +625,9 @@ def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) -> # value surfaces as an opaque MissingMandatoryValue deep in the pipeline. self.raise_on_missing_values(global_config_dict) + if self._is_no_model_initial_config(initial_global_config_dict): + self._seed_no_model_response_model_refs(global_config_dict) + # TODO @bxyu-nvidia: We need a better way of handling dummy model configs with open_dict(global_config_dict): for top_level_value in global_config_dict.values(): diff --git a/tests/unit_tests/test_global_config.py b/tests/unit_tests/test_global_config.py index 49f913d1f4..0785467aab 100644 --- a/tests/unit_tests/test_global_config.py +++ b/tests/unit_tests/test_global_config.py @@ -1191,6 +1191,65 @@ def hydra_main_wrapper(fn): assert expected_global_config_dict == actual_global_config_dict + def test_no_model_config_seeds_referenced_model_server_refs(self, monkeypatch: MonkeyPatch) -> None: + self._mock_versions_for_testing(monkeypatch) + + monkeypatch.delenv(NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME, raising=False) + monkeypatch.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", None) + + exists_mock = MagicMock() + exists_mock.return_value = False + monkeypatch.setattr(nemo_gym.global_config.Path, "exists", exists_mock) + + find_open_port_mock = MagicMock() + find_open_port_mock.return_value = 12345 + monkeypatch.setattr(nemo_gym.global_config, "_find_open_port_using_range", find_open_port_mock) + + hydra_main_mock = MagicMock() + + def hydra_main_wrapper(fn): + config_dict = DictConfig( + { + "example_agent": { + "responses_api_agents": { + "simple_agent": { + "entrypoint": "app.py", + "model_server": { + "type": "responses_api_models", + "name": "user_model", + }, + "judge_model_server": { + "type": "responses_api_models", + "name": "judge_model", + }, + } + } + } + } + ) + return lambda: fn(config_dict) + + hydra_main_mock.return_value = hydra_main_wrapper + monkeypatch.setattr(nemo_gym.global_config.hydra, "main", hydra_main_mock) + + with raises( + ServerRefNotFoundError, + match="responses_api_models/'user_model', which is not defined in the merged config.", + ): + get_global_config_dict() + + monkeypatch.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", None) + + global_config_dict = get_global_config_dict( + global_config_dict_parser_config=GlobalConfigDictParserConfig( + initial_global_config_dict=GlobalConfigDictParserConfig.NO_MODEL_GLOBAL_CONFIG_DICT, + ) + ) + + assert global_config_dict["user_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py" + assert global_config_dict["judge_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py" + assert global_config_dict["policy_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py" + def test_dummy_model_override(self, monkeypatch: MonkeyPatch) -> None: self._mock_versions_for_testing(monkeypatch) diff --git a/tests/unit_tests/test_train_data_utils.py b/tests/unit_tests/test_train_data_utils.py index 31ada07687..678d235009 100644 --- a/tests/unit_tests/test_train_data_utils.py +++ b/tests/unit_tests/test_train_data_utils.py @@ -19,6 +19,7 @@ from pydantic import ValidationError from pytest import MonkeyPatch, raises +import nemo_gym.cli.dataset import nemo_gym.global_config import nemo_gym.train_data_utils from nemo_gym.config_types import DatasetConfig, ResponsesAPIAgentServerInstanceConfig @@ -439,6 +440,54 @@ def test_validate_backend_credentials_valid(self, monkeypatch: MonkeyPatch) -> N class TestValidateSamplesAndAggregateMetrics: + def test_prepare_data_no_model_config_seeds_non_policy_model_refs(self, monkeypatch: MonkeyPatch) -> None: + monkeypatch.delenv(nemo_gym.global_config.NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME, raising=False) + monkeypatch.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", None) + + exists_mock = MagicMock() + exists_mock.return_value = False + monkeypatch.setattr(nemo_gym.global_config.Path, "exists", exists_mock) + + find_open_port_mock = MagicMock() + find_open_port_mock.return_value = 12345 + monkeypatch.setattr(nemo_gym.global_config, "_find_open_port_using_range", find_open_port_mock) + + hydra_main_mock = MagicMock() + + def hydra_main_wrapper(fn): + config_dict = DictConfig( + { + "example_agent": { + "responses_api_agents": { + "simple_agent": { + "entrypoint": "app.py", + "model_server": { + "type": "responses_api_models", + "name": "user_model", + }, + } + } + } + } + ) + return lambda: fn(config_dict) + + hydra_main_mock.return_value = hydra_main_wrapper + monkeypatch.setattr(nemo_gym.global_config.hydra, "main", hydra_main_mock) + + captured_configs = [] + + def run_mock(_processor, global_config_dict): + captured_configs.append(global_config_dict) + + monkeypatch.setattr(nemo_gym.cli.dataset.TrainDataProcessor, "run", run_mock) + + nemo_gym.cli.dataset.prepare_data() + + assert len(captured_configs) == 1 + global_config_dict = captured_configs[0] + assert global_config_dict["user_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py" + def test_validate_samples_and_aggregate_metrics_sanity(self, monkeypatch: MonkeyPatch) -> None: mock_write_file = mock_open() write_filenames = []