Skip to content

Commit 32fee14

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
integrate gen_oplist as a function (#1498)
Summary: Pull Request resolved: #1498 This function integrates gen_oplist as a function for other files use. Reviewed By: lucylq Differential Revision: D52449487 fbshipit-source-id: b7210c87c0b311e4e08b9de80135c0947d55208a
1 parent 67e4483 commit 32fee14

File tree

1 file changed

+69
-48
lines changed

1 file changed

+69
-48
lines changed

codegen/tools/gen_oplist.py

Lines changed: 69 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,67 @@ def _dump_yaml(
189189
)
190190

191191

192+
def gen_oplist(
193+
output_path: str,
194+
model_file_path: Optional[str] = None,
195+
ops_schema_yaml_path: Optional[str] = None,
196+
root_ops: Optional[str] = None,
197+
ops_dict: Optional[str] = None,
198+
include_all_operators: bool = False,
199+
):
200+
assert (
201+
model_file_path
202+
or ops_schema_yaml_path
203+
or root_ops
204+
or ops_dict
205+
or include_all_operators
206+
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
207+
208+
assert output_path, "Need to provide output_path for dumped yaml file."
209+
op_set = set()
210+
source_name = None
211+
et_kernel_metadata = {}
212+
if root_ops:
213+
op_set.update(set(filter(lambda x: len(x) > 0, root_ops.split(","))))
214+
et_kernel_metadata = merge_et_kernel_metadata(
215+
et_kernel_metadata, {op: ["default"] for op in op_set}
216+
)
217+
if ops_dict:
218+
ops_and_metadata = json.loads(ops_dict)
219+
for op, metadata in ops_and_metadata.items():
220+
op_set.update({op})
221+
op_metadata = metadata if len(metadata) > 0 else ["default"]
222+
et_kernel_metadata = merge_et_kernel_metadata(
223+
et_kernel_metadata, {op: op_metadata}
224+
)
225+
if model_file_path:
226+
assert os.path.isfile(
227+
model_file_path
228+
), "The value for --model_file_path needs to be a valid file."
229+
op_set.update(_get_operators(model_file_path))
230+
source_name = model_file_path
231+
et_kernel_metadata = merge_et_kernel_metadata(
232+
et_kernel_metadata, _get_kernel_metadata_for_model(model_file_path)
233+
)
234+
if ops_schema_yaml_path:
235+
assert os.path.isfile(
236+
ops_schema_yaml_path
237+
), "The value for --ops_schema_yaml_path needs to be a valid file."
238+
et_kernel_metadata = merge_et_kernel_metadata(
239+
et_kernel_metadata,
240+
_get_et_kernel_metadata_from_ops_yaml(ops_schema_yaml_path),
241+
)
242+
op_set.update(et_kernel_metadata.keys())
243+
source_name = ops_schema_yaml_path
244+
_dump_yaml(
245+
sorted(op_set),
246+
output_path,
247+
os.path.basename(source_name) if source_name else None,
248+
et_kernel_metadata,
249+
include_all_operators,
250+
)
251+
252+
192253
def main(args: List[Any]) -> None:
193254
"""This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
194255
It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
@@ -233,54 +294,14 @@ def main(args: List[Any]) -> None:
233294
required=False,
234295
)
235296
options = parser.parse_args(args)
236-
assert (
237-
options.model_file_path
238-
or options.ops_schema_yaml_path
239-
or options.root_ops
240-
or options.ops_dict
241-
or options.include_all_operators
242-
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or include_all_operators."
243-
op_set = set()
244-
source_name = None
245-
et_kernel_metadata = {}
246-
if options.root_ops:
247-
op_set.update(set(filter(lambda x: len(x) > 0, options.root_ops.split(","))))
248-
et_kernel_metadata = merge_et_kernel_metadata(
249-
et_kernel_metadata, {op: ["default"] for op in op_set}
250-
)
251-
if options.ops_dict:
252-
ops_and_metadata = json.loads(options.ops_dict)
253-
for op, metadata in ops_and_metadata.items():
254-
op_set.update({op})
255-
op_metadata = metadata if len(metadata) > 0 else ["default"]
256-
et_kernel_metadata = merge_et_kernel_metadata(
257-
et_kernel_metadata, {op: op_metadata}
258-
)
259-
if options.model_file_path:
260-
assert os.path.isfile(
261-
options.model_file_path
262-
), "The value for --model_file_path needs to be a valid file."
263-
op_set.update(_get_operators(options.model_file_path))
264-
source_name = options.model_file_path
265-
et_kernel_metadata = merge_et_kernel_metadata(
266-
et_kernel_metadata, _get_kernel_metadata_for_model(options.model_file_path)
267-
)
268-
if options.ops_schema_yaml_path:
269-
assert os.path.isfile(
270-
options.ops_schema_yaml_path
271-
), "The value for --ops_schema_yaml_path needs to be a valid file."
272-
et_kernel_metadata = merge_et_kernel_metadata(
273-
et_kernel_metadata,
274-
_get_et_kernel_metadata_from_ops_yaml(options.ops_schema_yaml_path),
275-
)
276-
op_set.update(et_kernel_metadata.keys())
277-
source_name = options.ops_schema_yaml_path
278-
_dump_yaml(
279-
sorted(op_set),
280-
options.output_path,
281-
os.path.basename(source_name) if source_name else None,
282-
et_kernel_metadata,
283-
options.include_all_operators,
297+
298+
gen_oplist(
299+
output_path=options.output_path,
300+
model_file_path=options.model_file_path,
301+
ops_schema_yaml_path=options.ops_schema_yaml_path,
302+
root_ops=options.root_ops,
303+
ops_dict=options.ops_dict,
304+
include_all_operators=options.include_all_operators,
284305
)
285306

286307

0 commit comments

Comments
 (0)