onnx,core: fuse standard ONNX LSTM cell into one LstmEpilogue op#2294
Open
czoli1976 wants to merge 1 commit into
Open
onnx,core: fuse standard ONNX LSTM cell into one LstmEpilogue op#2294czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
Profiling DTLN (an LSTM-heavy streaming denoiser) on tract showed the LSTM body spends most non-matmul time in per-gate dispatch: the importer decomposes each cell into ~15 separately-evaluated Sigmoid/Tanh/Mul/Add ops, each materialising an intermediate tensor. For streaming (batch=1, one frame per call) that scaffolding dominates the actual arithmetic. Add a single `LstmEpilogue` op that, given the combined gate pre-activations [batch, 4*hidden] (ONNX order i,o,f,c) and c_prev, computes Ht and Ct in one pass using tract's vectorised sigmoid_f32/ tanh_f32 kernels on contiguous slices. The importer emits it for the standard case (no peepholes, concrete hidden size); tract's ONNX LSTM always uses sigmoid/tanh activations, and peephole/symbolic-hidden cells fall through to the existing decomposed form. Also computes the four gate projections with one matmul per side (Xt·Wᵀ, Ht-1·Rᵀ) + slice instead of four separate matmuls. Numerically identical to the decomposed path (same activation kernels): DTLN end-to-end output stays at 110.47 dB / Pearson 1.00000 vs the native reference. Unit test checks the cell math against a scalar reference. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
When testing DTLN — an LSTM-heavy real-time speech denoiser — on tract, profiling showed that in streaming mode (batch=1, one audio frame per call) the LSTM body spends most of its non-matmul time on per-gate op dispatch, not arithmetic. tract's ONNX importer decomposes each LSTM cell into ~15 separately-evaluated
Sigmoid/Tanh/Mul/Addops, each materialising an intermediate tensor. The matmuls are already efficient (memory-bandwidth-bound), so this scaffolding is the dominant overhead for streaming LSTMs.This PR collapses that chain into one fused
LstmEpilogueop, emitted from the importer for the standard cell.What changed
LstmEpilogueop (core/src/ops/lstm_cell.rs): given the gate pre-activations[batch, 4*hidden](ONNX gate order i, o, f, c) andc_prev, computesHt/Ctin one pass — sigmoid over the i,o,f block + tanh over the c block via tract's vectorisedsigmoid_f32/tanh_f32kernels on contiguous slices, then the cell/hidden update. Replaces the ~15-node decomposed chain.onnx/src/ops/rec/lstm.rs): emitsLstmEpiloguefor the standard cell (no peepholes, concrete hidden size); peephole / symbolic-hidden cells keep the existing decomposed path. tract's ONNX LSTM always uses sigmoid/tanh activations, so the fused op is always valid where it fires. Also computes the gate projections with one matmul per side (Xt·Wᵀ,Ht-1·Rᵀ) + slice rather than four separate matmuls.Correctness
Same activation kernels as the decomposed path → numerically identical. DTLN end-to-end output stays at 110.47 dB / Pearson 1.00000 vs the native reference (bit-equivalent up to f32 rounding). A unit test checks the cell math against a scalar reference over a multi-row batch.
Perf gain
DTLN, M1 Pro, f32, single-thread, full pipeline (the production regime), restaurant-noise clip, min of 3 runs; lower = faster:
main(before)main(14.71 → 13.11).The win concentrates in streaming / small-batch inference, where per-op dispatch dominates the small per-frame matmuls. It is parity-preserving for all shapes; for large offline batches the relative gain shrinks as the matmuls dominate.
Who benefits
Any ONNX model with a standard LSTM run in streaming / small-batch mode — e.g. real-time speech enhancement (DTLN, NSNet), streaming ASR LSTM decoders (RNN-T predictors), TTS (Tacotron-2), OCR (CRNN), time-series forecasting (DeepAR).
Scope / notes
main; orthogonal to in-flight perf PRs (linalg/mmm: cache-adaptive 2D-blocking for the single-thread tile walk #2274 matmul cache-blocking is a different, compounding axis; core/ops/scan: reuse body state across iterations (skip per-timestep plan churn) #2257 Scan iter-reuse is neutral on seq=1 streaming).Scanfor streaming RNNs (a further ~−3% on DTLN, → 12.68 ms/s); and a declutter pass for LSTM cells that were unrolled at export time (some tf2onnx outputs contain noLSTMop, so this importer path never sees them).