Skip to content

[BUG] mx.logsumexp jvp does not reduce along the axis (wrong shape and values) #3707

@obchain

Description

@obchain

Describe the bug

The forward-mode derivative (jvp) of mx.logsumexp is wrong when it dispatches to the LogSumExp primitive (i.e. reducing the last axis, which is the common case used by softmax/attention code). The Jacobian-vector product should reduce over the axis:

d/dt logsumexp(x) = sum_j softmax(x)_j * t_j

but LogSumExp::jvp returns softmax(x) * t without the sum, so the returned tangent has the input shape instead of the reduced output shape, and the wrong values.

The vjp is correct, so this only affects forward-mode autodiff (mx.jvp) and anything built on it.

To Reproduce

import mlx.core as mx

x = mx.array([[1.0, 2.0, 3.0], [4.0, 1.0, 0.0]])
v = mx.ones_like(x)

_, jv = mx.jvp(lambda z: mx.logsumexp(z, axis=-1, keepdims=True), [x], [v])
print(jv[0].shape)   # -> (2, 3)  WRONG, expected (2, 1)
print(jv[0])         # un-reduced softmax*v instead of its sum over the axis

# expected (sum of softmax * tangent over the reduced axis):
print(mx.sum(mx.softmax(x, axis=-1) * v, axis=-1, keepdims=True))

The adjoint identity <w, Jv> == <v, Jᵀw> between jvp and vjp also fails for logsumexp because of this.

Expected behavior

mx.jvp of logsumexp returns a tangent with the reduced output shape, equal to sum(softmax(x) * t) over the reduced axis, and consistent with mx.vjp.

Desktop

  • OS: macOS
  • Version: main (0.32.0.dev)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions