Skip to content

onnx,core: fuse standard ONNX LSTM cell into one LstmEpilogue op#2294

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/lstm-gate-fusion
Open

onnx,core: fuse standard ONNX LSTM cell into one LstmEpilogue op#2294
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/lstm-gate-fusion

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

Summary

When testing DTLN — an LSTM-heavy real-time speech denoiser — on tract, profiling showed that in streaming mode (batch=1, one audio frame per call) the LSTM body spends most of its non-matmul time on per-gate op dispatch, not arithmetic. tract's ONNX importer decomposes each LSTM cell into ~15 separately-evaluated Sigmoid/Tanh/Mul/Add ops, each materialising an intermediate tensor. The matmuls are already efficient (memory-bandwidth-bound), so this scaffolding is the dominant overhead for streaming LSTMs.

This PR collapses that chain into one fused LstmEpilogue op, emitted from the importer for the standard cell.

What changed

  • New LstmEpilogue op (core/src/ops/lstm_cell.rs): given the gate pre-activations [batch, 4*hidden] (ONNX gate order i, o, f, c) and c_prev, computes Ht/Ct in one pass — sigmoid over the i,o,f block + tanh over the c block via tract's vectorised sigmoid_f32/tanh_f32 kernels on contiguous slices, then the cell/hidden update. Replaces the ~15-node decomposed chain.
  • Importer emission (onnx/src/ops/rec/lstm.rs): emits LstmEpilogue for the standard cell (no peepholes, concrete hidden size); peephole / symbolic-hidden cells keep the existing decomposed path. tract's ONNX LSTM always uses sigmoid/tanh activations, so the fused op is always valid where it fires. Also computes the gate projections with one matmul per side (Xt·Wᵀ, Ht-1·Rᵀ) + slice rather than four separate matmuls.

Correctness

Same activation kernels as the decomposed path → numerically identical. DTLN end-to-end output stays at 110.47 dB / Pearson 1.00000 vs the native reference (bit-equivalent up to f32 rounding). A unit test checks the cell math against a scalar reference over a multi-row batch.

Perf gain

DTLN, M1 Pro, f32, single-thread, full pipeline (the production regime), restaurant-noise clip, min of 3 runs; lower = faster:

engine ms / 1 s audio ×realtime
tract main (before) 14.71 68×
tract + this PR 13.11 76×
TFLite — Ruy (TFLite default) 13.80 72×
TFLite — XNNPACK delegate 11.76 85×
  • −11% vs tract main (14.71 → 13.11).
  • Now ~5% faster than TFLite's default (Ruy) engine (13.80).
  • TFLite's XNNPACK delegate (11.76) is still ~12% ahead; that residual is matmul memory-bandwidth plus a couple of streaming-graph follow-ups (below), not this epilogue.

The win concentrates in streaming / small-batch inference, where per-op dispatch dominates the small per-frame matmuls. It is parity-preserving for all shapes; for large offline batches the relative gain shrinks as the matmuls dominate.

Who benefits

Any ONNX model with a standard LSTM run in streaming / small-batch mode — e.g. real-time speech enhancement (DTLN, NSNet), streaming ASR LSTM decoders (RNN-T predictors), TTS (Tacotron-2), OCR (CRNN), time-series forecasting (DeepAR).

Scope / notes

Profiling DTLN (an LSTM-heavy streaming denoiser) on tract showed the
LSTM body spends most non-matmul time in per-gate dispatch: the importer
decomposes each cell into ~15 separately-evaluated Sigmoid/Tanh/Mul/Add
ops, each materialising an intermediate tensor. For streaming (batch=1,
one frame per call) that scaffolding dominates the actual arithmetic.

Add a single `LstmEpilogue` op that, given the combined gate
pre-activations [batch, 4*hidden] (ONNX order i,o,f,c) and c_prev,
computes Ht and Ct in one pass using tract's vectorised sigmoid_f32/
tanh_f32 kernels on contiguous slices. The importer emits it for the
standard case (no peepholes, concrete hidden size); tract's ONNX LSTM
always uses sigmoid/tanh activations, and peephole/symbolic-hidden cells
fall through to the existing decomposed form. Also computes the four gate
projections with one matmul per side (Xt·Wᵀ, Ht-1·Rᵀ) + slice instead of
four separate matmuls.

Numerically identical to the decomposed path (same activation kernels):
DTLN end-to-end output stays at 110.47 dB / Pearson 1.00000 vs the native
reference. Unit test checks the cell math against a scalar reference.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant