Describe the bug
The forward-mode derivative (jvp) of mx.logsumexp is wrong when it dispatches to the LogSumExp primitive (i.e. reducing the last axis, which is the common case used by softmax/attention code). The Jacobian-vector product should reduce over the axis:
d/dt logsumexp(x) = sum_j softmax(x)_j * t_j
but LogSumExp::jvp returns softmax(x) * t without the sum, so the returned tangent has the input shape instead of the reduced output shape, and the wrong values.
The vjp is correct, so this only affects forward-mode autodiff (mx.jvp) and anything built on it.
To Reproduce
import mlx.core as mx
x = mx.array([[1.0, 2.0, 3.0], [4.0, 1.0, 0.0]])
v = mx.ones_like(x)
_, jv = mx.jvp(lambda z: mx.logsumexp(z, axis=-1, keepdims=True), [x], [v])
print(jv[0].shape) # -> (2, 3) WRONG, expected (2, 1)
print(jv[0]) # un-reduced softmax*v instead of its sum over the axis
# expected (sum of softmax * tangent over the reduced axis):
print(mx.sum(mx.softmax(x, axis=-1) * v, axis=-1, keepdims=True))
The adjoint identity <w, Jv> == <v, Jᵀw> between jvp and vjp also fails for logsumexp because of this.
Expected behavior
mx.jvp of logsumexp returns a tangent with the reduced output shape, equal to sum(softmax(x) * t) over the reduced axis, and consistent with mx.vjp.
Desktop
- OS: macOS
- Version: main (0.32.0.dev)
Describe the bug
The forward-mode derivative (
jvp) ofmx.logsumexpis wrong when it dispatches to theLogSumExpprimitive (i.e. reducing the last axis, which is the common case used by softmax/attention code). The Jacobian-vector product should reduce over the axis:but
LogSumExp::jvpreturnssoftmax(x) * twithout the sum, so the returned tangent has the input shape instead of the reduced output shape, and the wrong values.The
vjpis correct, so this only affects forward-mode autodiff (mx.jvp) and anything built on it.To Reproduce
The adjoint identity
<w, Jv> == <v, Jᵀw>betweenjvpandvjpalso fails forlogsumexpbecause of this.Expected behavior
mx.jvpoflogsumexpreturns a tangent with the reduced output shape, equal tosum(softmax(x) * t)over the reduced axis, and consistent withmx.vjp.Desktop