Skip to content

[BUG] jvp of mx.logcumsumexp is not implemented #3710

@obchain

Description

@obchain

Describe the bug

Forward-mode autodiff (mx.jvp) through mx.logcumsumexp raises instead of returning the tangent:

RuntimeError: JVP is not implemented for cumulative prod/min/max

Scan::jvp only implements the Sum reduction; logcumsumexp (the LogAddExp scan) falls through to the generic "not implemented" branch even though its vjp is implemented. This breaks mx.jvp and forward-over-reverse compositions that go through logcumsumexp.

To Reproduce

import mlx.core as mx

x = mx.array([1.0, 2.0, 3.0])
t = mx.ones_like(x)
mx.jvp(lambda z: mx.logcumsumexp(z), [x], [t])
# RuntimeError: JVP is not implemented for cumulative prod/min/max

Expected behavior

mx.jvp of logcumsumexp returns the running softmax-weighted sum of the tangents,

d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i,

consistent with the existing vjp (the adjoint identity <w, Jv> == <v, Jᵀw> should hold).

Desktop

  • OS: macOS
  • Version: main (0.32.0.dev)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions