From 6370a343abe5b7d6a81f13dcef560a44262e7e09 Mon Sep 17 00:00:00 2001 From: chrismicah Date: Sun, 21 Jun 2026 08:40:05 -0500 Subject: [PATCH] Fix stale tangent reuse in repeated JVPs --- mlx/transforms.cpp | 2 +- python/tests/test_autograd.py | 94 +++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 9a8207339e..000a702487 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -643,7 +643,7 @@ std::pair, std::vector> jvp( // A primitive's jvp returns one tangent per output assert(jvps.size() <= outputs.size()); for (int i = 0; i < jvps.size(); ++i) { - tan_map.insert({outputs[i].id(), jvps[i]}); + tan_map.insert_or_assign(outputs[i].id(), jvps[i]); } } diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 1ed1bc6997..e2a2b059b0 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -125,6 +125,100 @@ def fun(x): _, (dout,) = mx.jvp(fun, [x], [mx.ones_like(x)]) self.assertTrue(mx.array_equal(dout, mx.array([1.0, 0.0, 1.0]))) + def test_repeated_jvp_with_reused_array_ids(self): + # Repeated JVPs over this graph used to corrupt tangent shapes + # nondeterministically because an older tangent could remain in the + # transform map when array ids were reused (issue #3629). + dtype = mx.float32 + + def trace_module(theta, p0, q0, n_total, scalar_a): + xdec = mx.array([theta, 0.0, 0.0, 0.0], dtype=dtype) + xdec = mx.repeat(xdec[None, :], n_total, axis=0) + ydec = mx.repeat( + mx.array([0.0, -1.5, -1.4, -1.3], dtype=dtype)[None, :], + n_total, + axis=0, + ) + t = mx.repeat( + mx.array([0.0, 5.0, 6.0, 7.0], dtype=dtype)[None, :], + n_total, + axis=0, + ) + c = mx.array([0.0, 0.01, -0.008, 0.0], dtype=dtype) + + for surf in range(4): + xs = xdec[:, surf : surf + 1] + ys = ydec[:, surf : surf + 1] + zs = t[:, surf : surf + 1] + p0 = mx.concatenate( + [p0[:, 0:1] - xs, p0[:, 1:2] - ys, p0[:, 2:3] - zs], + axis=1, + ) + k = q0[:, 0:1] + l = q0[:, 1:2] + m = q0[:, 2:3] + s0 = -p0[:, 2:3] / m + x1 = p0[:, 0:1] + k * s0 + y1 = p0[:, 1:2] + l * s0 + c_surf = c[surf] + sj = 0.0 + for _ in range(5): + xj = x1 + k * sj + yj = y1 + l * sj + zj = m * sj + rho2 = xj * xj + yj * yj + sqrt_val = mx.sqrt(1.0 - c_surf * c_surf * rho2 + 1e-12) + f = zj - c_surf * rho2 / (1.0 + sqrt_val) + e = c_surf / sqrt_val + sj = sj - f / (-xj * e * k + -yj * e * l + m) + + xj = x1 + k * sj + yj = y1 + l * sj + zj = m * sj + xj = mx.where(sj > -1e8, xj, mx.array(float("nan"), dtype=dtype)) + zero_rho = ((xj == 0) & (yj == 0)).astype(dtype) * 1e-9 + xj = xj + zero_rho + yj = yj + zero_rho + nvec_norm = mx.sqrt(c_surf * c_surf * (xj * xj + yj * yj) + 1.0 + 1e-12) + kk = -c_surf * xj / nvec_norm + ll = -c_surf * yj / nvec_norm + mm = 1.0 / nvec_norm + q0 = ( + q0 + - 2.0 + * (k * kk + l * ll + m * mm) + * mx.concatenate([kk, ll, mm], axis=1) + * scalar_a + ) + p0 = mx.concatenate([xj, yj, zj], axis=1) + return p0 + + def trace_returning_dict(theta, n_total=125, seed_offset=0.0): + rays = mx.arange(n_total, dtype=dtype) + p0 = mx.stack( + [rays * 0.01 + seed_offset, rays * 0.005, mx.zeros_like(rays)], + axis=1, + ) + q0 = mx.broadcast_to(mx.array([[0.0, 0.0, 1.0]], dtype=dtype), (n_total, 3)) + o_s = trace_module(theta, p0, q0, n_total, scalar_a=0.5) + chief_x = o_s[0, 0] + x_unit = (o_s[:, 0] - chief_x) / 10.0 + return {"x_unit": x_unit, "o_s": o_s} + + def residual(theta): + s0 = trace_returning_dict(theta, seed_offset=0.0) + s1 = trace_returning_dict(theta, seed_offset=10.0) + return mx.concatenate([s0["x_unit"], s1["x_unit"]]) + + theta = mx.array(0.5, dtype=dtype) + for trial in range(3): + for i in range(30): + _, (jv,) = mx.jvp( + residual, [theta], [mx.array(float(i + 1), dtype=dtype)] + ) + mx.eval(jv) + self.assertEqual(jv.shape, (250,), f"trial={trial} jvp={i}") + def test_vjp(self): fun = lambda x: 2 * x out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)])