"""Minimal MLX reproducer: grad of `mx.maximum(sqrt(x^2 + y^2), const)` is NaN
at the origin (x=y=0), even though the forward value is finite.
Run:
python repro_max_sqrt_nan.py
Expected output: the grad at the (0,0) sample is NaN.
Tested on MLX 0.31.2 / Apple Silicon GPU.
The forward clamp via `mx.maximum` works fine, but
the gradient still flows through the `sqrt(0)` branch where the local
derivative is mathematically `0/0 = NaN`. The `0 * NaN = NaN` trap then
poisons the gradient at every input that produces rho=0.
"""
import mlx.core as mx
mx.set_default_device(mx.gpu)
def f_maximum_outside_sqrt(xy):
"""Clamp via mx.maximum AFTER sqrt — gradient is NaN at (0,0)."""
x, y = xy[..., 0], xy[..., 1]
rho = mx.sqrt(x * x + y * y)
rho_safe = mx.maximum(rho, mx.array(1e-10, dtype=rho.dtype))
return mx.sum(rho_safe)
def f_where_outside_sqrt(xy):
"""Clamp via mx.where AFTER sqrt — also NaN at (0,0)."""
x, y = xy[..., 0], xy[..., 1]
rho = mx.sqrt(x * x + y * y)
rho_safe = mx.where(rho > 1e-10, rho, mx.array(1e-10, dtype=rho.dtype))
return mx.sum(rho_safe)
def f_eps_inside_sqrt(xy):
"""Add small epsilon INSIDE sqrt — bounds rho away from zero, gradient is fine."""
x, y = xy[..., 0], xy[..., 1]
rho = mx.sqrt(x * x + y * y + 1e-20)
return mx.sum(rho)
# Three sample points: chief ray at origin (0,0), and two off-axis rays.
xy = mx.array([
[0.0, 0.0],
[0.5, 0.3],
[-0.7, 0.1],
], dtype=mx.float32)
print(f"input xy:\n{xy}\n")
for name, f in [
("mx.maximum(sqrt, const) outside", f_maximum_outside_sqrt),
("mx.where(>eps, sqrt, const)", f_where_outside_sqrt),
("sqrt(x*x + y*y + 1e-20) [fix]", f_eps_inside_sqrt),
]:
val = f(xy)
g = mx.grad(f)(xy)
print(f"{name}:")
print(f" forward = {float(val):.6f} (always finite)")
print(f" grad = {g}")
print()
Describe the bug
VJP does not match the output of JVP and finite difference. 0 * NaN = NaN trap: gradient through mx.maximum(rho, eps) is NaN at rho=0 even though the forward value is finite.
To Reproduce
Include code snippet
Expected behavior
Match forward gradient.