1
- #!/usr/bin/env fbpython
2
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
# All rights reserved.
4
3
#
7
6
8
7
import argparse
9
8
import os
9
+ import re
10
10
import sys
11
+ from functools import reduce
12
+ from pathlib import Path
11
13
from typing import Any , List
12
14
13
- from tools_copy .code_analyzer import gen_oplist_copy_from_core
15
+ import yaml
16
+ from torchgen .selective_build .selector import (
17
+ combine_selective_builders ,
18
+ SelectiveBuilder ,
19
+ )
20
+
21
+
22
+ def throw_if_any_op_includes_overloads (selective_builder : SelectiveBuilder ) -> None :
23
+ ops = []
24
+ for op_name , op in selective_builder .operators .items ():
25
+ if op .include_all_overloads :
26
+ ops .append (op_name )
27
+ if ops :
28
+ raise Exception ( # noqa: TRY002
29
+ (
30
+ "Operators that include all overloads are "
31
+ + "not allowed since --allow-include-all-overloads "
32
+ + "was not specified: {}"
33
+ ).format (", " .join (ops ))
34
+ )
35
+
36
+
37
+ def resolve_model_file_path_to_buck_target (model_file_path : str ) -> str :
38
+ real_path = str (Path (model_file_path ).resolve (strict = True ))
39
+ # try my best to convert to buck target
40
+ prog = re .compile (
41
+ r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml"
42
+ )
43
+ match = prog .match (real_path )
44
+ if match :
45
+ return f"{ match .group (1 )} //{ match .group (2 )} :{ match .group (3 )} "
46
+ else :
47
+ return real_path
14
48
15
49
16
50
def main (argv : List [Any ]) -> None :
17
- """This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18
- This is needed because we intend to error out for the case where `model_file_list_path`
19
- is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20
- is provided as a dependency to ExecuTorch build.
51
+ """This binary generates 3 files:
52
+
53
+ 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
54
+ dtypes captured by tracing
55
+ 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
21
56
"""
22
57
parser = argparse .ArgumentParser (description = "Generate operator lists" )
23
58
parser .add_argument (
59
+ "--output-dir" ,
24
60
"--output_dir" ,
25
61
help = ("The directory to store the output yaml file (selected_operators.yaml)" ),
26
62
required = True ,
27
63
)
28
64
parser .add_argument (
65
+ "--model-file-list-path" ,
29
66
"--model_file_list_path" ,
30
67
help = (
31
68
"Path to a file that contains the locations of individual "
@@ -36,6 +73,7 @@ def main(argv: List[Any]) -> None:
36
73
required = True ,
37
74
)
38
75
parser .add_argument (
76
+ "--allow-include-all-overloads" ,
39
77
"--allow_include_all_overloads" ,
40
78
help = (
41
79
"Flag to allow operators that include all overloads. "
@@ -46,26 +84,109 @@ def main(argv: List[Any]) -> None:
46
84
default = False ,
47
85
required = False ,
48
86
)
87
+ parser .add_argument (
88
+ "--check-ops-not-overlapping" ,
89
+ "--check_ops_not_overlapping" ,
90
+ help = (
91
+ "Flag to check if the operators in the model file list are overlapping. "
92
+ + "If not set, the script will not error out for overlapping operators."
93
+ ),
94
+ action = "store_true" ,
95
+ default = False ,
96
+ required = False ,
97
+ )
98
+ options = parser .parse_args (argv )
49
99
50
- # check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
100
+ # Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
51
101
# 1. a yaml file containing selected ops (could be empty), or
52
- # 2. a non-empty list of yaml files in the `model_file_list_path`.
53
- # If none of the two things happened, the build target has no dependency on any selective build and we should error out .
54
- options = parser . parse_args ( argv )
102
+ # 2. a non-empty list of yaml files in the `model_file_list_path` or
103
+ # 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file .
104
+ # If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55
105
if os .path .isfile (options .model_file_list_path ):
56
- pass
106
+ print ("Processing model file: " , options .model_file_list_path )
107
+ model_dicts = []
108
+ model_dict = yaml .safe_load (open (options .model_file_list_path ))
109
+ model_dicts .append (model_dict )
57
110
else :
111
+ print ("Processing model directory: " , options .model_file_list_path )
58
112
assert (
59
113
options .model_file_list_path [0 ] == "@"
60
114
), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
115
+
61
116
model_file_list_path = options .model_file_list_path [1 :]
117
+
118
+ model_dicts = []
62
119
with open (model_file_list_path ) as model_list_file :
63
120
model_file_names = model_list_file .read ().split ()
64
121
assert (
65
122
len (model_file_names ) > 0
66
123
), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67
124
"build. Please refer to Selective Build wiki page to add at least one."
68
- gen_oplist_copy_from_core .main (argv )
125
+ for model_file_name in model_file_names :
126
+ if not os .path .isfile (model_file_name ):
127
+ model_file_name = os .path .join (
128
+ model_file_name , "selected_operators.yaml"
129
+ )
130
+ print ("Processing model file: " , model_file_name )
131
+ assert os .path .isfile (
132
+ model_file_name
133
+ ), f"{ model_file_name } is not a valid file path. This is likely a BUCK issue."
134
+ with open (model_file_name , "rb" ) as model_file :
135
+ model_dict = yaml .safe_load (model_file )
136
+ resolved = resolve_model_file_path_to_buck_target (model_file_name )
137
+ for op in model_dict ["operators" ]:
138
+ model_dict ["operators" ][op ]["debug_info" ] = [resolved ]
139
+ model_dicts .append (model_dict )
140
+
141
+ selective_builders = [SelectiveBuilder .from_yaml_dict (m ) for m in model_dicts ]
142
+
143
+ # Optionally check if the operators in the model file list are overlapping.
144
+ if options .check_ops_not_overlapping :
145
+ ops = {}
146
+ for model_dict in model_dicts :
147
+ for op_name in model_dict ["operators" ]:
148
+ if op_name in ops :
149
+ debug_info_1 = "," .join (ops [op_name ]["debug_info" ])
150
+ debug_info_2 = "," .join (
151
+ model_dict ["operators" ][op_name ]["debug_info" ]
152
+ )
153
+ error = f"Operator { op_name } is used in 2 models: { debug_info_1 } and { debug_info_2 } "
154
+ if "//" not in debug_info_1 and "//" not in debug_info_2 :
155
+ error += "\n We can't determine what BUCK targets these model files belong to."
156
+ tail = "."
157
+ else :
158
+ error += "\n Please run the following commands to find out where is the BUCK target being added as a dependency to your target:\n "
159
+ error += f'\n buck2 cquery <mode> "allpaths(<target>, { debug_info_1 } )"'
160
+ error += f'\n buck2 cquery <mode> "allpaths(<target>, { debug_info_2 } )"'
161
+ tail = "as well as results from BUCK commands listed above."
162
+
163
+ error += (
164
+ "\n \n If issue is not resolved, please post in PyTorch Edge Q&A with this error message"
165
+ + tail
166
+ )
167
+ raise Exception (error ) # noqa: TRY002
168
+ ops [op_name ] = model_dict ["operators" ][op_name ]
169
+ # We may have 0 selective builders since there may not be any viable
170
+ # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
171
+ # This is potentially an error, and we should probably raise an assertion
172
+ # failure here. However, this needs to be investigated further.
173
+ selective_builder = SelectiveBuilder .from_yaml_dict ({})
174
+ if len (selective_builders ) > 0 :
175
+ selective_builder = reduce (
176
+ combine_selective_builders ,
177
+ selective_builders ,
178
+ )
179
+
180
+ if not options .allow_include_all_overloads :
181
+ throw_if_any_op_includes_overloads (selective_builder )
182
+ with open (
183
+ os .path .join (options .output_dir , "selected_operators.yaml" ), "wb"
184
+ ) as out_file :
185
+ out_file .write (
186
+ yaml .safe_dump (
187
+ selective_builder .to_dict (), default_flow_style = False
188
+ ).encode ("utf-8" ),
189
+ )
69
190
70
191
71
192
if __name__ == "__main__" :
0 commit comments