Skip to content

Commit 9e478e8

Browse files
authored
Update fused kernels and call _safe_softmax from SDPA
Differential Revision: D61418679 Pull Request resolved: #4772
1 parent eaf383a commit 9e478e8

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,23 @@ def test_vit_skip_conv(self):
6868
)
6969
)
7070

71+
conv_block = ["aten.convolution.default", "executorch_call_delegate"]
72+
safe_softmax_block = [
73+
"getitem",
74+
"getitem",
75+
"getitem",
76+
"getitem",
77+
"aten.any.dim",
78+
"executorch_call_delegate",
79+
]
80+
final_block = ["getitem"]
81+
total = conv_block + 12 * safe_softmax_block + final_block
82+
7183
assert [
7284
node.target.__name__
7385
for node in delegated_program_manager.exported_program().graph.nodes
7486
if node.op == "call_function"
75-
] == [
76-
"aten.convolution.default",
77-
"executorch_call_delegate",
78-
"getitem",
79-
]
87+
] == total
8088

8189

8290
if __name__ == "__main__":

0 commit comments

Comments
 (0)