Skip to content

Commit 22390bf

Browse files
committed
Update base for Update on "[ET-VK][ez] Clean up organization of supported_ops"
As title. Group supported ops by features instead of op category. This will make it easier to mark that an op has increased its feature set. This also allows the registration code to be simplified a lot. Differential Revision: [D63913433](https://our.internmc.facebook.com/intern/diff/D63913433/) [ghstack-poisoned]
1 parent 5777ad3 commit 22390bf

File tree

3 files changed

+2
-21
lines changed

3 files changed

+2
-21
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,9 @@ def register_dynamic_shape_ops(ops: OpList):
169169
ops[op].supports_dynamic_shape = True
170170

171171

172-
def register_custom_ops(ops: OpList):
173-
for op in CUSTOM_OPS:
174-
ops[op].supports_dynamic_shape = True
175-
ops[op].supports_texture = True
176-
177-
178172
def enumerate_supported_ops():
179173
ops = OpList()
180174
register_prim_ops(ops)
181175
register_no_dynamic_shape_ops(ops)
182176
register_dynamic_shape_ops(ops)
183-
register_custom_ops(ops)
184177
return ops

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,13 @@ def is_node_supported(
119119
def _is_node_supported(
120120
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
121121
) -> bool:
122-
target = node.target
123-
if node.target == torch.ops.higher_order.auto_functionalized:
124-
first_arg = node.args[0]
125-
assert isinstance(first_arg, torch._ops.OpOverload)
126-
target = first_arg.name()
127-
128122
if self.is_linear_permute(node):
129123
return True
130124

131-
if target not in VulkanSupportedOperators._ops:
125+
if node.target not in VulkanSupportedOperators._ops:
132126
return False
133127

134-
features = VulkanSupportedOperators._ops[target]
128+
features = VulkanSupportedOperators._ops[node.target]
135129

136130
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
137131
return False

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636

3737
from executorch.exir.program._program import _copy_module
3838

39-
from torch.export._remove_auto_functionalized_pass import (
40-
unsafe_remove_auto_functionalized_pass,
41-
)
42-
4339
DEFAULT_DEBUG_HANDLE = 65535
4440

4541

@@ -52,8 +48,6 @@ def preprocess( # noqa: C901
5248
program: ExportedProgram,
5349
module_compile_spec: List[CompileSpec],
5450
) -> PreprocessResult:
55-
program = unsafe_remove_auto_functionalized_pass(program)
56-
5751
passes = [
5852
RemoveCloneOpsTransform(),
5953
AddmmToLinearTransform(),

0 commit comments

Comments
 (0)