Skip to content

Enable NVFP4 grouped MLP GLU RHT amax path#4

Open
sraman-rgb wants to merge 26 commits into
mainfrom
nvfp4-grouped-mlp-glu-rht-amax
Open

Enable NVFP4 grouped MLP GLU RHT amax path#4
sraman-rgb wants to merge 26 commits into
mainfrom
nvfp4-grouped-mlp-glu-rht-amax

Conversation

@sraman-rgb

@sraman-rgb sraman-rgb commented May 30, 2026

Copy link
Copy Markdown
Owner

Based on main after PR 3048 merged.

Changes:

  • Add the NVFP4 grouped MLP GEMM+SwiGLU+RHT+amax TE wiring.
  • Add precomputed-amax NVFP4 quantize hooks used by the GLU-Hadamard path.

Validation:

  • NVTE_CUTEDSL_FUSED_GROUPED_MLP_FC1_GLU_RHT_AMAX_TMEM=0: 1784 passed, 8968 skipped for tests/pytorch/test_fusible_ops.py::TestSequentialModules::test_grouped_mlp
  • NVTE_CUTEDSL_FUSED_GROUPED_MLP_FC1_GLU_RHT_AMAX_TMEM=1: currently blocked by installed cuDNN FE/CUTLASS helper API mismatch in grouped_gemm_glu_hadamard/hadamard_utils.py (make_trivial_tiled_mma signature)
  • C++ lint: pass
  • Python lint: pass

@sraman-rgb sraman-rgb changed the base branch from nvfp4-grouped-mlp-wgrad to main June 1, 2026 20:40
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-glu-rht-amax branch from 2522c2b to 8842a9a Compare June 1, 2026 20:42
eyupcanakman and others added 25 commits June 3, 2026 00:02
bdist_wheel runs with --python-tag=py3 which sets the wheel
filename's tag, but the WHEEL metadata still records the build
python's tag (cp310-cp310). Set the Tag to py3-none in the WHEEL
file right before wheel pack so the metadata matches the filename.

Fixes NVIDIA#2896

Signed-off-by: Eyüp Can Akman <eyupcanakman@gmail.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
[fix] Fix CUTLASS grouped GEMM segfault for empty groups.

Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Co-authored-by: yangfan.bai <yangfan.bai@shopee.com>
* Enable NVFP4 grouped MLP SReLU fusion

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address NVFP4 SReLU review comments

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Clarify NVFP4 SReLU alpha scaling

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* Clarify NVFP4 SReLU scale handling

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

---------

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Enable NVFP4 grouped MLP cuDNN wgrad

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address NVFP4 wgrad review comments

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* Use NVFP4 amax helper for wgrad scales

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

* Skip NVFP4 wgrad amax lookup for empty tokens

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>

---------

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* initial impl

Signed-off-by: tdophung <tdophung@nvidia.com>

* Register XLA FFI AttrDecoding for JAXX_Routing_Map_Format

Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler
templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>,
breaking the JAX build in router.cpp.

Signed-off-by: tdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add plumbing from pytorch side, change all internal functions to using the routing map type enum

Signed-off-by: tdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address all comments

Signed-off-by: tdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Address PR NVIDIA#3009 review: remove .view() calls, int routing_map_format

Apply the four CPU-overhead fixes the reviewer asked for and the
CLAUDE.md "CPU overhead in PyTorch wrappers" section codifies:

1. _validate_routing_map_format returns plain int (not enum); the
   autograd Function + tex.* bindings only see ints. Validates via
   precomputed frozenset and a single dict.get with canonical
   lowercase keys (no .lower()/.upper()).

2. Type annotations on Function.forward use int (not the string
   forward-ref 'RoutingMapFormat').

3. Removed every .view() from FusedTopkScoreFunction.{forward,backward}
   and FusedComputeScoresForMoEAuxLoss.{forward,backward}. C++
   extension now accepts N-D logits/grad_probs, computes num_tokens
   from the product of leading dims, num_experts from the last dim,
   allocates outputs at the user-facing N-D shape, and wraps tensors
   with an explicit 2D shape via makeTransformerEngineTensor only for
   the kernel call. Asserts is_contiguous() on inputs.

4. Bwd allocates grad_logits with torch.empty_like(grad_probs) (N-D)
   instead of allocate-2D-then-view.

PyTorch-extension boundary takes 'int routing_map_format' and casts
to NVTERoutingMapFormat inside; the common-layer C API (nvte_*_v2)
keeps the enum.

Signed-off-by: tdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [PyTorch] Fix routing_map_format validator: pybind11 enum is not an int subclass

pybind11 enum_<NVTERoutingMapFormat> binds as a standalone type, not a
subclass of int. The validator must check isinstance(x, RoutingMapFormat)
before the int branch and explicitly normalize via int(x).

Signed-off-by: tdophung <tdophung@nvidia.com>

* [PyTorch][JAX][Common] Clean up review-style comments in router code

Signed-off-by: tdophung <tdophung@nvidia.com>

* remove remaining useless comments

Signed-off-by: tdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* small comment change

Signed-off-by: tdophung <tdophung@nvidia.com>

---------

Signed-off-by: tdophung <tdophung@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…ete weights (NVIDIA#3076)

* [PyTorch] Move transform_and_copy_data_ptrs_to_device to experimental submodule

Introduces a `grouped_mlp_experimental` pybind submodule as a labeled
home for hyperspecific helpers that exist to satisfy the cuDNN CuTe
DSL grouped GEMM kernels. The submodule itself is documented as
unstable, so callers can see at the import path that these helpers
are not part of the supported surface.

`copy_data_ptrs_to_device` is genuinely general-purpose and stays at
the top level; only `transform_and_copy_data_ptrs_to_device` moves
into the submodule, and its four call sites in the fused grouped MLP
forward/backward are updated accordingly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [PyTorch] Fuse data + scale ptr packing for discrete-weight grouped MLP

Replaces `transform_and_copy_data_ptrs_to_device` with a more focused
helper, `swizzle_scales_and_pack_ptrs_for_discrete_weights`. The new
function takes both the FP8/FP4 weight data tensors and their scale
tensors, swizzles the scales, and copies both pointer arrays to device
in a single kernel launch (down from two — one from
`copy_data_ptrs_to_device` for data and one from the old transform
helper for scales). The two returned pointer arrays are views into a
single packed device buffer.

The general "transform_type" string dispatch is gone: the function
only supports `mxfp8_rowwise`, `mxfp8_columnwise`, and `nvfp4`, which
were the only modes ever used. The four discrete-weight call sites in
the fused grouped MLP forward/backward collapse their paired
`copy_data_ptrs_to_device` + transform calls into a single call.

The implementation moves to a dedicated source file,
`csrc/extensions/grouped_mlp_experimental.cpp`, so the experimental
submodule has a clear home for future helpers tied to the cuDNN CuTe
DSL grouped GEMM kernels. The declaration in `extensions.h` is
grouped under a matching banner. `copy_data_ptrs_to_device` stays in
`utils.cpp` since it remains a general-purpose helper.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [PyTorch] Use real data shapes and dispatch on a format enum in swizzle helper

In swizzle_scales_and_pack_ptrs_for_discrete_weights, take the data
shape directly from the data tensors instead of inferring it from the
padded scale shape. NVFP4 packs two 4-bit values per byte, so the
byte-shape's inner dim is doubled to recover the logical element
count.

Also replace the trio of is_mxfp8_rowwise / is_mxfp8_columnwise /
is_nvfp4 booleans with a function-local TensorFormat enum. Tensor
properties (scaling mode, dtypes, swizzle param names) are assigned
together per case in a single switch so adding a future format is a
single-point change rather than a fresh boolean threaded through the
function.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [PyTorch] Consolidate utils.cpp into misc.cpp

After moving the experimental grouped-MLP helper out, the only thing
left in extensions/utils.cpp was copy_data_ptrs_to_device, which fits
naturally alongside the cublasLt/cuDNN version getters and
splits_to_offsets already in extensions/misc.cpp. Move it there and
delete the now-empty utils.cpp. Build picks up sources via glob, so
no manifest update is needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [PyTorch] Tidy the grouped MLP experimental helper

Wraps the C++ implementation of
swizzle_scales_and_pack_ptrs_for_discrete_weights in a
`grouped_mlp_experimental` namespace and renames the format-selector
argument from `format` to `swizzle_type` across the declaration,
implementation, and pybind binding. The pybind submodule name was
already `grouped_mlp_experimental`, so the C++ namespace now mirrors
it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [PyTorch] Hook NVFP4 optimize_for_gemm into the quantize path

After NVFP4Quantizer::quantize runs, run inplace_swizzle_scale_for_gemm
on the output when optimize_for_gemm is set and the quantize kernel
hasn't already produced swizzled scales. The NVFP4 quantize kernel
rejects with_gemm_swizzled_scales=true and emits compact scales, so
without this hook callers had to follow up with a manual swizzle in
Python (see ops/_common.py:85). The hook is a no-op for MXFP8 (its
quantize kernel sets the flag itself) and for any quantizer with
optimize_for_gemm=false.

Also fixes a latent state-consistency bug in
NVFP4Quantizer::convert_and_update_tensor: it was resetting the C++
wrapper's with_gemm_swizzled_scales to false but never touching the
Python tensor's _with_gemm_swizzled_scales attribute. Re-quantizing
into a tensor that previously held swizzled scales would leave the
Python flag stuck at true while the buffer was compact, mismatched
state that downstream code could mis-read. The Python attribute is
now reset alongside the C++ wrapper, matching what
MXFP8Quantizer::convert_and_update_tensor already does.

Adds test_swizzle_scales_and_pack_ptrs_for_discrete_weights covering
mxfp8_rowwise, mxfp8_columnwise, and nvfp4, comparing the helper's
swizzled output against scales produced by the quantizer with
optimize_for_gemm=true. NVFP4 was the case that surfaced the
quantizer-side issues fixed above.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fix linter warning

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warning

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid swizzle kernel when tensor size is zero

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Review suggestion from @vthumbe1503

Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* Disable new triton ffi for autotuned kernels

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* optimize prepare grouped split

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>

* small fix

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>

* precompute offsets

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add comments because linter complained

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>

* disable linter

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>

* Redesign nvte multi splits-to-offsets API

Replace nvte_multi_splits_to_offsets with nvte_splits_to_offsets_multi,
which takes parallel arrays of output NVTETensors, strides, and a
per-output include_leading_zero flag. The dedicated cumsum slot is gone:
outputs are now a uniform inclusive scan list where each output's length
is either N or N+1 depending on its leading-zero flag. The per-launch
output cap is internal; the public function loops kernel launches when
num_outputs exceeds it, so callers see no hard limit.

v1 nvte_splits_to_offsets now goes through the same shared kernel, and
the prior duplicated kernel is removed. The PyTorch tex wrapper switches
to MultiTensorWrapper to batch NVTETensor allocation for outputs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Move splits-to-offsets kernel out of common.cu

Relocate the inclusive-scan kernel, helpers, and launch struct into a
dedicated transformer_engine/common/util/splits_to_offsets.cu, wrapped
in namespace transformer_engine::splits_to_offsets. Dtype dispatch in
load_split_size and store_output now uses a switch with NVTE_DEVICE_ERROR
on unsupported dtypes instead of silently treating unknown dtypes as
int64. nvte_splits_to_offsets_multi with num_outputs == 0 is now a noop
rather than an error, and the internal args struct uses bool for the
per-output include_leading_zero flag (C API stays int* for portability).
v1 parameters renamed to split_sizes/num_splits/stride to align with the
multi variant.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Generalize PyTorch splits_to_offsets_multi wrapper

Rename tex.prepare_grouped_splits to tex.splits_to_offsets_multi and
turn it into a general inclusive-scan utility instead of a grouped-MLP
helper. Outputs are configured per-entry via parallel arrays of strides,
include_leading_zero flags, and dtypes; bulk_allocate is opt-in so
general callers get separate per-output at::empty buffers and only the
grouped-MLP hot path takes the shared-storage / 16-byte-aligned route
that cuDNN needs. The wrapper now takes an explicit device, coerces
split_sizes to int64 / CUDA centrally, drops the redundant per-tensor
checks (the core lib validates), and drops non_blocking=true on the
host->device migration to avoid the race Greptile flagged. Update the
grouped MLP caller and tests to the new shape.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Review suggestion from @vthumbe1503

Rename the internal create_grouped_tensor parameter from
provided_tensor_offsets to precomputed_tensor_offsets in the quantizer
subclass impls, and document the optional tensor_offsets contract on
the abstract method declaration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…A#3020)

* tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1

The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes
(B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference
comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the
full (B, H, S, S) attention matrix in bf16, and keeps a handful of them
live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64
the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB
cards (H100) and pushes the suite into OOM territory on Blackwell (~91
GB measured with the cuDNN caching allocator residue).

Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured
on B200 (SM_100, cuDNN 9.23, TE main):

  per-test peak `torch.cuda.max_memory_allocated`:
     before: 70.0 GiB (fp8_14)
     after : 36.1 GiB (fp8_14)         -48%
  per-test peak `nvidia-smi memory.used`:
     before: 96.8 GiB
     after : 51.3 GiB                  -47%
  test outcome (B200, develop FE, this TE):
     identical 618F / 2196P / 891S, wall time within ~3%

The shrunk configs still exercise every distinct shape/mask/SWA/GQA
combination that the originals did -- only B is smaller. The suite now
fits comfortably on 80 GB cards.

fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small
(~few GiB) and the larger batch is useful coverage for padding_causal.

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>

* address changes recommended by Kshitij

Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>

* tests/attention: black format fp8_13 ModelConfig

Line was 105 chars; black requires <=100 with the project's preview+
string_processing settings.

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>

---------

Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
…itectures (NVIDIA#2836)

* [PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures

Replace the exact-match compute capability allowlist with a >= sm80 range
check, matching flash-attn's own gate:
Dao-AILab/flash-attention@bbb21d6

The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89
(L40S), sm86 (A40), and others where FA2 supports head_dim up to 256.
The sm103 case was validated on hardware with head_dim=256; the remaining
architectures appear to be supported based on flash-attn's >= sm80 guarantee.

Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>

* Addressed the review comments

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Loosen kv-cache test tolerances for FlashAttention head_dim>128

Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>

---------

Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Co-authored-by: Przemek Tredak <ptredak@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Add DSv3 MXFP8 attention unit test

Signed-off-by: Layali Rashid <lrashid@nvidia.com>
… for fused grouped MLP via fused ops (NVIDIA#3078)

* Interleave/de-interleave utils for GLU fused OP

Signed-off-by: ksivamani <ksivamani@nvidia.com>

* Apply suggestions from code review

Signed-off-by: ksivamani <ksivamani@nvidia.com>

Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/utils.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: ksivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
* Add JAX fused attention score_mod support

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Stabilize score_mod callback cache keys

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add distributed JAX score mod attention test

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address JAX score_mod review items

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use serialized cuDNN graphs for score_mod attention

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename score_mod graph cache helpers

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add Flax score_mod attention support

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address JAX score_mod review feedback

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add flex-attn tests to QA scripts

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Skip softcap score-mod test before SM90

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address score_mod fused attention review comments

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Clean up score_mod test scaling

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test: reuse fused attention reference for score mod

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: inline scaled fused attention inputs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: initialize fused attention doutput eagerly

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: clarify fused attention reference mask

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: use runner for score mod attention cases

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: consolidate score mod runner setup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: centralize score mod runner defaults

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test: fix distributed score mod import

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* test: address score mod review comments

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Add cuDNN score_mod attention path

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Avoid BHSD copies in score_mod attention

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Test relative position score_mod attention

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Test softcap score_mod attention

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Run score_mod graphs on current CUDA stream

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add PyTorch score_mod execution plan cache

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix score_mod cache edge cases

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix score_mod callback graph cache keys

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address score_mod review feedback

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix score_mod lambda cache keys

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address flex attention review feedback

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address flex attention backend review feedback

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Validate score_mod bprop tensor inputs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add flex-attn tests to QA scripts

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add lint directive for score_mod tensor type checks

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Alias dataclass field import in attention utils

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Skip softcap flex attention tests before sm90

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address flex attention review feedback

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address attention backend review nits

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove duplicate flex attention asserts

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Clarify score mod tensor keys

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Normalize Flex Attention naming

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Simplify score mod backward graph cache lookup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Return only cuDNN graph from helper

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refer to Flex Attention in error messages

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Address Flex Attention review comments

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* initial prototype

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comment

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* cleanup

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* some more

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* done

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* cache python_to_cpp and cpp_to_python casts for dtype

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add the missing conversion file

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup comments

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* cleanup

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* lint

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comment

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comments

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* fix build docs

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* fix review comment, lint

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* address review comments

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address review comments

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* address review comments

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* fix docs and addres review commentsg

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* cache

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* address review comments

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace 'te.DType' with 'tepytorch.DType' in tests

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>

* Apply suggestion from @timmoon10

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>

* address review comment

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rename pybind_dtype_casters.h to pybind_dtype_caster.h

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>

* Fix include statement for pybind_dtype_caster

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>

---------

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-glu-rht-amax branch from 22b9b0e to d3533df Compare June 4, 2026 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.