Skip to content

[XNNPACK][Partitioner] SDPA Config #4764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backends/xnnpack/operators/op_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ def define_node(

# Hack to broadcast the scale
q_shape = get_shape(get_input_node(node, 0))
scale = cast(float, node.kwargs["scale"])
embedding_dim = q_shape[-1]
scale = 1 / (embedding_dim**0.5)
if "scale" in node.kwargs and node.kwargs["scale"]:
scale = cast(float, node.kwargs["scale"])

t = torch.full((q_shape[-1],), scale, dtype=mask_dtype)
t = torch.full((embedding_dim,), scale, dtype=mask_dtype)
scale_node = self.get_fake_attr("scale", t)
self.define_tensor(
scale_node,
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
PowConfig,
QuantizedPerTensorConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down Expand Up @@ -87,6 +88,7 @@
PowConfig,
PreluConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down
23 changes: 23 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,26 @@ class BMMConfig(GenericNodePartitionerConfig):

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class SDPAConfig(GenericNodePartitionerConfig):
target_name = "scaled_dot_product_attention.default"

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Requires Mask to have Rank 2
"""
if not self.check_common_constraints(node, ep):
return False

if len(node.all_input_nodes) < 4:
return False
mask_node = node.all_input_nodes[3]
mask_rank = mask_node.meta["val"].dim()
return mask_rank == 2

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.scaled_dot_product_attention.default

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]
11 changes: 8 additions & 3 deletions backends/xnnpack/test/ops/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from typing import Optional

import torch
from executorch.backends.xnnpack.partition.config.generic_node_configs import SDPAConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import (
XnnpackPartitioner,
)
from executorch.backends.xnnpack.test.tester import Tester
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower


class TestSDPA(unittest.TestCase):
Expand Down Expand Up @@ -61,9 +66,9 @@ def _test(self, module, inputs, atol=1e-03, rtol=1e-03):
(
Tester(module, inputs)
.export()
.to_edge()
.check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 2})
.partition()
.to_edge_transform_and_lower(
ToEdgeTransformAndLower([XnnpackPartitioner(configs=[SDPAConfig])])
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
["executorch_exir_dialects_edge__ops_aten_bmm_default"],
Expand Down
Loading