From 0dfa1b1017a8ad3e4756d4dcaecca125ffc06db2 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 5 May 2026 20:28:02 -0700 Subject: [PATCH 01/20] add train utils Signed-off-by: Maanu Grover --- megatron/training/utils/train_utils.py | 59 ++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 megatron/training/utils/train_utils.py diff --git a/megatron/training/utils/train_utils.py b/megatron/training/utils/train_utils.py new file mode 100644 index 00000000000..3cf06a15524 --- /dev/null +++ b/megatron/training/utils/train_utils.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import os +from megatron.training.utils.common_utils import print_rank_0 +import torch + +from megatron.core._rank_utils import safe_get_rank +from megatron.training.config import ProfilingConfig + +import logging + +logger = logging.getLogger(__name__) + + +def start_memory_history_recording(profiling: ProfilingConfig | None) -> None: + """Enable the CUDA caching allocator trace so memory snapshots contain history. + + ``torch.cuda.memory._snapshot()`` only includes allocation/free events and + Python stack context after ``_record_memory_history()`` has been enabled. + Without this call, dumped snapshots contain only the current live + allocations — no timeline, no call sites. + + Must be invoked before model construction so every tensor allocation is + captured. Guarded by ``profile_ranks`` so only ranks that will dump a + snapshot pay the recording overhead. + """ + if profiling is None or not profiling.record_memory_history: + return + if safe_get_rank() not in profiling.profile_ranks: + return + + torch.cuda.memory._record_memory_history( + True, + # Retain up to 100k alloc/free events. + trace_alloc_max_entries=100_000, + # Record the Python stack at each event — lets memory_viz show call sites. + trace_alloc_record_context=True, + ) + + def _oom_observer( + device: int, alloc: int, device_alloc: int, device_free: int + ) -> None: + """Dump a snapshot on OOM so we can inspect what was live at the failure.""" + import pickle + + rank = safe_get_rank() + base, ext = os.path.splitext(profiling.memory_snapshot_path) + filename = f"{base}_oom_rank-{rank}{ext}" + snapshot = torch.cuda.memory._snapshot() + with open(filename, "wb") as f: + pickle.dump(snapshot, f) + # logger.info so the message reaches stderr on any profiled rank, not just rank 0. + logger.info(f"[OOM] rank {rank} saved memory snapshot to {filename}") + + torch._C._cuda_attach_out_of_memory_observer(_oom_observer) + print_rank_0( + f"Memory history recording enabled (rank {safe_get_rank()}); " + f"snapshots will be written to '{profiling.memory_snapshot_path}'." + ) From 6bdcab82de365c19c8f845b84f6dbe22816b96c5 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Thu, 4 Jun 2026 14:44:38 -0700 Subject: [PATCH 02/20] avoid direct pickle import Signed-off-by: Maanu Grover --- megatron/training/utils/train_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/megatron/training/utils/train_utils.py b/megatron/training/utils/train_utils.py index 3cf06a15524..37647a96bd7 100644 --- a/megatron/training/utils/train_utils.py +++ b/megatron/training/utils/train_utils.py @@ -41,14 +41,10 @@ def _oom_observer( device: int, alloc: int, device_alloc: int, device_free: int ) -> None: """Dump a snapshot on OOM so we can inspect what was live at the failure.""" - import pickle - rank = safe_get_rank() base, ext = os.path.splitext(profiling.memory_snapshot_path) filename = f"{base}_oom_rank-{rank}{ext}" - snapshot = torch.cuda.memory._snapshot() - with open(filename, "wb") as f: - pickle.dump(snapshot, f) + torch.cuda.memory._dump_snapshot(filename) # logger.info so the message reaches stderr on any profiled rank, not just rank 0. logger.info(f"[OOM] rank {rank} saved memory snapshot to {filename}") From 480e93ede80546b35afcb7871ee2745241cac575 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Wed, 6 May 2026 14:01:07 -0700 Subject: [PATCH 03/20] make model provider func optional Signed-off-by: Maanu Grover --- examples/bert/pretrain_bert.py | 4 ++-- examples/mimo/train.py | 2 +- examples/multimodal/train.py | 2 +- examples/t5/pretrain_t5.py | 2 +- megatron/training/training.py | 2 +- pretrain_gpt.py | 1 - pretrain_hybrid.py | 1 - pretrain_vlm.py | 2 +- train_rl.py | 2 +- 9 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/bert/pretrain_bert.py b/examples/bert/pretrain_bert.py index 3eb95ecf396..9bb3e653e22 100644 --- a/examples/bert/pretrain_bert.py +++ b/examples/bert/pretrain_bert.py @@ -184,6 +184,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None args = parse_and_validate_args(args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) full_config = pretrain_cfg_container_from_args(args) - pretrain(full_config, train_valid_test_datasets_provider, model_provider, + pretrain(full_config, train_valid_test_datasets_provider, ModelType.encoder_or_decoder, - forward_step) + forward_step, model_provider) diff --git a/examples/mimo/train.py b/examples/mimo/train.py index 594170faa7e..b934f402158 100644 --- a/examples/mimo/train.py +++ b/examples/mimo/train.py @@ -282,7 +282,7 @@ def model_provider( pretrain( full_config, train_valid_test_datasets_provider, - model_provider, ModelType.encoder_or_decoder, forward_step, + model_provider, ) diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py index 2345bf38cc1..57c500a9478 100644 --- a/examples/multimodal/train.py +++ b/examples/multimodal/train.py @@ -385,9 +385,9 @@ def write_online_eval_to_tensorboard(data, iteration, writer, walltime=None): pretrain( train_valid_test_dataloaders_provider, - model_provider, ModelType.encoder_or_decoder, forward_step, + model_provider, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, extra_args_provider=add_multimodal_extra_args, process_non_loss_data_func=write_online_eval_to_tensorboard, diff --git a/examples/t5/pretrain_t5.py b/examples/t5/pretrain_t5.py index 171166d08b2..4b33386e2d4 100644 --- a/examples/t5/pretrain_t5.py +++ b/examples/t5/pretrain_t5.py @@ -275,9 +275,9 @@ def t5_position_embedding_ranks(pp_ranks): pretrain( full_config, train_valid_test_datasets_provider, - model_provider, ModelType.encoder_or_decoder, forward_step, + model_provider, get_embedding_ranks=t5_embedding_ranks, get_position_embedding_ranks=t5_position_embedding_ranks, ) diff --git a/megatron/training/training.py b/megatron/training/training.py index ac7d8b57c4c..8479f353041 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1004,9 +1004,9 @@ def reorder_inner_param_groups(optimizer_state_dict): def pretrain( cfg_container: PretrainConfigContainer, train_valid_test_dataset_provider, - model_provider, model_type, forward_step_func, + model_provider=None, process_non_loss_data_func=None, get_embedding_ranks=None, get_position_embedding_ranks=None, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index bb9e06b71c9..aa59fdaaa51 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -501,7 +501,6 @@ def get_embedding_ranks(pp_ranks: List[int]): pretrain( full_config, train_valid_test_datasets_provider, - partial(model_provider, gpt_builder), ModelType.encoder_or_decoder, forward_step, store=store, diff --git a/pretrain_hybrid.py b/pretrain_hybrid.py index c2fe3bd510e..e4637f61f9b 100644 --- a/pretrain_hybrid.py +++ b/pretrain_hybrid.py @@ -447,7 +447,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None pretrain( full_config, train_valid_test_datasets_provider, - partial(model_provider, hybrid_builder), ModelType.encoder_or_decoder, forward_step, store=store, diff --git a/pretrain_vlm.py b/pretrain_vlm.py index dff56257ce4..f29d030c479 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -480,9 +480,9 @@ def llava_position_embedding_ranks(pp_ranks): pretrain( full_config, train_valid_test_datasets_provider, - model_provider, ModelType.encoder_or_decoder, forward_step, + model_provider, get_embedding_ranks=llava_embedding_ranks, get_position_embedding_ranks=llava_position_embedding_ranks, ) diff --git a/train_rl.py b/train_rl.py index 7d742772e91..529184d9a13 100644 --- a/train_rl.py +++ b/train_rl.py @@ -419,7 +419,7 @@ def _model_builder( pretrain( full_config, None, # we don't need to build any datasets for RL training - partial(model_provider, _model_builder), ModelType.encoder_or_decoder, forward_step, + partial(model_provider, _model_builder), ) From 31808f9e125770412fee49ee621e7e44bc952575 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 2 Jun 2026 23:23:55 -0700 Subject: [PATCH 04/20] create model config in rl script Signed-off-by: Maanu Grover --- train_rl.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_rl.py b/train_rl.py index 529184d9a13..acf54680f4a 100644 --- a/train_rl.py +++ b/train_rl.py @@ -24,7 +24,7 @@ from megatron.training import get_args, get_timers, pretrain, print_rank_0 from megatron.training.utils import is_hybrid_model from megatron.training.arguments import core_transformer_config_from_args, parse_and_validate_args -from megatron.training.argument_utils import pretrain_cfg_container_from_args +from megatron.training.argument_utils import gpt_config_from_args, hybrid_config_from_args, pretrain_cfg_container_from_args from model_provider import model_provider from megatron.core.packed_seq_params import PackedSeqParams @@ -415,7 +415,11 @@ def _model_builder( extra_args_provider=add_inference_args, args_defaults={}, ) - full_config = pretrain_cfg_container_from_args(args) + if is_hybrid_model(args): + model_cfg = hybrid_config_from_args(args) + else: + model_cfg = gpt_config_from_args(args) + full_config = pretrain_cfg_container_from_args(args, model_cfg) pretrain( full_config, None, # we don't need to build any datasets for RL training From 5a6510745d182d3090169f6859b6ea0051a2cd96 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 5 Jun 2026 11:42:46 -0700 Subject: [PATCH 05/20] temporary fallback for pg collection Signed-off-by: Maanu Grover --- megatron/training/training.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/training/training.py b/megatron/training/training.py index 8479f353041..eb2cc93aa7c 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1094,6 +1094,8 @@ def pretrain( seed_ep_group=getattr(init_pg_collection, "ep", None), seed_etp_group=getattr(init_pg_collection, "expt_tp", None), ) + # TODO (@maanug): temporary until initialize.py is refactored to build pgcollection as bridge does + pg_collection = ProcessGroupCollection.use_mpu_process_groups() timestamp_after_initialize_megatron = time.time() From aafa83272b84c2a083a67c6e1a574b2dbd77dc14 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 5 Jun 2026 13:46:33 -0700 Subject: [PATCH 06/20] integrate model builder into model+optim setup Signed-off-by: Maanu Grover --- megatron/training/training.py | 36 +++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index eb2cc93aa7c..55b1dd1c881 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1223,7 +1223,11 @@ def pretrain( # Model, optimizer, and learning rate. timers('model-and-optimizer-setup', log_level=0).start(barrier=True) model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, model_type, checkpointing_context=checkpointing_context + model_provider, + model_type, + checkpointing_context=checkpointing_context, + cfg_container=cfg_container, + pg_collection=pg_collection, ) timers('model-and-optimizer-setup').stop() @@ -1995,7 +1999,9 @@ def setup_model_and_optimizer( model_provider_func, model_type, checkpointing_context=None, - pg_collection=None, + *, + cfg_container: PretrainConfigContainer, + pg_collection: ProcessGroupCollection, ): """Setup model and optimizer.""" args = get_args() @@ -2008,9 +2014,27 @@ def setup_model_and_optimizer( has_rl_optimizer = args.perform_rl_step and not args.no_load_optim skip_optimizer = not (has_normal_optimizer or has_rl_optimizer) wrap_with_ddp = not skip_optimizer - model = get_model( - model_provider_func, model_type, wrap_with_ddp=wrap_with_ddp, pg_collection=pg_collection - ) + + def _build_model_wrapper(wrap_with_ddp: bool): + from megatron.training.utils.train_utils import start_memory_history_recording + + start_memory_history_recording(cfg_container.profiling) + + cfg = cfg_container + model_config = cfg.model + builder_cls = model_config.get_builder_cls() + builder = builder_cls(model_config) + return builder.build_distributed_models( + pg_collection=pg_collection, + ddp_config=cfg.ddp, + overlap_param_gather_with_optimizer_step=cfg.optimizer.overlap_param_gather_with_optimizer_step, + use_megatron_fsdp=cfg.dist.use_megatron_fsdp, + use_torch_fsdp2=cfg.dist.use_torch_fsdp2, + wrap_with_ddp=wrap_with_ddp, + data_parallel_random_init=cfg.rng.data_parallel_random_init, + ) + + model = _build_model_wrapper(wrap_with_ddp) unwrapped_model = unwrap_model(model) if args.logits_save_dir is not None: @@ -2083,7 +2107,7 @@ def setup_model_and_optimizer( args.ffn_hidden_size = moe_ffn_hidden_size * args.moe_upcycling_granularity # get dense model - dense_model_for_upcycling = get_model(model_provider_func, model_type) + dense_model_for_upcycling = _build_model_wrapper(wrap_with_ddp=True) # recover moe upcycling related args in global args before executing upcycling args.num_experts = num_experts From c1cf96a27665446d1b1b956f2f09a5f89c2d9716 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 2 Jun 2026 10:15:57 -0700 Subject: [PATCH 07/20] fix fsdp unit test Signed-off-by: Maanu Grover --- .../test_mcore_fully_sharded_data_parallel.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index 013b5ce4674..d1bedacfa11 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -1095,7 +1095,10 @@ def test_full_iteration_cuda_graph_e2e(self, extra_overrides): from megatron.core.rerun_state_machine import destroy_rerun_state_machine from megatron.core.transformer.enums import CudaGraphScope from megatron.training import pretrain - from megatron.training.argument_utils import pretrain_cfg_container_from_args + from megatron.training.argument_utils import ( + gpt_config_from_args, + pretrain_cfg_container_from_args, + ) from megatron.training.arguments import add_megatron_arguments, validate_args from megatron.training.global_vars import set_global_variables, unset_global_variables @@ -1199,7 +1202,8 @@ def pre_step_hook(optimizer, args_, kwargs_): args.world_size = int(os.getenv("WORLD_SIZE", "1")) validate_args(args) set_global_variables(args) - cfg = pretrain_cfg_container_from_args(args) + model_cfg = gpt_config_from_args(args) + cfg = pretrain_cfg_container_from_args(args, model_cfg) from gpt_builders import gpt_builder from model_provider import model_provider @@ -1207,7 +1211,6 @@ def pre_step_hook(optimizer, args_, kwargs_): pretrain( cfg, _pretrain_gpt.train_valid_test_datasets_provider, - partial(model_provider, gpt_builder), ModelType.encoder_or_decoder, wrapped_forward_step, get_embedding_ranks=_pretrain_gpt.get_embedding_ranks, From 663c20014dd0178e6a5e18c1ccb77c8e9a78e813 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 5 Jun 2026 23:40:59 -0700 Subject: [PATCH 08/20] remove stale assertion Signed-off-by: Maanu Grover --- megatron/training/models/hybrid.py | 8 -------- tests/unit_tests/training/models/test_hybrid_builder.py | 8 -------- 2 files changed, 16 deletions(-) diff --git a/megatron/training/models/hybrid.py b/megatron/training/models/hybrid.py index b58a70d3c01..287ca8ec2a3 100644 --- a/megatron/training/models/hybrid.py +++ b/megatron/training/models/hybrid.py @@ -160,14 +160,6 @@ def build_model( else: hybrid_stack_spec = default_hybrid_stack_spec - assert ( - getattr(self._model_config.transformer, "virtual_pipeline_model_parallel_size", None) is None - and vp_stage is None - ), ( - "Virtual pipeline model parallelism is temporarily unsupported in Hybrid " - "models due to upstream MCore HybridModel API dependency" - ) - assert self._model_config.vocab_size is not None, "vocab_size must be configured before calling build_model()" if self._model_config.should_pad_vocab: padded_vocab_size = calculate_padded_vocab_size( diff --git a/tests/unit_tests/training/models/test_hybrid_builder.py b/tests/unit_tests/training/models/test_hybrid_builder.py index d3fb7fdaf8a..9984e224ce3 100644 --- a/tests/unit_tests/training/models/test_hybrid_builder.py +++ b/tests/unit_tests/training/models/test_hybrid_builder.py @@ -319,14 +319,6 @@ def test_infers_post_process_from_pg(self, mock_model, mock_first, mock_last, *_ mock_last.assert_called_once_with(self.pg.pp) assert mock_model.call_args.kwargs["post_process"] is True - @patch("megatron.training.models.hybrid.calculate_padded_vocab_size") - @patch("megatron.training.models.hybrid.is_pp_last_stage", return_value=True) - @patch("megatron.training.models.hybrid.is_pp_first_stage", return_value=True) - @patch("megatron.training.models.hybrid.HybridModel") - def test_virtual_pipeline_raises(self, mock_model, *_): - with pytest.raises(AssertionError, match="Virtual pipeline"): - self.builder.build_model(self.pg, vp_stage=0) - @patch("megatron.training.models.hybrid.calculate_padded_vocab_size") @patch("megatron.training.models.hybrid.is_pp_last_stage", return_value=True) @patch("megatron.training.models.hybrid.is_pp_first_stage", return_value=True) From 1d8f4fa0ebfcc49c6a4a43bc748ef7a18f5dfa90 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Sat, 6 Jun 2026 01:09:00 -0700 Subject: [PATCH 09/20] mtp block spec bug fix Signed-off-by: Maanu Grover --- megatron/training/models/gpt.py | 12 ++++++++++-- tests/unit_tests/training/models/test_gpt_builder.py | 12 +++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/megatron/training/models/gpt.py b/megatron/training/models/gpt.py index 633b2ad0b27..6331c1b369d 100644 --- a/megatron/training/models/gpt.py +++ b/megatron/training/models/gpt.py @@ -88,7 +88,15 @@ def default_layer_spec(config: "GPTModelConfig", vp_stage: int) -> ModuleSpec: ) elif isinstance(transformer_cfg, HeterogeneousTransformerConfig): return get_gpt_heterogeneous_layer_spec(transformer_cfg, use_te) - elif use_te: + else: + return _te_or_local_layer_spec(config, vp_stage) + +def _te_or_local_layer_spec(config: "GPTModelConfig", vp_stage: int) -> ModuleSpec: + """Need to be able to call just these branches for mtp transformer layer spec.""" + + transformer_cfg = config.transformer + use_te = transformer_cfg.transformer_impl == "transformer_engine" + if use_te: if "use_te_op_fuser" in inspect.signature(get_gpt_layer_with_transformer_engine_spec).parameters: kwargs = {"use_te_op_fuser": config.use_transformer_engine_op_fuser} else: @@ -396,7 +404,7 @@ def mtp_block_spec( if hasattr(transformer_layer_spec, "layer_specs") and len(transformer_layer_spec.layer_specs) == 0: # Get the decoder layer spec explicitly if no decoder layer in the last stage, # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - spec = default_layer_spec(config, vp_stage) + spec = _te_or_local_layer_spec(config, vp_stage) else: decoder_specs = get_gpt_decoder_layer_specs(transformer_cfg, use_transformer_engine=use_te, normalization=transformer_cfg.normalization, qk_l2_norm=transformer_cfg.qk_l2_norm, vp_stage=vp_stage) spec = decoder_specs[-1] diff --git a/tests/unit_tests/training/models/test_gpt_builder.py b/tests/unit_tests/training/models/test_gpt_builder.py index 20603e780b7..2263525e030 100644 --- a/tests/unit_tests/training/models/test_gpt_builder.py +++ b/tests/unit_tests/training/models/test_gpt_builder.py @@ -851,19 +851,21 @@ def test_uses_explicit_spec_when_layer_specs_nonempty(self, mock_get_mtp): passed_spec = mock_get_mtp.call_args.args[1] assert passed_spec is mock_decoder_specs.return_value[-1] - @patch("megatron.training.models.gpt.default_layer_spec") + @patch("megatron.training.models.gpt._te_or_local_layer_spec") @patch("megatron.core.models.gpt.gpt_layer_specs.get_gpt_mtp_block_spec") - def test_uses_default_layer_spec_for_empty_layer_specs(self, mock_get_mtp, mock_default): + def test_uses_te_or_local_layer_spec_for_empty_layer_specs( + self, mock_get_mtp, mock_te_or_local + ): config = self._make_config(mtp_num_layers=1) spec = Mock(spec=ModuleSpec) - spec.layer_specs = [] # Empty → falls back to default_layer_spec + spec.layer_specs = [] # Empty → falls back to _te_or_local_layer_spec fallback_spec = Mock(spec=ModuleSpec) - mock_default.return_value = fallback_spec + mock_te_or_local.return_value = fallback_spec mock_get_mtp.return_value = Mock(spec=ModuleSpec) mtp_block_spec(config, spec, vp_stage=4) - mock_default.assert_called_once_with(config, 4) + mock_te_or_local.assert_called_once_with(config, 4) passed_spec = mock_get_mtp.call_args.args[1] assert passed_spec is fallback_spec From 3ef8a3ac1cfd4c4f5667ed1605add522e885a07b Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Sat, 6 Jun 2026 13:25:58 -0700 Subject: [PATCH 10/20] sync with gpt builder Signed-off-by: Maanu Grover --- megatron/training/models/gpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/training/models/gpt.py b/megatron/training/models/gpt.py index 6331c1b369d..b88ebbb66ba 100644 --- a/megatron/training/models/gpt.py +++ b/megatron/training/models/gpt.py @@ -113,6 +113,7 @@ def _te_or_local_layer_spec(config: "GPTModelConfig", vp_stage: int) -> ModuleSp use_kitchen_attention=config.transformer.use_kitchen_attention, kitchen_attention_backend=config.transformer.kitchen_attention_backend, mla_down_proj_fusion=getattr(config.transformer, "mla_down_proj_fusion", False), + use_grouped_gemm_for_dense_mlp=config.transformer.use_grouped_gemm_for_dense_mlp, **kwargs, ) else: From 36eee2cc12653afceae98fb1430f759bb3730bb5 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Wed, 17 Jun 2026 01:43:24 -0500 Subject: [PATCH 11/20] remove duplicate field Signed-off-by: Maanu Grover --- megatron/training/models/gpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/training/models/gpt.py b/megatron/training/models/gpt.py index b88ebbb66ba..2d8f1f3215b 100644 --- a/megatron/training/models/gpt.py +++ b/megatron/training/models/gpt.py @@ -178,7 +178,6 @@ class GPTModelConfig(ModelConfig): """Config file when tp_comm_overlap is enabled.""" ### settings for default layer spec options ### - use_transformer_engine_op_fuser: bool = False use_arbitrary_attention_mask: bool | None = None @override From 2e8930b685bfb7c11707c58a4f0d7e284f30f179 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 16 Jun 2026 17:40:10 -0500 Subject: [PATCH 12/20] address file naming feedback Signed-off-by: Maanu Grover --- megatron/training/training.py | 2 +- megatron/training/utils/__init__.py | 1 + megatron/training/utils/{train_utils.py => utils.py} | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) rename megatron/training/utils/{train_utils.py => utils.py} (100%) diff --git a/megatron/training/training.py b/megatron/training/training.py index 55b1dd1c881..8d4da8f2223 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -2016,7 +2016,7 @@ def setup_model_and_optimizer( wrap_with_ddp = not skip_optimizer def _build_model_wrapper(wrap_with_ddp: bool): - from megatron.training.utils.train_utils import start_memory_history_recording + from megatron.training.utils import start_memory_history_recording start_memory_history_recording(cfg_container.profiling) diff --git a/megatron/training/utils/__init__.py b/megatron/training/utils/__init__.py index d6e2fe7c246..15cd4b26d4f 100644 --- a/megatron/training/utils/__init__.py +++ b/megatron/training/utils/__init__.py @@ -27,3 +27,4 @@ ) from megatron.training.utils.log_utils import append_to_progress_log +from megatron.training.utils.utils import start_memory_history_recording diff --git a/megatron/training/utils/train_utils.py b/megatron/training/utils/utils.py similarity index 100% rename from megatron/training/utils/train_utils.py rename to megatron/training/utils/utils.py index 37647a96bd7..9cf6b27f340 100644 --- a/megatron/training/utils/train_utils.py +++ b/megatron/training/utils/utils.py @@ -1,13 +1,13 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +import logging import os -from megatron.training.utils.common_utils import print_rank_0 + import torch from megatron.core._rank_utils import safe_get_rank from megatron.training.config import ProfilingConfig - -import logging +from megatron.training.utils.common_utils import print_rank_0 logger = logging.getLogger(__name__) From ce3c57b48ea8b7ac8ffecf388249b3a71099ae0d Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 16 Jun 2026 18:12:24 -0500 Subject: [PATCH 13/20] fix model setup in unit tests Signed-off-by: Maanu Grover --- .../distributed/megatron_fsdp/utils.py | 7 ++ tests/unit_tests/test_fp8_param.py | 20 ++++-- tests/unit_tests/test_utilities.py | 20 ++++++ .../transformer/moe/test_upcycling.py | 50 +++++++++++-- .../test_multi_token_prediction.py | 71 +++++++++++++++---- 5 files changed, 146 insertions(+), 22 deletions(-) diff --git a/tests/unit_tests/distributed/megatron_fsdp/utils.py b/tests/unit_tests/distributed/megatron_fsdp/utils.py index 22b594403b1..bbde2f516ce 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/utils.py +++ b/tests/unit_tests/distributed/megatron_fsdp/utils.py @@ -8,6 +8,7 @@ from torch.utils.data.distributed import DistributedSampler from hybrid_builders import hybrid_builder +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator @@ -20,6 +21,8 @@ from megatron.training.utils import is_first_or_last_pipeline_stage from model_provider import model_provider +from tests.unit_tests.test_utilities import Utils + def pretrain_forward_backward( *, model, data_iterator, sequence_length=128, micro_batch_size=2, num_micro_batches=1 @@ -93,9 +96,13 @@ def make_moe_args_model_and_optimizer(ut_filename, **overrides): destroy_num_microbatches_calculator() set_global_variables(args, build_tokenizer=False) + cfg_container = Utils.pretrain_config_from_global_args(args, "hybrid") + pg_collection = ProcessGroupCollection.use_mpu_process_groups() model, optimizer, _ = setup_model_and_optimizer( model_provider_func=partial(model_provider, hybrid_builder), model_type=ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) return model, optimizer diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index f8e54e0c1ab..06bc9c85993 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -6,6 +6,7 @@ import sys import pytest +from megatron.core.process_groups_config import ProcessGroupCollection import torch from transformer_engine.pytorch.fp8 import check_fp8_support @@ -97,7 +98,7 @@ def model_provider( return GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, - vocab_size=args.vocal_size, + vocab_size=args.padded_vocab_size, max_sequence_length=args.max_position_embeddings, pre_process=pre_process, post_process=post_process, @@ -125,7 +126,7 @@ def create_test_args( sys.argv = ['test_fp8_param.py'] args = parse_args() args.num_layers = 4 - args.vocal_size = 128800 + args.padded_vocab_size = 128800 args.hidden_size = 128 args.num_attention_heads = 8 args.max_position_embeddings = 512 @@ -248,15 +249,24 @@ def _run_test_helper( input_ids, labels, position_ids, attention_mask, loss_mask = self.get_batch( self.seq_length, self.micro_batch_size ) + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") + pg_collection = ProcessGroupCollection.use_mpu_process_groups() if inference: - gpt_model = get_model( - self.model_provider, ModelType.encoder_or_decoder, wrap_with_ddp=False + model_cfg = cfg_container.model + builder_cls = model_cfg.get_builder_cls() + builder = builder_cls(model_cfg) + gpt_model = builder.build_distributed_models( + pg_collection=pg_collection, wrap_with_ddp=False ) gpt_model[0].eval() optimizer = None else: gpt_model, optimizer, _ = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder + self.model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) assert len(gpt_model) == 1 # Assume only one model in the model provider. diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py index 8dbc5d5a41b..7ff8aa40d27 100644 --- a/tests/unit_tests/test_utilities.py +++ b/tests/unit_tests/test_utilities.py @@ -1,12 +1,19 @@ # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import os +from typing import Literal from datetime import timedelta +from argparse import Namespace import torch from torch._C._distributed_c10d import PrefixStore from torch.distributed import rendezvous import megatron.core.parallel_state as ps +from megatron.training.argument_utils import ( + pretrain_cfg_container_from_args, + gpt_config_from_args, + hybrid_config_from_args, +) class TestModel(torch.nn.Module): @@ -134,6 +141,19 @@ def initialize_model_parallel( ) Utils.inited = True + @staticmethod + def pretrain_config_from_global_args(args: Namespace, model_class: Literal["gpt", "hybrid"]): + if model_class == "gpt": + model_cfg = gpt_config_from_args(args) + elif model_class == "hybrid": + model_cfg = hybrid_config_from_args(args) + else: + raise ValueError( + f"MCore model type {model_class} not supported. Choose one of 'gpt' or 'hybrid'." + ) + + return pretrain_cfg_container_from_args(args, model_cfg) + @staticmethod def fake_initialize_model_parallel( tensor_model_parallel_size=1, diff --git a/tests/unit_tests/transformer/moe/test_upcycling.py b/tests/unit_tests/transformer/moe/test_upcycling.py index ff4fc1ac1ce..abd0b5e4933 100644 --- a/tests/unit_tests/transformer/moe/test_upcycling.py +++ b/tests/unit_tests/transformer/moe/test_upcycling.py @@ -3,6 +3,8 @@ import sys import pytest +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.training.argument_utils import gpt_config_from_args import torch import torch.distributed @@ -93,7 +95,7 @@ def create_test_args(tp, grouped_gemm, swiglu, squared_relu, use_te): sys.argv = ['test_upcycling.py'] args = parse_args() args.num_layers = 2 - args.vocal_size = 256 + args.padded_vocab_size = 256 args.hidden_size = 128 args.num_attention_heads = 8 args.max_position_embeddings = 256 @@ -183,8 +185,17 @@ def test_upcycling_Local(self, tp_ep, granularity, grouped_gemm, swiglu, squared virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, ) + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") + cfg_container.model.transformer_layer_spec = get_gpt_layer_local_spec( + args.num_experts, args.moe_grouped_gemm, args.qk_layernorm + ) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, ModelType.encoder_or_decoder + model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) data = list(range(args.seq_length)) input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() @@ -206,7 +217,17 @@ def test_upcycling_Local(self, tp_ep, granularity, grouped_gemm, swiglu, squared ) set_upcycling_args(ep, granularity, num_experts=2) # model_parallel_cuda_manual_seed(_SEED+1) - moe_model = get_model(model_provider, ModelType.encoder_or_decoder) + model_cfg = gpt_config_from_args(args) + model_cfg.transformer_layer_spec = get_gpt_layer_local_spec( + args.num_experts, args.moe_grouped_gemm, args.qk_layernorm + ) + builder_cls = model_cfg.get_builder_cls() + builder = builder_cls(model_cfg) + moe_model = builder.build_distributed_models( + pg_collection=pg_collection, + ddp_config=cfg_container.ddp, + data_parallel_random_init=cfg_container.rng.data_parallel_random_init, + ) # Upcycle the dense model to the MoE model moe_model = unwrap_model(moe_model) @@ -254,8 +275,17 @@ def test_upcycling_TE(self, tp_ep, granularity, grouped_gemm, swiglu, squared_re virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, ) + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") + cfg_container.model.transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, args.qk_layernorm + ) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, ModelType.encoder_or_decoder + model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) data = list(range(args.seq_length)) input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() @@ -277,7 +307,17 @@ def test_upcycling_TE(self, tp_ep, granularity, grouped_gemm, swiglu, squared_re ) set_upcycling_args(ep, granularity) # model_parallel_cuda_manual_seed(_SEED+1) - moe_model = get_model(model_provider, ModelType.encoder_or_decoder) + model_cfg = gpt_config_from_args(args) + model_cfg.transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, args.qk_layernorm + ) + builder_cls = model_cfg.get_builder_cls() + builder = builder_cls(model_cfg) + moe_model = builder.build_distributed_models( + pg_collection=pg_collection, + ddp_config=cfg_container.ddp, + data_parallel_random_init=cfg_container.rng.data_parallel_random_init, + ) # Upcycle the dense model to the MoE model moe_model = unwrap_model(moe_model) diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index 8b1bbb385d4..51b5a22a98a 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -5,6 +5,7 @@ import types import pytest +from megatron.training.argument_utils import gpt_config_from_args, hybrid_config_from_args import torch from megatron.core.enums import ModelType @@ -477,7 +478,7 @@ def create_test_args( args.num_layers = 2 args.mtp_num_layers = 2 args.mtp_loss_scaling_factor = 0.1 - args.vocab_size = 128800 + args.padded_vocab_size = 128800 args.hidden_size = 128 args.num_attention_heads = 8 args.max_position_embeddings = 256 @@ -609,8 +610,15 @@ def test_sharded_state_dict(self, tp, cp): set_args(args) torch.manual_seed(_SEED) Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - gpt_model = get_model(self.model_provider, ModelType.encoder_or_decoder) - gpt_model = unwrap_model(gpt_model) + + model_parallel_cuda_manual_seed(_SEED) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + model_cfg = gpt_config_from_args(args) + builder_cls = model_cfg.get_builder_cls() + builder = builder_cls(model_cfg) + gpt_model = builder.build_distributed_models( + pg_collection=pg_collection, wrap_with_ddp=False + ) sharded_state_dict = gpt_model[0].sharded_state_dict() for i in range(args.mtp_num_layers): assert f"mtp.layers.{i}.enorm.weight" in sharded_state_dict.keys() @@ -789,8 +797,14 @@ def test_packed_sequences(self, tp, cp): packed_seq_params = batch['packed_seq_params'] # Create model + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") + pg_collection = ProcessGroupCollection.use_mpu_process_groups() gpt_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder + self.model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) # Forward pass with packed sequences @@ -850,8 +864,15 @@ def test_packed_sequences_with_full_recompute(self): Utils.initialize_model_parallel(tensor_model_parallel_size=1, context_parallel_size=1) batch = self.get_packed_batch(seq_lengths, micro_batch_size=1) + + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") + pg_collection = ProcessGroupCollection.use_mpu_process_groups() gpt_model, _, _ = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder + self.model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) output = gpt_model[0].forward( @@ -1291,7 +1312,7 @@ def create_test_args( args = parse_args() args.mtp_num_layers = 2 args.mtp_loss_scaling_factor = 0.1 - args.vocab_size = 128800 + args.padded_vocab_size = 128800 args.hidden_size = 128 args.num_attention_heads = 8 args.num_query_groups = 8 @@ -1315,7 +1336,6 @@ def create_test_args( args.bf16 = True # Unified pattern: "main/mtp/mtp" - main decoder "M*M*", MTP pattern "M*" with 2 depths args.hybrid_layer_pattern = "M*M*/M*/M*" - args.spec = "megatron.core.models.hybrid.hybrid_layer_specs.hybrid_stack_spec" if fp8 is not None: args.fp8 = 'e4m3' @@ -1358,8 +1378,15 @@ def test_sharded_state_dict_mamba(self, tp, cp): set_args(args) torch.manual_seed(_SEED) Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) - mamba_model = get_model(self.model_provider, ModelType.encoder_or_decoder) - mamba_model = unwrap_model(mamba_model) + + model_parallel_cuda_manual_seed(_SEED) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + model_cfg = hybrid_config_from_args(args) + builder_cls = model_cfg.get_builder_cls() + builder = builder_cls(model_cfg) + mamba_model = builder.build_distributed_models( + pg_collection=pg_collection, wrap_with_ddp=False + ) sharded_state_dict = mamba_model[0].sharded_state_dict() # Verify MTP layers are in the state dict @@ -1383,8 +1410,14 @@ def test_forward_backward_mamba(self, tmp_path_dist_ckpt, tp, cp): batch = self.get_batch(self.seq_length, self.micro_batch_size) tokens, labels, loss_mask, attention_mask, position_ids = batch.values() + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "hybrid") + pg_collection = ProcessGroupCollection.use_mpu_process_groups() mamba_model_ref, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder + self.model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) output_ref = mamba_model_ref[0].forward( @@ -1426,8 +1459,15 @@ def set_ckpt_path(ckpt_path): set_ckpt_path(ckpt_dir) torch.manual_seed(_SEED) Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) + + model_parallel_cuda_manual_seed(_SEED) + cfg_container = Utils.pretrain_config_from_global_args(args, "hybrid") + pg_collection = ProcessGroupCollection.use_mpu_process_groups() mamba_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder + self.model_provider, + ModelType.encoder_or_decoder, + cfg_container=cfg_container, + pg_collection=pg_collection, ) load_checkpoint(mamba_model, optimizer, opt_param_scheduler, strict=False) @@ -1471,8 +1511,15 @@ def test_attention_mask_validation_mamba(self): set_args(args) torch.manual_seed(_SEED) Utils.initialize_model_parallel(tensor_model_parallel_size=tp, context_parallel_size=cp) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + model_cfg = hybrid_config_from_args(args) + builder_cls = model_cfg.get_builder_cls() + builder = builder_cls(model_cfg) try: - mamba_model = get_model(self.model_provider, ModelType.encoder_or_decoder) + model_parallel_cuda_manual_seed(_SEED) + mamba_model = builder.build_distributed_models( + pg_collection=pg_collection, wrap_with_ddp=False + ) mamba_model = unwrap_model(mamba_model) assert isinstance(mamba_model[0], HybridModel) assert mamba_model[0].mtp is not None From 72ed6148d7b651439a821340eb07fff2c09119a2 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Tue, 16 Jun 2026 21:42:50 -0500 Subject: [PATCH 14/20] formatting Signed-off-by: Maanu Grover --- tests/unit_tests/distributed/megatron_fsdp/utils.py | 3 +-- tests/unit_tests/test_fp8_param.py | 2 +- tests/unit_tests/test_utilities.py | 6 +++--- tests/unit_tests/transformer/moe/test_upcycling.py | 4 ++-- tests/unit_tests/transformer/test_multi_token_prediction.py | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/unit_tests/distributed/megatron_fsdp/utils.py b/tests/unit_tests/distributed/megatron_fsdp/utils.py index bbde2f516ce..4b4b785f3f9 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/utils.py +++ b/tests/unit_tests/distributed/megatron_fsdp/utils.py @@ -8,11 +8,11 @@ from torch.utils.data.distributed import DistributedSampler from hybrid_builders import hybrid_builder -from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.utils import get_attr_wrapped_model from megatron.training.arguments import parse_args, validate_args @@ -20,7 +20,6 @@ from megatron.training.training import setup_model_and_optimizer from megatron.training.utils import is_first_or_last_pipeline_stage from model_provider import model_provider - from tests.unit_tests.test_utilities import Utils diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 06bc9c85993..4d9ec6b0943 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -6,7 +6,6 @@ import sys import pytest -from megatron.core.process_groups_config import ProcessGroupCollection import torch from transformer_engine.pytorch.fp8 import check_fp8_support @@ -17,6 +16,7 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.utils import is_te_min_version from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args diff --git a/tests/unit_tests/test_utilities.py b/tests/unit_tests/test_utilities.py index 7ff8aa40d27..9529a419938 100644 --- a/tests/unit_tests/test_utilities.py +++ b/tests/unit_tests/test_utilities.py @@ -1,8 +1,8 @@ # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import os -from typing import Literal -from datetime import timedelta from argparse import Namespace +from datetime import timedelta +from typing import Literal import torch from torch._C._distributed_c10d import PrefixStore @@ -10,9 +10,9 @@ import megatron.core.parallel_state as ps from megatron.training.argument_utils import ( - pretrain_cfg_container_from_args, gpt_config_from_args, hybrid_config_from_args, + pretrain_cfg_container_from_args, ) diff --git a/tests/unit_tests/transformer/moe/test_upcycling.py b/tests/unit_tests/transformer/moe/test_upcycling.py index abd0b5e4933..bcf8e568228 100644 --- a/tests/unit_tests/transformer/moe/test_upcycling.py +++ b/tests/unit_tests/transformer/moe/test_upcycling.py @@ -3,8 +3,6 @@ import sys import pytest -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.training.argument_utils import gpt_config_from_args import torch import torch.distributed @@ -17,6 +15,7 @@ ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP @@ -26,6 +25,7 @@ is_te_min_version, unwrap_model, ) +from megatron.training.argument_utils import gpt_config_from_args from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args from megatron.training.global_vars import ( destroy_global_vars, diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index 51b5a22a98a..61d6ad33646 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -5,7 +5,6 @@ import types import pytest -from megatron.training.argument_utils import gpt_config_from_args, hybrid_config_from_args import torch from megatron.core.enums import ModelType @@ -32,6 +31,7 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_batch_on_this_cp_rank, is_te_min_version, unwrap_model +from megatron.training.argument_utils import gpt_config_from_args, hybrid_config_from_args from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args from megatron.training.checkpointing import load_checkpoint, save_checkpoint from megatron.training.global_vars import ( From b19952df2196c4ee915662a2a30717b61c25716a Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 26 Jun 2026 17:35:54 -0700 Subject: [PATCH 15/20] cleanup access of hybrid builder Signed-off-by: Maanu Grover --- pretrain_hybrid.py | 1 - tests/unit_tests/distributed/megatron_fsdp/utils.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/pretrain_hybrid.py b/pretrain_hybrid.py index e4637f61f9b..e9646e51ef8 100644 --- a/pretrain_hybrid.py +++ b/pretrain_hybrid.py @@ -22,7 +22,6 @@ import torch -from hybrid_builders import hybrid_builder from megatron.core import mpu from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset diff --git a/tests/unit_tests/distributed/megatron_fsdp/utils.py b/tests/unit_tests/distributed/megatron_fsdp/utils.py index 4b4b785f3f9..f0a8e8ec7c9 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/utils.py +++ b/tests/unit_tests/distributed/megatron_fsdp/utils.py @@ -7,7 +7,6 @@ from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler -from hybrid_builders import hybrid_builder from megatron.core.distributed import finalize_model_grads from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator @@ -98,7 +97,6 @@ def make_moe_args_model_and_optimizer(ut_filename, **overrides): cfg_container = Utils.pretrain_config_from_global_args(args, "hybrid") pg_collection = ProcessGroupCollection.use_mpu_process_groups() model, optimizer, _ = setup_model_and_optimizer( - model_provider_func=partial(model_provider, hybrid_builder), model_type=ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection, From 5640814c2f79e73c609d239246d29e68186edac5 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 26 Jun 2026 17:47:40 -0700 Subject: [PATCH 16/20] use model builder for rl inference model Signed-off-by: Maanu Grover --- megatron/training/training.py | 11 +++++------ train_rl.py | 24 ------------------------ 2 files changed, 5 insertions(+), 30 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 8d4da8f2223..8ebdf0f218e 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1268,7 +1268,7 @@ def pretrain( ) # Build an isolated inference config so training config remains unchanged - inference_config = copy.deepcopy(model_cfg) + inference_config = copy.deepcopy(cfg_container.model) if args.rl_inference_tensor_model_parallel_size is not None: inference_config.tensor_model_parallel_size = args.rl_inference_tensor_model_parallel_size if args.rl_inference_pipeline_model_parallel_size is not None: @@ -1312,12 +1312,11 @@ def pretrain( model_alloc_ctx = nullcontext() with model_alloc_ctx: - inference_model = get_model( - model_provider, - model_type, - wrap_with_ddp=False, + builder_cls = inference_config.get_builder_cls() + builder = builder_cls(inference_config) + inference_model = builder.build_distributed_models( pg_collection=inference_pg_collection, - config=inference_config, + wrap_with_ddp=False, ) inference_model[0].eval() diff --git a/train_rl.py b/train_rl.py index acf54680f4a..0b735ae4e00 100644 --- a/train_rl.py +++ b/train_rl.py @@ -8,7 +8,6 @@ import torch from gpt_builders import gpt_builder -from hybrid_builders import hybrid_builder from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel @@ -389,28 +388,6 @@ def __getitem__(self, idx): # Temporary for transition to core datasets train_valid_test_datasets_provider.is_distributed = True - def _model_builder( - args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None - ): - if is_hybrid_model(args): - return hybrid_builder( - args, - pre_process, - post_process, - vp_stage, - config=config, - pg_collection=pg_collection, - ) - else: - return _gpt_builder( - args, - pre_process, - post_process, - vp_stage, - config=config, - pg_collection=pg_collection, - ) - args = parse_and_validate_args( extra_args_provider=add_inference_args, args_defaults={}, @@ -425,5 +402,4 @@ def _model_builder( None, # we don't need to build any datasets for RL training ModelType.encoder_or_decoder, forward_step, - partial(model_provider, _model_builder), ) From 9738391fb69a6b94c4e22495dd4d8b24d9035449 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 26 Jun 2026 18:31:24 -0700 Subject: [PATCH 17/20] remove model provider from setup model and optimizer Signed-off-by: Maanu Grover --- megatron/training/training.py | 2 -- tests/unit_tests/test_fp8_param.py | 1 - .../unit_tests/transformer/moe/test_upcycling.py | 10 ++-------- .../transformer/test_multi_token_prediction.py | 16 +++------------- 4 files changed, 5 insertions(+), 24 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 8ebdf0f218e..875977a43db 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1223,7 +1223,6 @@ def pretrain( # Model, optimizer, and learning rate. timers('model-and-optimizer-setup', log_level=0).start(barrier=True) model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, model_type, checkpointing_context=checkpointing_context, cfg_container=cfg_container, @@ -1995,7 +1994,6 @@ def get_megatron_ddp_config(args: argparse.Namespace) -> DistributedDataParallel def setup_model_and_optimizer( - model_provider_func, model_type, checkpointing_context=None, *, diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 4d9ec6b0943..436f39fcbdb 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -263,7 +263,6 @@ def _run_test_helper( optimizer = None else: gpt_model, optimizer, _ = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection, diff --git a/tests/unit_tests/transformer/moe/test_upcycling.py b/tests/unit_tests/transformer/moe/test_upcycling.py index bcf8e568228..24a757346bc 100644 --- a/tests/unit_tests/transformer/moe/test_upcycling.py +++ b/tests/unit_tests/transformer/moe/test_upcycling.py @@ -192,10 +192,7 @@ def test_upcycling_Local(self, tp_ep, granularity, grouped_gemm, swiglu, squared ) pg_collection = ProcessGroupCollection.use_mpu_process_groups() dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, - ModelType.encoder_or_decoder, - cfg_container=cfg_container, - pg_collection=pg_collection, + ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection ) data = list(range(args.seq_length)) input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() @@ -282,10 +279,7 @@ def test_upcycling_TE(self, tp_ep, granularity, grouped_gemm, swiglu, squared_re ) pg_collection = ProcessGroupCollection.use_mpu_process_groups() dense_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - model_provider, - ModelType.encoder_or_decoder, - cfg_container=cfg_container, - pg_collection=pg_collection, + ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection ) data = list(range(args.seq_length)) input_ids = torch.tensor(data, dtype=torch.int64).repeat((args.micro_batch_size, 1)).cuda() diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index 61d6ad33646..cf346d0c87c 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -801,10 +801,7 @@ def test_packed_sequences(self, tp, cp): cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") pg_collection = ProcessGroupCollection.use_mpu_process_groups() gpt_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, - ModelType.encoder_or_decoder, - cfg_container=cfg_container, - pg_collection=pg_collection, + ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection ) # Forward pass with packed sequences @@ -869,10 +866,7 @@ def test_packed_sequences_with_full_recompute(self): cfg_container = Utils.pretrain_config_from_global_args(args, "gpt") pg_collection = ProcessGroupCollection.use_mpu_process_groups() gpt_model, _, _ = setup_model_and_optimizer( - self.model_provider, - ModelType.encoder_or_decoder, - cfg_container=cfg_container, - pg_collection=pg_collection, + ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection ) output = gpt_model[0].forward( @@ -1414,10 +1408,7 @@ def test_forward_backward_mamba(self, tmp_path_dist_ckpt, tp, cp): cfg_container = Utils.pretrain_config_from_global_args(args, "hybrid") pg_collection = ProcessGroupCollection.use_mpu_process_groups() mamba_model_ref, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, - ModelType.encoder_or_decoder, - cfg_container=cfg_container, - pg_collection=pg_collection, + ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection ) output_ref = mamba_model_ref[0].forward( @@ -1464,7 +1455,6 @@ def set_ckpt_path(ckpt_path): cfg_container = Utils.pretrain_config_from_global_args(args, "hybrid") pg_collection = ProcessGroupCollection.use_mpu_process_groups() mamba_model, optimizer, opt_param_scheduler = setup_model_and_optimizer( - self.model_provider, ModelType.encoder_or_decoder, cfg_container=cfg_container, pg_collection=pg_collection, From f2a967e444d42d0da97a4f276096317d9bfc14f1 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 26 Jun 2026 18:45:10 -0700 Subject: [PATCH 18/20] remove model provider arg from pretrain Signed-off-by: Maanu Grover --- megatron/training/training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 875977a43db..0685e39298f 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1006,7 +1006,6 @@ def pretrain( train_valid_test_dataset_provider, model_type, forward_step_func, - model_provider=None, process_non_loss_data_func=None, get_embedding_ranks=None, get_position_embedding_ranks=None, From 429b3ab9972fb04aba44950a7fb05bf84d387b6a Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 26 Jun 2026 18:52:52 -0700 Subject: [PATCH 19/20] clean up imports of gpt builder Signed-off-by: Maanu Grover --- pretrain_gpt.py | 1 - .../megatron_fsdp/test_mcore_fully_sharded_data_parallel.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index aa59fdaaa51..2356427a71e 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -23,7 +23,6 @@ import torch -from gpt_builders import gpt_builder from megatron.core import mpu from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index d1bedacfa11..ed7ec160900 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -1205,7 +1205,6 @@ def pre_step_hook(optimizer, args_, kwargs_): model_cfg = gpt_config_from_args(args) cfg = pretrain_cfg_container_from_args(args, model_cfg) - from gpt_builders import gpt_builder from model_provider import model_provider pretrain( From 8ae2ad62d5860b80dd748918435a8e1b4b195d83 Mon Sep 17 00:00:00 2001 From: Maanu Grover Date: Fri, 26 Jun 2026 18:54:41 -0700 Subject: [PATCH 20/20] remove model builder funcs Signed-off-by: Maanu Grover --- gpt_builders.py | 84 ---------------------------------------------- hybrid_builders.py | 53 ----------------------------- mamba_builders.py | 15 --------- 3 files changed, 152 deletions(-) delete mode 100644 hybrid_builders.py delete mode 100644 mamba_builders.py diff --git a/gpt_builders.py b/gpt_builders.py index 2f3a8c3aff7..ef0ff4cf2ee 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -49,90 +49,6 @@ def _set_if_missing(attr: str, value) -> None: _set_if_missing('yarn_correction_range_round_to_int', args.yarn_correction_range_round_to_int) -def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): - print_rank_0('building GPT model ...') - if config is None: - if args.yaml_cfg is not None: - config = core_transformer_config_from_yaml(args, "language_model") - else: - config = core_transformer_config_from_args(args) - _apply_yarn_config_from_args(config, args) - if args.spec is not None: - transformer_layer_spec = import_module(args.spec) - else: - use_te = args.transformer_impl == "transformer_engine" - - if args.experimental_attention_variant is not None: - transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec( - config=config, vp_stage=vp_stage - ) - elif args.num_experts: - # Define the decoder block spec - transformer_layer_spec = get_gpt_decoder_block_spec( - config, - use_transformer_engine=use_te, - normalization=args.normalization, - qk_l2_norm=args.qk_l2_norm, - vp_stage=vp_stage, - ) - elif args.heterogeneous_layers_config_path is not None: - assert not (config.transformer_impl == "inference_optimized") - transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) - else: - # Define the decoder layer spec - transformer_layer_spec = _get_transformer_layer_spec(use_te, config) - mtp_block_spec = None - if args.mtp_num_layers is not None: - assert not (config.transformer_impl == "inference_optimized") - if ( - hasattr(transformer_layer_spec, 'layer_specs') - and len(transformer_layer_spec.layer_specs) == 0 - ): - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config) - else: - # Define the decoder block spec - if args.experimental_attention_variant is not None: - decoder_layer_specs = ( - get_transformer_layer_with_experimental_attention_variant_spec(config=config) - ) - else: - decoder_layer_specs = get_gpt_decoder_layer_specs( - config, - use_transformer_engine=use_te, - normalization=args.normalization, - qk_l2_norm=args.qk_l2_norm, - vp_stage=vp_stage, - ) - transformer_layer_spec_for_mtp = decoder_layer_specs[-1] - # Use spec of the last layer in decoder block as spec of the transformer layer in MTP - mtp_block_spec = get_gpt_mtp_block_spec( - config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, vp_stage=vp_stage - ) - - model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - rotary_base=args.rotary_base, - rope_scaling=args.use_rope_scaling, - mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, - pg_collection=pg_collection, - ) - - return model - - def _get_transformer_layer_spec(use_te, config): """Get transformer layer specification based on configuration. diff --git a/hybrid_builders.py b/hybrid_builders.py deleted file mode 100644 index 7e1c58682ac..00000000000 --- a/hybrid_builders.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. - -from model_provider import count_parameters_in_layer -from megatron.core.models.hybrid.hybrid_model import HybridModel -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.spec_utils import import_module -from megatron.training import print_rank_0 -from megatron.training.arguments import core_transformer_config_from_args -from megatron.core.models.hybrid.hybrid_layer_specs import hybrid_inference_stack_spec - - -def hybrid_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None): - print_rank_0('building Hybrid model ...') - if config is None: - config = core_transformer_config_from_args(args, TransformerConfig) - - if config.transformer_impl == "inference_optimized": - hybrid_stack_spec = hybrid_inference_stack_spec - assert ( - not config.inference_fuse_tp_communication - ), "inference_fuse_tp_communication is not supported for HybridModel" - elif args.spec is not None: - hybrid_stack_spec = import_module(args.spec) - else: - raise ValueError("You must provide a valid hybrid layer spec via --spec") - - model = HybridModel( - config=config, - hybrid_stack_spec=hybrid_stack_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - hybrid_layer_pattern=args.hybrid_layer_pattern, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent, - rotary_base=args.rotary_base, - pg_collection=pg_collection, - vp_stage=vp_stage, - ) - - for l in range(model.decoder.num_layers_per_pipeline_rank): - layer_params = count_parameters_in_layer(model, f'decoder.layers.{l}.') - print_rank_0(f" == params layer {l}: {layer_params}") - - return model - - -# Backward-compatible alias -mamba_builder = hybrid_builder diff --git a/mamba_builders.py b/mamba_builders.py deleted file mode 100644 index f824fce9be3..00000000000 --- a/mamba_builders.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved. -"""Backward-compatible re-export of hybrid_builders. - -Deprecated. Use hybrid_builders instead. -""" -import warnings - -warnings.warn( - "mamba_builders has been deprecated. Use hybrid_builders instead.", - DeprecationWarning, - stacklevel=2, -) - -from hybrid_builders import * # noqa: F401,F403 -from hybrid_builders import hybrid_builder as mamba_builder # noqa: F401