Skip to content

Commit c67bb38

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
integrate gen_oplist as a function (#1498)
Summary: This function integrates gen_oplist as a function for other files use. Reviewed By: lucylq Differential Revision: D52449487
1 parent 53b99f3 commit c67bb38

File tree

1 file changed

+67
-48
lines changed

1 file changed

+67
-48
lines changed

codegen/tools/gen_oplist.py

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,65 @@ def _dump_yaml(
189189
)
190190

191191

192+
def gen_oplist(
193+
model_file_path: Optional[str],
194+
ops_schema_yaml_path: Optional[str],
195+
root_ops: Optional[str],
196+
ops_dict: Optional[str],
197+
include_all_operators: Optional[bool],
198+
output_path: str,
199+
):
200+
assert (
201+
model_file_path
202+
or ops_schema_yaml_path
203+
or root_ops
204+
or ops_dict
205+
or include_all_operators
206+
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
207+
op_set = set()
208+
source_name = None
209+
et_kernel_metadata = {}
210+
if root_ops:
211+
op_set.update(set(filter(lambda x: len(x) > 0, root_ops.split(","))))
212+
et_kernel_metadata = merge_et_kernel_metadata(
213+
et_kernel_metadata, {op: ["default"] for op in op_set}
214+
)
215+
if ops_dict:
216+
ops_and_metadata = json.loads(ops_dict)
217+
for op, metadata in ops_and_metadata.items():
218+
op_set.update({op})
219+
op_metadata = metadata if len(metadata) > 0 else ["default"]
220+
et_kernel_metadata = merge_et_kernel_metadata(
221+
et_kernel_metadata, {op: op_metadata}
222+
)
223+
if model_file_path:
224+
assert os.path.isfile(
225+
model_file_path
226+
), "The value for --model_file_path needs to be a valid file."
227+
op_set.update(_get_operators(model_file_path))
228+
source_name = model_file_path
229+
et_kernel_metadata = merge_et_kernel_metadata(
230+
et_kernel_metadata, _get_kernel_metadata_for_model(model_file_path)
231+
)
232+
if ops_schema_yaml_path:
233+
assert os.path.isfile(
234+
ops_schema_yaml_path
235+
), "The value for --ops_schema_yaml_path needs to be a valid file."
236+
et_kernel_metadata = merge_et_kernel_metadata(
237+
et_kernel_metadata,
238+
_get_et_kernel_metadata_from_ops_yaml(ops_schema_yaml_path),
239+
)
240+
op_set.update(et_kernel_metadata.keys())
241+
source_name = ops_schema_yaml_path
242+
_dump_yaml(
243+
sorted(op_set),
244+
output_path,
245+
os.path.basename(source_name) if source_name else None,
246+
et_kernel_metadata,
247+
include_all_operators,
248+
)
249+
250+
192251
def main(args: List[Any]) -> None:
193252
"""This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
194253
It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
@@ -233,54 +292,14 @@ def main(args: List[Any]) -> None:
233292
required=False,
234293
)
235294
options = parser.parse_args(args)
236-
assert (
237-
options.model_file_path
238-
or options.ops_schema_yaml_path
239-
or options.root_ops
240-
or options.ops_dict
241-
or options.include_all_operators
242-
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or include_all_operators."
243-
op_set = set()
244-
source_name = None
245-
et_kernel_metadata = {}
246-
if options.root_ops:
247-
op_set.update(set(filter(lambda x: len(x) > 0, options.root_ops.split(","))))
248-
et_kernel_metadata = merge_et_kernel_metadata(
249-
et_kernel_metadata, {op: ["default"] for op in op_set}
250-
)
251-
if options.ops_dict:
252-
ops_and_metadata = json.loads(options.ops_dict)
253-
for op, metadata in ops_and_metadata.items():
254-
op_set.update({op})
255-
op_metadata = metadata if len(metadata) > 0 else ["default"]
256-
et_kernel_metadata = merge_et_kernel_metadata(
257-
et_kernel_metadata, {op: op_metadata}
258-
)
259-
if options.model_file_path:
260-
assert os.path.isfile(
261-
options.model_file_path
262-
), "The value for --model_file_path needs to be a valid file."
263-
op_set.update(_get_operators(options.model_file_path))
264-
source_name = options.model_file_path
265-
et_kernel_metadata = merge_et_kernel_metadata(
266-
et_kernel_metadata, _get_kernel_metadata_for_model(options.model_file_path)
267-
)
268-
if options.ops_schema_yaml_path:
269-
assert os.path.isfile(
270-
options.ops_schema_yaml_path
271-
), "The value for --ops_schema_yaml_path needs to be a valid file."
272-
et_kernel_metadata = merge_et_kernel_metadata(
273-
et_kernel_metadata,
274-
_get_et_kernel_metadata_from_ops_yaml(options.ops_schema_yaml_path),
275-
)
276-
op_set.update(et_kernel_metadata.keys())
277-
source_name = options.ops_schema_yaml_path
278-
_dump_yaml(
279-
sorted(op_set),
280-
options.output_path,
281-
os.path.basename(source_name) if source_name else None,
282-
et_kernel_metadata,
283-
options.include_all_operators,
295+
296+
gen_oplist(
297+
model_file_path=options.model_file_path,
298+
ops_schema_yaml_path=options.ops_schema_yaml_path,
299+
root_ops=options.root_ops,
300+
ops_dict=options.ops_dict,
301+
include_all_operators=options.include_all_operators,
302+
output_path=options.output_path,
284303
)
285304

286305

0 commit comments

Comments
 (0)