Enable NVFP4 grouped MLP GLU RHT amax path#4
Open
sraman-rgb wants to merge 26 commits into
Open
Conversation
2522c2b to
8842a9a
Compare
bdist_wheel runs with --python-tag=py3 which sets the wheel filename's tag, but the WHEEL metadata still records the build python's tag (cp310-cp310). Set the Tag to py3-none in the WHEEL file right before wheel pack so the metadata matches the filename. Fixes NVIDIA#2896 Signed-off-by: Eyüp Can Akman <eyupcanakman@gmail.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
[fix] Fix CUTLASS grouped GEMM segfault for empty groups. Signed-off-by: yangfan.bai <yangfan.bai@shopee.com> Co-authored-by: yangfan.bai <yangfan.bai@shopee.com>
* Enable NVFP4 grouped MLP SReLU fusion Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address NVFP4 SReLU review comments Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clarify NVFP4 SReLU alpha scaling Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * Clarify NVFP4 SReLU scale handling Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> --------- Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Enable NVFP4 grouped MLP cuDNN wgrad Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address NVFP4 wgrad review comments Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * Use NVFP4 amax helper for wgrad scales Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> * Skip NVFP4 wgrad amax lookup for empty tokens Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> --------- Signed-off-by: Siddhartha Raman S <sraman@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* initial impl Signed-off-by: tdophung <tdophung@nvidia.com> * Register XLA FFI AttrDecoding for JAXX_Routing_Map_Format Without this XLA_FFI_REGISTER_ENUM_ATTR_DECODING the FFI handler templates cannot instantiate AttrDecoding<JAXX_Routing_Map_Format>, breaking the JAX build in router.cpp. Signed-off-by: tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add plumbing from pytorch side, change all internal functions to using the routing map type enum Signed-off-by: tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address all comments Signed-off-by: tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [PyTorch] Address PR NVIDIA#3009 review: remove .view() calls, int routing_map_format Apply the four CPU-overhead fixes the reviewer asked for and the CLAUDE.md "CPU overhead in PyTorch wrappers" section codifies: 1. _validate_routing_map_format returns plain int (not enum); the autograd Function + tex.* bindings only see ints. Validates via precomputed frozenset and a single dict.get with canonical lowercase keys (no .lower()/.upper()). 2. Type annotations on Function.forward use int (not the string forward-ref 'RoutingMapFormat'). 3. Removed every .view() from FusedTopkScoreFunction.{forward,backward} and FusedComputeScoresForMoEAuxLoss.{forward,backward}. C++ extension now accepts N-D logits/grad_probs, computes num_tokens from the product of leading dims, num_experts from the last dim, allocates outputs at the user-facing N-D shape, and wraps tensors with an explicit 2D shape via makeTransformerEngineTensor only for the kernel call. Asserts is_contiguous() on inputs. 4. Bwd allocates grad_logits with torch.empty_like(grad_probs) (N-D) instead of allocate-2D-then-view. PyTorch-extension boundary takes 'int routing_map_format' and casts to NVTERoutingMapFormat inside; the common-layer C API (nvte_*_v2) keeps the enum. Signed-off-by: tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [PyTorch] Fix routing_map_format validator: pybind11 enum is not an int subclass pybind11 enum_<NVTERoutingMapFormat> binds as a standalone type, not a subclass of int. The validator must check isinstance(x, RoutingMapFormat) before the int branch and explicitly normalize via int(x). Signed-off-by: tdophung <tdophung@nvidia.com> * [PyTorch][JAX][Common] Clean up review-style comments in router code Signed-off-by: tdophung <tdophung@nvidia.com> * remove remaining useless comments Signed-off-by: tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * small comment change Signed-off-by: tdophung <tdophung@nvidia.com> --------- Signed-off-by: tdophung <tdophung@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…ete weights (NVIDIA#3076) * [PyTorch] Move transform_and_copy_data_ptrs_to_device to experimental submodule Introduces a `grouped_mlp_experimental` pybind submodule as a labeled home for hyperspecific helpers that exist to satisfy the cuDNN CuTe DSL grouped GEMM kernels. The submodule itself is documented as unstable, so callers can see at the import path that these helpers are not part of the supported surface. `copy_data_ptrs_to_device` is genuinely general-purpose and stays at the top level; only `transform_and_copy_data_ptrs_to_device` moves into the submodule, and its four call sites in the fused grouped MLP forward/backward are updated accordingly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [PyTorch] Fuse data + scale ptr packing for discrete-weight grouped MLP Replaces `transform_and_copy_data_ptrs_to_device` with a more focused helper, `swizzle_scales_and_pack_ptrs_for_discrete_weights`. The new function takes both the FP8/FP4 weight data tensors and their scale tensors, swizzles the scales, and copies both pointer arrays to device in a single kernel launch (down from two — one from `copy_data_ptrs_to_device` for data and one from the old transform helper for scales). The two returned pointer arrays are views into a single packed device buffer. The general "transform_type" string dispatch is gone: the function only supports `mxfp8_rowwise`, `mxfp8_columnwise`, and `nvfp4`, which were the only modes ever used. The four discrete-weight call sites in the fused grouped MLP forward/backward collapse their paired `copy_data_ptrs_to_device` + transform calls into a single call. The implementation moves to a dedicated source file, `csrc/extensions/grouped_mlp_experimental.cpp`, so the experimental submodule has a clear home for future helpers tied to the cuDNN CuTe DSL grouped GEMM kernels. The declaration in `extensions.h` is grouped under a matching banner. `copy_data_ptrs_to_device` stays in `utils.cpp` since it remains a general-purpose helper. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [PyTorch] Use real data shapes and dispatch on a format enum in swizzle helper In swizzle_scales_and_pack_ptrs_for_discrete_weights, take the data shape directly from the data tensors instead of inferring it from the padded scale shape. NVFP4 packs two 4-bit values per byte, so the byte-shape's inner dim is doubled to recover the logical element count. Also replace the trio of is_mxfp8_rowwise / is_mxfp8_columnwise / is_nvfp4 booleans with a function-local TensorFormat enum. Tensor properties (scaling mode, dtypes, swizzle param names) are assigned together per case in a single switch so adding a future format is a single-point change rather than a fresh boolean threaded through the function. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [PyTorch] Consolidate utils.cpp into misc.cpp After moving the experimental grouped-MLP helper out, the only thing left in extensions/utils.cpp was copy_data_ptrs_to_device, which fits naturally alongside the cublasLt/cuDNN version getters and splits_to_offsets already in extensions/misc.cpp. Move it there and delete the now-empty utils.cpp. Build picks up sources via glob, so no manifest update is needed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [PyTorch] Tidy the grouped MLP experimental helper Wraps the C++ implementation of swizzle_scales_and_pack_ptrs_for_discrete_weights in a `grouped_mlp_experimental` namespace and renames the format-selector argument from `format` to `swizzle_type` across the declaration, implementation, and pybind binding. The pybind submodule name was already `grouped_mlp_experimental`, so the C++ namespace now mirrors it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [PyTorch] Hook NVFP4 optimize_for_gemm into the quantize path After NVFP4Quantizer::quantize runs, run inplace_swizzle_scale_for_gemm on the output when optimize_for_gemm is set and the quantize kernel hasn't already produced swizzled scales. The NVFP4 quantize kernel rejects with_gemm_swizzled_scales=true and emits compact scales, so without this hook callers had to follow up with a manual swizzle in Python (see ops/_common.py:85). The hook is a no-op for MXFP8 (its quantize kernel sets the flag itself) and for any quantizer with optimize_for_gemm=false. Also fixes a latent state-consistency bug in NVFP4Quantizer::convert_and_update_tensor: it was resetting the C++ wrapper's with_gemm_swizzled_scales to false but never touching the Python tensor's _with_gemm_swizzled_scales attribute. Re-quantizing into a tensor that previously held swizzled scales would leave the Python flag stuck at true while the buffer was compact, mismatched state that downstream code could mis-read. The Python attribute is now reset alongside the C++ wrapper, matching what MXFP8Quantizer::convert_and_update_tensor already does. Adds test_swizzle_scales_and_pack_ptrs_for_discrete_weights covering mxfp8_rowwise, mxfp8_columnwise, and nvfp4, comparing the helper's swizzled output against scales produced by the quantizer with optimize_for_gemm=true. NVFP4 was the case that surfaced the quantizer-side issues fixed above. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Fix linter warning Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warning Signed-off-by: Tim Moon <tmoon@nvidia.com> * Avoid swizzle kernel when tensor size is zero Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestion from @vthumbe1503 Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
* Disable new triton ffi for autotuned kernels Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* optimize prepare grouped split Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * small fix Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * precompute offsets Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comments because linter complained Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * disable linter Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> * Redesign nvte multi splits-to-offsets API Replace nvte_multi_splits_to_offsets with nvte_splits_to_offsets_multi, which takes parallel arrays of output NVTETensors, strides, and a per-output include_leading_zero flag. The dedicated cumsum slot is gone: outputs are now a uniform inclusive scan list where each output's length is either N or N+1 depending on its leading-zero flag. The per-launch output cap is internal; the public function loops kernel launches when num_outputs exceeds it, so callers see no hard limit. v1 nvte_splits_to_offsets now goes through the same shared kernel, and the prior duplicated kernel is removed. The PyTorch tex wrapper switches to MultiTensorWrapper to batch NVTETensor allocation for outputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Move splits-to-offsets kernel out of common.cu Relocate the inclusive-scan kernel, helpers, and launch struct into a dedicated transformer_engine/common/util/splits_to_offsets.cu, wrapped in namespace transformer_engine::splits_to_offsets. Dtype dispatch in load_split_size and store_output now uses a switch with NVTE_DEVICE_ERROR on unsupported dtypes instead of silently treating unknown dtypes as int64. nvte_splits_to_offsets_multi with num_outputs == 0 is now a noop rather than an error, and the internal args struct uses bool for the per-output include_leading_zero flag (C API stays int* for portability). v1 parameters renamed to split_sizes/num_splits/stride to align with the multi variant. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Generalize PyTorch splits_to_offsets_multi wrapper Rename tex.prepare_grouped_splits to tex.splits_to_offsets_multi and turn it into a general inclusive-scan utility instead of a grouped-MLP helper. Outputs are configured per-entry via parallel arrays of strides, include_leading_zero flags, and dtypes; bulk_allocate is opt-in so general callers get separate per-output at::empty buffers and only the grouped-MLP hot path takes the shared-storage / 16-byte-aligned route that cuDNN needs. The wrapper now takes an explicit device, coerces split_sizes to int64 / CUDA centrally, drops the redundant per-tensor checks (the core lib validates), and drops non_blocking=true on the host->device migration to avoid the race Greptile flagged. Update the grouped MLP caller and tests to the new shape. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * Review suggestion from @vthumbe1503 Rename the internal create_grouped_tensor parameter from provided_tensor_offsets to precomputed_tensor_offsets in the quantizer subclass impls, and document the optional tensor_offsets contract on the abstract method declaration. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ncl. framework API) (NVIDIA#2443)
…A#3020) * tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1 The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB cards (H100) and pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue). Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main): per-test peak `torch.cuda.max_memory_allocated`: before: 70.0 GiB (fp8_14) after : 36.1 GiB (fp8_14) -48% per-test peak `nvidia-smi memory.used`: before: 96.8 GiB after : 51.3 GiB -47% test outcome (B200, develop FE, this TE): identical 618F / 2196P / 891S, wall time within ~3% The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards. fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal. Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com> * address changes recommended by Kshitij Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> * tests/attention: black format fp8_13 ModelConfig Line was 105 chars; black requires <=100 with the project's preview+ string_processing settings. Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com> --------- Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com> Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
…itectures (NVIDIA#2836) * [PyTorch] Fix FlashAttention 2 head_dim > 192 on sm103 and other architectures Replace the exact-match compute capability allowlist with a >= sm80 range check, matching flash-attn's own gate: Dao-AILab/flash-attention@bbb21d6 The allowlist ((8,0), (9,0), (10,0), (12,0)) missed sm103 (B300), sm89 (L40S), sm86 (A40), and others where FA2 supports head_dim up to 256. The sm103 case was validated on hardware with head_dim=256; the remaining architectures appear to be supported based on flash-attn's >= sm80 guarantee. Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com> * Addressed the review comments Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Loosen kv-cache test tolerances for FlashAttention head_dim>128 Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com> --------- Signed-off-by: Pedram Razavi <pedram.razavi@gmail.com> Signed-off-by: Przemek Tredak <ptredak@nvidia.com> Co-authored-by: Przemek Tredak <ptredak@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Add DSv3 MXFP8 attention unit test Signed-off-by: Layali Rashid <lrashid@nvidia.com>
… for fused grouped MLP via fused ops (NVIDIA#3078) * Interleave/de-interleave utils for GLU fused OP Signed-off-by: ksivamani <ksivamani@nvidia.com> * Apply suggestions from code review Signed-off-by: ksivamani <ksivamani@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Update transformer_engine/pytorch/utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: ksivamani <ksivamani@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
* Add JAX fused attention score_mod support Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Stabilize score_mod callback cache keys Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add distributed JAX score mod attention test Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address JAX score_mod review items Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use serialized cuDNN graphs for score_mod attention Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename score_mod graph cache helpers Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Flax score_mod attention support Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address JAX score_mod review feedback Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Add flex-attn tests to QA scripts Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Skip softcap score-mod test before SM90 Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address score_mod fused attention review comments Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Clean up score_mod test scaling Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: reuse fused attention reference for score mod Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: inline scaled fused attention inputs Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: initialize fused attention doutput eagerly Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: clarify fused attention reference mask Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: use runner for score mod attention cases Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: consolidate score mod runner setup Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: centralize score mod runner defaults Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: fix distributed score mod import Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * test: address score mod review comments Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> --------- Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Add cuDNN score_mod attention path Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Avoid BHSD copies in score_mod attention Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Test relative position score_mod attention Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Test softcap score_mod attention Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Run score_mod graphs on current CUDA stream Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Add PyTorch score_mod execution plan cache Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Fix score_mod cache edge cases Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix score_mod callback graph cache keys Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address score_mod review feedback Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix score_mod lambda cache keys Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address flex attention review feedback Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address flex attention backend review feedback Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Validate score_mod bprop tensor inputs Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Add flex-attn tests to QA scripts Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Add lint directive for score_mod tensor type checks Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Alias dataclass field import in attention utils Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Skip softcap flex attention tests before sm90 Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address flex attention review feedback Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address attention backend review nits Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Remove duplicate flex attention asserts Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Clarify score mod tensor keys Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Normalize Flex Attention naming Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Simplify score mod backward graph cache lookup Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Return only cuDNN graph from helper Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refer to Flex Attention in error messages Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> * Address Flex Attention review comments Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> --------- Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* initial prototype Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comment Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * some more Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * done Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cache python_to_cpp and cpp_to_python casts for dtype Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add the missing conversion file Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cleanup Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * lint Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comment Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix build docs Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix review comment, lint Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * fix docs and addres review commentsg Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * cache Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * address review comments Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace 'te.DType' with 'tepytorch.DType' in tests Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Apply suggestion from @timmoon10 Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * address review comment Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename pybind_dtype_casters.h to pybind_dtype_caster.h Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> * Fix include statement for pybind_dtype_caster Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> --------- Signed-off-by: Varun Thumbe <vthumbe@nvidia.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
22b9b0e to
d3533df
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Based on main after PR 3048 merged.
Changes:
Validation:
NVTE_CUTEDSL_FUSED_GROUPED_MLP_FC1_GLU_RHT_AMAX_TMEM=0:1784 passed, 8968 skippedfortests/pytorch/test_fusible_ops.py::TestSequentialModules::test_grouped_mlpNVTE_CUTEDSL_FUSED_GROUPED_MLP_FC1_GLU_RHT_AMAX_TMEM=1: currently blocked by installed cuDNN FE/CUTLASS helper API mismatch ingrouped_gemm_glu_hadamard/hadamard_utils.py(make_trivial_tiled_mmasignature)