Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions nemo_gym/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,52 @@ def filter_for_server_instance_configs(self, global_config_dict: DictConfig) ->

return server_instance_configs

def _is_no_model_initial_config(self, initial_global_config_dict: DictConfig) -> bool:
policy_model_config = initial_global_config_dict.get(POLICY_MODEL_KEY_NAME)
if not isinstance(policy_model_config, (DictConfig, dict)):
return False

responses_api_models = policy_model_config.get("responses_api_models")
if not isinstance(responses_api_models, (DictConfig, dict)):
return False

dummy_model_config = responses_api_models.get("dummy_model")
return (
isinstance(dummy_model_config, (DictConfig, dict))
and dummy_model_config.get(ENTRYPOINT_KEY_NAME) == "app.py"
)

def _seed_no_model_response_model_refs(self, global_config_dict: DictConfig) -> None:
server_instance_configs = self.filter_for_server_instance_configs(global_config_dict)
existing_response_model_names = {
server_config.name
for server_config in server_instance_configs
if server_config.get_server_ref().type == "responses_api_models"
}
refs_to_seed = []

for server_instance_config in server_instance_configs:
run_server_config_dict = server_instance_config.get_inner_run_server_config_dict()
for maybe_ref_config in run_server_config_dict.values():
maybe_server_ref = is_server_ref(maybe_ref_config)
if (
maybe_server_ref
and maybe_server_ref.type == "responses_api_models"
and maybe_server_ref.name not in existing_response_model_names
):
refs_to_seed.append(maybe_server_ref.name)
existing_response_model_names.add(maybe_server_ref.name)

dummy_model_config = GlobalConfigDictParserConfig.NO_MODEL_GLOBAL_CONFIG_DICT[POLICY_MODEL_KEY_NAME][
"responses_api_models"
]["dummy_model"]

with open_dict(global_config_dict):
for ref_name in refs_to_seed:
global_config_dict[ref_name] = {
"responses_api_models": {"dummy_model": OmegaConf.create(deepcopy(dummy_model_config))}
}

def raise_on_no_server_instances(self, global_config_dict: DictConfig) -> None:
"""Fail fast if a run has no server instances to start.

Expand Down Expand Up @@ -579,6 +625,9 @@ def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) ->
# value surfaces as an opaque MissingMandatoryValue deep in the pipeline.
self.raise_on_missing_values(global_config_dict)

if self._is_no_model_initial_config(initial_global_config_dict):
self._seed_no_model_response_model_refs(global_config_dict)

# TODO @bxyu-nvidia: We need a better way of handling dummy model configs
with open_dict(global_config_dict):
for top_level_value in global_config_dict.values():
Expand Down
59 changes: 59 additions & 0 deletions tests/unit_tests/test_global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,65 @@ def hydra_main_wrapper(fn):

assert expected_global_config_dict == actual_global_config_dict

def test_no_model_config_seeds_referenced_model_server_refs(self, monkeypatch: MonkeyPatch) -> None:
self._mock_versions_for_testing(monkeypatch)

monkeypatch.delenv(NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME, raising=False)
monkeypatch.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", None)

exists_mock = MagicMock()
exists_mock.return_value = False
monkeypatch.setattr(nemo_gym.global_config.Path, "exists", exists_mock)

find_open_port_mock = MagicMock()
find_open_port_mock.return_value = 12345
monkeypatch.setattr(nemo_gym.global_config, "_find_open_port_using_range", find_open_port_mock)

hydra_main_mock = MagicMock()

def hydra_main_wrapper(fn):
config_dict = DictConfig(
{
"example_agent": {
"responses_api_agents": {
"simple_agent": {
"entrypoint": "app.py",
"model_server": {
"type": "responses_api_models",
"name": "user_model",
},
"judge_model_server": {
"type": "responses_api_models",
"name": "judge_model",
},
}
}
}
}
)
return lambda: fn(config_dict)

hydra_main_mock.return_value = hydra_main_wrapper
monkeypatch.setattr(nemo_gym.global_config.hydra, "main", hydra_main_mock)

with raises(
ServerRefNotFoundError,
match="responses_api_models/'user_model', which is not defined in the merged config.",
):
get_global_config_dict()

monkeypatch.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", None)

global_config_dict = get_global_config_dict(
global_config_dict_parser_config=GlobalConfigDictParserConfig(
initial_global_config_dict=GlobalConfigDictParserConfig.NO_MODEL_GLOBAL_CONFIG_DICT,
)
)

assert global_config_dict["user_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py"
assert global_config_dict["judge_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py"
assert global_config_dict["policy_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py"

def test_dummy_model_override(self, monkeypatch: MonkeyPatch) -> None:
self._mock_versions_for_testing(monkeypatch)

Expand Down
49 changes: 49 additions & 0 deletions tests/unit_tests/test_train_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic import ValidationError
from pytest import MonkeyPatch, raises

import nemo_gym.cli.dataset
import nemo_gym.global_config
import nemo_gym.train_data_utils
from nemo_gym.config_types import DatasetConfig, ResponsesAPIAgentServerInstanceConfig
Expand Down Expand Up @@ -439,6 +440,54 @@ def test_validate_backend_credentials_valid(self, monkeypatch: MonkeyPatch) -> N


class TestValidateSamplesAndAggregateMetrics:
def test_prepare_data_no_model_config_seeds_non_policy_model_refs(self, monkeypatch: MonkeyPatch) -> None:
monkeypatch.delenv(nemo_gym.global_config.NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME, raising=False)
monkeypatch.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", None)

exists_mock = MagicMock()
exists_mock.return_value = False
monkeypatch.setattr(nemo_gym.global_config.Path, "exists", exists_mock)

find_open_port_mock = MagicMock()
find_open_port_mock.return_value = 12345
monkeypatch.setattr(nemo_gym.global_config, "_find_open_port_using_range", find_open_port_mock)

hydra_main_mock = MagicMock()

def hydra_main_wrapper(fn):
config_dict = DictConfig(
{
"example_agent": {
"responses_api_agents": {
"simple_agent": {
"entrypoint": "app.py",
"model_server": {
"type": "responses_api_models",
"name": "user_model",
},
}
}
}
}
)
return lambda: fn(config_dict)

hydra_main_mock.return_value = hydra_main_wrapper
monkeypatch.setattr(nemo_gym.global_config.hydra, "main", hydra_main_mock)

captured_configs = []

def run_mock(_processor, global_config_dict):
captured_configs.append(global_config_dict)

monkeypatch.setattr(nemo_gym.cli.dataset.TrainDataProcessor, "run", run_mock)

nemo_gym.cli.dataset.prepare_data()

assert len(captured_configs) == 1
global_config_dict = captured_configs[0]
assert global_config_dict["user_model"]["responses_api_models"]["dummy_model"]["entrypoint"] == "app.py"

def test_validate_samples_and_aggregate_metrics_sanity(self, monkeypatch: MonkeyPatch) -> None:
mock_write_file = mock_open()
write_filenames = []
Expand Down