Skip to content

Fix stale tangent reuse in repeated JVPs#3738

Open
chrismicah wants to merge 1 commit into
ml-explore:mainfrom
chrismicah:fix-jvp-tangent-map-reassign
Open

Fix stale tangent reuse in repeated JVPs#3738
chrismicah wants to merge 1 commit into
ml-explore:mainfrom
chrismicah:fix-jvp-tangent-map-reassign

Conversation

@chrismicah

Copy link
Copy Markdown
Contributor

Summary

Why

Issue #3629 still reproduces on current main after the earlier partial JVP fixes in #3633 and #3636. On 602b5359, the original ray-trace-shaped repro crashed in 9/10 trials with errors such as:

[squeeze] Cannot squeeze axis 0 with size 125 which is not equal to 1.
[reshape] Cannot reshape array of size 125 into shape (125,4).

The failure is nondeterministic because an older tangent can remain in the JVP transform map when an output id is encountered again. Assigning the output tangent instead of preserving the first inserted value keeps the map aligned with the current primitive output.

Prior related work

Those fixes cover related symptoms, but the original #3629 repro still fails on current main; this PR fixes a distinct current-main regression in the JVP transform map.

Test plan

CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
PYTHONPATH=python python -m pytest python/tests/test_autograd.py -q -k 'repeated_jvp_with_reused_array_ids or jvp_through_bitwise_ops' --tb=short
PYTHONPATH=python python -m pytest python/tests/test_autograd.py -q --tb=short
python -m black --check python/tests/test_autograd.py
git diff --check

All passed locally.

Fixes #3629

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] JVP non-deterministic crashes when called in a loop.

1 participant