Skip to content

[BUG] JVP non-deterministic crashes when called in a loop. #3629

@kyrollosyanny

Description

@kyrollosyanny

Describe the bug
I have some minimal ray tracing code that fails non-deterministically. Same code, same inputs, fails at different points. I share an example below. VJP on the other hand works correctly but is vey inefficient for this problem type.

To Reproduce

Include code snippet

import numpy as np
import mlx.core as mx

DTYPE = mx.float32


def trace_module_mimic(theta, P0, q0, n_total, scalar_a):
    """Minimum 4-surface mimic of a ray-trace through optical elements."""
    xdec = mx.array([theta, 0.0, 0.0, 0.0], dtype=DTYPE)
    xdec = mx.repeat(xdec[None, :], n_total, axis=0)
    ydec = mx.repeat(mx.array([0.0, -1.5, -1.4, -1.3], dtype=DTYPE)[None, :], n_total, axis=0)
    t    = mx.repeat(mx.array([0.0,  5.0,  6.0,  7.0], dtype=DTYPE)[None, :], n_total, axis=0)
    c    = mx.array([0.0, 0.01, -0.008, 0.0], dtype=DTYPE)

    for surf in range(4):
        xs = xdec[:, surf:surf+1]; ys = ydec[:, surf:surf+1]; zs = t[:, surf:surf+1]
        P0 = mx.concatenate([P0[:, 0:1] - xs, P0[:, 1:2] - ys, P0[:, 2:3] - zs], axis=1)
        k = q0[:, 0:1]; l = q0[:, 1:2]; m = q0[:, 2:3]
        s0 = -P0[:, 2:3] / m
        X1 = P0[:, 0:1] + k*s0
        Y1 = P0[:, 1:2] + l*s0
        c_surf = c[surf]
        sj = 0.0
        for _ in range(5):
            Xj = X1 + k*sj; Yj = Y1 + l*sj; Zj = m*sj
            rho2 = Xj*Xj + Yj*Yj
            sqrt_val = mx.sqrt(1.0 - c_surf*c_surf*rho2 + 1e-12)
            F = Zj - c_surf*rho2/(1.0+sqrt_val)
            E = c_surf/sqrt_val
            sj = sj - F/(-Xj*E*k + -Yj*E*l + m)
        Xj = X1 + k*sj; Yj = Y1 + l*sj; Zj = m*sj
        # mx.where(cond, traced, const) — known-safe ordering
        TIR = (sj > -1e8)
        Xj = mx.where(TIR, Xj, mx.array(float('nan'), dtype=DTYPE))
        zero_rho = ((Xj == 0) & (Yj == 0)).astype(DTYPE) * 1e-9
        Xj = Xj + zero_rho; Yj = Yj + zero_rho
        # Snell-like normal update with a broadcast multiply
        nvec_norm = mx.sqrt(c_surf*c_surf*(Xj*Xj+Yj*Yj) + 1.0 + 1e-12)
        K = -c_surf*Xj/nvec_norm; L = -c_surf*Yj/nvec_norm
        M_norm = 1.0/nvec_norm
        q_new = q0 - 2.0 * (k*K + l*L + m*M_norm) * mx.concatenate([K, L, M_norm], axis=1) * scalar_a
        P0 = mx.concatenate([Xj, Yj, Zj], axis=1)
        q0 = q_new
    return P0


def trace_returning_dict(theta, n_total=125, seed_offset=0.0):
    """Returns a dict holding BOTH x_unit (used in residual) AND o_s (the
    (N, 3) landing array, kept alive in the dict but not used by residual)."""
    rays = mx.arange(n_total, dtype=DTYPE)
    P0 = mx.stack([rays * 0.01 + seed_offset, rays * 0.005, mx.zeros_like(rays)], axis=1)
    q0 = mx.broadcast_to(mx.array([[0.0, 0.0, 1.0]], dtype=DTYPE), (n_total, 3))
    o_s = trace_module_mimic(theta, P0, q0, n_total, scalar_a=0.5)
    chief_x = o_s[0, 0]
    x_unit = (o_s[:, 0] - chief_x) / 10.0
    return {'x_unit': x_unit, 'o_s': o_s}    # HOLD o_s


def residual(theta):
    """Two independent trace calls, concat their x_unit columns."""
    s0 = trace_returning_dict(theta, seed_offset=0.0)
    s1 = trace_returning_dict(theta, seed_offset=10.0)
    return mx.concatenate([s0['x_unit'], s1['x_unit']])


theta = mx.array(0.5, dtype=DTYPE)

print("=== MLX jvp non-deterministic rank-mismatch repro ===")
print("Running 30 jvp calls per trial, 10 trials")
print()
n_cols = 30
n_trials = 10
crash_count = 0
for trial in range(n_trials):
    crashed = None
    for i in range(n_cols):
        v = mx.array(float(i+1), dtype=DTYPE)
        try:
            _, jv = mx.jvp(residual, [theta], [v])
            mx.eval(jv[0])
        except Exception as e:
            crashed = (i, type(e).__name__, str(e)[:110])
            break
    if crashed is None:
        print(f"  trial {trial:2d}: all {n_cols} jvp calls PASSED")
    else:
        i, et, msg = crashed
        print(f"  trial {trial:2d}: CRASH at jvp call {i}: {et}: {msg}")
        crash_count += 1
print()
print(f"{crash_count}/{n_trials} trials crashed.")
if crash_count > 0:
    print()
    print("==> BUG CONFIRMED: mx.jvp returns silent rank-collapsed tangents")
    print("    (some dimension becomes 0) on identical input, non-deterministically.")
    print("    VJP and FD on the same function are always correct.")

Expected behavior
I expect it to pass and be equal to the vjp equivalent.

Desktop (please complete the following information):

  • MLX 0.31.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions