|
11 | 11 | import re
|
12 | 12 | from typing import Any, Dict
|
13 | 13 |
|
14 |
| -from examples.models import MODEL_NAME_TO_MODEL |
| 14 | +from executorch.examples.models import MODEL_NAME_TO_MODEL |
15 | 15 |
|
16 | 16 |
|
17 | 17 | # Device pools for AWS Device Farm
|
|
45 | 45 | }
|
46 | 46 |
|
47 | 47 |
|
| 48 | +def extract_all_configs(data, os_name=None): |
| 49 | + if isinstance(data, dict): |
| 50 | + # If os_name is specified, include "xplat" and the specified branch |
| 51 | + include_branches = {"xplat", os_name} if os_name else data.keys() |
| 52 | + return [ |
| 53 | + v |
| 54 | + for key, value in data.items() |
| 55 | + if key in include_branches |
| 56 | + for v in extract_all_configs(value, os_name) |
| 57 | + ] |
| 58 | + elif isinstance(data, list): |
| 59 | + return [v for item in data for v in extract_all_configs(item, os_name)] |
| 60 | + else: |
| 61 | + return [data] |
| 62 | + |
| 63 | + |
48 | 64 | def parse_args() -> Any:
|
49 | 65 | """
|
50 | 66 | Parse command-line arguments.
|
@@ -82,6 +98,11 @@ def comma_separated(value: str):
|
82 | 98 | type=comma_separated, # Use the custom parser for comma-separated values
|
83 | 99 | help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
|
84 | 100 | )
|
| 101 | + parser.add_argument( |
| 102 | + "--configs", |
| 103 | + type=comma_separated, # Use the custom parser for comma-separated values |
| 104 | + help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}", |
| 105 | + ) |
85 | 106 |
|
86 | 107 | return parser.parse_args()
|
87 | 108 |
|
@@ -123,7 +144,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
|
123 | 144 | return bool(re.match(pattern, model_name))
|
124 | 145 |
|
125 | 146 |
|
126 |
| -def get_benchmark_configs() -> Dict[str, Dict]: |
| 147 | +def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901 |
127 | 148 | """
|
128 | 149 | Gather benchmark configurations for a given set of models on the target operating system and devices.
|
129 | 150 |
|
@@ -153,48 +174,61 @@ def get_benchmark_configs() -> Dict[str, Dict]:
|
153 | 174 | }
|
154 | 175 | """
|
155 | 176 | args = parse_args()
|
156 |
| - target_os = args.os |
157 | 177 | devices = args.devices
|
158 | 178 | models = args.models
|
| 179 | + target_os = args.os |
159 | 180 |
|
160 | 181 | benchmark_configs = {"include": []}
|
161 | 182 |
|
162 | 183 | for model_name in models:
|
163 | 184 | configs = []
|
164 |
| - if is_valid_huggingface_model_id(model_name): |
165 |
| - if model_name.startswith("meta-llama/"): |
166 |
| - # LLaMA models |
167 |
| - repo_name = model_name.split("meta-llama/")[1] |
168 |
| - if "qlora" in repo_name.lower(): |
169 |
| - configs.append("llama3_qlora") |
170 |
| - elif "spinquant" in repo_name.lower(): |
171 |
| - configs.append("llama3_spinquant") |
172 |
| - else: |
173 |
| - configs.append("llama3_fb16") |
174 |
| - configs.extend( |
175 |
| - [ |
176 |
| - config |
177 |
| - for config in BENCHMARK_CONFIGS.get(target_os, []) |
178 |
| - if config.startswith("llama") |
179 |
| - ] |
| 185 | + if args.configs is not None: |
| 186 | + configs = args.configs |
| 187 | + supported_configs = extract_all_configs(BENCHMARK_CONFIGS, target_os) |
| 188 | + for config in configs: |
| 189 | + if config not in supported_configs: |
| 190 | + raise Exception( |
| 191 | + f"Unsupported config '{config}'. Skipping. Available configs: {extract_all_configs(BENCHMARK_CONFIGS, target_os)}" |
180 | 192 | )
|
181 |
| - else: |
182 |
| - # Non-LLaMA models |
183 |
| - configs.append("hf_xnnpack_fp32") |
184 |
| - elif model_name in MODEL_NAME_TO_MODEL: |
185 |
| - # ExecuTorch in-tree non-GenAI models |
186 |
| - configs.append("xnnpack_q8") |
187 |
| - configs.extend( |
188 |
| - [ |
189 |
| - config |
190 |
| - for config in BENCHMARK_CONFIGS.get(target_os, []) |
191 |
| - if not config.startswith("llama") |
192 |
| - ] |
193 |
| - ) |
| 193 | + print(f"Using provided configs {configs} for model '{model_name}'") |
194 | 194 | else:
|
195 |
| - # Skip unknown models with a warning |
196 |
| - logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.") |
197 |
| - continue |
| 195 | + print(f"Discover all compatible configs for model '{model_name}'") |
| 196 | + if is_valid_huggingface_model_id(model_name): |
| 197 | + if model_name.startswith("meta-llama/"): |
| 198 | + # LLaMA models |
| 199 | + repo_name = model_name.split("meta-llama/")[1] |
| 200 | + if "qlora" in repo_name.lower(): |
| 201 | + configs.append("llama3_qlora") |
| 202 | + elif "spinquant" in repo_name.lower(): |
| 203 | + configs.append("llama3_spinquant") |
| 204 | + else: |
| 205 | + configs.append("llama3_fb16") |
| 206 | + configs.extend( |
| 207 | + [ |
| 208 | + config |
| 209 | + for config in BENCHMARK_CONFIGS.get(target_os, []) |
| 210 | + if config.startswith("llama") |
| 211 | + ] |
| 212 | + ) |
| 213 | + else: |
| 214 | + # Non-LLaMA models |
| 215 | + configs.append("hf_xnnpack_fp32") |
| 216 | + elif model_name in MODEL_NAME_TO_MODEL: |
| 217 | + # ExecuTorch in-tree non-GenAI models |
| 218 | + configs.append("xnnpack_q8") |
| 219 | + configs.extend( |
| 220 | + [ |
| 221 | + config |
| 222 | + for config in BENCHMARK_CONFIGS.get(target_os, []) |
| 223 | + if not config.startswith("llama") |
| 224 | + ] |
| 225 | + ) |
| 226 | + else: |
| 227 | + # Skip unknown models with a warning |
| 228 | + logging.warning( |
| 229 | + f"Unknown or invalid model name '{model_name}'. Skipping." |
| 230 | + ) |
| 231 | + continue |
198 | 232 |
|
199 | 233 | # Add configurations for each valid device
|
200 | 234 | for device in devices:
|
|
0 commit comments