Skip to content

Commit 8d4abd9

Browse files
authored
[XNNPACK][Partitioner] SDPA Config
Differential Revision: D60323285 Pull Request resolved: #4764
1 parent e6f5435 commit 8d4abd9

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

backends/xnnpack/operators/op_sdpa.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@ def define_node(
6666

6767
# Hack to broadcast the scale
6868
q_shape = get_shape(get_input_node(node, 0))
69-
scale = cast(float, node.kwargs["scale"])
69+
embedding_dim = q_shape[-1]
70+
scale = 1 / (embedding_dim**0.5)
71+
if "scale" in node.kwargs and node.kwargs["scale"]:
72+
scale = cast(float, node.kwargs["scale"])
7073

71-
t = torch.full((q_shape[-1],), scale, dtype=mask_dtype)
74+
t = torch.full((embedding_dim,), scale, dtype=mask_dtype)
7275
scale_node = self.get_fake_attr("scale", t)
7376
self.define_tensor(
7477
scale_node,

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PowConfig,
4141
QuantizedPerTensorConfig,
4242
ReLUConfig,
43+
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
4344
SigmoidConfig,
4445
SliceCopyConfig,
4546
SoftmaxConfig,
@@ -87,6 +88,7 @@
8788
PowConfig,
8889
PreluConfig,
8990
ReLUConfig,
91+
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
9092
SigmoidConfig,
9193
SliceCopyConfig,
9294
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,26 @@ class BMMConfig(GenericNodePartitionerConfig):
415415

416416
def supported_precision_types(self) -> List[ConfigPrecisionType]:
417417
return [ConfigPrecisionType.FP32]
418+
419+
420+
class SDPAConfig(GenericNodePartitionerConfig):
421+
target_name = "scaled_dot_product_attention.default"
422+
423+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
424+
"""
425+
Requires Mask to have Rank 2
426+
"""
427+
if not self.check_common_constraints(node, ep):
428+
return False
429+
430+
if len(node.all_input_nodes) < 4:
431+
return False
432+
mask_node = node.all_input_nodes[3]
433+
mask_rank = mask_node.meta["val"].dim()
434+
return mask_rank == 2
435+
436+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
437+
return torch.ops.aten.scaled_dot_product_attention.default
438+
439+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
440+
return [ConfigPrecisionType.FP32]

backends/xnnpack/test/ops/sdpa.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import Optional
99

1010
import torch
11+
from executorch.backends.xnnpack.partition.config.generic_node_configs import SDPAConfig
12+
from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import (
13+
XnnpackPartitioner,
14+
)
1115
from executorch.backends.xnnpack.test.tester import Tester
16+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
1217

1318

1419
class TestSDPA(unittest.TestCase):
@@ -61,9 +66,9 @@ def _test(self, module, inputs, atol=1e-03, rtol=1e-03):
6166
(
6267
Tester(module, inputs)
6368
.export()
64-
.to_edge()
65-
.check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 2})
66-
.partition()
69+
.to_edge_transform_and_lower(
70+
ToEdgeTransformAndLower([XnnpackPartitioner(configs=[SDPAConfig])])
71+
)
6772
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
6873
.check_not(
6974
["executorch_exir_dialects_edge__ops_aten_bmm_default"],

0 commit comments

Comments
 (0)