Skip to content

onnx: auto-enable external_state to inline single-iteration Scan for streaming RNNs#2295

Draft
czoli1976 wants to merge 2 commits into
sonos:mainfrom
czoli1976:feature/scan-external-state
Draft

onnx: auto-enable external_state to inline single-iteration Scan for streaming RNNs#2295
czoli1976 wants to merge 2 commits into
sonos:mainfrom
czoli1976:feature/scan-external-state

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

Summary

Companion to #2294, from the same DTLN profiling. tract already has a declutter_single_loop pass that inlines a single-iteration Scan, but it only fires when Scan::external_state is set — which until now was reachable only via the manual force_scan_external_state transform. So streaming RNNs (seq_len == 1, one frame/token per call) carried a dead one-iteration Scan on every call: pure orchestration overhead (state spawn, per-iter setup, output assembly).

This PR has the ONNX recurrent importer set external_state automatically when the caller manages state.

What changed

In onnx/src/ops/rec/common.rs, when an LSTM/GRU/RNN exposes its full recurrent state both as input and output — initial_h + Y_h (and initial_c + Y_c for LSTM) — set Scan::external_state = true. That is the caller-managed-state contract (streaming denoisers, autoregressive / RNN-T decoders): the caller supplies state in and reads it out every call. The existing declutter_single_loop then inlines the seq_len == 1 Scan, feeding the body's State input directly from the outer input each call.

13 lines; no new op or pass — it just enables the existing inliner for the case it was designed for.

Correctness / soundness

external_state only affects declutter_single_loop (the inline guard); it is inert for multi-iteration Scans, so full-sequence (seq_len > 1) runs are unchanged. Inlining a single-iteration Scan with externally-supplied state is sound — the body runs once and state flows through the model I/O (this satisfies the guard from issue #2157).

Measured on DTLN end-to-end: output stays 110.47 dB / Pearson 1.00000 vs the native reference. Existing tract-onnx tests pass.

Perf gain

DTLN, M1 Pro, f32, single-thread, full pipeline, restaurant-noise clip, min of 3 runs; lower = faster:

config ms / 1 s audio ×realtime
tract main 14.71 68×
tract + this PR 13.54 74×
tract + #2294 + this PR 12.68 79×
TFLite — Ruy (TFLite default) 13.80 72×
TFLite — XNNPACK delegate 11.76 85×

Who benefits

Every seq_len == 1 streaming RNN exported with exposed state — real-time speech enhancement (DTLN, NSNet), streaming ASR LSTM/GRU decoders (RNN-T predictors), autoregressive TTS, and similar on-device / real-time workloads.

When an ONNX LSTM/GRU/RNN exposes its full recurrent state both as input
and output (initial_h + Y_h, plus initial_c + Y_c for LSTM), the caller
manages state across calls. Set Scan::external_state in that case so the
existing declutter_single_loop pass can inline a single-iteration Scan
(seq_len == 1) — the streaming / autoregressive-decoder regime where the
one-iteration Scan is pure orchestration overhead.

Previously external_state was only reachable via the manual
force_scan_external_state transform, so streaming RNNs carried a dead Scan
on every call. Inlining is sound here because the body's State input is fed
from the outer (caller-supplied) input each call (see issue sonos#2157).

Measured on DTLN (an LSTM-heavy streaming denoiser): -8% end-to-end, output
unchanged at 110.47 dB / Pearson 1.00000 vs the native reference.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kali
Copy link
Copy Markdown
Collaborator

kali commented May 27, 2026

So i have tried this before, but ran into a problem. It don't think it is possible to distinguish reliably a stateful model meant to be called with pulses of length 1 from a "stateless" (from tract POV) model with state managed on the outside. The difference is only manifest once the user call it with Runnable::run or ::spawn then State::run, but we need to apply the Transform before that. (Now that I think of it, this is how 0.23 where breaking DFN3, right ?)

WDYT ?

@kali
Copy link
Copy Markdown
Collaborator

kali commented May 27, 2026

Ha, your heuristics is, C and H are outputted so the state is managed externally ?

@czoli1976
Copy link
Copy Markdown
Contributor Author

You are right, it has an hole now, as it is

image

Trying to refine to see if can make it more robust, let's see

@czoli1976 czoli1976 marked this pull request as draft May 27, 2026 09:16
…t flag

The previous importer heuristic set external_state whenever a GRU/LSTM node
carried initial_h/Y_h (and initial_c/Y_c). That mis-fires: DFN3's GRU nodes
also carry initial_h/Y_h, but their state is carried internally by tract under
pulse, not by the caller — so inlining the single-iteration Scan would break
it (the 0.23 regression kali flagged).

Move the decision into declutter_single_loop, which has the whole graph:
inline a single-iteration Scan only when every recurrent state has a
last-value output that reaches a model output, i.e. the caller can observe the
updated state and thread it back. Adds outlet_reaches_model_output.

DTLN (state feeds a model output) still inlines, output unchanged at
110.47 dB / Pearson 1.00000. DFN3 df_dec (Y_h reaches no model output, only
coefs) is not inlined.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants