Skip to content

Fix RNNT ONNX decoding memory leaks#73

Open
Qeca wants to merge 1 commit into
salute-developers:mainfrom
Qeca:codex/fix-rnnt-onnx-memory-leaks
Open

Fix RNNT ONNX decoding memory leaks#73
Qeca wants to merge 1 commit into
salute-developers:mainfrom
Qeca:codex/fix-rnnt-onnx-memory-leaks

Conversation

@Qeca

@Qeca Qeca commented May 28, 2026

Copy link
Copy Markdown

What changed

  • Fix RNNT state splitting so per-sample decoder state no longer keeps views into full batched LSTM states.
  • Align ONNX RNNT decoding with the torch path by using max_symbols_per_step from config with default 10.
  • Return token frames from ONNX ASR decode path for timestamp support.
  • Add focused regression tests for state copying, symbol-limit config, token frames, and SpecScaler no-mutation behavior.
  • Include small stability fixes for long positional encodings, checkpoint loading, VAD pipeline caching, downloads, flash-attention unpadding, and decoder device lookup.

Why

The RNNT split helpers returned views. Keeping those views in dec_state could retain full [L, B, H] LSTM buffers for each sample and cause memory growth on long audio
or larger batches. The ONNX decoder also had a hard-coded per-frame symbol limit of 3, while the torch decoder defaults to 10, which could produce different transcripts
on the same weights.

@Alexander4127

Copy link
Copy Markdown
Collaborator

Thank you for the report!

We tried to reproduce the issue in our environment, but so far it does not reproduce on our side.

Could you please share a bit more detail about your setup?

  • Library versions you are using
  • Whether you are using CPU or CUDA onnxruntime
  • Approximate amount of data processed
  • Typical audio lengths
  • The observed memory leak type / size

@Qeca

Qeca commented Jun 4, 2026

Copy link
Copy Markdown
Author

Sure, here are the details from our setup.

Library versions:
Full pip freeze from the Triton Python backend container:

Jinja2==3.1.6
MarkupSafe==3.0.3
PyGObject==3.42.1
PyJWT==2.3.0
PyYAML==6.0.3
SecretStorage==3.3.1
antlr4-python3-runtime==4.9.3
blinker==1.4
cffi==2.0.0
coloredlogs==15.0.1
cryptography==3.4.8
cuda-bindings==12.9.4
cuda-pathfinder==1.5.5
dbus-python==1.2.18
distlib==0.3.9
distro==1.7.0
filelock==3.16.1
flatbuffers==25.12.19
fsspec==2026.4.0
httplib2==0.20.2
humanfriendly==10.0
hydra-core==1.3.2
importlib-metadata==4.6.4
jeepney==0.7.1
keyring==23.5.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
more-itertools==8.10.0
mpmath==1.3.0
networkx==3.4.2
numpy==1.26.4
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.5
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.4.5
nvidia-nvtx-cu12==12.8.90
oauthlib==3.2.0
omegaconf==2.3.0
onnxruntime-gpu==1.23.2
packaging==26.2
platformdirs==4.3.6
protobuf==7.35.0
pycparser==3.0
pyparsing==2.4.7
python-apt==2.4.0+ubuntu4
sentencepiece==0.2.1
six==1.16.0
soundfile==0.13.1
sympy==1.14.0
torch==2.10.0
torchaudio==2.10.0
tqdm==4.67.3
triton==3.6.0
typing_extensions==4.15.0
virtualenv==20.27.0
wadllib==1.3.6
zipp==1.0.0

ONNXRuntime mode:
We use CUDA ONNXRuntime in the Triton/Python backend, not CPU-only ONNXRuntime.

Approximate amount of data processed:
The issue was observed under concurrent long-audio transcription load. One soak run processed 133 transcription requests in about 10 minutes at request concurrency 4.

Typical audio lengths:
Typical files were long-form audio/video files, approximately 10-20 minutes per request. One smoke file was about 17 minutes long.

Observed memory leak type / size:
It looked like CPU RSS growth / Python object retention in the RNNT ONNX/Triton path, not CUDA VRAM growth. The retained objects appear to be numpy/tensor views of batched
RNNT decoder LSTM states.

@Alexander4127

Copy link
Copy Markdown
Collaborator

Thanks for the details.

We ran a similar soak and saw RSS move from ~4.5 to ~4.7 GB; with .copy() in _split_state the difference was small and hard to separate from restart noise.

Also, since training used segments up to ~30 s, long raw audio at inference isn’t really supported - could you clarify how 10-20 min files were passed in your backend (single clip vs segmented)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants