@@ -47,17 +47,16 @@ def __contains__(self, op):
47
47
operator .getitem ,
48
48
]
49
49
50
- BINARY_OPS = [
50
+ SUPPORTS_DYNAMIC_SHAPE = [
51
+ # Binary broadcasting
51
52
exir_ops .edge .aten .add .Tensor ,
52
53
exir_ops .edge .aten .sub .Tensor ,
53
54
exir_ops .edge .aten .minimum .default ,
54
55
exir_ops .edge .aten .mul .Tensor ,
55
56
exir_ops .edge .aten .div .Tensor ,
56
57
exir_ops .edge .aten .div .Tensor_mode ,
57
58
exir_ops .edge .aten .pow .Tensor_Tensor ,
58
- ]
59
-
60
- UNARY_OPS = [
59
+ # Unary elementwise
61
60
exir_ops .edge .aten .abs .default ,
62
61
exir_ops .edge .aten .clamp .default ,
63
62
exir_ops .edge .aten .cos .default ,
@@ -71,60 +70,46 @@ def __contains__(self, op):
71
70
exir_ops .edge .aten .sin .default ,
72
71
exir_ops .edge .aten .sqrt .default ,
73
72
exir_ops .edge .aten .tanh .default ,
74
- ]
75
-
76
- MATMUL_OPS = [
73
+ # Matrix Multiplication
77
74
exir_ops .edge .aten .bmm .default ,
78
75
exir_ops .edge .aten .mm .default ,
79
76
exir_ops .edge .aten .addmm .default ,
80
77
exir_ops .edge .aten .linear .default ,
81
- ]
82
-
83
- POOLING_OPS = [
78
+ # Reduction
79
+ exir_ops .edge .aten ._log_softmax .default ,
80
+ exir_ops .edge .aten ._softmax .default ,
81
+ # 2D Pooling
84
82
exir_ops .edge .aten .avg_pool2d .default ,
85
83
exir_ops .edge .aten .max_pool2d_with_indices .default ,
86
- ]
87
-
88
- CONVOLUTION_OPS = [
84
+ # Convolution
89
85
exir_ops .edge .aten .convolution .default ,
90
86
exir_ops .edge .et_vk .conv_with_clamp .default ,
91
87
]
92
88
93
- REDUCTION_OPS = [
89
+ NO_DYNAMIC_SHAPE = [
90
+ # Reduction
94
91
exir_ops .edge .aten .mean .dim ,
95
92
exir_ops .edge .aten .sum .dim_IntList ,
96
- exir_ops .edge .aten ._log_softmax .default ,
97
- exir_ops .edge .aten ._softmax .default ,
98
- ]
99
-
100
- NORMALIZATION_OPS = [
93
+ # Normalization
101
94
exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
102
95
exir_ops .edge .aten .native_layer_norm .default ,
103
- ]
104
-
105
- SHAPE_MANIPULATION_OPS = [
96
+ # Shape Manipulation
106
97
exir_ops .edge .aten .squeeze_copy .dims ,
107
98
exir_ops .edge .aten .unsqueeze_copy .default ,
108
99
exir_ops .edge .aten .view_copy .default ,
109
100
exir_ops .edge .aten .permute_copy .default ,
110
101
exir_ops .edge .aten .t_copy .default ,
111
- ]
112
-
113
- INDEXING_OPS = [
102
+ # Indexing and lookup
114
103
exir_ops .edge .aten .embedding .default ,
115
104
exir_ops .edge .aten .index_select .default ,
116
105
exir_ops .edge .aten .select_copy .int ,
117
106
exir_ops .edge .aten .slice_copy .Tensor ,
118
- ]
119
-
120
- ORCHESTRATION_OPS = [
107
+ # Tensor combination
121
108
exir_ops .edge .aten .cat .default ,
122
109
exir_ops .edge .aten .split_with_sizes_copy .default ,
123
110
exir_ops .edge .aten .split .Tensor ,
124
111
exir_ops .edge .aten .repeat .default ,
125
- ]
126
-
127
- CREATION_OPS = [
112
+ # Tensor creation
128
113
exir_ops .edge .aten .arange .start_step ,
129
114
exir_ops .edge .aten .clone .default ,
130
115
exir_ops .edge .aten .constant_pad_nd .default ,
@@ -139,39 +124,20 @@ def __contains__(self, op):
139
124
]
140
125
141
126
142
- def register_prim_ops (ops : OpList ):
143
- for op in PRIM_OPS :
144
- ops [op ].supports_texture = True
145
- ops [op ].supports_buffer = True
146
- ops [op ].supports_dynamic_shape = True
127
+ def enumerate_supported_ops ():
128
+ ops = OpList ()
147
129
130
+ # Register in order of least to most capabilities
148
131
149
- def register_no_dynamic_shape_ops (ops : OpList ):
150
- for op in [
151
- * REDUCTION_OPS ,
152
- * NORMALIZATION_OPS ,
153
- * SHAPE_MANIPULATION_OPS ,
154
- * INDEXING_OPS ,
155
- * ORCHESTRATION_OPS ,
156
- * CREATION_OPS ,
157
- ]:
132
+ for op in NO_DYNAMIC_SHAPE :
158
133
ops [op ].supports_dynamic_shape = False
159
134
160
-
161
- def register_dynamic_shape_ops (ops : OpList ):
162
- for op in [
163
- * BINARY_OPS ,
164
- * UNARY_OPS ,
165
- * MATMUL_OPS ,
166
- * POOLING_OPS ,
167
- * CONVOLUTION_OPS ,
168
- ]:
135
+ for op in SUPPORTS_DYNAMIC_SHAPE :
169
136
ops [op ].supports_dynamic_shape = True
170
137
138
+ for op in PRIM_OPS :
139
+ ops [op ].supports_texture = True
140
+ ops [op ].supports_buffer = True
141
+ ops [op ].supports_dynamic_shape = True
171
142
172
- def enumerate_supported_ops ():
173
- ops = OpList ()
174
- register_prim_ops (ops )
175
- register_no_dynamic_shape_ops (ops )
176
- register_dynamic_shape_ops (ops )
177
143
return ops
0 commit comments