onnx: auto-enable external_state to inline single-iteration Scan for streaming RNNs#2295
Draft
czoli1976 wants to merge 2 commits into
Draft
onnx: auto-enable external_state to inline single-iteration Scan for streaming RNNs#2295czoli1976 wants to merge 2 commits into
czoli1976 wants to merge 2 commits into
Conversation
When an ONNX LSTM/GRU/RNN exposes its full recurrent state both as input and output (initial_h + Y_h, plus initial_c + Y_c for LSTM), the caller manages state across calls. Set Scan::external_state in that case so the existing declutter_single_loop pass can inline a single-iteration Scan (seq_len == 1) — the streaming / autoregressive-decoder regime where the one-iteration Scan is pure orchestration overhead. Previously external_state was only reachable via the manual force_scan_external_state transform, so streaming RNNs carried a dead Scan on every call. Inlining is sound here because the body's State input is fed from the outer (caller-supplied) input each call (see issue sonos#2157). Measured on DTLN (an LSTM-heavy streaming denoiser): -8% end-to-end, output unchanged at 110.47 dB / Pearson 1.00000 vs the native reference. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Collaborator
|
So i have tried this before, but ran into a problem. It don't think it is possible to distinguish reliably a stateful model meant to be called with pulses of length 1 from a "stateless" (from tract POV) model with state managed on the outside. The difference is only manifest once the user call it with Runnable::run or ::spawn then State::run, but we need to apply the Transform before that. (Now that I think of it, this is how 0.23 where breaking DFN3, right ?) WDYT ? |
Collaborator
|
Ha, your heuristics is, C and H are outputted so the state is managed externally ? |
Contributor
Author
…t flag The previous importer heuristic set external_state whenever a GRU/LSTM node carried initial_h/Y_h (and initial_c/Y_c). That mis-fires: DFN3's GRU nodes also carry initial_h/Y_h, but their state is carried internally by tract under pulse, not by the caller — so inlining the single-iteration Scan would break it (the 0.23 regression kali flagged). Move the decision into declutter_single_loop, which has the whole graph: inline a single-iteration Scan only when every recurrent state has a last-value output that reaches a model output, i.e. the caller can observe the updated state and thread it back. Adds outlet_reaches_model_output. DTLN (state feeds a model output) still inlines, output unchanged at 110.47 dB / Pearson 1.00000. DFN3 df_dec (Y_h reaches no model output, only coefs) is not inlined. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
Companion to #2294, from the same DTLN profiling. tract already has a
declutter_single_looppass that inlines a single-iterationScan, but it only fires whenScan::external_stateis set — which until now was reachable only via the manualforce_scan_external_statetransform. So streaming RNNs (seq_len == 1, one frame/token per call) carried a dead one-iteration Scan on every call: pure orchestration overhead (state spawn, per-iter setup, output assembly).This PR has the ONNX recurrent importer set
external_stateautomatically when the caller manages state.What changed
In
onnx/src/ops/rec/common.rs, when an LSTM/GRU/RNN exposes its full recurrent state both as input and output —initial_h+Y_h(andinitial_c+Y_cfor LSTM) — setScan::external_state = true. That is the caller-managed-state contract (streaming denoisers, autoregressive / RNN-T decoders): the caller supplies state in and reads it out every call. The existingdeclutter_single_loopthen inlines theseq_len == 1Scan, feeding the body's State input directly from the outer input each call.13 lines; no new op or pass — it just enables the existing inliner for the case it was designed for.
Correctness / soundness
external_stateonly affectsdeclutter_single_loop(the inline guard); it is inert for multi-iteration Scans, so full-sequence (seq_len > 1) runs are unchanged. Inlining a single-iteration Scan with externally-supplied state is sound — the body runs once and state flows through the model I/O (this satisfies the guard from issue #2157).Measured on DTLN end-to-end: output stays 110.47 dB / Pearson 1.00000 vs the native reference. Existing
tract-onnxtests pass.Perf gain
DTLN, M1 Pro, f32, single-thread, full pipeline, restaurant-noise clip, min of 3 runs; lower = faster:
mainmain(14.71 → 13.54), edging past TFLite's default (Ruy) engine.LstmEpiloguegate-epilogue fusion): together −14% (→ 12.68), comfortably past Ruy. The two are orthogonal — this PR removes Scan orchestration, onnx,core: fuse standard ONNX LSTM cell into one LstmEpilogue op #2294 fuses the gate epilogue.Who benefits
Every
seq_len == 1streaming RNN exported with exposed state — real-time speech enhancement (DTLN, NSNet), streaming ASR LSTM/GRU decoders (RNN-T predictors), autoregressive TTS, and similar on-device / real-time workloads.