Skip to content

Add hetero MIMO (Nemotron6-MoE VLM) training entrypoint on the stock pretrain loop#5504

Draft
yashaswikarnati wants to merge 19 commits into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/mimo-hetero-training-entry
Draft

Add hetero MIMO (Nemotron6-MoE VLM) training entrypoint on the stock pretrain loop#5504
yashaswikarnati wants to merge 19 commits into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/mimo-hetero-training-entry

Conversation

@yashaswikarnati

Copy link
Copy Markdown
Contributor

Adds the heterogeneous MIMO (Nemotron6-MoE VLM) training example layer that drives the stock pretrain() loop across disjoint vision/language process-group grids. This is the example layer on top of the already-merged MIMO core/example PRs (#5374, #5375, #5376, #5285, #5286, #5486).

What's here

  • examples/mimo/pretrain_mimo.py — entrypoint. Builds the per-rank runtime, then calls pretrain() with a real model provider (mimo_model_provider), a named MimoSetup (setup_model_and_optimizer_func) that does per-submodule DDP + MimoOptimizer + resume-load, the cross-module p2p communicator, and the multi-module schedule PGC. Uses skip_model_parallel_init=True and writes no parallel_state globals.
  • examples/mimo/training/bootstrap.pybuild_mimo_runtime + the shared mimo_model_provider. RNG is seeded by pretrain()/initialize_megatron from the schedule PGC (no bespoke seeding).
  • examples/mimo/training/data.py — mock VLM data iterator (mock text tokens are kept off the image placeholder id so embedding alignment never miscounts).
  • examples/mimo/scripts/run_hetero_nemotron_20l_mock_train.sh — 8-GPU launch.
  • tests/unit_tests/test_mimo_hetero_data.py.

Design notes

  • The MimoModel spans disjoint grids and needs per-submodule DDP + a chained per-grid MimoOptimizer, which stock get_model/get_megatron_optimizer cannot express — hence the MimoSetup seam (named, not an inline closure). The model provider is real and get_model-shaped; build_mimo_runtime is its single call site.

🤖 Generated with Claude Code

Run the heterogeneous Nemotron6-MoE VLM through the stock pretrain() loop on
disjoint vision/language grids. Adds the example layer on top of the merged
MIMO core/example PRs (NVIDIA#5374, NVIDIA#5375, NVIDIA#5376, NVIDIA#5285, NVIDIA#5286, NVIDIA#5486):

- pretrain_mimo.py: entrypoint with a real model provider + named MimoSetup
  (per-submodule DDP + MimoOptimizer + resume), cross-module p2p communicator,
  and the multi-module schedule PGC. skip_model_parallel_init; no globals.
- training/bootstrap.py: build_mimo_runtime + shared mimo_model_provider; RNG
  seeded by pretrain()/initialize_megatron from the schedule PGC.
- training/data.py: mock VLM data iterator (text tokens kept off image id).
- scripts/run_hetero_nemotron_20l_mock_train.sh: 8-GPU launch.
- tests/unit_tests/test_mimo_hetero_data.py.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

The merged provider (NVIDIA#5374) exposes add_model_provider_args + config/spec
builders but no prepare_/validate_model_provider_args; architecture flags are
passed on the CLI and validation is validate_args + validate_hetero_grid_args.
Drop the stale prepare/validate calls and their imports.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Comment thread examples/mimo/training/bootstrap.py Outdated
) -> MultiModulePipelineCommunicator:
"""Build the MIMO cross-module P2P communicator the train schedule uses.

The vision encoder emits a 2D ``[B*S, H]`` activation, so its

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove verbose docsstrings, be concise

Comment thread examples/mimo/training/bootstrap.py Outdated
specs = build_module_grid_specs(args, world_size)
topology = create_topology(specs)

# --- 2. Resolve this rank's role and build the bare model. -------------

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove verbose commentary

Comment thread examples/mimo/training/bootstrap.py Outdated


@dataclass
class MimoRuntime:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move these all to existing runtime folder

yashaswikarnati and others added 6 commits June 25, 2026 13:33
The 20L mock launcher passed seven args the entrypoint never registered. Of
those, only --image-token-id is consumed (data.py / bootstrap.py); register it
in add_data_args. The other six (--training-stage, --encoder-pp, --llm-expt-dp,
--num-image-tiles, --tokenizer-prompt-format, --image-token) have no reader in
the mock path -- encoder PP is hardcoded to 1 and the mock run builds no
tokenizer -- so drop them from the launcher rather than ship dead CLI surface.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
The merged build_module_grid_specs (args.py) takes a third positional
encoder_module_name that names the encoder ModuleGridSpec; the bootstrap caller
still used the stale two-arg form. Pass the canonical RADIO_ENCODER_MODULE_NAME
(deriving it from the topology would be circular, as the topology is built from
these specs).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
build_mimo_runtime constructs the MimoModel eagerly, before pretrain() runs
initialize_megatron's seeding. Weight init forks the 'model-parallel-rng'
tracker, which model_parallel_cuda_manual_seed must have added first, so seed
each rank's single module role here (threading that role's parallel groups,
since these ranks have no initialized mpu). pretrain's later re-seed resets the
tracker, so the double seed is harmless.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
MockVLMIterator reads args.image_seq_length and args.fp32, neither registered
nor set by core, so the mock data build raised AttributeError. Register
--image-seq-length in add_data_args (the reader already falls back to
seq_length // 2 when unset) and read fp32 via getattr, matching the optional
getattr reads alongside it.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
…rain

pretrain() hardcodes setup_model_and_optimizer, which builds a single
get_model + get_megatron_optimizer. Heterogeneous MIMO needs a per-submodule
DDP model and a chained per-grid optimizer that path cannot produce. Add an
optional setup_model_and_optimizer_func that, when provided, replaces the stock
setup; it defaults to None so existing callers are unaffected.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
_get_pg_collection_for_optimizer queried ['tp','ep','pp'] and ['dp','ep'] on the
base view, but 'ep' lives only in the grid's expert view, so the lookup raised
"'ep' is not in view 'base'" for every grid (and would have used the wrong, base
'tp' for experts even if it resolved). Fetch tp_ep_pp/expt_dp from the expert
view (['expt_tp','ep','pp'] and 'expt_dp'), matching the schedule pg-collection,
and span grad stats over the full base grid, which already covers every rank the
expert view re-views. Fixes the build for both the MoE language grid and a
non-expert encoder grid (degenerate ep == 1).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Comment thread examples/mimo/training/bootstrap.py Outdated
# --- 3. Seed this rank's single module role BEFORE building the model. -
# The model is built eagerly here, before pretrain()/initialize_megatron runs,
# so weight init's get_cuda_rng_tracker().fork() needs the per-role seed now.
# pretrain's later re-seed resets the tracker, so the double seed is harmless.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do double seed? does not make sense ? pretrain function anyway handles the seeding?

Comment thread examples/mimo/training/bootstrap.py Outdated
# --- 5. DDP-wrap the active submodules (per-submodule DDP). ------------
wrap_active_modules_with_ddp(args, mimo_model, topology)

# --- 6. Attach this rank's per-module PGC for the stock train() path. --

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these verbose commentary

Comment thread megatron/core/models/mimo/optimizer.py Outdated
# This mirrors standard Megatron's intra_distributed_optimizer_instance_group which
# spans the full world when num_distributed_optimizer_instances == 1.
pg.intra_dist_opt = grid.get_pg(["tp", "cp", "ep", "pp", "dp"])
# Expert groups. 'ep' belongs to the grid's expert view (a re-factorization of

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove verbose commentary

--adam-beta2 0.95 \
--clip-grad 1.0 \
--ddp-bucket-size 0 \
--no-ckpt-fully-parallel-save \

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why --no-ckpt-fully-parallel-save ? can we try with ckpt fully save ?

Comment thread examples/mimo/pretrain_mimo.py Outdated
args.padded_vocab_size = vocab_size_with_padding(args.vocab_size, args)
args.tensor_model_parallel_size = tp
args.dataloader_type = "external" # per-rank iterator passed through
args.eval_iters = 0 # train-only; positive eval_interval avoids None-division below

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do train only ? and why hard code eval iters and eval interval ?

Comment thread examples/mimo/pretrain_mimo.py Outdated
args.mtp_num_layers = 0
if getattr(args, "padded_vocab_size", None) is None:
# No tokenizer in the mock run; pad the vocab for the language TP shard.
tp = args.tensor_model_parallel_size

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tp = args.tensor_model_parallel_size
args.tensor_model_parallel_size = args.llm_tp
args.tensor_model_parallel_size = tp

this is non sense

Comment thread examples/mimo/pretrain_mimo.py Outdated
per-submodule DDP and a chained per-grid optimizer that neither stock path provides.
"""

def __init__(self, rt, args: argparse.Namespace):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rt as an arg name is not informative and not clear. also can we add the type hint ?

@yashaswikarnati yashaswikarnati force-pushed the ykarnati/mimo-hetero-training-entry branch from 02dbcdd to 6d9d6e7 Compare June 26, 2026 01:01
yashaswikarnati and others added 4 commits June 25, 2026 23:57
The vision submodule's inner encoders dict is keyed by the encoder module name
(the provider registers it as {RADIO_ENCODER_MODULE_NAME: spec}), but the mock
data nested its encoder tensor under a fixed "clip_encoder" key. MimoModel.encode
then raised "No inputs found for encoder 'radio_encoder'" inside the encoder's
forward, which the cross-grid teardown masked as a hang. Default the inner key to
the encoder module name so it matches self.encoders.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
RADIOEncoderWrapper.forward takes pixel inputs (x/imgs_sizes/packed_seq_params),
not precomputed hidden_states, so the "hidden_states" mock default raised
TypeError inside the encoder forward (masked by the cross-grid teardown as a
hang). Default the mock vision input mode to "pixels" and run the encoder in
dynamic-resolution mode (launcher --dynamic-resolution), whose packed-patch
builder sizes per-image tokens to sum to image_seq_length.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
The dynamic-resolution mock builds 4x the patches expecting RADIO's 0.5x/axis
pixel shuffle to reduce them back to image_seq_length, and assumes class tokens
are dropped. RADIO applies neither by default, so it emitted 16392 tokens
(16384 patches + 8 class tokens) against 4096 image placeholders. Pass
--pixel-shuffle and --disable-vision-class-token so RADIO emits exactly
image_seq_length (4096) embeddings.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
train() unconditionally overwrote config.finalize_model_grads_func with the stock
implementation, clobbering the heterogeneous MIMO grad-sync hook installed before
pretrain(); stock finalize then ran with the multimodule schedule pg_collection
(no tp group) and asserted. Only install the stock default when none was set.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/mimo-hetero-training-entry branch from e1dab28 to d07d678 Compare June 26, 2026 06:57
Comment thread megatron/training/training.py Outdated
if len(model) == 1:
config.param_sync_func = config.param_sync_func[0]
config.finalize_model_grads_func = finalize_model_grads
# Preserve a caller-installed finalize hook (e.g. heterogeneous MIMO grad sync);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment - # Preserve a caller-installed finalize hook (e.g. heterogeneous MIMO grad sync);
# only default to the stock implementation when none was set.

Comment thread megatron/core/models/mimo/optimizer.py Outdated
pg.tp_ep_pp = grid.get_pg(["expt_tp", "ep", "pp"], view=_EXPERT_VIEW)
pg.expt_dp = grid.get_pg("expt_dp", view=_EXPERT_VIEW)

# Distributed optimizer grad stats group: must span all ranks holding a unique

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove verbose commentary

Comment thread megatron/core/models/mimo/optimizer.py Outdated
if TYPE_CHECKING:
from megatron.core.hyper_comm_grid import HyperCommGrid

# Name of the grid's expert rank view; must match the view registered on the grid.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove - # Name of the grid's expert rank view; must match the view registered on the grid.

Comment thread megatron/core/models/mimo/optimizer.py Outdated
grid.create_pg(["dp", "ep"])
grid.create_pg(["tp", "cp", "ep", "pp", "dp"])
grid.create_pg(["tp", "cp", "dp", "pp"])
grid.create_pg(["expt_tp", "ep", "pp"], view="expert")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be _EXPERT_VIEW ? instead of "expert"?

… runs

Per-GPU throughput (training_log and the progress-log job/cumulative throughput)
was normalized by args.world_size, which for heterogeneous MIMO includes the
vision-encoder ranks that bear none of the language FLOPs, skewing TFLOP/s.
Thread an optional throughput_world_size through training_log,
save_checkpoint_and_time, checkpoint_and_decide_exit and
compute_throughputs_and_append_to_progress_log; train() passes the language grid
size (mp x dp_cp). Defaults to None -> args.world_size, so non-MIMO callers are
unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
yashaswikarnati and others added 2 commits June 26, 2026 00:34
Trim verbose docstrings/comments in bootstrap.py, optimizer.py and the finalize
guard; compute the mock padded vocab directly from llm_tp (drop the tp swap);
name the MimoSetup runtime arg and type-hint it.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Merge the per-rank assembly (MimoRuntime, mimo_model_provider, build_mimo_runtime,
the role/seed helpers and the p2p-communicator builder) into runtime.py, which
already held the DDP-wrap helpers, and drop bootstrap.py. This removes the
bootstrap<->runtime import edge and keeps all per-rank runtime setup in one place.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/mimo-hetero-training-entry branch from 9c74cc8 to 1eac569 Compare June 26, 2026 17:27
yashaswikarnati and others added 2 commits June 26, 2026 10:45
… DP group

report_memory in save_checkpoint_and_time read the global data-parallel rank via
parallel_state, which heterogeneous MIMO never initializes
(skip_model_parallel_init), so any --save asserted before save_checkpoint ran.
Pass the module DP group (already resolved for save_checkpoint) to both
report_memory calls; defaults to None -> stock mpu read for non-MIMO callers.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
The "successfully saved checkpoint [t .., p ..]" finalize print read the global
mpu tensor/pipeline rank and world size, which heterogeneous MIMO never
initializes, so rank 0 asserted in iter_finalize_fn after the data shards were
written. Resolve rank+size from the threaded tp_group/pp_group (falling back to
the explicit *_rank args, then mpu), mirroring the load-path fix.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/mimo-hetero-training-entry branch from 0eb4af2 to 62ece16 Compare June 26, 2026 18:41
load_checkpoint's get_rng_state reads the data-parallel world size, which falls
back to the global mpu group (uninitialized under MIMO) when dp_group is None.
The save path already threads it; thread it on the load path too so resume works
on the disjoint grids.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/mimo-hetero-training-entry branch from 6ecbb03 to fcd0906 Compare June 26, 2026 19:01
Fully-parallel checkpoint save/load now works on the disjoint grids (the
report_memory, save-finalize and load get_rng_state global-mpu reads are
threaded through the module process groups), so the opt-out is no longer needed;
use the default fully-parallel save.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: ykarnati <ykarnati@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant