Skip to content

Commit f93a5b5

Browse files
authored
[XNNPACK][Partitioner] SDPA Config (#4797)
We add the SDPA Config here for partitioner. Currently there is an issue with SDPA when used from the FairSeq Multihead attention models, so I currently have it disabled for the base partitioner until we resolve that. Otherwise, for our tests, we can use the SDPA correctly from there. We have to track D60553559. Will follow up on this later. Differential Revision: [D60323285](https://our.internmc.facebook.com/intern/diff/D60323285/) Co-authored-by: Max Ren <[email protected]> Pull Request resolved: #4764
1 parent 7a2d885 commit f93a5b5

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

backends/xnnpack/operators/op_sdpa.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@ def define_node(
6666

6767
# Hack to broadcast the scale
6868
q_shape = get_shape(get_input_node(node, 0))
69-
scale = cast(float, node.kwargs["scale"])
69+
embedding_dim = q_shape[-1]
70+
scale = 1 / (embedding_dim**0.5)
71+
if "scale" in node.kwargs and node.kwargs["scale"]:
72+
scale = cast(float, node.kwargs["scale"])
7073

71-
t = torch.full((q_shape[-1],), scale, dtype=mask_dtype)
74+
t = torch.full((embedding_dim,), scale, dtype=mask_dtype)
7275
scale_node = self.get_fake_attr("scale", t)
7376
self.define_tensor(
7477
scale_node,

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PowConfig,
4141
QuantizedPerTensorConfig,
4242
ReLUConfig,
43+
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
4344
SigmoidConfig,
4445
SliceCopyConfig,
4546
SoftmaxConfig,
@@ -87,6 +88,7 @@
8788
PowConfig,
8889
PreluConfig,
8990
ReLUConfig,
91+
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
9092
SigmoidConfig,
9193
SliceCopyConfig,
9294
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,26 @@ class BMMConfig(GenericNodePartitionerConfig):
415415

416416
def supported_precision_types(self) -> List[ConfigPrecisionType]:
417417
return [ConfigPrecisionType.FP32]
418+
419+
420+
class SDPAConfig(GenericNodePartitionerConfig):
421+
target_name = "scaled_dot_product_attention.default"
422+
423+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
424+
"""
425+
Requires Mask to have Rank 2
426+
"""
427+
if not self.check_common_constraints(node, ep):
428+
return False
429+
430+
if len(node.all_input_nodes) < 4:
431+
return False
432+
mask_node = node.all_input_nodes[3]
433+
mask_rank = mask_node.meta["val"].dim()
434+
return mask_rank == 2
435+
436+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
437+
return torch.ops.aten.scaled_dot_product_attention.default
438+
439+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
440+
return [ConfigPrecisionType.FP32]

backends/xnnpack/test/ops/sdpa.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import Optional
99

1010
import torch
11+
from executorch.backends.xnnpack.partition.config.generic_node_configs import SDPAConfig
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import (
13+
XnnpackPartitioner,
14+
)
1115
from executorch.backends.xnnpack.test.tester import Tester
16+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
1217

1318

1419
class TestSDPA(unittest.TestCase):
@@ -61,9 +66,9 @@ def _test(self, module, inputs, atol=1e-03, rtol=1e-03):
6166
(
6267
Tester(module, inputs)
6368
.export()
64-
.to_edge()
65-
.check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 2})
66-
.partition()
69+
.to_edge_transform_and_lower(
70+
ToEdgeTransformAndLower([XnnpackPartitioner(configs=[SDPAConfig])])
71+
)
6772
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6873
.check_not(
6974
["executorch_exir_dialects_edge__ops_aten_bmm_default"],

0 commit comments

Comments
 (0)