Commit 5be478b
committed
Update on "[Executorch] Add non-flash SDPA for decode"
Add cpu_sdpa template function in op_sdpa_impl.h that provides a
simpler SDPA implementation using standard GEMM (no tiling). This is
useful as a baseline and for cases where flash attention is not optimal.
The implementation uses a single SeqDim parameter for all tensors and
supports causal masking, attention masks, GQA, and multi-threading.
During decode (seq_len == 1), the tiled flash attention implementation
has unnecessary overhead from its blocking/tiling logic. The simpler
unfused SDPA path using direct GEMM is more efficient for single-query
attention, yielding ~25-30% decode throughput improvement on S25
(41 -> 53 tok/s for 1.4B parameter model).
This makes cpu_sdpa always available (previously gated behind
ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and
inputs are not quantized. Prefill continues to use flash attention.
Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/)
[ghstack-poisoned]1 file changed
Lines changed: 4 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
138 | 138 | | |
139 | 139 | | |
140 | 140 | | |
141 | | - | |
142 | | - | |
143 | | - | |
| 141 | + | |
| 142 | + | |
144 | 143 | | |
145 | 144 | | |
146 | 145 | | |
| 146 | + | |
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
| |||
203 | 203 | | |
204 | 204 | | |
205 | 205 | | |
206 | | - | |
| 206 | + | |
207 | 207 | | |
208 | 208 | | |
209 | 209 | | |
| |||
0 commit comments