Skip to content

Native Liger-Kernel integration in Megatron-LM #5488

Description

@vaibhavjindal

Native Liger-Kernel integration in Megatron-LM

Liger-Kernel is an open-source Triton-kernel library for LLM training, maintained by LinkedIn and used in production at LinkedIn and across the community. The project has about 6.4k GitHub stars, 540+ forks, 9M+ pypi downloads and 100+ contributors as of writing, with active commits every week. It targets the kernels where Megatron-LM's local defaults leave performance on the table for modern decoder LLMs: RMSNorm, vocab-parallel cross-entropy (and the fused lm_head + CE variant), RoPE, SwiGLU MLP etc.

Over the past quarter we've made Liger's kernels directly callable from Megatron's existing calling conventions:

Liger PR Kernel Status
#1254 LigerMegatronRMSNorm — drop-in for Megatron's RMSNorm builder Merged
#1207 LigerMegatronCrossEntropy — drop-in for vocab-parallel CE Merged
#1260 TP>1 support + parity tests for both above Merged
megatron-vp-flce branch LigerMegatronFusedLinearCrossEntropy (VP-FLCE) — fuses lm_head matmul + CE; never materializes the [S, B, V/tp] logits tensor WIP on Liger side

Each Liger PR ships kernel-level microbenchmarks showing forward+backward wins on H100 across a small shape grid that we maintain for liger kernels. This design doc is about the Megatron side: making these kernels first-class options instead of monkey-patches.

Benefits of integrating Liger-Kernel natively in Megatron

Beyond the per-kernel speed and memory wins:

  1. A community-maintained kernel collection that goes well beyond the core kernels. Liger ships kernels for SFT loss masking, knowledge distillation (JSD), and alignment objectives (DPO, KTO, ORPO, GRPO). Native integration lets Megatron's post-training paths (megatron/rl, SFT workflows) opt into these without re-implementing them.

  2. Multi-dsl kernels reach Megatron faster. NVIDIA contributors are already landing kernels in Liger via the Cutile collaboration (Add cutile jsd linkedin/Liger-Kernel#1228). We are also planning cuTe / cuTeDSL implementations from our end. With a native integration, these kernels are callable from Megatron the moment they merge in Liger — no per-kernel plumbing on the Megatron side.

  3. An open-source, portable alternative to TransformerEngine. Liger's Triton kernels run on broader hardware (AMD via Triton, future architectures) and are inspectable / forkable. Gives Megatron a non-TE path for users who can't adopt TE due to licensing, hardware support, or audit requirements.

  4. Two-way CI signal between the projects. Today Liger's monkey-patch breaks silently when Megatron internals shift, and Megatron's CI is unaware of Liger. A native integration puts both on each other's surface — Megatron has clear signal when a Liger update is needed, and Liger maintainers see Megatron-side breakage immediately.

Why native-integration, not monkey-patch

Liger already ships an apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True) helper. It works, but:

  • It mutates Megatron's module namespaces at runtime — invisible to anyone reading Megatron's source or grepping for symbols.
  • The patch points fan out (LocalSpecProvider.layer_norm, fused_vocab_parallel_cross_entropy, vocab_parallel_cross_entropy) — adding a new kernel means adding a new monkey-patch.
  • Errors surface late and far from the cause.

A first-class integration makes the swap an explicit, documented config with a single source of truth in Megatron's tree.

Proposed design

Two extension points, used asymmetrically because the two kernels live in different architectural layers.

RMSNorm — BackendSpecProvider slot

Layer-internal modules already flow through the BackendSpecProvider Protocol (megatron/core/models/backends.py). We add a new provider that mirrors TESpecProvider / KitchenSpecProvider:

class LigerSpecProvider(LocalSpecProvider):
    def layer_norm(self, rms_norm=False, for_qk=False, has_residual=False):
        if rms_norm:
            return LigerMegatronRMSNorm
        return super().layer_norm(rms_norm=rms_norm, ...)

Selected via a new use_liger=True kwarg on get_gpt_layer_local_submodules. Inherits every other slot (linear, attention, MLP) from LocalSpecProvider. Mutually exclusive with use_kitchen. No new abstractions.

Cross-entropy — config-driven dispatch in LanguageModule

Cross-entropy lives outside the spec system; it runs in LanguageModule.compute_language_model_loss after output_layer() produces logits. Megatron already does config-driven dispatch here for 'native' vs 'te'. We extend the existing Literal:

cross_entropy_fusion_impl: Literal['native', 'te', 'liger'] = 'native'

…and add one branch that delegates to a thin wrapper in megatron/core/fusions/liger_cross_entropy.py:

elif self.config.cross_entropy_fusion_impl == 'liger':
    loss = liger_vocab_parallel_cross_entropy(logits, labels, self.pg_collection.tp)

Mirrors the TE path exactly.

Why the asymmetry

BackendSpecProvider slots describe transformer-layer submodules. RMSNorm fits. CE doesn't — it's not a layer submodule.

Sample PR

NVIDIA/Megatron-LM#5438 implements §5.1 + §5.2:

  • ~340 lines net, ~70% of which is tests + user-guide docs.
  • 10 unit tests: forward + backward parity at TP=1 and TP=2 vs Megatron's reference CE, missing-package surface, slot dispatch, fallthrough for non-RMSNorm, use_liger/use_kitchen mutex.
  • Liger-Kernel is an optional runtime dependency — every import is lazy; a clean ImportError if the package isn't installed.
  • Zero behavioural change for users who don't set the new flags.

End-to-end benchmark

Llama-2-7B-sized GPTModel on 8× H100, TP=8, bf16, seq=4096, batch=1. End-to-end forward + backward + gradient sync through get_forward_backward_func + DistributedDataParallel; 5 warmup / 20 measured iters; torch.cuda.Event timing on the GPU side.

Mode Median step tokens/sec Peak GB/rank
local (native CE + WrappedTorchNorm) 234.76 ms 17,447 34.23
liger (Liger CE + Liger RMSNorm) 177.21 ms 23,113 31.16

1.32× speedup, −3.07 GB peak memory per rank, attributable purely to the two kernel swaps — every other config knob is identical between runs.

Roadmap (after the sample PR lands)

PR Kernel Notes
#5438 RMSNorm + vocab-parallel CE Sample / pattern
2 Fused-linear vocab-parallel CE (VP-FLCE) Largest memory win (~11 GB/rank at long context + large vocab — eliminates the [S, B, V/tp] logits tensor). Requires bypassing output_layer() in GPTModel._postprocess and feeding hidden_states + weight to the fused kernel. Needs design alignment with you — config flag vs spec slot vs ColumnParallelLinear subclass.
3 RoPE Drop-in replacement for apply_rotary_pos_emb; small surface.
4 SwiGLU MLP Plugs into grouped_mlp_modules / activation_func slots.
5 Other kernels like layernorm, relu_squared etc TBD

Each is independent, gated behind its own switch, and continues the same pattern.

9. Open questions for the meeting

  1. Naming. LigerSpecProvider and cross_entropy_fusion_impl='liger' — preferred conventions?
  2. VP-FLCE shape. The most invasive follow-up. What's the cleanest way to fuse the lm_head matmul with CE without forking GPTModel._postprocess?
  3. CI policy for optional kernels. Tests use pytest.importorskip("liger_kernel") so they no-op when the package isn't installed. Acceptable, or do you want an opt-in pytest marker / dedicated CI bucket?
  4. Functional tests. Want us to author a Liger recipe under tests/test_utils/recipes/?
  5. Docs location. Currently docs/user-guide/features/liger_kernel.md. Right place?
  6. Default behaviour going forward. We're proposing strictly opt-in for now. Is there a path where Liger becomes the default for a specific configuration class (e.g., Llama-style with normalization="RMSNorm") once the kernel set is complete?

10. References

@NVIDIA/mcore-oncall

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions