Skip to content

Commit 5be478b

Browse files
committed
Update on "[Executorch] Add non-flash SDPA for decode"
Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/) [ghstack-poisoned]
2 parents 056d9f3 + 4942002 commit 5be478b

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

.ci/scripts/test_lora.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start
138138
Okay, so I need to calculate 15% of 80."
139139
EXPECTED_QUANT_LORA_PREFIX="
140140
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
141-
To calculate 15% of 80, we can multiply 80 by 15/100.
142-
80 * 15/100 = 12.
143-
So, 15% of 80 is 12.
141+
To calculate 15% of 80, we can multiply 80 by 15/100 and then simplify the fraction.
142+
So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
144143
#### 12
145144
The answer is: 12<|im_end|>"
146145

146+
147147
# Export Quantized PTE, PTD file, no LoRA.
148148
# override base.lora_config=null to avoid creating a lora model
149149
# and loading lora weights.
@@ -203,7 +203,7 @@ fi
203203
NOW=$(date +"%H:%M:%S")
204204
echo "Test 4: Quantized, program-data separation lora. Starting to run llama runner at ${NOW}"
205205
# shellcheck source=/dev/null
206-
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
206+
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} --seq_len=104 > result.txt
207207
NOW=$(date +"%H:%M:%S")
208208
echo "Finished at ${NOW}"
209209

0 commit comments

Comments
 (0)