Skip to content

Commit 2c592e6

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Introduce executorch_ops_check
Summary: Introducing `executorch_ops_check`. For any given target, find all the `executorch_generated_lib` targets in the dependent transitive closure and make sure the operators in `executorch_generated_lib` are not overlapping. Differential Revision: D66560425
1 parent ddec0c7 commit 2c592e6

File tree

5 files changed

+152
-152
lines changed

5 files changed

+152
-152
lines changed

codegen/tools/gen_all_oplist.py

Lines changed: 124 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,65 @@
1-
#!/usr/bin/env fbpython
1+
#!/usr/bin/env python3
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

88
import argparse
9+
import re
910
import os
1011
import sys
11-
from typing import Any, List
12+
from functools import reduce
13+
from typing import Any, List, Tuple
1214

13-
from tools_copy.code_analyzer import gen_oplist_copy_from_core
15+
import yaml
16+
from torchgen.selective_build.selector import (
17+
combine_selective_builders,
18+
SelectiveBuilder,
19+
)
20+
from pathlib import Path
1421

1522

23+
def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
24+
ops = []
25+
for op_name, op in selective_builder.operators.items():
26+
if op.include_all_overloads:
27+
ops.append(op_name)
28+
if ops:
29+
raise Exception( # noqa: TRY002
30+
(
31+
"Operators that include all overloads are "
32+
+ "not allowed since --allow-include-all-overloads "
33+
+ "was not specified: {}"
34+
).format(", ".join(ops))
35+
)
36+
37+
def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
38+
real_path = str(Path(model_file_path).resolve(strict=True))
39+
# try my best to convert to buck target
40+
prog = re.compile(r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml")
41+
match = prog.match(real_path)
42+
if match:
43+
return f"{match.group(1)}//{match.group(2)}:{match.group(3)}"
44+
else:
45+
return real_path
46+
1647
def main(argv: List[Any]) -> None:
17-
"""This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18-
This is needed because we intend to error out for the case where `model_file_list_path`
19-
is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20-
is provided as a dependency to ExecuTorch build.
48+
"""This binary generates 3 files:
49+
50+
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
51+
dtypes captured by tracing
52+
2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
2153
"""
2254
parser = argparse.ArgumentParser(description="Generate operator lists")
2355
parser.add_argument(
56+
"--output-dir",
2457
"--output_dir",
2558
help=("The directory to store the output yaml file (selected_operators.yaml)"),
2659
required=True,
2760
)
2861
parser.add_argument(
62+
"--model-file-list-path",
2963
"--model_file_list_path",
3064
help=(
3165
"Path to a file that contains the locations of individual "
@@ -36,6 +70,7 @@ def main(argv: List[Any]) -> None:
3670
required=True,
3771
)
3872
parser.add_argument(
73+
"--allow-include-all-overloads",
3974
"--allow_include_all_overloads",
4075
help=(
4176
"Flag to allow operators that include all overloads. "
@@ -46,26 +81,99 @@ def main(argv: List[Any]) -> None:
4681
default=False,
4782
required=False,
4883
)
84+
parser.add_argument(
85+
"--check-ops-not-overlapping",
86+
"--check_ops_not_overlapping",
87+
help=(
88+
"Flag to check if the operators in the model file list are overlapping. "
89+
+ "If not set, the script will not error out for overlapping operators."
90+
),
91+
action="store_true",
92+
default=False,
93+
required=False,
94+
)
95+
options = parser.parse_args(argv)
96+
4997

50-
# check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
98+
# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
5199
# 1. a yaml file containing selected ops (could be empty), or
52-
# 2. a non-empty list of yaml files in the `model_file_list_path`.
53-
# If none of the two things happened, the build target has no dependency on any selective build and we should error out.
54-
options = parser.parse_args(argv)
100+
# 2. a non-empty list of yaml files in the `model_file_list_path` or
101+
# 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file.
102+
# If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55103
if os.path.isfile(options.model_file_list_path):
56-
pass
104+
print("Processing model file: ", options.model_file_list_path)
105+
model_dicts = []
106+
model_dict = yaml.safe_load(open(options.model_file_list_path))
107+
model_dicts.append(model_dict)
57108
else:
58-
assert (
59-
options.model_file_list_path[0] == "@"
60-
), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
109+
print("Processing model directory: ", options.model_file_list_path)
110+
assert options.model_file_list_path[0] == "@" , "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
111+
61112
model_file_list_path = options.model_file_list_path[1:]
113+
114+
model_dicts = []
62115
with open(model_file_list_path) as model_list_file:
63116
model_file_names = model_list_file.read().split()
64117
assert (
65118
len(model_file_names) > 0
66119
), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67120
"build. Please refer to Selective Build wiki page to add at least one."
68-
gen_oplist_copy_from_core.main(argv)
121+
for model_file_name in model_file_names:
122+
if not os.path.isfile(model_file_name):
123+
model_file_name = os.path.join(model_file_name, "selected_operators.yaml")
124+
print("Processing model file: ", model_file_name)
125+
assert os.path.isfile(model_file_name), f"{model_file_name} is not a valid file path. This is likely a BUCK issue."
126+
with open(model_file_name, "rb") as model_file:
127+
model_dict = yaml.safe_load(model_file)
128+
resolved = resolve_model_file_path_to_buck_target(model_file_name)
129+
for op in model_dict["operators"]:
130+
model_dict["operators"][op]["debug_info"] = [resolved]
131+
model_dicts.append(model_dict)
132+
133+
selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]
134+
135+
# Optionally check if the operators in the model file list are overlapping.
136+
if options.check_ops_not_overlapping:
137+
ops = {}
138+
for model_dict in model_dicts:
139+
for op_name in model_dict["operators"]:
140+
if op_name in ops:
141+
debug_info_1 = ','.join(ops[op_name]["debug_info"])
142+
debug_info_2 = ','.join(model_dict["operators"][op_name]["debug_info"])
143+
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
144+
if "//" not in debug_info_1 and "//" not in debug_info_2:
145+
error += "\nWe can't determine what BUCK targets these model files belong to."
146+
tail = "."
147+
else:
148+
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
149+
error += f"\n buck2 cquery <mode> \"allpaths(<target>, {debug_info_1})\""
150+
error += f"\n buck2 cquery <mode> \"allpaths(<target>, {debug_info_2})\""
151+
tail = "as well as results from BUCK commands listed above."
152+
153+
error += "\n\nIf issue is not resolved, please post in PyTorch Edge Q&A with this error message" + tail
154+
raise Exception(error) # noqa: TRY002
155+
ops[op_name] = model_dict["operators"][op_name]
156+
# We may have 0 selective builders since there may not be any viable
157+
# pt_operator_library rule marked as a dep for the pt_operator_registry rule.
158+
# This is potentially an error, and we should probably raise an assertion
159+
# failure here. However, this needs to be investigated further.
160+
selective_builder = SelectiveBuilder.from_yaml_dict({})
161+
if len(selective_builders) > 0:
162+
selective_builder = reduce(
163+
combine_selective_builders,
164+
selective_builders,
165+
)
166+
167+
if not options.allow_include_all_overloads:
168+
throw_if_any_op_includes_overloads(selective_builder)
169+
with open(
170+
os.path.join(options.output_dir, "selected_operators.yaml"), "wb"
171+
) as out_file:
172+
out_file.write(
173+
yaml.safe_dump(
174+
selective_builder.to_dict(), default_flow_style=False
175+
).encode("utf-8"),
176+
)
69177

70178

71179
if __name__ == "__main__":

codegen/tools/gen_oplist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def gen_oplist(
249249
_dump_yaml(
250250
sorted(op_set),
251251
output_path,
252-
os.path.basename(source_name) if source_name else None,
252+
source_name,
253253
et_kernel_metadata,
254254
include_all_operators,
255255
)

codegen/tools/gen_oplist_copy_from_core.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

codegen/tools/targets.bzl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,14 @@ def define_common_targets(is_fbcode = False):
7878
],
7979
)
8080

81-
runtime.python_library(
82-
name = "gen_oplist_copy_from_core",
83-
srcs = [
84-
"gen_oplist_copy_from_core.py",
85-
],
86-
base_module = "tools_copy.code_analyzer",
87-
external_deps = ["torchgen"],
88-
)
89-
9081
runtime.python_library(
9182
name = "gen_all_oplist_lib",
9283
srcs = ["gen_all_oplist.py"],
9384
base_module = "executorch.codegen.tools",
9485
visibility = [
9586
"//executorch/...",
9687
],
97-
deps = [":gen_oplist_copy_from_core"],
88+
external_deps = ["torchgen"],
9889
)
9990

10091
runtime.python_binary(
@@ -130,7 +121,7 @@ def define_common_targets(is_fbcode = False):
130121
srcs = ["gen_selected_op_variants.py"],
131122
base_module = "executorch.codegen.tools",
132123
visibility = ["//executorch/..."],
133-
deps = [":gen_oplist_copy_from_core"],
124+
deps = [":gen_all_oplist_lib"],
134125
)
135126

136127
runtime.python_binary(

shim/xplat/executorch/codegen/codegen.bzl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def executorch_generated_lib(
531531
)
532532

533533
# genrule for selective build from static operator list
534-
oplist_dir_name = name + "_pt_oplist"
534+
oplist_dir_name = name + "_et_oplist"
535535
runtime.genrule(
536536
name = oplist_dir_name,
537537
macros_only = False,
@@ -665,3 +665,27 @@ def executorch_generated_lib(
665665
define_static_target = define_static_targets,
666666
platforms = platforms,
667667
)
668+
669+
# Util macro that takes in a binary or a shared library, find targets ending with `_et_oplist` in the transitive closure of deps,
670+
# get the `selected_operators.yaml` from those targets, try to merge them into a single yaml. This target will fail to build, if
671+
# there are intersections of all `selected_operators.yaml` the `target` is depending on.
672+
#
673+
# An example failure case: a binary `bin` is depending on 2 `executorch_generated_lib`s and they both register `aten::add.out`
674+
# with either the same or different kernels associated to it.
675+
#
676+
# If build successfully, all of the `selected_operators.yaml` will be merged into 1 `selected_operators.yaml` for debugging purpose.
677+
def executorch_ops_check(
678+
name,
679+
target,
680+
):
681+
runtime.genrule(
682+
name = name,
683+
macros_only = False,
684+
cmd = ("$(exe fbsource//xplat/executorch/codegen/tools:gen_all_oplist) " +
685+
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps({target}))\") " +
686+
"--allow_include_all_overloads " +
687+
"--check_ops_not_overlapping " +
688+
"--output_dir $OUT ").format(target=target),
689+
outs = {"selected_operators.yaml": ["selected_operators.yaml"]},
690+
default_outs = ["."],
691+
)

0 commit comments

Comments
 (0)