Skip to content

Commit 662f285

Browse files
committed
Update on "[XNNPACK][Partitioner] SDPA Config"
We add the SDPA Config here for partitioner. Currently there is an issue with SDPA when used from the FairSeq Multihead attention models, so I currently have it disabled for the base partitioner until we resolve that. Otherwise, for our tests, we can use the SDPA correctly from there. We have to track D60553559. Will follow up on this later. Differential Revision: [D60323285](https://our.internmc.facebook.com/intern/diff/D60323285/) [ghstack-poisoned]
2 parents 3c540f5 + e6f5435 commit 662f285

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

backends/xnnpack/operators/op_sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import cast, Dict
88

99
import torch
1010
from executorch.backends.transforms import get_shape
@@ -69,7 +69,7 @@ def define_node(
6969
embedding_dim = q_shape[-1]
7070
scale = 1 / (embedding_dim**0.5)
7171
if "scale" in node.kwargs and node.kwargs["scale"]:
72-
scale = node.kwargs["scale"]
72+
scale = cast(float, node.kwargs["scale"])
7373

7474
t = torch.full((embedding_dim,), scale, dtype=mask_dtype)
7575
scale_node = self.get_fake_attr("scale", t)

0 commit comments

Comments
 (0)