9
9
import logging
10
10
import os
11
11
import re
12
- from typing import Any , Dict
12
+ import sys
13
+ from typing import Any , Dict , List
13
14
15
+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
14
16
from examples .models import MODEL_NAME_TO_MODEL
15
17
16
18
45
47
}
46
48
47
49
50
+ def extract_all_configs (data , target_os = None ):
51
+ if isinstance (data , dict ):
52
+ # If target_os is specified, include "xplat" and the specified branch
53
+ include_branches = {"xplat" , target_os } if target_os else data .keys ()
54
+ return [
55
+ v
56
+ for key , value in data .items ()
57
+ if key in include_branches
58
+ for v in extract_all_configs (value , target_os )
59
+ ]
60
+ elif isinstance (data , list ):
61
+ return [v for item in data for v in extract_all_configs (item , target_os )]
62
+ else :
63
+ return [data ]
64
+
65
+
66
+ def generate_compatible_configs (model_name : str , target_os = None ) -> List [str ]:
67
+ """
68
+ Generate a list of compatible benchmark configurations for a given model name and target OS.
69
+
70
+ Args:
71
+ model_name (str): The name of the model to generate configurations for.
72
+ target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73
+
74
+ Returns:
75
+ List[str]: A list of compatible benchmark configurations.
76
+
77
+ Raises:
78
+ None
79
+
80
+ Example:
81
+ generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82
+ """
83
+ configs = []
84
+ if is_valid_huggingface_model_id (model_name ):
85
+ if model_name .startswith ("meta-llama/" ):
86
+ # LLaMA models
87
+ repo_name = model_name .split ("meta-llama/" )[1 ]
88
+ if "qlora" in repo_name .lower ():
89
+ configs .append ("llama3_qlora" )
90
+ elif "spinquant" in repo_name .lower ():
91
+ configs .append ("llama3_spinquant" )
92
+ else :
93
+ configs .append ("llama3_fb16" )
94
+ configs .extend (
95
+ [
96
+ config
97
+ for config in BENCHMARK_CONFIGS .get (target_os , [])
98
+ if config .startswith ("llama" )
99
+ ]
100
+ )
101
+ else :
102
+ # Non-LLaMA models
103
+ configs .append ("hf_xnnpack_fp32" )
104
+ elif model_name in MODEL_NAME_TO_MODEL :
105
+ # ExecuTorch in-tree non-GenAI models
106
+ configs .append ("xnnpack_q8" )
107
+ if target_os != "xplat" :
108
+ # Add OS-specific configs
109
+ configs .extend (
110
+ [
111
+ config
112
+ for config in BENCHMARK_CONFIGS .get (target_os , [])
113
+ if not config .startswith ("llama" )
114
+ ]
115
+ )
116
+ else :
117
+ # Skip unknown models with a warning
118
+ logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
119
+
120
+ return configs
121
+
122
+
48
123
def parse_args () -> Any :
49
124
"""
50
125
Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82
157
type = comma_separated , # Use the custom parser for comma-separated values
83
158
help = f"Comma-separated device names. Available devices: { list (DEVICE_POOLS .keys ())} " ,
84
159
)
160
+ parser .add_argument (
161
+ "--configs" ,
162
+ type = comma_separated , # Use the custom parser for comma-separated values
163
+ help = f"Comma-separated benchmark configs. Available configs: { extract_all_configs (BENCHMARK_CONFIGS )} " ,
164
+ )
85
165
86
166
return parser .parse_args ()
87
167
@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98
178
set_output("benchmark_configs", {"include": [...]})
99
179
"""
100
180
101
- if os .getenv ("GITHUB_OUTPUT" ):
102
- print (f"Setting { val } to GitHub output" )
103
- with open (str (os .getenv ("GITHUB_OUTPUT" )), "a" ) as env :
104
- print (f"{ name } ={ val } " , file = env )
105
- else :
181
+ github_output = os .getenv ("GITHUB_OUTPUT" )
182
+ if not github_output :
183
+ print (f"::set-output name={ name } ::{ val } " )
184
+ return
185
+
186
+ try :
187
+ with open (github_output , "a" ) as env :
188
+ env .write (f"{ name } ={ val } \n " )
189
+ except PermissionError :
190
+ # Fall back to printing in case of permission error in unit tests
106
191
print (f"::set-output name={ name } ::{ val } " )
107
192
108
193
@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123
208
return bool (re .match (pattern , model_name ))
124
209
125
210
126
- def get_benchmark_configs () -> Dict [str , Dict ]:
211
+ def get_benchmark_configs () -> Dict [str , Dict ]: # noqa: C901
127
212
"""
128
213
Gather benchmark configurations for a given set of models on the target operating system and devices.
129
214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153
238
}
154
239
"""
155
240
args = parse_args ()
156
- target_os = args .os
157
241
devices = args .devices
158
242
models = args .models
243
+ target_os = args .os
244
+ target_configs = args .configs
159
245
160
246
benchmark_configs = {"include" : []}
161
247
162
248
for model_name in models :
163
249
configs = []
164
- if is_valid_huggingface_model_id (model_name ):
165
- if model_name .startswith ("meta-llama/" ):
166
- # LLaMA models
167
- repo_name = model_name .split ("meta-llama/" )[1 ]
168
- if "qlora" in repo_name .lower ():
169
- configs .append ("llama3_qlora" )
170
- elif "spinquant" in repo_name .lower ():
171
- configs .append ("llama3_spinquant" )
172
- else :
173
- configs .append ("llama3_fb16" )
174
- configs .extend (
175
- [
176
- config
177
- for config in BENCHMARK_CONFIGS .get (target_os , [])
178
- if config .startswith ("llama" )
179
- ]
250
+ configs .extend (generate_compatible_configs (model_name , target_os ))
251
+ print (f"Discovered all supported configs for model '{ model_name } ': { configs } " )
252
+ if target_configs is not None :
253
+ for config in target_configs :
254
+ if config not in configs :
255
+ raise Exception (
256
+ f"Unsupported config '{ config } ' for model '{ model_name } ' on '{ target_os } '. Skipped.\n "
257
+ f"Supported configs are: { configs } "
180
258
)
181
- else :
182
- # Non-LLaMA models
183
- configs .append ("hf_xnnpack_fp32" )
184
- elif model_name in MODEL_NAME_TO_MODEL :
185
- # ExecuTorch in-tree non-GenAI models
186
- configs .append ("xnnpack_q8" )
187
- configs .extend (
188
- [
189
- config
190
- for config in BENCHMARK_CONFIGS .get (target_os , [])
191
- if not config .startswith ("llama" )
192
- ]
193
- )
194
- else :
195
- # Skip unknown models with a warning
196
- logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
197
- continue
259
+ configs = target_configs
260
+ print (f"Using provided configs { configs } for model '{ model_name } '" )
198
261
199
262
# Add configurations for each valid device
200
263
for device in devices :
0 commit comments