Skip to content

Commit bd8b985

Browse files
committed
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)"
ATT Differential Revision: [D96044308](https://our.internmc.facebook.com/intern/diff/D96044308/) [ghstack-poisoned]
2 parents 7914266 + cb3d6f5 commit bd8b985

49 files changed

Lines changed: 1218 additions & 235 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/test_lora.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start
139139
Okay, so I need to calculate 15% of 80."
140140
EXPECTED_QUANT_LORA_PREFIX="
141141
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
142-
To calculate 15% of 80, we can multiply 80 by 15/100.
143-
80 * 15/100 = 12.
144-
So, 15% of 80 is 12.
142+
To calculate 15% of 80, we can multiply 80 by 15/100 and then simplify the fraction.
143+
So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
145144
#### 12
146145
The answer is: 12<|im_end|>"
147146

147+
148148
# Export Quantized PTE, PTD file, no LoRA.
149149
# override base.lora_config=null to avoid creating a lora model
150150
# and loading lora weights.
@@ -204,7 +204,7 @@ fi
204204
NOW=$(date +"%H:%M:%S")
205205
echo "Test 4: Quantized, program-data separation lora. Starting to run llama runner at ${NOW}"
206206
# shellcheck source=/dev/null
207-
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
207+
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} --seq_len=104 > result.txt
208208
NOW=$(date +"%H:%M:%S")
209209
echo "Finished at ${NOW}"
210210

.github/workflows/_test_backend.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ on:
3636
required: false
3737
type: string
3838
default: linux.4xlarge.memory
39+
docker-image:
40+
description: 'Docker image for Linux jobs'
41+
required: false
42+
type: string
43+
default: ci-image:executorch-ubuntu-22.04-clang12
3944

4045
jobs:
4146
test-backend-linux:
@@ -50,7 +55,7 @@ jobs:
5055
with:
5156
ref: ${{ inputs.ref }}
5257
runner: ${{ inputs.runner-linux }}
53-
docker-image: ci-image:executorch-ubuntu-22.04-clang12
58+
docker-image: ${{ inputs.docker-image }}
5459
submodules: recursive
5560
timeout: ${{ inputs.timeout }}
5661
upload-artifact: test-report-${{ matrix.flow }}-${{ matrix.suite }}

.github/workflows/build-wheels-aarch64-linux.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ on:
99
- examples/**/*
1010
- pyproject.toml
1111
- setup.py
12-
tags:
13-
- ciflow/binaries/*
1412
push:
1513
branches:
1614
- nightly

.github/workflows/build-wheels-linux.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ on:
99
- examples/**/*
1010
- pyproject.toml
1111
- setup.py
12-
tags:
13-
- ciflow/binaries/*
1412
push:
1513
branches:
1614
- nightly

.github/workflows/build-wheels-macos.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ on:
99
- examples/**/*
1010
- pyproject.toml
1111
- setup.py
12-
tags:
13-
- ciflow/binaries/*
1412
push:
1513
branches:
1614
- nightly

.github/workflows/build-wheels-windows.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ on:
88
- examples/**/*
99
- pyproject.toml
1010
- setup.py
11-
tags:
12-
- ciflow/binaries/*
1311
push:
1412
branches:
1513
- nightly

.github/workflows/test-backend-arm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ jobs:
2828
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2929
timeout: 120
3030
run-linux: true
31+
docker-image: ci-image:executorch-ubuntu-22.04-arm-sdk

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
1818
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
19-
from executorch.exir.backend.utils import WhyNoPartitionReporter
2019
from executorch.exir.dialects._ops import ops as exir_ops
2120
from executorch.exir.pass_base import ExportPass
2221

@@ -51,14 +50,6 @@ def get_dynamic_meandim_decomposition(op) -> tuple:
5150
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5251

5352

54-
def get_avgpool(op):
55-
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
56-
return exir_ops.edge.aten.avg_pool2d.default
57-
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
58-
return torch.ops.aten.avg_pool2d.default
59-
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
60-
61-
6253
def get_view(op):
6354
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
6455
return exir_ops.edge.aten.view_copy.default
@@ -79,23 +70,21 @@ def get_quantization(op):
7970

8071

8172
class DecomposeMeanDimPass(ArmPass):
82-
"""Decomposes a meandim into avg_pool and/or sum + mul (1/N).
83-
84-
::
73+
"""Decomposes a meandim into sum + mul (1/N).
8574
86-
h, w -> avg_pool
87-
n, c -> sum + mul(1/N)
75+
Each reduction dimension is handled via REDUCE_SUM followed by
76+
multiplication by 1/N, which works on any axis without layout
77+
constraints (unlike AVG_POOL2D which only pools over spatial H×W).
8878
8979
For rank < 4, the input is reshaped to 4D by padding with dim=1 from the
9080
left.
9181
9282
Example:
9383
x = mean_dim(x, (0,2), keepdim=False) # x = (c,h,w)
9484
Becomes:
95-
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to work with avg_pool
96-
x = avg_pool2d.default(x, kernel=(1,w), stride=(1,1)) # Reduce w with avg_pool
97-
x = sum.dim_IntList(x, dim=1, keepdims=True) # Reduce c with sum
98-
x = mul.Tensor(x, 1/c) # Divide by number of channels to get mean
85+
x = view_copy.default(x, new_shape=(1,c,h,w)) # Reshape to 4D
86+
x = sum.dim_IntList(x, dim=(1,3), keepdims=True) # Reduce c,w with sum
87+
x = mul.Tensor(x, 1/(c*w)) # Divide by number of elements to get mean
9988
x = view_copy.default(x, new_shape=(h)) # Squeeze dims since keepdims = False
10089
10190
"""
@@ -110,14 +99,6 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
11099
super().__init__(*args, **kwargs)
111100
self._graph_module = graph_module
112101
self._tosa_spec = tosa_spec
113-
# Lazy import to avoid circular dependency with operator_support
114-
from executorch.backends.arm.operator_support.pool_2d_support import (
115-
AvgPool2dSupported,
116-
)
117-
118-
self._avg_pool_checker = AvgPool2dSupported(
119-
self._tosa_spec, WhyNoPartitionReporter()
120-
)
121102

122103
def call_operator(self, op, args, kwargs, meta, updated=False):
123104
if op not in (
@@ -168,12 +149,6 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
168149
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
169150
x = self._maybe_insert_q_dq_after(x, meta)
170151

171-
# Reduce (h,w) dims by avg pool if possible
172-
if not has_symbolic_reduce_dim:
173-
x, dims_to_reduce = self._reduce_by_average_pool(
174-
op, x, dims_to_reduce, meta
175-
)
176-
177152
# Reshape back to 5D if necessary
178153
if len(input_shape) > 4:
179154
original_dims = input_shape[:-3]
@@ -259,44 +234,6 @@ def _reduce_by_sum(self, op, input_node, dims, meta):
259234

260235
return super().call_operator(mul_op, (sum, divisor), {}, meta, True)
261236

262-
def _reduce_by_average_pool(self, op, input_node, dims, meta):
263-
dims_to_reduce_by_avgpool = [dim for dim in dims if dim >= 2]
264-
if len(dims_to_reduce_by_avgpool) == 0:
265-
return input_node, dims
266-
267-
dims_to_reduce_by_sum = [dim for dim in dims if dim < 2]
268-
269-
avgpool_op = get_avgpool(op)
270-
input_shape = input_node.data.size()
271-
272-
stride = [1, 1]
273-
if dims_to_reduce_by_avgpool in ([2, 3], [3, 2]):
274-
kernel_size = [input_shape[2], input_shape[3]]
275-
elif dims_to_reduce_by_avgpool == [3]:
276-
kernel_size = [1, input_shape[3]]
277-
elif dims_to_reduce_by_avgpool == [2]:
278-
kernel_size = [input_shape[2], 1]
279-
else:
280-
raise RuntimeError(
281-
f"Bad dims {dims_to_reduce_by_avgpool} for {op} decomposition of mean_dim."
282-
)
283-
284-
args = (input_node, kernel_size, stride)
285-
286-
avg_pool_node = self._graph_module.graph.create_node(
287-
"call_function", avgpool_op, args
288-
)
289-
is_supported = self._avg_pool_checker.is_node_tosa_supported(
290-
avg_pool_node, self._tosa_spec
291-
)
292-
293-
if is_supported:
294-
out = super().call_operator(avgpool_op, args, {}, meta, True)
295-
out = self._maybe_insert_q_dq_after(out, meta)
296-
return out, dims_to_reduce_by_sum
297-
298-
return input_node, dims
299-
300237
def _maybe_insert_q_dq_after(self, op, meta):
301238
"""If the input node of op is a dequant node, insert a q-dq pair after
302239
op with identical quantization parameters.

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
4040
return None
4141

4242

43+
def _merge_qparams(qspec_1: QuantArgs, qspec_2: QuantArgs) -> QuantArgs:
44+
"""Merge two QuantArgs when inputs are quantized differently.
45+
46+
Requires same dtype; picks the first's parameters by default.
47+
48+
"""
49+
if qspec_1.dtype != qspec_2.dtype:
50+
raise RuntimeError(
51+
f"Cannot merge qparams of different dtypes: {qspec_1.dtype} vs {qspec_2.dtype}"
52+
)
53+
return qspec_1
54+
55+
4356
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
4457
"""Get the input quantization parameters from a node, set by the
4558
'FoldAndAnnotateQParamsPass'.
@@ -121,57 +134,72 @@ def __init__(
121134
super().__init__(*args, **kwargs)
122135
self.exported_program = exported_program
123136

124-
def fold_and_annotate_arg(
125-
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
126-
) -> None:
127-
input_qparams = None
128-
nodes_to_remove = set()
137+
def _extract_input_params(
138+
self, arg_list: list[Node]
139+
) -> tuple[Optional[QuantArgs], set[Node]]:
140+
input_qparams: Optional[QuantArgs] = None
141+
nodes_to_remove: set[Node] = set()
129142
for arg in arg_list:
130143
if not isinstance(arg, Node):
131-
return
132-
133-
arg_quant_params = None
144+
return None, set()
145+
arg_quant: Optional[QuantArgs] = None
134146
if arg.target in DQ_OPS:
135147
args = arg.args
136148
scales = args[1]
137149
if (
138-
isinstance(args[1], Node)
150+
isinstance(scales, Node)
139151
and self.exported_program is not None
140-
and is_param_node(self.exported_program, args[1])
152+
and is_param_node(self.exported_program, scales)
141153
):
142-
scales = get_param_tensor(self.exported_program, args[1])
154+
scales = get_param_tensor(self.exported_program, scales)
143155
zps = args[2]
144156
if (
145-
isinstance(args[2], Node)
157+
isinstance(zps, Node)
146158
and self.exported_program is not None
147-
and is_param_node(self.exported_program, args[2])
159+
and is_param_node(self.exported_program, zps)
148160
):
149-
zps = get_param_tensor(self.exported_program, args[2])
150-
arg_quant_params = QuantArgs.from_operator(
161+
zps = get_param_tensor(self.exported_program, zps)
162+
arg_quant = QuantArgs.from_operator(
151163
arg.target, (args[0], scales, zps, *args[3:])
152164
)
153-
# add arg to nodes_to_remove to fold the dq-node
154165
nodes_to_remove.add(arg)
155-
if input_qparams is not None and input_qparams != arg_quant_params:
156-
# Two args are quantized differently
157-
raise RuntimeError("Input qparams do not match")
158-
input_qparams = arg_quant_params
159-
if input_qparams is not None:
160-
node.meta["input_qparams"][i] = input_qparams
161-
for n in nodes_to_remove:
162-
if n.target not in DQ_OPS:
163-
raise RuntimeError(
164-
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
165-
)
166+
if arg_quant is not None:
167+
if input_qparams is None:
168+
input_qparams = arg_quant
169+
elif input_qparams != arg_quant:
170+
input_qparams = _merge_qparams(input_qparams, arg_quant)
171+
return input_qparams, nodes_to_remove
172+
173+
def _annotate_input_params(
174+
self,
175+
graph_module: GraphModule,
176+
node: Node,
177+
index: int,
178+
input_qparams: QuantArgs,
179+
nodes_to_remove: set[Node],
180+
) -> None:
181+
node.meta["input_qparams"][index] = input_qparams
182+
183+
for dq in nodes_to_remove:
184+
if dq.target not in DQ_OPS:
185+
raise RuntimeError(f"Expected one of {DQ_OPS} dq_op, got {dq.target}")
186+
node.replace_input_with(dq, cast(Node, dq.args[0]))
187+
if not dq.users:
188+
graph_module.graph.erase_node(dq)
189+
190+
special = _get_special_dtype(input_qparams)
191+
if special:
192+
node.all_input_nodes[index].meta[TosaSpecialDtype.meta_key()] = special
166193

167-
node.replace_input_with(n, cast(Node, n.args[0]))
168-
if len(n.users) == 0:
169-
graph_module.graph.erase_node(n)
170-
special_dtype = _get_special_dtype(input_qparams)
171-
if special_dtype:
172-
node.all_input_nodes[i].meta[
173-
TosaSpecialDtype.meta_key()
174-
] = special_dtype
194+
def fold_and_annotate_arg(
195+
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
196+
) -> None:
197+
input_qparams, nodes_to_remove = self._extract_input_params(arg_list)
198+
if input_qparams is None:
199+
return
200+
self._annotate_input_params(
201+
graph_module, node, i, input_qparams, nodes_to_remove
202+
)
175203

176204
def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
177205
"""Fold outmost quant nodes inside submodule.

backends/arm/_passes/normalize_while_initial_args_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -82,6 +82,8 @@ def _normalize_node(self, graph_module: GraphModule, node: Node) -> bool:
8282
new_carried = tuple(carried_inputs + additional_inputs)
8383
node.update_arg(2, new_carried)
8484
node.update_arg(3, ())
85+
# annotate node so later keying of captured vs loop‐carried args is possible
86+
node.meta["additional_inputs"] = additional_inputs
8587

8688
body_module_name = str(cast(Node, node.args[1]).target)
8789
body_module = cast(GraphModule, graph_module.get_submodule(body_module_name)) # type: ignore

0 commit comments

Comments
 (0)