1
- #!/usr/bin/env fbpython
1
+ #!/usr/bin/env python3
2
2
# Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
# All rights reserved.
4
4
#
5
5
# This source code is licensed under the BSD-style license found in the
6
6
# LICENSE file in the root directory of this source tree.
7
7
8
8
import argparse
9
+ import re
9
10
import os
10
11
import sys
11
- from typing import Any , List
12
+ from functools import reduce
13
+ from typing import Any , List , Tuple
12
14
13
- from tools_copy .code_analyzer import gen_oplist_copy_from_core
15
+ import yaml
16
+ from torchgen .selective_build .selector import (
17
+ combine_selective_builders ,
18
+ SelectiveBuilder ,
19
+ )
20
+ from pathlib import Path
14
21
15
22
23
+ def throw_if_any_op_includes_overloads (selective_builder : SelectiveBuilder ) -> None :
24
+ ops = []
25
+ for op_name , op in selective_builder .operators .items ():
26
+ if op .include_all_overloads :
27
+ ops .append (op_name )
28
+ if ops :
29
+ raise Exception ( # noqa: TRY002
30
+ (
31
+ "Operators that include all overloads are "
32
+ + "not allowed since --allow-include-all-overloads "
33
+ + "was not specified: {}"
34
+ ).format (", " .join (ops ))
35
+ )
36
+
37
+ def resolve_model_file_path_to_buck_target (model_file_path : str ) -> str :
38
+ real_path = str (Path (model_file_path ).resolve (strict = True ))
39
+ # try my best to convert to buck target
40
+ prog = re .compile (r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml" )
41
+ match = prog .match (real_path )
42
+ if match :
43
+ return f"{ match .group (1 )} //{ match .group (2 )} :{ match .group (3 )} "
44
+ else :
45
+ return real_path
46
+
16
47
def main (argv : List [Any ]) -> None :
17
- """This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18
- This is needed because we intend to error out for the case where `model_file_list_path`
19
- is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20
- is provided as a dependency to ExecuTorch build.
48
+ """This binary generates 3 files:
49
+
50
+ 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
51
+ dtypes captured by tracing
52
+ 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
21
53
"""
22
54
parser = argparse .ArgumentParser (description = "Generate operator lists" )
23
55
parser .add_argument (
56
+ "--output-dir" ,
24
57
"--output_dir" ,
25
58
help = ("The directory to store the output yaml file (selected_operators.yaml)" ),
26
59
required = True ,
27
60
)
28
61
parser .add_argument (
62
+ "--model-file-list-path" ,
29
63
"--model_file_list_path" ,
30
64
help = (
31
65
"Path to a file that contains the locations of individual "
@@ -36,6 +70,7 @@ def main(argv: List[Any]) -> None:
36
70
required = True ,
37
71
)
38
72
parser .add_argument (
73
+ "--allow-include-all-overloads" ,
39
74
"--allow_include_all_overloads" ,
40
75
help = (
41
76
"Flag to allow operators that include all overloads. "
@@ -46,26 +81,99 @@ def main(argv: List[Any]) -> None:
46
81
default = False ,
47
82
required = False ,
48
83
)
84
+ parser .add_argument (
85
+ "--check-ops-not-overlapping" ,
86
+ "--check_ops_not_overlapping" ,
87
+ help = (
88
+ "Flag to check if the operators in the model file list are overlapping. "
89
+ + "If not set, the script will not error out for overlapping operators."
90
+ ),
91
+ action = "store_true" ,
92
+ default = False ,
93
+ required = False ,
94
+ )
95
+ options = parser .parse_args (argv )
96
+
49
97
50
- # check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
98
+ # Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
51
99
# 1. a yaml file containing selected ops (could be empty), or
52
- # 2. a non-empty list of yaml files in the `model_file_list_path`.
53
- # If none of the two things happened, the build target has no dependency on any selective build and we should error out .
54
- options = parser . parse_args ( argv )
100
+ # 2. a non-empty list of yaml files in the `model_file_list_path` or
101
+ # 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file .
102
+ # If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55
103
if os .path .isfile (options .model_file_list_path ):
56
- pass
104
+ print ("Processing model file: " , options .model_file_list_path )
105
+ model_dicts = []
106
+ model_dict = yaml .safe_load (open (options .model_file_list_path ))
107
+ model_dicts .append (model_dict )
57
108
else :
58
- assert (
59
- options .model_file_list_path [0 ] == "@"
60
- ), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
109
+ print ( "Processing model directory: " , options . model_file_list_path )
110
+ assert options .model_file_list_path [0 ] == "@" , "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue. "
111
+
61
112
model_file_list_path = options .model_file_list_path [1 :]
113
+
114
+ model_dicts = []
62
115
with open (model_file_list_path ) as model_list_file :
63
116
model_file_names = model_list_file .read ().split ()
64
117
assert (
65
118
len (model_file_names ) > 0
66
119
), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67
120
"build. Please refer to Selective Build wiki page to add at least one."
68
- gen_oplist_copy_from_core .main (argv )
121
+ for model_file_name in model_file_names :
122
+ if not os .path .isfile (model_file_name ):
123
+ model_file_name = os .path .join (model_file_name , "selected_operators.yaml" )
124
+ print ("Processing model file: " , model_file_name )
125
+ assert os .path .isfile (model_file_name ), f"{ model_file_name } is not a valid file path. This is likely a BUCK issue."
126
+ with open (model_file_name , "rb" ) as model_file :
127
+ model_dict = yaml .safe_load (model_file )
128
+ resolved = resolve_model_file_path_to_buck_target (model_file_name )
129
+ for op in model_dict ["operators" ]:
130
+ model_dict ["operators" ][op ]["debug_info" ] = [resolved ]
131
+ model_dicts .append (model_dict )
132
+
133
+ selective_builders = [SelectiveBuilder .from_yaml_dict (m ) for m in model_dicts ]
134
+
135
+ # Optionally check if the operators in the model file list are overlapping.
136
+ if options .check_ops_not_overlapping :
137
+ ops = {}
138
+ for model_dict in model_dicts :
139
+ for op_name in model_dict ["operators" ]:
140
+ if op_name in ops :
141
+ debug_info_1 = ',' .join (ops [op_name ]["debug_info" ])
142
+ debug_info_2 = ',' .join (model_dict ["operators" ][op_name ]["debug_info" ])
143
+ error = f"Operator { op_name } is used in 2 models: { debug_info_1 } and { debug_info_2 } "
144
+ if "//" not in debug_info_1 and "//" not in debug_info_2 :
145
+ error += "\n We can't determine what BUCK targets these model files belong to."
146
+ tail = "."
147
+ else :
148
+ error += "\n Please run the following commands to find out where is the BUCK target being added as a dependency to your target:\n "
149
+ error += f"\n buck2 cquery <mode> \" allpaths(<target>, { debug_info_1 } )\" "
150
+ error += f"\n buck2 cquery <mode> \" allpaths(<target>, { debug_info_2 } )\" "
151
+ tail = "as well as results from BUCK commands listed above."
152
+
153
+ error += "\n \n If issue is not resolved, please post in PyTorch Edge Q&A with this error message" + tail
154
+ raise Exception (error ) # noqa: TRY002
155
+ ops [op_name ] = model_dict ["operators" ][op_name ]
156
+ # We may have 0 selective builders since there may not be any viable
157
+ # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
158
+ # This is potentially an error, and we should probably raise an assertion
159
+ # failure here. However, this needs to be investigated further.
160
+ selective_builder = SelectiveBuilder .from_yaml_dict ({})
161
+ if len (selective_builders ) > 0 :
162
+ selective_builder = reduce (
163
+ combine_selective_builders ,
164
+ selective_builders ,
165
+ )
166
+
167
+ if not options .allow_include_all_overloads :
168
+ throw_if_any_op_includes_overloads (selective_builder )
169
+ with open (
170
+ os .path .join (options .output_dir , "selected_operators.yaml" ), "wb"
171
+ ) as out_file :
172
+ out_file .write (
173
+ yaml .safe_dump (
174
+ selective_builder .to_dict (), default_flow_style = False
175
+ ).encode ("utf-8" ),
176
+ )
69
177
70
178
71
179
if __name__ == "__main__" :
0 commit comments