Skip to content

integrate gen_oplist as a function #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 69 additions & 48 deletions codegen/tools/gen_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,67 @@ def _dump_yaml(
)


def gen_oplist(
output_path: str,
model_file_path: Optional[str] = None,
ops_schema_yaml_path: Optional[str] = None,
root_ops: Optional[str] = None,
ops_dict: Optional[str] = None,
include_all_operators: bool = False,
):
assert (
model_file_path
or ops_schema_yaml_path
or root_ops
or ops_dict
or include_all_operators
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."

assert output_path, "Need to provide output_path for dumped yaml file."
op_set = set()
source_name = None
et_kernel_metadata = {}
if root_ops:
op_set.update(set(filter(lambda x: len(x) > 0, root_ops.split(","))))
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, {op: ["default"] for op in op_set}
)
if ops_dict:
ops_and_metadata = json.loads(ops_dict)
for op, metadata in ops_and_metadata.items():
op_set.update({op})
op_metadata = metadata if len(metadata) > 0 else ["default"]
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, {op: op_metadata}
)
if model_file_path:
assert os.path.isfile(
model_file_path
), "The value for --model_file_path needs to be a valid file."
op_set.update(_get_operators(model_file_path))
source_name = model_file_path
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, _get_kernel_metadata_for_model(model_file_path)
)
if ops_schema_yaml_path:
assert os.path.isfile(
ops_schema_yaml_path
), "The value for --ops_schema_yaml_path needs to be a valid file."
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata,
_get_et_kernel_metadata_from_ops_yaml(ops_schema_yaml_path),
)
op_set.update(et_kernel_metadata.keys())
source_name = ops_schema_yaml_path
_dump_yaml(
sorted(op_set),
output_path,
os.path.basename(source_name) if source_name else None,
et_kernel_metadata,
include_all_operators,
)


def main(args: List[Any]) -> None:
"""This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
Expand Down Expand Up @@ -233,54 +294,14 @@ def main(args: List[Any]) -> None:
required=False,
)
options = parser.parse_args(args)
assert (
options.model_file_path
or options.ops_schema_yaml_path
or options.root_ops
or options.ops_dict
or options.include_all_operators
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or include_all_operators."
op_set = set()
source_name = None
et_kernel_metadata = {}
if options.root_ops:
op_set.update(set(filter(lambda x: len(x) > 0, options.root_ops.split(","))))
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, {op: ["default"] for op in op_set}
)
if options.ops_dict:
ops_and_metadata = json.loads(options.ops_dict)
for op, metadata in ops_and_metadata.items():
op_set.update({op})
op_metadata = metadata if len(metadata) > 0 else ["default"]
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, {op: op_metadata}
)
if options.model_file_path:
assert os.path.isfile(
options.model_file_path
), "The value for --model_file_path needs to be a valid file."
op_set.update(_get_operators(options.model_file_path))
source_name = options.model_file_path
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, _get_kernel_metadata_for_model(options.model_file_path)
)
if options.ops_schema_yaml_path:
assert os.path.isfile(
options.ops_schema_yaml_path
), "The value for --ops_schema_yaml_path needs to be a valid file."
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata,
_get_et_kernel_metadata_from_ops_yaml(options.ops_schema_yaml_path),
)
op_set.update(et_kernel_metadata.keys())
source_name = options.ops_schema_yaml_path
_dump_yaml(
sorted(op_set),
options.output_path,
os.path.basename(source_name) if source_name else None,
et_kernel_metadata,
options.include_all_operators,

gen_oplist(
output_path=options.output_path,
model_file_path=options.model_file_path,
ops_schema_yaml_path=options.ops_schema_yaml_path,
root_ops=options.root_ops,
ops_dict=options.ops_dict,
include_all_operators=options.include_all_operators,
)


Expand Down