Skip to content

jaxlib/gpu/rnn_kernels.cc: remove premature input_tensor_desc destruction (ROCM-21859)#765

Draft
srinivamd wants to merge 3 commits into
rocm-jaxlib-v0.9.0from
fix/rnn-input-tensor-desc-v0.9.0-backport
Draft

jaxlib/gpu/rnn_kernels.cc: remove premature input_tensor_desc destruction (ROCM-21859)#765
srinivamd wants to merge 3 commits into
rocm-jaxlib-v0.9.0from
fix/rnn-input-tensor-desc-v0.9.0-backport

Conversation

@srinivamd
Copy link
Copy Markdown

Problem

PR #726 ("Fix HIP memory leaks in RNN kernels", merged 2026-03-05) added
gpudnnDestroyTensorDescriptor(input_tensor_desc) inside the #ifdef JAX_GPU_HIP
cleanup blocks at the end of both DnnRNNForward_ and DnnRNNBackward_.

MIOpen requires input_tensor_desc to remain valid throughout the
gpudnnRNNBackwardWeights call — it uses the descriptor as part of its
execution buffer. Destroying it in the forward / workspace-sizing path before
BackwardWeights completes triggers miopenStatusUnknownError on gfx1201
(Navi48 / RDNA4).

This regressed experimental_rnn_test::test_lstm1 and test_lstm9 on gfx1201
(first seen in image 1317, built 2026-04-29, ROCm 7.13.0). The same root cause
was independently identified by PR #729 / PR #730.

Fix

Remove the two gpudnnDestroyTensorDescriptor(input_tensor_desc) lines from
the #ifdef JAX_GPU_HIP cleanup blocks in DnnRNNForward_ and DnnRNNBackward_.
input_tensor_desc lifetime is managed by MIOpen internally; the explicit destroy
here is both premature and incorrect.

Tests

  • tests/experimental_rnn_test.py::RNNTest::test_lstm1
  • tests/experimental_rnn_test.py::RNNTest::test_lstm9

References

…ensor_desc destruction

PR#726 ("Fix HIP memory leaks in RNN kernels") introduced
gpudnnDestroyTensorDescriptor(input_tensor_desc) in DnnRNNForward_ and
DnnRNNBackward_ cleanup blocks. MIOpen requires input_tensor_desc to
remain valid through gpudnnRNNBackwardWeights because it uses the
descriptor as part of its execution buffer; destroying it early triggers
miopenStatusUnknownError on gfx1201 (RDNA4).

Remove the two premature destroy calls. The descriptor is stack-allocated
so it is freed when the function returns.

Fixes: ROCM-21859 (test_lstm1, test_lstm9 on gfx1201)
Regressed-by: #726
Equivalent to: #729 (targeting v0.9.0)
…or_desc destroy

The previous commit accidentally introduced an extra } after
DnnRNNForward_. Remove it and also remove the premature
gpudnnDestroyTensorDescriptor(input_tensor_desc) from DnnRNNBackward_.
…tion (ROCM-21859)

PR#726 ("Fix HIP memory leaks in RNN kernels", 2026-03-05) introduced
gpudnnDestroyTensorDescriptor(input_tensor_desc) in the cleanup blocks of
both DnnRNNForward_ and DnnRNNBackward_. MIOpen requires input_tensor_desc
to remain valid throughout gpudnnRNNBackwardWeights because it uses the
descriptor as part of its execution buffer; destroying it early triggers
miopenStatusUnknownError on gfx1201 (RDNA4, Navi48).

Remove the two premature destroy calls. input_tensor_desc is stack-allocated
so MIOpen owns no persistent reference beyond the call; it is automatically
released when the function returns.

Fixes test_lstm1 and test_lstm9 failures on gfx1201 (image 1317).
Regressed-by: #726
Equivalent-to: #729 (targeting v0.9.0)
@srinivamd srinivamd marked this pull request as draft April 30, 2026 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant