Skip to content

Commit 94892f6

Browse files
committed
[XNNPACK][Partitioner] SDPA Config
We add the SDPA Config here for partitioner. Currently there is an issue with SDPA when used from the FairSeq Multihead attention models, so I currently have it disabled for the base partitioner until we resolve that. Otherwise, for our tests, we can use the SDPA correctly from there. We have to track D60553559. Will follow up on this later. Differential Revision: [D60323285](https://our.internmc.facebook.com/intern/diff/D60323285/) [ghstack-poisoned]
1 parent c2caa04 commit 94892f6

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

backends/xnnpack/operators/op_sdpa.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import cast, Dict
7+
from typing import Dict
88

99
import torch
1010
from executorch.backends.transforms import get_shape
@@ -66,9 +66,12 @@ def define_node(
6666

6767
# Hack to broadcast the scale
6868
q_shape = get_shape(get_input_node(node, 0))
69-
scale = cast(float, node.kwargs["scale"])
69+
C = q_shape[-1]
70+
scale = 1 / (C**0.5)
71+
if "scale" in node.kwargs and node.kwargs["scale"]:
72+
scale = node.kwargs["scale"]
7073

71-
t = torch.full((q_shape[-1],), scale, dtype=mask_dtype)
74+
t = torch.full((C,), scale, dtype=mask_dtype)
7275
scale_node = self.get_fake_attr("scale", t)
7376
self.define_tensor(
7477
scale_node,

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PowConfig,
4141
QuantizedPerTensorConfig,
4242
ReLUConfig,
43+
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
4344
SigmoidConfig,
4445
SliceCopyConfig,
4546
SoftmaxConfig,
@@ -87,6 +88,7 @@
8788
PowConfig,
8889
PreluConfig,
8990
ReLUConfig,
91+
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
9092
SigmoidConfig,
9193
SliceCopyConfig,
9294
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,26 @@ class BMMConfig(GenericNodePartitionerConfig):
415415

416416
def supported_precision_types(self) -> List[ConfigPrecisionType]:
417417
return [ConfigPrecisionType.FP32]
418+
419+
420+
class SDPAConfig(GenericNodePartitionerConfig):
421+
target_name = "scaled_dot_product_attention.default"
422+
423+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
424+
"""
425+
Requires Mask to have Rank 2
426+
"""
427+
if not self.check_common_constraints(node, ep):
428+
return False
429+
430+
if len(node.all_input_nodes) < 4:
431+
return False
432+
mask_node = node.all_input_nodes[3]
433+
mask_rank = mask_node.meta["val"].dim()
434+
return mask_rank == 2
435+
436+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
437+
return torch.ops.aten.scaled_dot_product_attention.default
438+
439+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
440+
return [ConfigPrecisionType.FP32]

backends/xnnpack/test/ops/sdpa.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import Optional
99

1010
import torch
11+
from executorch.backends.xnnpack.partition.config.generic_node_configs import SDPAConfig
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import (
13+
XnnpackPartitioner,
14+
)
1115
from executorch.backends.xnnpack.test.tester import Tester
16+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
1217

1318

1419
class TestSDPA(unittest.TestCase):
@@ -61,9 +66,9 @@ def _test(self, module, inputs, atol=1e-03, rtol=1e-03):
6166
(
6267
Tester(module, inputs)
6368
.export()
64-
.to_edge()
65-
.check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 2})
66-
.partition()
69+
.to_edge_transform_and_lower(
70+
ToEdgeTransformAndLower([XnnpackPartitioner(configs=[SDPAConfig])])
71+
)
6772
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6873
.check_not(
6974
["executorch_exir_dialects_edge__ops_aten_bmm_default"],

0 commit comments

Comments
 (0)