@@ -47,17 +47,16 @@ def __contains__(self, op):
47
47
operator .getitem ,
48
48
]
49
49
50
- BINARY_OPS = [
50
+ SUPPORTS_DYNAMIC_SHAPE = [
51
+ # Binary broadcasting operators
51
52
exir_ops .edge .aten .add .Tensor ,
52
53
exir_ops .edge .aten .sub .Tensor ,
53
54
exir_ops .edge .aten .minimum .default ,
54
55
exir_ops .edge .aten .mul .Tensor ,
55
56
exir_ops .edge .aten .div .Tensor ,
56
57
exir_ops .edge .aten .div .Tensor_mode ,
57
58
exir_ops .edge .aten .pow .Tensor_Tensor ,
58
- ]
59
-
60
- UNARY_OPS = [
59
+ # Unary elementwise operators
61
60
exir_ops .edge .aten .abs .default ,
62
61
exir_ops .edge .aten .clamp .default ,
63
62
exir_ops .edge .aten .cos .default ,
@@ -71,60 +70,48 @@ def __contains__(self, op):
71
70
exir_ops .edge .aten .sin .default ,
72
71
exir_ops .edge .aten .sqrt .default ,
73
72
exir_ops .edge .aten .tanh .default ,
74
- ]
75
-
76
- MATMUL_OPS = [
73
+ # Matrix Multiplication Operators
77
74
exir_ops .edge .aten .bmm .default ,
78
75
exir_ops .edge .aten .mm .default ,
79
76
exir_ops .edge .aten .addmm .default ,
80
77
exir_ops .edge .aten .linear .default ,
81
- ]
82
-
83
- POOLING_OPS = [
78
+ # Reduction operators
79
+ exir_ops .edge .aten ._log_softmax .default ,
80
+ exir_ops .edge .aten ._softmax .default ,
81
+ # 2D Pooling ops
84
82
exir_ops .edge .aten .avg_pool2d .default ,
85
83
exir_ops .edge .aten .max_pool2d_with_indices .default ,
86
- ]
87
-
88
- CONVOLUTION_OPS = [
84
+ # Convolution ops
89
85
exir_ops .edge .aten .convolution .default ,
90
86
exir_ops .edge .et_vk .conv_with_clamp .default ,
87
+ # Custom ops
88
+ "llama::sdpa_with_kv_cache" ,
91
89
]
92
90
93
- REDUCTION_OPS = [
91
+ NO_DYNAMIC_SHAPE = [
92
+ # Reduction operators
94
93
exir_ops .edge .aten .mean .dim ,
95
94
exir_ops .edge .aten .sum .dim_IntList ,
96
- exir_ops .edge .aten ._log_softmax .default ,
97
- exir_ops .edge .aten ._softmax .default ,
98
- ]
99
-
100
- NORMALIZATION_OPS = [
95
+ # Normalization operators
101
96
exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
102
97
exir_ops .edge .aten .native_layer_norm .default ,
103
- ]
104
-
105
- SHAPE_MANIPULATION_OPS = [
98
+ # Shape Manipulation operators
106
99
exir_ops .edge .aten .squeeze_copy .dims ,
107
100
exir_ops .edge .aten .unsqueeze_copy .default ,
108
101
exir_ops .edge .aten .view_copy .default ,
109
102
exir_ops .edge .aten .permute_copy .default ,
110
103
exir_ops .edge .aten .t_copy .default ,
111
- ]
112
-
113
- INDEXING_OPS = [
104
+ # Indexing and lookup operators
114
105
exir_ops .edge .aten .embedding .default ,
115
106
exir_ops .edge .aten .index_select .default ,
116
107
exir_ops .edge .aten .select_copy .int ,
117
108
exir_ops .edge .aten .slice_copy .Tensor ,
118
- ]
119
-
120
- ORCHESTRATION_OPS = [
109
+ # Tensor combination operators
121
110
exir_ops .edge .aten .cat .default ,
122
111
exir_ops .edge .aten .split_with_sizes_copy .default ,
123
112
exir_ops .edge .aten .split .Tensor ,
124
113
exir_ops .edge .aten .repeat .default ,
125
- ]
126
-
127
- CREATION_OPS = [
114
+ # Tensor creation operators
128
115
exir_ops .edge .aten .arange .start_step ,
129
116
exir_ops .edge .aten .clone .default ,
130
117
exir_ops .edge .aten .constant_pad_nd .default ,
@@ -139,46 +126,20 @@ def __contains__(self, op):
139
126
]
140
127
141
128
142
- def register_prim_ops (ops : OpList ):
143
- for op in PRIM_OPS :
144
- ops [op ].supports_texture = True
145
- ops [op ].supports_buffer = True
146
- ops [op ].supports_dynamic_shape = True
129
+ def enumerate_supported_ops ():
130
+ ops = OpList ()
147
131
132
+ # Register in order of least to most capabilities
148
133
149
- def register_no_dynamic_shape_ops (ops : OpList ):
150
- for op in [
151
- * REDUCTION_OPS ,
152
- * NORMALIZATION_OPS ,
153
- * SHAPE_MANIPULATION_OPS ,
154
- * INDEXING_OPS ,
155
- * ORCHESTRATION_OPS ,
156
- * CREATION_OPS ,
157
- ]:
134
+ for op in NO_DYNAMIC_SHAPE :
158
135
ops [op ].supports_dynamic_shape = False
159
136
160
-
161
- def register_dynamic_shape_ops (ops : OpList ):
162
- for op in [
163
- * BINARY_OPS ,
164
- * UNARY_OPS ,
165
- * MATMUL_OPS ,
166
- * POOLING_OPS ,
167
- * CONVOLUTION_OPS ,
168
- ]:
137
+ for op in SUPPORTS_DYNAMIC_SHAPE :
169
138
ops [op ].supports_dynamic_shape = True
170
139
171
-
172
- def register_custom_ops (ops : OpList ):
173
- for op in CUSTOM_OPS :
174
- ops [op ].supports_dynamic_shape = True
140
+ for op in PRIM_OPS :
175
141
ops [op ].supports_texture = True
142
+ ops [op ].supports_buffer = True
143
+ ops [op ].supports_dynamic_shape = True
176
144
177
-
178
- def enumerate_supported_ops ():
179
- ops = OpList ()
180
- register_prim_ops (ops )
181
- register_no_dynamic_shape_ops (ops )
182
- register_dynamic_shape_ops (ops )
183
- register_custom_ops (ops )
184
145
return ops
0 commit comments