Skip to content

Commit 8497cba

Browse files
authored
Introduce executorch_ops_check
Differential Revision: D66560425 Pull Request resolved: #7142
1 parent 1cf9482 commit 8497cba

File tree

7 files changed

+255
-171
lines changed

7 files changed

+255
-171
lines changed

codegen/tools/gen_all_oplist.py

Lines changed: 136 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env fbpython
21
# Copyright (c) Meta Platforms, Inc. and affiliates.
32
# All rights reserved.
43
#
@@ -7,25 +6,63 @@
76

87
import argparse
98
import os
9+
import re
1010
import sys
11+
from functools import reduce
12+
from pathlib import Path
1113
from typing import Any, List
1214

13-
from tools_copy.code_analyzer import gen_oplist_copy_from_core
15+
import yaml
16+
from torchgen.selective_build.selector import (
17+
combine_selective_builders,
18+
SelectiveBuilder,
19+
)
20+
21+
22+
def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
23+
ops = []
24+
for op_name, op in selective_builder.operators.items():
25+
if op.include_all_overloads:
26+
ops.append(op_name)
27+
if ops:
28+
raise Exception( # noqa: TRY002
29+
(
30+
"Operators that include all overloads are "
31+
+ "not allowed since --allow-include-all-overloads "
32+
+ "was not specified: {}"
33+
).format(", ".join(ops))
34+
)
35+
36+
37+
def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
38+
real_path = str(Path(model_file_path).resolve(strict=True))
39+
# try my best to convert to buck target
40+
prog = re.compile(
41+
r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml"
42+
)
43+
match = prog.match(real_path)
44+
if match:
45+
return f"{match.group(1)}//{match.group(2)}:{match.group(3)}"
46+
else:
47+
return real_path
1448

1549

1650
def main(argv: List[Any]) -> None:
17-
"""This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18-
This is needed because we intend to error out for the case where `model_file_list_path`
19-
is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20-
is provided as a dependency to ExecuTorch build.
51+
"""This binary generates 3 files:
52+
53+
1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
54+
dtypes captured by tracing
55+
2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
2156
"""
2257
parser = argparse.ArgumentParser(description="Generate operator lists")
2358
parser.add_argument(
59+
"--output-dir",
2460
"--output_dir",
2561
help=("The directory to store the output yaml file (selected_operators.yaml)"),
2662
required=True,
2763
)
2864
parser.add_argument(
65+
"--model-file-list-path",
2966
"--model_file_list_path",
3067
help=(
3168
"Path to a file that contains the locations of individual "
@@ -36,6 +73,7 @@ def main(argv: List[Any]) -> None:
3673
required=True,
3774
)
3875
parser.add_argument(
76+
"--allow-include-all-overloads",
3977
"--allow_include_all_overloads",
4078
help=(
4179
"Flag to allow operators that include all overloads. "
@@ -46,26 +84,112 @@ def main(argv: List[Any]) -> None:
4684
default=False,
4785
required=False,
4886
)
87+
parser.add_argument(
88+
"--check-ops-not-overlapping",
89+
"--check_ops_not_overlapping",
90+
help=(
91+
"Flag to check if the operators in the model file list are overlapping. "
92+
+ "If not set, the script will not error out for overlapping operators."
93+
),
94+
action="store_true",
95+
default=False,
96+
required=False,
97+
)
98+
options = parser.parse_args(argv)
4999

50-
# check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
100+
# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
51101
# 1. a yaml file containing selected ops (could be empty), or
52-
# 2. a non-empty list of yaml files in the `model_file_list_path`.
53-
# If none of the two things happened, the build target has no dependency on any selective build and we should error out.
54-
options = parser.parse_args(argv)
102+
# 2. a non-empty list of yaml files in the `model_file_list_path` or
103+
# 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file.
104+
# If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55105
if os.path.isfile(options.model_file_list_path):
56-
pass
106+
print("Processing model file: ", options.model_file_list_path)
107+
model_dicts = []
108+
model_dict = yaml.safe_load(open(options.model_file_list_path))
109+
model_dicts.append(model_dict)
57110
else:
111+
print(
112+
"Processing model file list or model directory list: ",
113+
options.model_file_list_path,
114+
)
58115
assert (
59116
options.model_file_list_path[0] == "@"
60117
), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
118+
61119
model_file_list_path = options.model_file_list_path[1:]
120+
121+
model_dicts = []
62122
with open(model_file_list_path) as model_list_file:
63123
model_file_names = model_list_file.read().split()
64124
assert (
65125
len(model_file_names) > 0
66126
), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67127
"build. Please refer to Selective Build wiki page to add at least one."
68-
gen_oplist_copy_from_core.main(argv)
128+
for model_file_name in model_file_names:
129+
if not os.path.isfile(model_file_name):
130+
model_file_name = os.path.join(
131+
model_file_name, "selected_operators.yaml"
132+
)
133+
print("Processing model file: ", model_file_name)
134+
assert os.path.isfile(
135+
model_file_name
136+
), f"{model_file_name} is not a valid file path. This is likely a BUCK issue."
137+
with open(model_file_name, "rb") as model_file:
138+
model_dict = yaml.safe_load(model_file)
139+
resolved = resolve_model_file_path_to_buck_target(model_file_name)
140+
for op in model_dict["operators"]:
141+
model_dict["operators"][op]["debug_info"] = [resolved]
142+
model_dicts.append(model_dict)
143+
144+
selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]
145+
146+
# Optionally check if the operators in the model file list are overlapping.
147+
if options.check_ops_not_overlapping:
148+
ops = {}
149+
for model_dict in model_dicts:
150+
for op_name in model_dict["operators"]:
151+
if op_name in ops:
152+
debug_info_1 = ",".join(ops[op_name]["debug_info"])
153+
debug_info_2 = ",".join(
154+
model_dict["operators"][op_name]["debug_info"]
155+
)
156+
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
157+
if "//" not in debug_info_1 and "//" not in debug_info_2:
158+
error += "\nWe can't determine what BUCK targets these model files belong to."
159+
tail = "."
160+
else:
161+
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
162+
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
163+
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
164+
tail = "as well as results from BUCK commands listed above."
165+
166+
error += (
167+
"\n\nIf issue is not resolved, please post in PyTorch Edge Q&A with this error message"
168+
+ tail
169+
)
170+
raise Exception(error) # noqa: TRY002
171+
ops[op_name] = model_dict["operators"][op_name]
172+
# We may have 0 selective builders since there may not be any viable
173+
# pt_operator_library rule marked as a dep for the pt_operator_registry rule.
174+
# This is potentially an error, and we should probably raise an assertion
175+
# failure here. However, this needs to be investigated further.
176+
selective_builder = SelectiveBuilder.from_yaml_dict({})
177+
if len(selective_builders) > 0:
178+
selective_builder = reduce(
179+
combine_selective_builders,
180+
selective_builders,
181+
)
182+
183+
if not options.allow_include_all_overloads:
184+
throw_if_any_op_includes_overloads(selective_builder)
185+
with open(
186+
os.path.join(options.output_dir, "selected_operators.yaml"), "wb"
187+
) as out_file:
188+
out_file.write(
189+
yaml.safe_dump(
190+
selective_builder.to_dict(), default_flow_style=False
191+
).encode("utf-8"),
192+
)
69193

70194

71195
if __name__ == "__main__":

codegen/tools/gen_oplist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def gen_oplist(
249249
_dump_yaml(
250250
sorted(op_set),
251251
output_path,
252-
os.path.basename(source_name) if source_name else None,
252+
source_name,
253253
et_kernel_metadata,
254254
include_all_operators,
255255
)

codegen/tools/gen_oplist_copy_from_core.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

codegen/tools/targets.bzl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,14 @@ def define_common_targets(is_fbcode = False):
7878
],
7979
)
8080

81-
runtime.python_library(
82-
name = "gen_oplist_copy_from_core",
83-
srcs = [
84-
"gen_oplist_copy_from_core.py",
85-
],
86-
base_module = "tools_copy.code_analyzer",
87-
external_deps = ["torchgen"],
88-
)
89-
9081
runtime.python_library(
9182
name = "gen_all_oplist_lib",
9283
srcs = ["gen_all_oplist.py"],
9384
base_module = "executorch.codegen.tools",
9485
visibility = [
9586
"//executorch/...",
9687
],
97-
deps = [":gen_oplist_copy_from_core"],
88+
external_deps = ["torchgen"],
9889
)
9990

10091
runtime.python_binary(
@@ -130,7 +121,7 @@ def define_common_targets(is_fbcode = False):
130121
srcs = ["gen_selected_op_variants.py"],
131122
base_module = "executorch.codegen.tools",
132123
visibility = ["//executorch/..."],
133-
deps = [":gen_oplist_copy_from_core"],
124+
deps = [":gen_all_oplist_lib"],
134125
)
135126

136127
runtime.python_binary(

0 commit comments

Comments
 (0)