import numpy as np
import mlx.core as mx
DTYPE = mx.float32
def trace_module_mimic(theta, P0, q0, n_total, scalar_a):
"""Minimum 4-surface mimic of a ray-trace through optical elements."""
xdec = mx.array([theta, 0.0, 0.0, 0.0], dtype=DTYPE)
xdec = mx.repeat(xdec[None, :], n_total, axis=0)
ydec = mx.repeat(mx.array([0.0, -1.5, -1.4, -1.3], dtype=DTYPE)[None, :], n_total, axis=0)
t = mx.repeat(mx.array([0.0, 5.0, 6.0, 7.0], dtype=DTYPE)[None, :], n_total, axis=0)
c = mx.array([0.0, 0.01, -0.008, 0.0], dtype=DTYPE)
for surf in range(4):
xs = xdec[:, surf:surf+1]; ys = ydec[:, surf:surf+1]; zs = t[:, surf:surf+1]
P0 = mx.concatenate([P0[:, 0:1] - xs, P0[:, 1:2] - ys, P0[:, 2:3] - zs], axis=1)
k = q0[:, 0:1]; l = q0[:, 1:2]; m = q0[:, 2:3]
s0 = -P0[:, 2:3] / m
X1 = P0[:, 0:1] + k*s0
Y1 = P0[:, 1:2] + l*s0
c_surf = c[surf]
sj = 0.0
for _ in range(5):
Xj = X1 + k*sj; Yj = Y1 + l*sj; Zj = m*sj
rho2 = Xj*Xj + Yj*Yj
sqrt_val = mx.sqrt(1.0 - c_surf*c_surf*rho2 + 1e-12)
F = Zj - c_surf*rho2/(1.0+sqrt_val)
E = c_surf/sqrt_val
sj = sj - F/(-Xj*E*k + -Yj*E*l + m)
Xj = X1 + k*sj; Yj = Y1 + l*sj; Zj = m*sj
# mx.where(cond, traced, const) — known-safe ordering
TIR = (sj > -1e8)
Xj = mx.where(TIR, Xj, mx.array(float('nan'), dtype=DTYPE))
zero_rho = ((Xj == 0) & (Yj == 0)).astype(DTYPE) * 1e-9
Xj = Xj + zero_rho; Yj = Yj + zero_rho
# Snell-like normal update with a broadcast multiply
nvec_norm = mx.sqrt(c_surf*c_surf*(Xj*Xj+Yj*Yj) + 1.0 + 1e-12)
K = -c_surf*Xj/nvec_norm; L = -c_surf*Yj/nvec_norm
M_norm = 1.0/nvec_norm
q_new = q0 - 2.0 * (k*K + l*L + m*M_norm) * mx.concatenate([K, L, M_norm], axis=1) * scalar_a
P0 = mx.concatenate([Xj, Yj, Zj], axis=1)
q0 = q_new
return P0
def trace_returning_dict(theta, n_total=125, seed_offset=0.0):
"""Returns a dict holding BOTH x_unit (used in residual) AND o_s (the
(N, 3) landing array, kept alive in the dict but not used by residual)."""
rays = mx.arange(n_total, dtype=DTYPE)
P0 = mx.stack([rays * 0.01 + seed_offset, rays * 0.005, mx.zeros_like(rays)], axis=1)
q0 = mx.broadcast_to(mx.array([[0.0, 0.0, 1.0]], dtype=DTYPE), (n_total, 3))
o_s = trace_module_mimic(theta, P0, q0, n_total, scalar_a=0.5)
chief_x = o_s[0, 0]
x_unit = (o_s[:, 0] - chief_x) / 10.0
return {'x_unit': x_unit, 'o_s': o_s} # HOLD o_s
def residual(theta):
"""Two independent trace calls, concat their x_unit columns."""
s0 = trace_returning_dict(theta, seed_offset=0.0)
s1 = trace_returning_dict(theta, seed_offset=10.0)
return mx.concatenate([s0['x_unit'], s1['x_unit']])
theta = mx.array(0.5, dtype=DTYPE)
print("=== MLX jvp non-deterministic rank-mismatch repro ===")
print("Running 30 jvp calls per trial, 10 trials")
print()
n_cols = 30
n_trials = 10
crash_count = 0
for trial in range(n_trials):
crashed = None
for i in range(n_cols):
v = mx.array(float(i+1), dtype=DTYPE)
try:
_, jv = mx.jvp(residual, [theta], [v])
mx.eval(jv[0])
except Exception as e:
crashed = (i, type(e).__name__, str(e)[:110])
break
if crashed is None:
print(f" trial {trial:2d}: all {n_cols} jvp calls PASSED")
else:
i, et, msg = crashed
print(f" trial {trial:2d}: CRASH at jvp call {i}: {et}: {msg}")
crash_count += 1
print()
print(f"{crash_count}/{n_trials} trials crashed.")
if crash_count > 0:
print()
print("==> BUG CONFIRMED: mx.jvp returns silent rank-collapsed tangents")
print(" (some dimension becomes 0) on identical input, non-deterministically.")
print(" VJP and FD on the same function are always correct.")
Describe the bug
I have some minimal ray tracing code that fails non-deterministically. Same code, same inputs, fails at different points. I share an example below. VJP on the other hand works correctly but is vey inefficient for this problem type.
To Reproduce
Include code snippet
Expected behavior
I expect it to pass and be equal to the vjp equivalent.
Desktop (please complete the following information):