Skip to content

Implement jvp for cumulative logsumexp#3711

Open
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/scan-jvp-logcumsumexp
Open

Implement jvp for cumulative logsumexp#3711
obchain wants to merge 1 commit into
ml-explore:mainfrom
obchain:fix/scan-jvp-logcumsumexp

Conversation

@obchain

@obchain obchain commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Proposed changes

Fixes #3710.

Scan::jvp only implemented the Sum reduction and threw for everything else, so forward-mode differentiation through mx.logcumsumexp raised JVP is not implemented for cumulative prod/min/max. Its vjp was already implemented, so only forward mode was affected.

The jvp of logcumsumexp is the running softmax-weighted sum of the tangents:

d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i

This is computed in log space for numerical stability by splitting the tangent into its positive and negative parts, mirroring the structure of the existing LogAddExp vjp. Exclusive scans leave the first element with no inputs (output -inf, locally constant), so its tangent is set to zero — this also avoids an inf - inf in the expression.

cumprod / cummax / cummin jvps are unchanged (still not implemented) and the error message stays accurate.

Before:

>>> mx.jvp(lambda z: mx.logcumsumexp(z), [mx.array([1.,2.,3.])], [mx.ones(3)])
RuntimeError: JVP is not implemented for cumulative prod/min/max

Added a test in test_autograd.py that checks the jvp against an explicit softmax-weighted reference and verifies the jvp/vjp adjoint identity for every combination of the reverse / inclusive flags across axes.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Scan::jvp only handled the Sum reduction and threw for everything else,
so forward-mode differentiation through mx.logcumsumexp raised. The jvp
is the running softmax-weighted sum of the tangents,

    d/dt logcumsumexp(x)_k = sum_{i<=k} softmax(x)_i * t_i,

computed in log space by splitting the tangent into its positive and
negative parts, mirroring the existing vjp. Exclusive scans leave the
first element with no inputs (output -inf, locally constant), so its
tangent is set to zero, which also avoids an inf - inf there.

@aeiwz aeiwz left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finding

[P2] Preserve complex tangent values — mlx/primitives.cpp (

mlx/mlx/primitives.cpp

Lines 4308 to 4316 in cba7502

auto log_min = array(finfo(t.dtype()).min, t.dtype());
auto log_abs_t = log(abs(t, stream()), stream());
auto log_t_positive =
where(greater(t, zero, stream()), log_abs_t, log_min, stream());
auto log_t_negative =
where(less(t, zero, stream()), log_abs_t, log_min, stream());
auto masked_scan = [&](const array& log_t) {
return exp(
)

The positive/negative split uses abs(t) and comparisons, which discards the phase of complex tangents. Since logcumsumexp supports complex arrays, JVPs such as tangent [1+1j, 2-1j]
return real magnitudes instead of the expected complex derivative. The implementation should either support complex tangents directly or explicitly reject them. Add a complex JVP test.

Numerically verified expected derivative:

[1+1j, 1.73106-0.462117j]

The current expression produces approximately:

[1.41421+0j, 2.01504+0j]

No other actionable findings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] jvp of mx.logcumsumexp is not implemented

2 participants