@@ -189,6 +189,65 @@ def _dump_yaml(
189
189
)
190
190
191
191
192
+ def gen_oplist (
193
+ model_file_path : Optional [str ],
194
+ ops_schema_yaml_path : Optional [str ],
195
+ root_ops : Optional [str ],
196
+ ops_dict : Optional [str ],
197
+ include_all_operators : Optional [bool ],
198
+ output_path : str ,
199
+ ):
200
+ assert (
201
+ model_file_path
202
+ or ops_schema_yaml_path
203
+ or root_ops
204
+ or ops_dict
205
+ or include_all_operators
206
+ ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
207
+ op_set = set ()
208
+ source_name = None
209
+ et_kernel_metadata = {}
210
+ if root_ops :
211
+ op_set .update (set (filter (lambda x : len (x ) > 0 , root_ops .split ("," ))))
212
+ et_kernel_metadata = merge_et_kernel_metadata (
213
+ et_kernel_metadata , {op : ["default" ] for op in op_set }
214
+ )
215
+ if ops_dict :
216
+ ops_and_metadata = json .loads (ops_dict )
217
+ for op , metadata in ops_and_metadata .items ():
218
+ op_set .update ({op })
219
+ op_metadata = metadata if len (metadata ) > 0 else ["default" ]
220
+ et_kernel_metadata = merge_et_kernel_metadata (
221
+ et_kernel_metadata , {op : op_metadata }
222
+ )
223
+ if model_file_path :
224
+ assert os .path .isfile (
225
+ model_file_path
226
+ ), "The value for --model_file_path needs to be a valid file."
227
+ op_set .update (_get_operators (model_file_path ))
228
+ source_name = model_file_path
229
+ et_kernel_metadata = merge_et_kernel_metadata (
230
+ et_kernel_metadata , _get_kernel_metadata_for_model (model_file_path )
231
+ )
232
+ if ops_schema_yaml_path :
233
+ assert os .path .isfile (
234
+ ops_schema_yaml_path
235
+ ), "The value for --ops_schema_yaml_path needs to be a valid file."
236
+ et_kernel_metadata = merge_et_kernel_metadata (
237
+ et_kernel_metadata ,
238
+ _get_et_kernel_metadata_from_ops_yaml (ops_schema_yaml_path ),
239
+ )
240
+ op_set .update (et_kernel_metadata .keys ())
241
+ source_name = ops_schema_yaml_path
242
+ _dump_yaml (
243
+ sorted (op_set ),
244
+ output_path ,
245
+ os .path .basename (source_name ) if source_name else None ,
246
+ et_kernel_metadata ,
247
+ include_all_operators ,
248
+ )
249
+
250
+
192
251
def main (args : List [Any ]) -> None :
193
252
"""This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
194
253
It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
@@ -233,54 +292,14 @@ def main(args: List[Any]) -> None:
233
292
required = False ,
234
293
)
235
294
options = parser .parse_args (args )
236
- assert (
237
- options .model_file_path
238
- or options .ops_schema_yaml_path
239
- or options .root_ops
240
- or options .ops_dict
241
- or options .include_all_operators
242
- ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or include_all_operators."
243
- op_set = set ()
244
- source_name = None
245
- et_kernel_metadata = {}
246
- if options .root_ops :
247
- op_set .update (set (filter (lambda x : len (x ) > 0 , options .root_ops .split ("," ))))
248
- et_kernel_metadata = merge_et_kernel_metadata (
249
- et_kernel_metadata , {op : ["default" ] for op in op_set }
250
- )
251
- if options .ops_dict :
252
- ops_and_metadata = json .loads (options .ops_dict )
253
- for op , metadata in ops_and_metadata .items ():
254
- op_set .update ({op })
255
- op_metadata = metadata if len (metadata ) > 0 else ["default" ]
256
- et_kernel_metadata = merge_et_kernel_metadata (
257
- et_kernel_metadata , {op : op_metadata }
258
- )
259
- if options .model_file_path :
260
- assert os .path .isfile (
261
- options .model_file_path
262
- ), "The value for --model_file_path needs to be a valid file."
263
- op_set .update (_get_operators (options .model_file_path ))
264
- source_name = options .model_file_path
265
- et_kernel_metadata = merge_et_kernel_metadata (
266
- et_kernel_metadata , _get_kernel_metadata_for_model (options .model_file_path )
267
- )
268
- if options .ops_schema_yaml_path :
269
- assert os .path .isfile (
270
- options .ops_schema_yaml_path
271
- ), "The value for --ops_schema_yaml_path needs to be a valid file."
272
- et_kernel_metadata = merge_et_kernel_metadata (
273
- et_kernel_metadata ,
274
- _get_et_kernel_metadata_from_ops_yaml (options .ops_schema_yaml_path ),
275
- )
276
- op_set .update (et_kernel_metadata .keys ())
277
- source_name = options .ops_schema_yaml_path
278
- _dump_yaml (
279
- sorted (op_set ),
280
- options .output_path ,
281
- os .path .basename (source_name ) if source_name else None ,
282
- et_kernel_metadata ,
283
- options .include_all_operators ,
295
+
296
+ gen_oplist (
297
+ model_file_path = options .model_file_path ,
298
+ ops_schema_yaml_path = options .ops_schema_yaml_path ,
299
+ root_ops = options .root_ops ,
300
+ ops_dict = options .ops_dict ,
301
+ include_all_operators = options .include_all_operators ,
302
+ output_path = options .output_path ,
284
303
)
285
304
286
305
0 commit comments