From b5d18fe1c2c127d6e244fd5a57ee86869afeb6df Mon Sep 17 00:00:00 2001 From: obchain Date: Wed, 17 Jun 2026 12:26:05 +0530 Subject: [PATCH] Fix logsumexp jvp to reduce along the axis LogSumExp::jvp returned tangents[0] * softmax(x) elementwise without summing over the reduced axis, so the forward-mode tangent had the input shape instead of the reduced output shape and the wrong values. The jvp of logsumexp is sum(softmax(x) * t) over that axis; sum the product back to the output shape to match. Mirrors the reduction already done in the forward pass and the broadcasting in the vjp. --- mlx/primitives.cpp | 14 ++++++++++---- python/tests/test_autograd.py | 27 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e2f31f4954..538d57c8df 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2947,10 +2947,16 @@ std::vector LogSumExp::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(tangents.size() == 1); - return {multiply( - tangents[0], - softmax(primals[0], std::vector{-1}, true, stream()), - stream())}; + // d/dt logsumexp(x) = sum(softmax(x) * t) over the reduced axis. The result + // must be summed back to the output shape, not left as a per-element product. + return { + sum(multiply( + tangents[0], + softmax(primals[0], std::vector{-1}, true, stream()), + stream()), + -1, + /* keepdims = */ true, + stream())}; } std::vector LogSumExp::output_shapes(const std::vector& inputs) { diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 221f9229b1..84ae47add8 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -621,6 +621,33 @@ def test_sort_grad(self): mx.sum(w * jv).item(), mx.sum(v * jtw).item(), places=4 ) + def test_logsumexp_grad(self): + # The jvp of logsumexp reduces along the axis (sum of softmax * tangent), + # so the tangent it returns must have the reduced output shape, not the + # input shape. + x = mx.array([[1.0, 2.0, 3.0], [4.0, 1.0, 0.0]]) + v = mx.array([[1.0, 0.0, -1.0], [2.0, 1.0, 0.0]]) + jv = mx.jvp(lambda z: mx.logsumexp(z, axis=-1, keepdims=True), (x,), (v,))[1][0] + self.assertEqual(jv.shape, (2, 1)) + expected = mx.sum(mx.softmax(x, axis=-1) * v, axis=-1, keepdims=True) + self.assertTrue(mx.allclose(jv, expected)) + + # vjp must be the transpose of the jvp (adjoint test). + mx.random.seed(0) + for keepdims in (True, False): + a = mx.random.normal((4, 6)) + v = mx.random.normal(a.shape) + + def fun(z): + return mx.logsumexp(z, axis=-1, keepdims=keepdims) + + w = mx.random.normal(fun(a).shape) + jv = mx.jvp(fun, (a,), (v,))[1][0] + jtw = mx.vjp(fun, (a,), (w,))[1][0] + self.assertAlmostEqual( + mx.sum(w * jv).item(), mx.sum(v * jtw).item(), places=4 + ) + def test_custom_function(self): # Make a custom function my_exp = mx.custom_function(mx.exp)