Skip to content

Commit a4b88a3

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Support exporting of custom operator calls via higher_order_auto_functionalized (#5884)
Summary: Pull Request resolved: #5884 As title. This diff adds the ability to partition custom op calls to the Vulkan delegate. ghstack-source-id: 246752222 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D63913434 fbshipit-source-id: 7ca6cbef461265f59e48f5ce7110ef4c08a6a534
1 parent 400fefa commit a4b88a3

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __contains__(self, op):
8484
# Convolution
8585
exir_ops.edge.aten.convolution.default,
8686
exir_ops.edge.et_vk.conv_with_clamp.default,
87+
# Custom ops
88+
"llama::sdpa_with_kv_cache",
8789
]
8890

8991
NO_DYNAMIC_SHAPE = [

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,22 @@ def is_node_supported(
144144
def _is_node_supported(
145145
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
146146
) -> bool:
147+
target = node.target
148+
if node.target == torch.ops.higher_order.auto_functionalized:
149+
first_arg = node.args[0]
150+
assert isinstance(first_arg, torch._ops.OpOverload)
151+
target = first_arg.name()
152+
147153
if self.is_linear_permute(node):
148154
return True
149155

150156
if self.is_in_local_scalar_dense_chain(node):
151157
return True
152158

153-
if node.target not in VulkanSupportedOperators._ops:
159+
if target not in VulkanSupportedOperators._ops:
154160
return False
155161

156-
features = VulkanSupportedOperators._ops[node.target]
162+
features = VulkanSupportedOperators._ops[target]
157163

158164
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
159165
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040

4141
from executorch.exir.program._program import _copy_module
4242

43+
from torch.export._remove_auto_functionalized_pass import (
44+
unsafe_remove_auto_functionalized_pass,
45+
)
46+
4347
DEFAULT_DEBUG_HANDLE = 65535
4448

4549

@@ -52,6 +56,8 @@ def preprocess( # noqa: C901
5256
program: ExportedProgram,
5357
module_compile_spec: List[CompileSpec],
5458
) -> PreprocessResult:
59+
program = unsafe_remove_auto_functionalized_pass(program)
60+
5561
passes = [
5662
RemoveCloneOpsTransform(),
5763
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)