Skip to content

Enable Fused Kernels by Default for Memory Efficiency#6832

Open
0hujun wants to merge 5 commits into
verl-project:mainfrom
0hujun:main
Open

Enable Fused Kernels by Default for Memory Efficiency#6832
0hujun wants to merge 5 commits into
verl-project:mainfrom
0hujun:main

Conversation

@0hujun

@0hujun 0hujun commented Jun 24, 2026

Copy link
Copy Markdown

Summary

This PR proposes changing the default value of use_fused_kernels from False to True across all engine backends (FSDP2, Megatron, VeOmni, AutoModel). Fused kernels provide significant memory savings (32x reduction in logits memory) and enable longer context training without sacrificing correctness. The change includes graceful fallback logic for incompatible configurations.

Motivation

The Problem

Currently, use_fused_kernels defaults to False in three config classes:

Config Class File Line Default
HFModelConfig verl/workers/config/model.py 138 False
ActorConfig verl/workers/config/actor.py 182 False
EngineConfig verl/workers/config/engine.py 110 False

And in the canonical YAML config:

# verl/trainer/config/model/hf_model.yaml
use_fused_kernels: False

This means users must explicitly opt-in to fused kernels, and most users are unaware of this feature, resulting in:

  1. Unnecessary OOM errors at long context lengths (32K+)
  2. Suboptimal memory usage — full logits materialization wastes ~30 GB of GPU memory
  3. Shorter maximum context — users hit memory limits earlier than necessary

The Benefits

Fused kernels avoid materializing the full logits tensor by computing log_probs and entropy directly from hidden_states and vocab_weights:

Component Without Fused Kernels With Fused Kernels Reduction
Logits (bf16) 9.93 GB 0.31 GB 32x
Logits gradient (fp32) 19.87 GB 0.31 GB 64x
Total peak savings ~29 GB

Real-world impact (Qwen3.5-9B, 32K context, Ascend910 61GB NPU):

  • Without fused kernels: OOM at 32K context
  • With fused kernels: trains successfully at 32K, peak memory 34.4 GB (56%)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request changes the default value of use_fused_kernels from False to True across several configuration files. The reviewer identified a critical issue where enabling this by default will cause a runtime crash (NotImplementedError) if calculate_sum_pi_squared is enabled, and suggested adding a validation check in ActorConfig.__post_init__ to gracefully disable fused kernels with a warning in that case.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
optim: OptimizerConfig = field(default_factory=OptimizerConfig)
use_fused_kernels: bool = False
use_fused_kernels: bool = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Enabling use_fused_kernels by default will cause a runtime crash (NotImplementedError) for any configuration that has calculate_sum_pi_squared: True enabled.

Both FSDPEngineWithLMHead (in verl/workers/engine/fsdp/transformer_impl.py) and MegatronEngineWithLMHead (in verl/workers/engine/megatron/transformer_impl.py) explicitly raise NotImplementedError when both calculate_sum_pi_squared and use_fused_kernels are True because fused kernels do not materialize the full logits tensor needed for Sigma pi^2.

To prevent this crash and provide a graceful fallback, please add a check in ActorConfig.__post_init__ to automatically disable use_fused_kernels with a warning when calculate_sum_pi_squared is enabled.

        if self.calculate_sum_pi_squared and self.use_fused_kernels:
            import warnings
            warnings.warn(
                "calculate_sum_pi_squared=True is not supported with use_fused_kernels=True. "
                "Automatically disabling use_fused_kernels to allow Sigma pi^2 computation.",
                UserWarning
            )
            self.use_fused_kernels = False

@Luosuu

Luosuu commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

@0hujun please fix pre-commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants