Skip to content

Commit 94289ad

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Clean up organization of supported_ops (#5885)
Summary: Pull Request resolved: #5885 As title. Group supported ops by features instead of op category. This will make it easier to mark that an op has increased its feature set. This also allows the registration code to be simplified a lot. ghstack-source-id: 246400773 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D63913433 fbshipit-source-id: 5ff1919d2f1201363d87ad1ecbecb5dbf574ec42
1 parent 84498b2 commit 94289ad

File tree

1 file changed

+25
-59
lines changed

1 file changed

+25
-59
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,16 @@ def __contains__(self, op):
4747
operator.getitem,
4848
]
4949

50-
BINARY_OPS = [
50+
SUPPORTS_DYNAMIC_SHAPE = [
51+
# Binary broadcasting
5152
exir_ops.edge.aten.add.Tensor,
5253
exir_ops.edge.aten.sub.Tensor,
5354
exir_ops.edge.aten.minimum.default,
5455
exir_ops.edge.aten.mul.Tensor,
5556
exir_ops.edge.aten.div.Tensor,
5657
exir_ops.edge.aten.div.Tensor_mode,
5758
exir_ops.edge.aten.pow.Tensor_Tensor,
58-
]
59-
60-
UNARY_OPS = [
59+
# Unary elementwise
6160
exir_ops.edge.aten.abs.default,
6261
exir_ops.edge.aten.clamp.default,
6362
exir_ops.edge.aten.cos.default,
@@ -71,60 +70,46 @@ def __contains__(self, op):
7170
exir_ops.edge.aten.sin.default,
7271
exir_ops.edge.aten.sqrt.default,
7372
exir_ops.edge.aten.tanh.default,
74-
]
75-
76-
MATMUL_OPS = [
73+
# Matrix Multiplication
7774
exir_ops.edge.aten.bmm.default,
7875
exir_ops.edge.aten.mm.default,
7976
exir_ops.edge.aten.addmm.default,
8077
exir_ops.edge.aten.linear.default,
81-
]
82-
83-
POOLING_OPS = [
78+
# Reduction
79+
exir_ops.edge.aten._log_softmax.default,
80+
exir_ops.edge.aten._softmax.default,
81+
# 2D Pooling
8482
exir_ops.edge.aten.avg_pool2d.default,
8583
exir_ops.edge.aten.max_pool2d_with_indices.default,
86-
]
87-
88-
CONVOLUTION_OPS = [
84+
# Convolution
8985
exir_ops.edge.aten.convolution.default,
9086
exir_ops.edge.et_vk.conv_with_clamp.default,
9187
]
9288

93-
REDUCTION_OPS = [
89+
NO_DYNAMIC_SHAPE = [
90+
# Reduction
9491
exir_ops.edge.aten.mean.dim,
9592
exir_ops.edge.aten.sum.dim_IntList,
96-
exir_ops.edge.aten._log_softmax.default,
97-
exir_ops.edge.aten._softmax.default,
98-
]
99-
100-
NORMALIZATION_OPS = [
93+
# Normalization
10194
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
10295
exir_ops.edge.aten.native_layer_norm.default,
103-
]
104-
105-
SHAPE_MANIPULATION_OPS = [
96+
# Shape Manipulation
10697
exir_ops.edge.aten.squeeze_copy.dims,
10798
exir_ops.edge.aten.unsqueeze_copy.default,
10899
exir_ops.edge.aten.view_copy.default,
109100
exir_ops.edge.aten.permute_copy.default,
110101
exir_ops.edge.aten.t_copy.default,
111-
]
112-
113-
INDEXING_OPS = [
102+
# Indexing and lookup
114103
exir_ops.edge.aten.embedding.default,
115104
exir_ops.edge.aten.index_select.default,
116105
exir_ops.edge.aten.select_copy.int,
117106
exir_ops.edge.aten.slice_copy.Tensor,
118-
]
119-
120-
ORCHESTRATION_OPS = [
107+
# Tensor combination
121108
exir_ops.edge.aten.cat.default,
122109
exir_ops.edge.aten.split_with_sizes_copy.default,
123110
exir_ops.edge.aten.split.Tensor,
124111
exir_ops.edge.aten.repeat.default,
125-
]
126-
127-
CREATION_OPS = [
112+
# Tensor creation
128113
exir_ops.edge.aten.arange.start_step,
129114
exir_ops.edge.aten.clone.default,
130115
exir_ops.edge.aten.constant_pad_nd.default,
@@ -139,39 +124,20 @@ def __contains__(self, op):
139124
]
140125

141126

142-
def register_prim_ops(ops: OpList):
143-
for op in PRIM_OPS:
144-
ops[op].supports_texture = True
145-
ops[op].supports_buffer = True
146-
ops[op].supports_dynamic_shape = True
127+
def enumerate_supported_ops():
128+
ops = OpList()
147129

130+
# Register in order of least to most capabilities
148131

149-
def register_no_dynamic_shape_ops(ops: OpList):
150-
for op in [
151-
*REDUCTION_OPS,
152-
*NORMALIZATION_OPS,
153-
*SHAPE_MANIPULATION_OPS,
154-
*INDEXING_OPS,
155-
*ORCHESTRATION_OPS,
156-
*CREATION_OPS,
157-
]:
132+
for op in NO_DYNAMIC_SHAPE:
158133
ops[op].supports_dynamic_shape = False
159134

160-
161-
def register_dynamic_shape_ops(ops: OpList):
162-
for op in [
163-
*BINARY_OPS,
164-
*UNARY_OPS,
165-
*MATMUL_OPS,
166-
*POOLING_OPS,
167-
*CONVOLUTION_OPS,
168-
]:
135+
for op in SUPPORTS_DYNAMIC_SHAPE:
169136
ops[op].supports_dynamic_shape = True
170137

138+
for op in PRIM_OPS:
139+
ops[op].supports_texture = True
140+
ops[op].supports_buffer = True
141+
ops[op].supports_dynamic_shape = True
171142

172-
def enumerate_supported_ops():
173-
ops = OpList()
174-
register_prim_ops(ops)
175-
register_no_dynamic_shape_ops(ops)
176-
register_dynamic_shape_ops(ops)
177143
return ops

0 commit comments

Comments
 (0)