Skip to content

Commit 78b9c8c

Browse files
authored
Arm backend: Fix index_put broadcasted indices in Arm pass (#17247)
- Compute W from broadcasted index shapes in RewriteIndexPutPas - Add broadcast_indices_mismatch test for index_put broadcasting Change-Id: I1a85889754bc46807bbc3fc41846ae89ac48abd1 Signed-off-by: Rob Elliott <Robert.Elliott@arm.com>
1 parent 5cba889 commit 78b9c8c

2 files changed

Lines changed: 31 additions & 1 deletion

File tree

backends/arm/_passes/rewrite_index_put_pass.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import math
77
from typing import Any, Iterable, List, Sequence, Set, Type
88

9+
import torch
10+
911
from executorch.backends.arm._passes import ArmPass
1012
from executorch.backends.arm._passes.fuse_view_copy_transform_pass import (
1113
FuseViewCopyTransformPass,
@@ -189,9 +191,17 @@ def call_operator(self, op, args, kwargs, meta):
189191
processed_indices = self._expand_none_indices(
190192
source_shape, indices, plain_meta, full_op
191193
)
194+
index_shapes = [tuple(idx.data.shape) for idx in processed_indices]
195+
try:
196+
broadcast_shape = torch.broadcast_shapes(*index_shapes)
197+
except Exception as exc:
198+
raise RuntimeError(
199+
"RewriteIndexPutPass: failed to broadcast index shapes %s: %s"
200+
% (index_shapes, exc)
201+
) from exc
192202

193203
N, K, W, C = calculate_tosa_values(
194-
list(processed_indices[0].data.shape),
204+
list(broadcast_shape),
195205
[idx.node for idx in processed_indices],
196206
source_shape,
197207
)
@@ -204,6 +214,13 @@ def call_operator(self, op, args, kwargs, meta):
204214
full_op,
205215
plain_meta,
206216
)
217+
idx_shape = list(indices_reshaped.data.shape)
218+
idx_numel = math.prod(idx_shape)
219+
if idx_numel != N * W:
220+
raise RuntimeError(
221+
"RewriteIndexPutPass: flat index numel (%s) does not match expected N*W (%s)"
222+
% (idx_numel, N * W)
223+
)
207224

208225
# Scatter expects a 3D layout; flatten everything into [N, K, C].
209226
reshape_indices = super().call_operator(

backends/arm/test/ops/test_index_put.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,19 @@
147147
),
148148
0, # used for u55 tests to config n_expected_delgates, only 1 when accumulate is True
149149
),
150+
"broadcast_indices_mismatch": (
151+
lambda: (
152+
torch.zeros((1, 2, 3, 4), dtype=torch.float32),
153+
(
154+
torch.tensor([0], dtype=torch.int64),
155+
torch.tensor([0, 1], dtype=torch.int64),
156+
torch.tensor([1], dtype=torch.int64),
157+
),
158+
torch.randn((2, 4), dtype=torch.float32),
159+
False,
160+
),
161+
0,
162+
),
150163
}
151164
test_data_suite_bf16 = {
152165
"rank3_rand_bf16": (

0 commit comments

Comments
 (0)