Skip to content

Commit fda6fe4

Browse files
committed
Update on "[ET-VK][ez] Clean up organization of supported_ops"
As title. Group supported ops by features instead of op category. This will make it easier to mark that an op has increased its feature set. This also allows the registration code to be simplified a lot. Differential Revision: [D63913433](https://our.internmc.facebook.com/intern/diff/D63913433/) [ghstack-poisoned]
2 parents ae48f99 + 22390bf commit fda6fe4

File tree

3 files changed

+2
-16
lines changed

3 files changed

+2
-16
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def __contains__(self, op):
8484
# Convolution ops
8585
exir_ops.edge.aten.convolution.default,
8686
exir_ops.edge.et_vk.conv_with_clamp.default,
87-
# Custom ops
88-
"llama::sdpa_with_kv_cache",
8987
]
9088

9189
NO_DYNAMIC_SHAPE = [

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,13 @@ def is_node_supported(
119119
def _is_node_supported(
120120
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
121121
) -> bool:
122-
target = node.target
123-
if node.target == torch.ops.higher_order.auto_functionalized:
124-
first_arg = node.args[0]
125-
assert isinstance(first_arg, torch._ops.OpOverload)
126-
target = first_arg.name()
127-
128122
if self.is_linear_permute(node):
129123
return True
130124

131-
if target not in VulkanSupportedOperators._ops:
125+
if node.target not in VulkanSupportedOperators._ops:
132126
return False
133127

134-
features = VulkanSupportedOperators._ops[target]
128+
features = VulkanSupportedOperators._ops[node.target]
135129

136130
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
137131
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636

3737
from executorch.exir.program._program import _copy_module
3838

39-
from torch.export._remove_auto_functionalized_pass import (
40-
unsafe_remove_auto_functionalized_pass,
41-
)
42-
4339
DEFAULT_DEBUG_HANDLE = 65535
4440

4541

@@ -52,8 +48,6 @@ def preprocess( # noqa: C901
5248
program: ExportedProgram,
5349
module_compile_spec: List[CompileSpec],
5450
) -> PreprocessResult:
55-
program = unsafe_remove_auto_functionalized_pass(program)
56-
5751
passes = [
5852
RemoveCloneOpsTransform(),
5953
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)