Skip to content

Commit aa58f62

Browse files
committed
Update on "[XNNPACK][Partitioner] SDPA Config"
We add the SDPA Config here for partitioner. Currently there is an issue with SDPA when used from the FairSeq Multihead attention models, so I currently have it disabled for the base partitioner until we resolve that. Otherwise, for our tests, we can use the SDPA correctly from there. We have to track D60553559. Will follow up on this later. Differential Revision: [D60323285](https://our.internmc.facebook.com/intern/diff/D60323285/) [ghstack-poisoned]
1 parent 94892f6 commit aa58f62

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

backends/xnnpack/operators/op_sdpa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ def define_node(
6666

6767
# Hack to broadcast the scale
6868
q_shape = get_shape(get_input_node(node, 0))
69-
C = q_shape[-1]
70-
scale = 1 / (C**0.5)
69+
embedding_dim = q_shape[-1]
70+
scale = 1 / (embedding_dim**0.5)
7171
if "scale" in node.kwargs and node.kwargs["scale"]:
7272
scale = node.kwargs["scale"]
7373

74-
t = torch.full((C,), scale, dtype=mask_dtype)
74+
t = torch.full((embedding_dim,), scale, dtype=mask_dtype)
7575
scale_node = self.get_fake_attr("scale", t)
7676
self.define_tensor(
7777
scale_node,

0 commit comments

Comments
 (0)