Skip to content

Gpt oss jax Changes for me to look at#1

Open
vthumbe1503 wants to merge 63 commits into
users/vthumbe/gpt_oss_swiglu_integrationfrom
gpt-oss-jax
Open

Gpt oss jax Changes for me to look at#1
vthumbe1503 wants to merge 63 commits into
users/vthumbe/gpt_oss_swiglu_integrationfrom
gpt-oss-jax

Conversation

@vthumbe1503

Copy link
Copy Markdown
Owner

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng and others added 5 commits September 22, 2025 12:58
* remove import jax.extend.ffi

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
* first draft; debug plan failure

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug uid error

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak params

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add grad in output

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix prints in test

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* address review comments

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused grad; add softmax_type; add sink to bwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding mask; add swa tests; remove requires_grad for off-by-one

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix non-determinism and shapes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add GQA

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add CP A2A; dq/dk mismatches

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; need cleaner solution

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; pending cudnn kernel change

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix world size in unit test; avoid thd format

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd logic

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 context

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak CP logging

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow no_mask/padding for SWA(left,0)

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "allow no_mask/padding for SWA(left,0)"

This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add softmax_type to Jax

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add cuDNN version control

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prettify tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip 9.13 for MLA, non 192/128

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename compare_with_error

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small cleanups and improvements

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor CI failures

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force sink/dsink to be float32

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch FE to GH FE

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to GH TE main FE commit

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE to 1.14.1

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up before CI

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* bump up cudnn version

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add backend selection guard for unit tests

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for softmax type enums in C

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…NVIDIA#2169)

* Add pytest xml report for debug unittest and onnx unittest, and remove the duplicated test line in qa/L0_pytorch_debug_unittest/test.sh

---------

Signed-off-by: erindai <shengfangd@nvidia.com>
* Adding Amax Primitive and related args.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Adding doc for Amax Primitive.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix the function name conflict.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Modification as feedback suggested.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix errors from lint.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix the wrong amax-scope in the bwd.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Added more description for amax-scope

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix the wrong attribute name.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Keep dim for AmaxCalcuation.

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Remove keepDim and add shardy_rule

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Fix shardy_rule

Signed-off-by: Ming Huang <mingh@nvidia.com>

* Remove extra-collective bytes from ref_coll_count due to local amax.

Signed-off-by: Ming Huang <mingh@nvidia.com>

---------

Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming-Xu Huang <mingh@nvidia.com>
Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
* Rework shardy rules

* WAR for compound factor=1

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the gpt-oss-jax branch 6 times, most recently from e81a0e1 to 6675779 Compare September 24, 2025 00:47
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 17 commits September 23, 2025 18:41
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Add documentation for quantization function parameters and return value.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
update jax requirements

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

* fix

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>

---------

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
jberchtold-nvidia and others added 29 commits September 30, 2025 16:53
…DIA#2219)

Load modules during initialize

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: JAX Toolbox <jax@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
* Introduce QuantizerBase

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Expose as a first-class API

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Undo QuantizerBase

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Make Quantizer a base class without implementations

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Support CustomRecipe and CustomRecipeState

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Resolving comments: quantize impl, num_quantizers, defaults

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Quantizer factories

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Add tests

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* QuantizedTensorBase _get_quantizer() + quantize_()

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Experimental note + LayerNormMLP fix

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tensor._internal -> tensor.base

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Expose

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor import fix

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Single quantizer factory with roles

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* More context for qfactory, fwd/bwd_roles

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Rename *Base -> *Storage quantized tensors

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* make_quantizers() will take roles from the operation

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* Improve tests and fix missing imports

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>

* Merge main followup

Signed-off-by: Evgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
* rm using_global_amax_of_x

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
fix rng_state shape

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
)

Fix QuantizedTensorBase -> QuantizedTensorStorage

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Disable debug build for cutlass GEMM

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Fix passing args to nvfp4 recipe

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Fix the cublas workspace alignment

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Przemyslaw Tredak <ptrendx@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…2222)

* Make sure to set usages for linear op quantizers before forward

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Avoid unsupported case for fused dbias+quantize kernel

Hopper does not support dbias + FP8 cast without FP8 transpose.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Fix code block in fp8_autocast docstring

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
…2229)

Fix shard map issue

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
* fix overflow of int32 in permute kernels

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Hongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Xin Yao <xiny@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…e into gpt-oss-jax

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 pushed a commit that referenced this pull request Jun 24, 2026
…ache

Enable weight swizzling for most cases
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.