Skip to content

Commit 5777ad3

Browse files
committed
[ET-VK][ez] Support exporting of custom operator calls via higher_order_auto_functionalized, checkpoint
As title. This diff adds the ability to partition custom op calls to the Vulkan delegate. Differential Revision: [D63913434](https://our.internmc.facebook.com/intern/diff/D63913434/) [ghstack-poisoned]
1 parent 20a157f commit 5777ad3

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,16 @@ def register_dynamic_shape_ops(ops: OpList):
169169
ops[op].supports_dynamic_shape = True
170170

171171

172+
def register_custom_ops(ops: OpList):
173+
for op in CUSTOM_OPS:
174+
ops[op].supports_dynamic_shape = True
175+
ops[op].supports_texture = True
176+
177+
172178
def enumerate_supported_ops():
173179
ops = OpList()
174180
register_prim_ops(ops)
175181
register_no_dynamic_shape_ops(ops)
176182
register_dynamic_shape_ops(ops)
183+
register_custom_ops(ops)
177184
return ops

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,19 @@ def is_node_supported(
119119
def _is_node_supported(
120120
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
121121
) -> bool:
122+
target = node.target
123+
if node.target == torch.ops.higher_order.auto_functionalized:
124+
first_arg = node.args[0]
125+
assert isinstance(first_arg, torch._ops.OpOverload)
126+
target = first_arg.name()
127+
122128
if self.is_linear_permute(node):
123129
return True
124130

125-
if node.target not in VulkanSupportedOperators._ops:
131+
if target not in VulkanSupportedOperators._ops:
126132
return False
127133

128-
features = VulkanSupportedOperators._ops[node.target]
134+
features = VulkanSupportedOperators._ops[target]
129135

130136
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
131137
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636

3737
from executorch.exir.program._program import _copy_module
3838

39+
from torch.export._remove_auto_functionalized_pass import (
40+
unsafe_remove_auto_functionalized_pass,
41+
)
42+
3943
DEFAULT_DEBUG_HANDLE = 65535
4044

4145

@@ -48,6 +52,8 @@ def preprocess( # noqa: C901
4852
program: ExportedProgram,
4953
module_compile_spec: List[CompileSpec],
5054
) -> PreprocessResult:
55+
program = unsafe_remove_auto_functionalized_pass(program)
56+
5157
passes = [
5258
RemoveCloneOpsTransform(),
5359
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)