Skip to content

Commit 7574954

Browse files
permansnils
authored andcommitted
Arm backend: Add SDPA decomposition to annotation pipeline
Adds SDPA decomposition and since SDPA is decomposed to safe_softmax instead of regular softmax, softmax decomposition is updated to handle safe_softmax. Co-authored-by: Per Åstrand <[email protected]> Co-authored-by: Måns Nilsson <[email protected]> Change-Id: I384034fee8bd372d224e438adef5d3b0d1ec3ee3
1 parent 12ed924 commit 7574954

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
)
6060

6161
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
62+
from executorch.backends.transforms.decompose_sdpa import (
63+
DecomposeScaledDotProductAttention,
64+
)
6265
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
6366
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6467
from executorch.exir import ExportedProgram
@@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
194197
)
195198

196199
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200+
self.add_pass(DecomposeScaledDotProductAttention())
197201
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
198202
self.add_pass(ScalarsToAttributePass())
199203
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_softmax_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from executorch.exir.pass_base import ExportPass
99

1010
# For BI case
11-
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
11+
torch_softmax = (
12+
torch.ops.aten.softmax.int,
13+
torch.ops.aten._safe_softmax.default,
14+
torch.ops.aten.log_softmax.int,
15+
)
1216
# For MI case
1317
edge_softmax = (
1418
exir_ops.edge.aten._softmax.default,

backends/arm/test/models/test_conformer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_conformer_tosa_BI(self):
8383
)
8484
)
8585

86-
@unittest.expectedFailure # TODO(MLETORCH-635)
8786
def test_conformer_u55_BI(self):
8887
tester = (
8988
ArmTester(
@@ -97,13 +96,20 @@ def test_conformer_u55_BI(self):
9796
.to_executorch()
9897
.serialize()
9998
)
99+
100100
if conftest.is_option_enabled("corstone_fvp"):
101-
tester.run_method_and_compare_outputs(
102-
qtol=1.0,
103-
rtol=1.0,
104-
atol=5.0,
105-
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
106-
)
101+
try:
102+
tester.run_method_and_compare_outputs(
103+
qtol=1.0,
104+
rtol=1.0,
105+
atol=5.0,
106+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
107+
)
108+
self.fail(
109+
"TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
110+
)
111+
except Exception:
112+
pass
107113

108114
@unittest.expectedFailure # TODO(MLETORCH-635)
109115
def test_conformer_u85_BI(self):

backends/arm/test/ops/test_sdpa.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
17+
class SDPA(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, query, key, value):
22+
return torch.nn.functional.scaled_dot_product_attention(
23+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
24+
)
25+
26+
27+
input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
28+
29+
30+
def test_sdpa_MI():
31+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
32+
pipeline = TosaPipelineMI[input_t](SDPA(), test_input, [], [])
33+
pipeline.pop_stage("check_count.exir")
34+
pipeline.run()
35+
36+
37+
def test_sdpa_BI():
38+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
39+
pipeline = TosaPipelineBI[input_t](SDPA(), test_input, [], [])
40+
pipeline.pop_stage("check.quant_nodes")
41+
pipeline.pop_stage("check_count.exir")
42+
pipeline.pop_stage(
43+
"run_method_and_compare_outputs"
44+
) # TODO: reference is not quantized
45+
pipeline.run()

0 commit comments

Comments
 (0)