From cba7502e4a74bd92cd7ca96ba87643191b55b1d4 Mon Sep 17 00:00:00 2001 From: obchain Date: Wed, 17 Jun 2026 12:44:53 +0530 Subject: [PATCH] Implement jvp for cumulative logsumexp Scan::jvp only handled the Sum reduction and threw for everything else, so forward-mode differentiation through mx.logcumsumexp raised. The jvp is the running softmax-weighted sum of the tangents, d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i, computed in log space by splitting the tangent into its positive and negative parts, mirroring the existing vjp. Exclusive scans leave the first element with no inputs (output -inf, locally constant), so its tangent is set to zero, which also avoids an inf - inf there. --- mlx/primitives.cpp | 36 +++++++++++++++++++++++++++++++++++ python/tests/test_autograd.py | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e2f31f4954..c42171a0e4 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4296,6 +4296,42 @@ std::vector Scan::jvp( if (reduce_type_ == Scan::Sum) { return {cumsum(tangents[0], axis_, reverse_, inclusive_, stream())}; + } else if (reduce_type_ == Scan::LogAddExp) { + // d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i. Compute it in log + // space for stability by splitting the tangent into its positive and + // negative parts, mirroring the vjp. + auto x = primals[0]; + auto t = tangents[0]; + auto y = logcumsumexp(x, axis_, reverse_, inclusive_, stream()); + + auto zero = zeros({1}, t.dtype(), stream()); + auto log_min = array(finfo(t.dtype()).min, t.dtype()); + auto log_abs_t = log(abs(t, stream()), stream()); + auto log_t_positive = + where(greater(t, zero, stream()), log_abs_t, log_min, stream()); + auto log_t_negative = + where(less(t, zero, stream()), log_abs_t, log_min, stream()); + + auto masked_scan = [&](const array& log_t) { + return exp( + subtract( + logcumsumexp( + add(log_t, x, stream()), + axis_, + reverse_, + inclusive_, + stream()), + y, + stream()), + stream()); + }; + auto out = subtract( + masked_scan(log_t_positive), masked_scan(log_t_negative), stream()); + // An exclusive scan leaves the first element with no inputs, so the output + // is -inf and locally constant there: its jvp is zero (this also avoids an + // inf - inf in the expression above). + return { + where(isneginf(y, stream()), zeros_like(out, stream()), out, stream())}; } else { throw std::runtime_error( "JVP is not implemented for cumulative prod/min/max"); diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 221f9229b1..4a204222b7 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -590,6 +590,41 @@ def fun(y): expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0]) self.assertTrue(mx.allclose(out, expected)) + def test_logcumsumexp_grad(self): + # The jvp of logcumsumexp is the running softmax-weighted sum of the + # tangents; check it against an explicit reference and against the vjp. + x = mx.array([1.0, 3.0, 2.0, 4.0]) + v = mx.array([1.0, -1.0, 2.0, 0.5]) + jv = mx.jvp(lambda z: mx.logcumsumexp(z), (x,), (v,))[1][0] + # d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x[:k+1])_i * v_i + ref = [] + for k in range(x.size): + w = mx.softmax(x[: k + 1]) + ref.append(mx.sum(w * v[: k + 1]).item()) + self.assertTrue(mx.allclose(jv, mx.array(ref))) + + # vjp must be the transpose of the jvp (adjoint test) for every + # combination of the reverse / inclusive flags and across axes. + mx.random.seed(0) + for reverse in (False, True): + for inclusive in (True, False): + for axis in (0, 1, -1): + a = mx.random.normal((4, 6)) + v = mx.random.normal(a.shape) + w = mx.random.normal(a.shape) + + def fun(z): + return mx.logcumsumexp( + z, axis=axis, reverse=reverse, inclusive=inclusive + ) + + jv = mx.jvp(fun, (a,), (v,))[1][0] + jtw = mx.vjp(fun, (a,), (w,))[1][0] + self.assertTrue(mx.all(mx.isfinite(jv))) + self.assertAlmostEqual( + mx.sum(w * jv).item(), mx.sum(v * jtw).item(), places=4 + ) + def test_topk_grad(self): a = mx.array([[1, 2, 6, 4, 5], [9, 5, 6, 7, 8]], mx.float32)