Gpt oss jax Changes for me to look at#1
Open
vthumbe1503 wants to merge 63 commits into
Open
Conversation
* remove import jax.extend.ffi Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
* first draft; debug plan failure Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug uid error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak params Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add grad in output Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix prints in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * address review comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unfused grad; add softmax_type; add sink to bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix padding mask; add swa tests; remove requires_grad for off-by-one Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix indent Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix non-determinism and shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add GQA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add CP A2A; dq/dk mismatches Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP A2A; need cleaner solution Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP A2A; pending cudnn kernel change Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix world size in unit test; avoid thd format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 context Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak CP logging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allow no_mask/padding for SWA(left,0) Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "allow no_mask/padding for SWA(left,0)" This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add softmax_type to Jax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add cuDNN version control Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prettify tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip 9.13 for MLA, non 192/128 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename compare_with_error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * small cleanups and improvements Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix minor CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force sink/dsink to be float32 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch FE to GH FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return to GH TE main FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.14.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up before CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * bump up cudnn version Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add backend selection guard for unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add docstring for softmax type enums in C Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Chen Cui <chcui@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…NVIDIA#2169) * Add pytest xml report for debug unittest and onnx unittest, and remove the duplicated test line in qa/L0_pytorch_debug_unittest/test.sh --------- Signed-off-by: erindai <shengfangd@nvidia.com>
* Adding Amax Primitive and related args. Signed-off-by: Ming Huang <mingh@nvidia.com> * Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP. Signed-off-by: Ming Huang <mingh@nvidia.com> * Adding doc for Amax Primitive. Signed-off-by: Ming Huang <mingh@nvidia.com> * Fix the function name conflict. Signed-off-by: Ming Huang <mingh@nvidia.com> * Modification as feedback suggested. Signed-off-by: Ming Huang <mingh@nvidia.com> * Fix errors from lint. Signed-off-by: Ming Huang <mingh@nvidia.com> * Fix the wrong amax-scope in the bwd. Signed-off-by: Ming Huang <mingh@nvidia.com> * Added more description for amax-scope Signed-off-by: Ming Huang <mingh@nvidia.com> * Fix the wrong attribute name. Signed-off-by: Ming Huang <mingh@nvidia.com> * Keep dim for AmaxCalcuation. Signed-off-by: Ming Huang <mingh@nvidia.com> * Remove keepDim and add shardy_rule Signed-off-by: Ming Huang <mingh@nvidia.com> * Fix shardy_rule Signed-off-by: Ming Huang <mingh@nvidia.com> * Remove extra-collective bytes from ref_coll_count due to local amax. Signed-off-by: Ming Huang <mingh@nvidia.com> --------- Signed-off-by: Ming Huang <mingh@nvidia.com> Signed-off-by: Ming-Xu Huang <mingh@nvidia.com> Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
* Rework shardy rules * WAR for compound factor=1 Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
e81a0e1 to
6675779
Compare
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
1a1b0eb to
0c17c7e
Compare
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Add documentation for quantization function parameters and return value. Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
update jax requirements Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…DIA#2219) Load modules during initialize Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by: JAX Toolbox <jax@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
* Introduce QuantizerBase Signed-off-by: Evgeny <etsykunov@nvidia.com> * Expose as a first-class API Signed-off-by: Evgeny <etsykunov@nvidia.com> * Undo QuantizerBase Signed-off-by: Evgeny <etsykunov@nvidia.com> * Make Quantizer a base class without implementations Signed-off-by: Evgeny <etsykunov@nvidia.com> * Support CustomRecipe and CustomRecipeState Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolving comments: quantize impl, num_quantizers, defaults Signed-off-by: Evgeny <etsykunov@nvidia.com> * Quantizer factories Signed-off-by: Evgeny <etsykunov@nvidia.com> * Add tests Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * QuantizedTensorBase _get_quantizer() + quantize_() Signed-off-by: Evgeny <etsykunov@nvidia.com> * Experimental note + LayerNormMLP fix Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor._internal -> tensor.base Signed-off-by: Evgeny <etsykunov@nvidia.com> * Expose Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor import fix Signed-off-by: Evgeny <etsykunov@nvidia.com> * Single quantizer factory with roles Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More context for qfactory, fwd/bwd_roles Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor Signed-off-by: Evgeny <etsykunov@nvidia.com> * Rename *Base -> *Storage quantized tensors Signed-off-by: Evgeny <etsykunov@nvidia.com> * make_quantizers() will take roles from the operation Signed-off-by: Evgeny <etsykunov@nvidia.com> * Improve tests and fix missing imports Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Merge main followup Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evgeny <etsykunov@nvidia.com> Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* rm using_global_amax_of_x Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
fix rng_state shape Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Disable debug build for cutlass GEMM Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax
Fix passing args to nvfp4 recipe Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Fix the cublas workspace alignment Signed-off-by: Przemek Tredak <ptredak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak <ptrendx@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak <ptredak@nvidia.com> Signed-off-by: Przemyslaw Tredak <ptrendx@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…2222) * Make sure to set usages for linear op quantizers before forward Signed-off-by: Tim Moon <tmoon@nvidia.com> * Avoid unsupported case for fused dbias+quantize kernel Hopper does not support dbias + FP8 cast without FP8 transpose. Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com>
Fix code block in fp8_autocast docstring Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
…2229) Fix shard map issue Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
* fix overflow of int32 in permute kernels Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao <xiny@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503
pushed a commit
that referenced
this pull request
Jun 24, 2026
…ache Enable weight swizzling for most cases
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: