Skip to content

port rocm/jax 0.9.2 cherrypicks to amd-main branch (#740)#741

Open
magaonka-amd wants to merge 42 commits into
ROCm:amd-mainfrom
magaonka-amd:merge/pr740-to-amd-main
Open

port rocm/jax 0.9.2 cherrypicks to amd-main branch (#740)#741
magaonka-amd wants to merge 42 commits into
ROCm:amd-mainfrom
magaonka-amd:merge/pr740-to-amd-main

Conversation

@magaonka-amd
Copy link
Copy Markdown

cherry-picking this, but changes are not merged to upstream yet.
Add ROCm lowering for ScaledMatmul/ScaledDot (jax-ml#35995)

charleshofer and others added 30 commits March 20, 2026 12:51
…tignore (ROCm#563)

When jaxlib was built in debug more, an assertion in LLVM code that lazy-loads VHLO dialect could fire, since the code path could execute in a multi-threaded environment, and LLVM dialect repositories aren't thread safe to modify.

This patch applies the same changes that upstream makes to fix this: jax-ml@48c8762

(this includes disabling a call to `jax_mlir_ext.enter_multi_threaded_execution(context)` in `mlir.py`. Presumably, the whole functionality related to `enter_multi_threaded_execution()` multithreaded checks isn't ready yet, and it was prematurely rolled into the production code.

Manual testing
(forgot this skip in the previous PR)
Co-authored-by: Daniel Suo <danielsuo@gmail.com>
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants