Skip to content

Arm backend: Add SDPA decomposition to annotation pipeline #10657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
)

from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
Expand Down Expand Up @@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeLayerNormPass())
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/_passes/decompose_softmax_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from executorch.exir.pass_base import ExportPass

# For BI case
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
torch_softmax = (
torch.ops.aten.softmax.int,
torch.ops.aten._safe_softmax.default,
torch.ops.aten.log_softmax.int,
)
# For MI case
edge_softmax = (
exir_ops.edge.aten._softmax.default,
Expand Down
20 changes: 13 additions & 7 deletions backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_conformer_tosa_BI(self):
)
)

@unittest.expectedFailure # TODO(MLETORCH-635)
def test_conformer_u55_BI(self):
tester = (
ArmTester(
Expand All @@ -97,13 +96,20 @@ def test_conformer_u55_BI(self):
.to_executorch()
.serialize()
)

if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(
qtol=1.0,
rtol=1.0,
atol=5.0,
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
)
try:
tester.run_method_and_compare_outputs(
qtol=1.0,
rtol=1.0,
atol=5.0,
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
)
self.fail(
"TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
)
except Exception:
pass

@unittest.expectedFailure # TODO(MLETORCH-635)
def test_conformer_u85_BI(self):
Expand Down
45 changes: 45 additions & 0 deletions backends/arm/test/ops/test_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch

from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineBI,
TosaPipelineMI,
)


class SDPA(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)


input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]


def test_sdpa_MI():
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
pipeline = TosaPipelineMI[input_t](SDPA(), test_input, [], [])
pipeline.pop_stage("check_count.exir")
pipeline.run()


def test_sdpa_BI():
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
pipeline = TosaPipelineBI[input_t](SDPA(), test_input, [], [])
pipeline.pop_stage("check.quant_nodes")
pipeline.pop_stage("check_count.exir")
pipeline.pop_stage(
"run_method_and_compare_outputs"
) # TODO: reference is not quantized
pipeline.run()
Loading