Skip to content

Add math mode option for custom Metal kernels#3728

Open
shubhxho wants to merge 3 commits into
ml-explore:mainfrom
shubhxho:main
Open

Add math mode option for custom Metal kernels#3728
shubhxho wants to merge 3 commits into
ml-explore:mainfrom
shubhxho:main

Conversation

@shubhxho

@shubhxho shubhxho commented Jun 19, 2026

Copy link
Copy Markdown

Summary

Closes #3592.

This PR introduces an explicit math_mode option to mx.fast.metal_kernel:

mx.fast.metal_kernel(
    name="my_kernel",
    input_names=["x"],
    output_names=["y"],
    source=source,
    math_mode="safe",  # "safe", "relaxed", or "fast"
)

Custom Metal kernels now default to math_mode="safe" to preserve IEEE-compliant behavior. This is particularly important for operations involving special floating-point values, such as ensuring:

exp(-inf) == 0

Preserving this behavior is critical for masked softmax implementations used in causal and sliding-window attention, where masked logits are commonly represented as -inf.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Comment thread mlx/backend/metal/device.h Outdated
@shubhxho shubhxho requested a review from zcbenz June 20, 2026 07:48

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the test is not working, if I change {"math_mode": "safe"} to {"math_mode": "fast"} it still passes, I'm not sure if there is a good way to detect fast math in Metal though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

mx.fast.metal_kernel: Add support for compiler options (-fmetal-math-mode, integer template parameters, Metal 4 Tensor types

2 participants