Skip to content

Commit c1c5eb5

Browse files
committed
[ET-VK][ez] Support exporting of custom operator calls via higher_order_auto_functionalized
Pull Request resolved: #5884 As title. This diff adds the ability to partition custom op calls to the Vulkan delegate. ghstack-source-id: 246380974 @exported-using-ghexport Differential Revision: [D63913434](https://our.internmc.facebook.com/intern/diff/D63913434/)
1 parent 34e22d1 commit c1c5eb5

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def __contains__(self, op):
8484
# Convolution ops
8585
exir_ops.edge.aten.convolution.default,
8686
exir_ops.edge.et_vk.conv_with_clamp.default,
87+
# Custom ops
88+
"llama::sdpa_with_kv_cache",
8789
]
8890

8991
NO_DYNAMIC_SHAPE = [

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,22 @@ def is_node_supported(
143143
def _is_node_supported(
144144
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
145145
) -> bool:
146+
target = node.target
147+
if node.target == torch.ops.higher_order.auto_functionalized:
148+
first_arg = node.args[0]
149+
assert isinstance(first_arg, torch._ops.OpOverload)
150+
target = first_arg.name()
151+
146152
if self.is_linear_permute(node):
147153
return True
148154

149155
if self.is_in_local_scalar_dense_chain(node):
150156
return True
151157

152-
if node.target not in VulkanSupportedOperators._ops:
158+
if target not in VulkanSupportedOperators._ops:
153159
return False
154160

155-
features = VulkanSupportedOperators._ops[node.target]
161+
features = VulkanSupportedOperators._ops[target]
156162

157163
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
158164
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040

4141
from executorch.exir.program._program import _copy_module
4242

43+
from torch.export._remove_auto_functionalized_pass import (
44+
unsafe_remove_auto_functionalized_pass,
45+
)
46+
4347
DEFAULT_DEBUG_HANDLE = 65535
4448

4549

@@ -52,6 +56,8 @@ def preprocess( # noqa: C901
5256
program: ExportedProgram,
5357
module_compile_spec: List[CompileSpec],
5458
) -> PreprocessResult:
59+
program = unsafe_remove_auto_functionalized_pass(program)
60+
5561
passes = [
5662
RemoveCloneOpsTransform(),
5763
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)