Skip to content

Actually fixing reverse AD#33

Merged
TonyZhou729 merged 2 commits into
mainfrom
reverseAD2
May 28, 2026
Merged

Actually fixing reverse AD#33
TonyZhou729 merged 2 commits into
mainfrom
reverseAD2

Conversation

@cgiovanetti

Copy link
Copy Markdown
Collaborator

Fixing #31. The VeryChord thing was a red herring, turns out the problem was with the inf padding in the hyrex outputs (which is why we were only seeing NaNs in a couple of densities that touch e.g. xe).

Here's what was happening:

  1. On the backward pass of reverse AD, the RHS of the perturbations equations get evaluated again. This already opens us up to the double-where NaN issue in JAX
  2. Sometimes, t1 in evolution_one_k ends up pushing these arrays very close to lastval, getting closer and closer for larger k. lensing=True adds a few additional larger k. This increases the risk of fast_interp needing to interpolate at an index beyond lastval.
  3. If we are unlucky with our grid so that we do end up beyond the lastval boundary, fast_interp gives an inf. This doesn't show up on forward passes since this isn't the branch of the jnp.where that actually gets evaluated, but because we're taking a gradient JAX freaks out unless both branches of where are finite.

So the fix was pretty simple, once we have the outputs from HyRex I just patch lastnum on to all of the inf values of the array so this doesn't trigger. No change to speed or accuracy, gives finite gradients.

Also got rid of the previous VeryChord tolerance fix, it doesn't actually get us anything.

@cgiovanetti cgiovanetti requested a review from TonyZhou729 May 27, 2026 21:06
@TonyZhou729 TonyZhou729 merged commit fb154e8 into main May 28, 2026
1 check passed
@cgiovanetti cgiovanetti deleted the reverseAD2 branch May 28, 2026 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants