Skip to content

[ET-VK][ez] Support exporting of custom operator calls via higher_order_auto_functionalized #5884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
2 changes: 2 additions & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __contains__(self, op):
# Convolution
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
# Custom ops
"llama::sdpa_with_kv_cache",
]

NO_DYNAMIC_SHAPE = [
Expand Down
10 changes: 8 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,22 @@ def is_node_supported(
def _is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
target = node.target
if node.target == torch.ops.higher_order.auto_functionalized:
first_arg = node.args[0]
assert isinstance(first_arg, torch._ops.OpOverload)
target = first_arg.name()

if self.is_linear_permute(node):
return True

if self.is_in_local_scalar_dense_chain(node):
return True

if node.target not in VulkanSupportedOperators._ops:
if target not in VulkanSupportedOperators._ops:
return False

features = VulkanSupportedOperators._ops[node.target]
features = VulkanSupportedOperators._ops[target]

if self.require_dynamic_shapes and not features.supports_dynamic_shape:
return False
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@

from executorch.exir.program._program import _copy_module

from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
)

DEFAULT_DEBUG_HANDLE = 65535


Expand All @@ -52,6 +56,8 @@ def preprocess( # noqa: C901
program: ExportedProgram,
module_compile_spec: List[CompileSpec],
) -> PreprocessResult:
program = unsafe_remove_auto_functionalized_pass(program)

passes = [
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
Expand Down
Loading