@@ -189,6 +189,67 @@ def _dump_yaml(
189
189
)
190
190
191
191
192
+ def gen_oplist (
193
+ output_path : str ,
194
+ model_file_path : Optional [str ] = None ,
195
+ ops_schema_yaml_path : Optional [str ] = None ,
196
+ root_ops : Optional [str ] = None ,
197
+ ops_dict : Optional [str ] = None ,
198
+ include_all_operators : bool = False ,
199
+ ):
200
+ assert (
201
+ model_file_path
202
+ or ops_schema_yaml_path
203
+ or root_ops
204
+ or ops_dict
205
+ or include_all_operators
206
+ ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
207
+
208
+ assert output_path , "Need to provide output_path for dumped yaml file."
209
+ op_set = set ()
210
+ source_name = None
211
+ et_kernel_metadata = {}
212
+ if root_ops :
213
+ op_set .update (set (filter (lambda x : len (x ) > 0 , root_ops .split ("," ))))
214
+ et_kernel_metadata = merge_et_kernel_metadata (
215
+ et_kernel_metadata , {op : ["default" ] for op in op_set }
216
+ )
217
+ if ops_dict :
218
+ ops_and_metadata = json .loads (ops_dict )
219
+ for op , metadata in ops_and_metadata .items ():
220
+ op_set .update ({op })
221
+ op_metadata = metadata if len (metadata ) > 0 else ["default" ]
222
+ et_kernel_metadata = merge_et_kernel_metadata (
223
+ et_kernel_metadata , {op : op_metadata }
224
+ )
225
+ if model_file_path :
226
+ assert os .path .isfile (
227
+ model_file_path
228
+ ), "The value for --model_file_path needs to be a valid file."
229
+ op_set .update (_get_operators (model_file_path ))
230
+ source_name = model_file_path
231
+ et_kernel_metadata = merge_et_kernel_metadata (
232
+ et_kernel_metadata , _get_kernel_metadata_for_model (model_file_path )
233
+ )
234
+ if ops_schema_yaml_path :
235
+ assert os .path .isfile (
236
+ ops_schema_yaml_path
237
+ ), "The value for --ops_schema_yaml_path needs to be a valid file."
238
+ et_kernel_metadata = merge_et_kernel_metadata (
239
+ et_kernel_metadata ,
240
+ _get_et_kernel_metadata_from_ops_yaml (ops_schema_yaml_path ),
241
+ )
242
+ op_set .update (et_kernel_metadata .keys ())
243
+ source_name = ops_schema_yaml_path
244
+ _dump_yaml (
245
+ sorted (op_set ),
246
+ output_path ,
247
+ os .path .basename (source_name ) if source_name else None ,
248
+ et_kernel_metadata ,
249
+ include_all_operators ,
250
+ )
251
+
252
+
192
253
def main (args : List [Any ]) -> None :
193
254
"""This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
194
255
It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
@@ -233,54 +294,14 @@ def main(args: List[Any]) -> None:
233
294
required = False ,
234
295
)
235
296
options = parser .parse_args (args )
236
- assert (
237
- options .model_file_path
238
- or options .ops_schema_yaml_path
239
- or options .root_ops
240
- or options .ops_dict
241
- or options .include_all_operators
242
- ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or include_all_operators."
243
- op_set = set ()
244
- source_name = None
245
- et_kernel_metadata = {}
246
- if options .root_ops :
247
- op_set .update (set (filter (lambda x : len (x ) > 0 , options .root_ops .split ("," ))))
248
- et_kernel_metadata = merge_et_kernel_metadata (
249
- et_kernel_metadata , {op : ["default" ] for op in op_set }
250
- )
251
- if options .ops_dict :
252
- ops_and_metadata = json .loads (options .ops_dict )
253
- for op , metadata in ops_and_metadata .items ():
254
- op_set .update ({op })
255
- op_metadata = metadata if len (metadata ) > 0 else ["default" ]
256
- et_kernel_metadata = merge_et_kernel_metadata (
257
- et_kernel_metadata , {op : op_metadata }
258
- )
259
- if options .model_file_path :
260
- assert os .path .isfile (
261
- options .model_file_path
262
- ), "The value for --model_file_path needs to be a valid file."
263
- op_set .update (_get_operators (options .model_file_path ))
264
- source_name = options .model_file_path
265
- et_kernel_metadata = merge_et_kernel_metadata (
266
- et_kernel_metadata , _get_kernel_metadata_for_model (options .model_file_path )
267
- )
268
- if options .ops_schema_yaml_path :
269
- assert os .path .isfile (
270
- options .ops_schema_yaml_path
271
- ), "The value for --ops_schema_yaml_path needs to be a valid file."
272
- et_kernel_metadata = merge_et_kernel_metadata (
273
- et_kernel_metadata ,
274
- _get_et_kernel_metadata_from_ops_yaml (options .ops_schema_yaml_path ),
275
- )
276
- op_set .update (et_kernel_metadata .keys ())
277
- source_name = options .ops_schema_yaml_path
278
- _dump_yaml (
279
- sorted (op_set ),
280
- options .output_path ,
281
- os .path .basename (source_name ) if source_name else None ,
282
- et_kernel_metadata ,
283
- options .include_all_operators ,
297
+
298
+ gen_oplist (
299
+ output_path = options .output_path ,
300
+ model_file_path = options .model_file_path ,
301
+ ops_schema_yaml_path = options .ops_schema_yaml_path ,
302
+ root_ops = options .root_ops ,
303
+ ops_dict = options .ops_dict ,
304
+ include_all_operators = options .include_all_operators ,
284
305
)
285
306
286
307
0 commit comments