9
9
import logging
10
10
import os
11
11
import re
12
- from typing import Any , Dict
12
+ import sys
13
+ from typing import Any , Dict , List , NamedTuple
13
14
15
+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
14
16
from examples .models import MODEL_NAME_TO_MODEL
15
17
16
18
45
47
}
46
48
47
49
50
+ class DisabledConfig (NamedTuple ):
51
+ config_name : str
52
+ github_issue : str # Link to the GitHub issue
53
+
54
+
55
+ # Updated DISABLED_CONFIGS
56
+ DISABLED_CONFIGS : Dict [str , List [DisabledConfig ]] = {
57
+ "resnet50" : [
58
+ DisabledConfig (
59
+ config_name = "qnn_q8" ,
60
+ github_issue = "https://github.com/pytorch/executorch/issues/7892" ,
61
+ ),
62
+ ],
63
+ "w2l" : [
64
+ DisabledConfig (
65
+ config_name = "qnn_q8" ,
66
+ github_issue = "https://github.com/pytorch/executorch/issues/7634" ,
67
+ ),
68
+ ],
69
+ "mobilebert" : [
70
+ DisabledConfig (
71
+ config_name = "mps" ,
72
+ github_issue = "https://github.com/pytorch/executorch/issues/7904" ,
73
+ ),
74
+ DisabledConfig (
75
+ config_name = "qnn_q8" ,
76
+ github_issue = "https://github.com/pytorch/executorch/issues/7946" ,
77
+ ),
78
+ ],
79
+ "edsr" : [
80
+ DisabledConfig (
81
+ config_name = "mps" ,
82
+ github_issue = "https://github.com/pytorch/executorch/issues/7905" ,
83
+ ),
84
+ ],
85
+ "llama" : [
86
+ DisabledConfig (
87
+ config_name = "mps" ,
88
+ github_issue = "https://github.com/pytorch/executorch/issues/7907" ,
89
+ ),
90
+ ],
91
+ }
92
+
93
+
94
+ def extract_all_configs (data , target_os = None ):
95
+ if isinstance (data , dict ):
96
+ # If target_os is specified, include "xplat" and the specified branch
97
+ include_branches = {"xplat" , target_os } if target_os else data .keys ()
98
+ return [
99
+ v
100
+ for key , value in data .items ()
101
+ if key in include_branches
102
+ for v in extract_all_configs (value , target_os )
103
+ ]
104
+ elif isinstance (data , list ):
105
+ return [v for item in data for v in extract_all_configs (item , target_os )]
106
+ else :
107
+ return [data ]
108
+
109
+
110
+ def generate_compatible_configs (model_name : str , target_os = None ) -> List [str ]:
111
+ """
112
+ Generate a list of compatible benchmark configurations for a given model name and target OS.
113
+
114
+ Args:
115
+ model_name (str): The name of the model to generate configurations for.
116
+ target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
117
+
118
+ Returns:
119
+ List[str]: A list of compatible benchmark configurations.
120
+
121
+ Raises:
122
+ None
123
+
124
+ Example:
125
+ generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
126
+ """
127
+ configs = []
128
+ if is_valid_huggingface_model_id (model_name ):
129
+ if model_name .startswith ("meta-llama/" ):
130
+ # LLaMA models
131
+ repo_name = model_name .split ("meta-llama/" )[1 ]
132
+ if "qlora" in repo_name .lower ():
133
+ configs .append ("llama3_qlora" )
134
+ elif "spinquant" in repo_name .lower ():
135
+ configs .append ("llama3_spinquant" )
136
+ else :
137
+ configs .append ("llama3_fb16" )
138
+ configs .extend (
139
+ [
140
+ config
141
+ for config in BENCHMARK_CONFIGS .get (target_os , [])
142
+ if config .startswith ("llama" )
143
+ ]
144
+ )
145
+ else :
146
+ # Non-LLaMA models
147
+ configs .append ("hf_xnnpack_fp32" )
148
+ elif model_name in MODEL_NAME_TO_MODEL :
149
+ # ExecuTorch in-tree non-GenAI models
150
+ configs .append ("xnnpack_q8" )
151
+ if target_os != "xplat" :
152
+ # Add OS-specific configs
153
+ configs .extend (
154
+ [
155
+ config
156
+ for config in BENCHMARK_CONFIGS .get (target_os , [])
157
+ if not config .startswith ("llama" )
158
+ ]
159
+ )
160
+ else :
161
+ # Skip unknown models with a warning
162
+ logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
163
+
164
+ # Remove disabled configs for the given model
165
+ disabled_configs = DISABLED_CONFIGS .get (model_name , [])
166
+ disabled_config_names = {disabled .config_name for disabled in disabled_configs }
167
+ for disabled in disabled_configs :
168
+ print (
169
+ f"Excluding disabled config: '{ disabled .config_name } ' for model '{ model_name } ' on '{ target_os } '. Linked GitHub issue: { disabled .github_issue } "
170
+ )
171
+ configs = [config for config in configs if config not in disabled_config_names ]
172
+ return configs
173
+
174
+
48
175
def parse_args () -> Any :
49
176
"""
50
177
Parse command-line arguments.
@@ -82,6 +209,11 @@ def comma_separated(value: str):
82
209
type = comma_separated , # Use the custom parser for comma-separated values
83
210
help = f"Comma-separated device names. Available devices: { list (DEVICE_POOLS .keys ())} " ,
84
211
)
212
+ parser .add_argument (
213
+ "--configs" ,
214
+ type = comma_separated , # Use the custom parser for comma-separated values
215
+ help = f"Comma-separated benchmark configs. Available configs: { extract_all_configs (BENCHMARK_CONFIGS )} " ,
216
+ )
85
217
86
218
return parser .parse_args ()
87
219
@@ -98,11 +230,16 @@ def set_output(name: str, val: Any) -> None:
98
230
set_output("benchmark_configs", {"include": [...]})
99
231
"""
100
232
101
- if os .getenv ("GITHUB_OUTPUT" ):
102
- print (f"Setting { val } to GitHub output" )
103
- with open (str (os .getenv ("GITHUB_OUTPUT" )), "a" ) as env :
104
- print (f"{ name } ={ val } " , file = env )
105
- else :
233
+ github_output = os .getenv ("GITHUB_OUTPUT" )
234
+ if not github_output :
235
+ print (f"::set-output name={ name } ::{ val } " )
236
+ return
237
+
238
+ try :
239
+ with open (github_output , "a" ) as env :
240
+ env .write (f"{ name } ={ val } \n " )
241
+ except PermissionError :
242
+ # Fall back to printing in case of permission error in unit tests
106
243
print (f"::set-output name={ name } ::{ val } " )
107
244
108
245
@@ -123,7 +260,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123
260
return bool (re .match (pattern , model_name ))
124
261
125
262
126
- def get_benchmark_configs () -> Dict [str , Dict ]:
263
+ def get_benchmark_configs () -> Dict [str , Dict ]: # noqa: C901
127
264
"""
128
265
Gather benchmark configurations for a given set of models on the target operating system and devices.
129
266
@@ -153,48 +290,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153
290
}
154
291
"""
155
292
args = parse_args ()
156
- target_os = args .os
157
293
devices = args .devices
158
294
models = args .models
295
+ target_os = args .os
296
+ target_configs = args .configs
159
297
160
298
benchmark_configs = {"include" : []}
161
299
162
300
for model_name in models :
163
301
configs = []
164
- if is_valid_huggingface_model_id (model_name ):
165
- if model_name .startswith ("meta-llama/" ):
166
- # LLaMA models
167
- repo_name = model_name .split ("meta-llama/" )[1 ]
168
- if "qlora" in repo_name .lower ():
169
- configs .append ("llama3_qlora" )
170
- elif "spinquant" in repo_name .lower ():
171
- configs .append ("llama3_spinquant" )
172
- else :
173
- configs .append ("llama3_fb16" )
174
- configs .extend (
175
- [
176
- config
177
- for config in BENCHMARK_CONFIGS .get (target_os , [])
178
- if config .startswith ("llama" )
179
- ]
302
+ configs .extend (generate_compatible_configs (model_name , target_os ))
303
+ print (f"Discovered all supported configs for model '{ model_name } ': { configs } " )
304
+ if target_configs is not None :
305
+ for config in target_configs :
306
+ if config not in configs :
307
+ raise Exception (
308
+ f"Unsupported config '{ config } ' for model '{ model_name } ' on '{ target_os } '. Skipped.\n "
309
+ f"Supported configs are: { configs } "
180
310
)
181
- else :
182
- # Non-LLaMA models
183
- configs .append ("hf_xnnpack_fp32" )
184
- elif model_name in MODEL_NAME_TO_MODEL :
185
- # ExecuTorch in-tree non-GenAI models
186
- configs .append ("xnnpack_q8" )
187
- configs .extend (
188
- [
189
- config
190
- for config in BENCHMARK_CONFIGS .get (target_os , [])
191
- if not config .startswith ("llama" )
192
- ]
193
- )
194
- else :
195
- # Skip unknown models with a warning
196
- logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
197
- continue
311
+ configs = target_configs
312
+ print (f"Using provided configs { configs } for model '{ model_name } '" )
198
313
199
314
# Add configurations for each valid device
200
315
for device in devices :
0 commit comments