Skip to content

Commit ae48f99

Browse files
committed
[ET-VK][ez] Clean up organization of supported_ops
As title. Group supported ops by features instead of op category. This will make it easier to mark that an op has increased its feature set. This also allows the registration code to be simplified a lot. Differential Revision: [D63913433](https://our.internmc.facebook.com/intern/diff/D63913433/) [ghstack-poisoned]
1 parent 5777ad3 commit ae48f99

File tree

1 file changed

+26
-65
lines changed

1 file changed

+26
-65
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 26 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,16 @@ def __contains__(self, op):
4747
operator.getitem,
4848
]
4949

50-
BINARY_OPS = [
50+
SUPPORTS_DYNAMIC_SHAPE = [
51+
# Binary broadcasting operators
5152
exir_ops.edge.aten.add.Tensor,
5253
exir_ops.edge.aten.sub.Tensor,
5354
exir_ops.edge.aten.minimum.default,
5455
exir_ops.edge.aten.mul.Tensor,
5556
exir_ops.edge.aten.div.Tensor,
5657
exir_ops.edge.aten.div.Tensor_mode,
5758
exir_ops.edge.aten.pow.Tensor_Tensor,
58-
]
59-
60-
UNARY_OPS = [
59+
# Unary elementwise operators
6160
exir_ops.edge.aten.abs.default,
6261
exir_ops.edge.aten.clamp.default,
6362
exir_ops.edge.aten.cos.default,
@@ -71,60 +70,48 @@ def __contains__(self, op):
7170
exir_ops.edge.aten.sin.default,
7271
exir_ops.edge.aten.sqrt.default,
7372
exir_ops.edge.aten.tanh.default,
74-
]
75-
76-
MATMUL_OPS = [
73+
# Matrix Multiplication Operators
7774
exir_ops.edge.aten.bmm.default,
7875
exir_ops.edge.aten.mm.default,
7976
exir_ops.edge.aten.addmm.default,
8077
exir_ops.edge.aten.linear.default,
81-
]
82-
83-
POOLING_OPS = [
78+
# Reduction operators
79+
exir_ops.edge.aten._log_softmax.default,
80+
exir_ops.edge.aten._softmax.default,
81+
# 2D Pooling ops
8482
exir_ops.edge.aten.avg_pool2d.default,
8583
exir_ops.edge.aten.max_pool2d_with_indices.default,
86-
]
87-
88-
CONVOLUTION_OPS = [
84+
# Convolution ops
8985
exir_ops.edge.aten.convolution.default,
9086
exir_ops.edge.et_vk.conv_with_clamp.default,
87+
# Custom ops
88+
"llama::sdpa_with_kv_cache",
9189
]
9290

93-
REDUCTION_OPS = [
91+
NO_DYNAMIC_SHAPE = [
92+
# Reduction operators
9493
exir_ops.edge.aten.mean.dim,
9594
exir_ops.edge.aten.sum.dim_IntList,
96-
exir_ops.edge.aten._log_softmax.default,
97-
exir_ops.edge.aten._softmax.default,
98-
]
99-
100-
NORMALIZATION_OPS = [
95+
# Normalization operators
10196
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
10297
exir_ops.edge.aten.native_layer_norm.default,
103-
]
104-
105-
SHAPE_MANIPULATION_OPS = [
98+
# Shape Manipulation operators
10699
exir_ops.edge.aten.squeeze_copy.dims,
107100
exir_ops.edge.aten.unsqueeze_copy.default,
108101
exir_ops.edge.aten.view_copy.default,
109102
exir_ops.edge.aten.permute_copy.default,
110103
exir_ops.edge.aten.t_copy.default,
111-
]
112-
113-
INDEXING_OPS = [
104+
# Indexing and lookup operators
114105
exir_ops.edge.aten.embedding.default,
115106
exir_ops.edge.aten.index_select.default,
116107
exir_ops.edge.aten.select_copy.int,
117108
exir_ops.edge.aten.slice_copy.Tensor,
118-
]
119-
120-
ORCHESTRATION_OPS = [
109+
# Tensor combination operators
121110
exir_ops.edge.aten.cat.default,
122111
exir_ops.edge.aten.split_with_sizes_copy.default,
123112
exir_ops.edge.aten.split.Tensor,
124113
exir_ops.edge.aten.repeat.default,
125-
]
126-
127-
CREATION_OPS = [
114+
# Tensor creation operators
128115
exir_ops.edge.aten.arange.start_step,
129116
exir_ops.edge.aten.clone.default,
130117
exir_ops.edge.aten.constant_pad_nd.default,
@@ -139,46 +126,20 @@ def __contains__(self, op):
139126
]
140127

141128

142-
def register_prim_ops(ops: OpList):
143-
for op in PRIM_OPS:
144-
ops[op].supports_texture = True
145-
ops[op].supports_buffer = True
146-
ops[op].supports_dynamic_shape = True
129+
def enumerate_supported_ops():
130+
ops = OpList()
147131

132+
# Register in order of least to most capabilities
148133

149-
def register_no_dynamic_shape_ops(ops: OpList):
150-
for op in [
151-
*REDUCTION_OPS,
152-
*NORMALIZATION_OPS,
153-
*SHAPE_MANIPULATION_OPS,
154-
*INDEXING_OPS,
155-
*ORCHESTRATION_OPS,
156-
*CREATION_OPS,
157-
]:
134+
for op in NO_DYNAMIC_SHAPE:
158135
ops[op].supports_dynamic_shape = False
159136

160-
161-
def register_dynamic_shape_ops(ops: OpList):
162-
for op in [
163-
*BINARY_OPS,
164-
*UNARY_OPS,
165-
*MATMUL_OPS,
166-
*POOLING_OPS,
167-
*CONVOLUTION_OPS,
168-
]:
137+
for op in SUPPORTS_DYNAMIC_SHAPE:
169138
ops[op].supports_dynamic_shape = True
170139

171-
172-
def register_custom_ops(ops: OpList):
173-
for op in CUSTOM_OPS:
174-
ops[op].supports_dynamic_shape = True
140+
for op in PRIM_OPS:
175141
ops[op].supports_texture = True
142+
ops[op].supports_buffer = True
143+
ops[op].supports_dynamic_shape = True
176144

177-
178-
def enumerate_supported_ops():
179-
ops = OpList()
180-
register_prim_ops(ops)
181-
register_no_dynamic_shape_ops(ops)
182-
register_dynamic_shape_ops(ops)
183-
register_custom_ops(ops)
184145
return ops

0 commit comments

Comments
 (0)