Describe the bug
Forward-mode autodiff (mx.jvp) through mx.logcumsumexp raises instead of returning the tangent:
RuntimeError: JVP is not implemented for cumulative prod/min/max
Scan::jvp only implements the Sum reduction; logcumsumexp (the LogAddExp scan) falls through to the generic "not implemented" branch even though its vjp is implemented. This breaks mx.jvp and forward-over-reverse compositions that go through logcumsumexp.
To Reproduce
import mlx.core as mx
x = mx.array([1.0, 2.0, 3.0])
t = mx.ones_like(x)
mx.jvp(lambda z: mx.logcumsumexp(z), [x], [t])
# RuntimeError: JVP is not implemented for cumulative prod/min/max
Expected behavior
mx.jvp of logcumsumexp returns the running softmax-weighted sum of the tangents,
d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i,
consistent with the existing vjp (the adjoint identity <w, Jv> == <v, Jᵀw> should hold).
Desktop
- OS: macOS
- Version: main (0.32.0.dev)
Describe the bug
Forward-mode autodiff (
mx.jvp) throughmx.logcumsumexpraises instead of returning the tangent:Scan::jvponly implements theSumreduction;logcumsumexp(theLogAddExpscan) falls through to the generic "not implemented" branch even though itsvjpis implemented. This breaksmx.jvpand forward-over-reverse compositions that go throughlogcumsumexp.To Reproduce
Expected behavior
mx.jvpoflogcumsumexpreturns the running softmax-weighted sum of the tangents,consistent with the existing
vjp(the adjoint identity<w, Jv> == <v, Jᵀw>should hold).Desktop