|
45 | 45 | }
|
46 | 46 |
|
47 | 47 |
|
| 48 | +def extract_all_configs(data, os_name=None): |
| 49 | + if isinstance(data, dict): |
| 50 | + # If os_name is specified, include "xplat" and the specified branch |
| 51 | + include_branches = {"xplat", os_name} if os_name else data.keys() |
| 52 | + return [ |
| 53 | + v |
| 54 | + for key, value in data.items() |
| 55 | + if key in include_branches |
| 56 | + for v in extract_all_configs(value, os_name) |
| 57 | + ] |
| 58 | + elif isinstance(data, list): |
| 59 | + return [v for item in data for v in extract_all_configs(item, os_name)] |
| 60 | + else: |
| 61 | + return [data] |
| 62 | + |
| 63 | + |
48 | 64 | def parse_args() -> Any:
|
49 | 65 | """
|
50 | 66 | Parse command-line arguments.
|
@@ -82,6 +98,11 @@ def comma_separated(value: str):
|
82 | 98 | type=comma_separated, # Use the custom parser for comma-separated values
|
83 | 99 | help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
|
84 | 100 | )
|
| 101 | + parser.add_argument( |
| 102 | + "--configs", |
| 103 | + type=comma_separated, # Use the custom parser for comma-separated values |
| 104 | + help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}", |
| 105 | + ) |
85 | 106 |
|
86 | 107 | return parser.parse_args()
|
87 | 108 |
|
@@ -153,48 +174,61 @@ def get_benchmark_configs() -> Dict[str, Dict]:
|
153 | 174 | }
|
154 | 175 | """
|
155 | 176 | args = parse_args()
|
156 |
| - target_os = args.os |
157 | 177 | devices = args.devices
|
158 | 178 | models = args.models
|
| 179 | + configs = args.configs |
| 180 | + target_os = args.os |
159 | 181 |
|
160 | 182 | benchmark_configs = {"include": []}
|
161 | 183 |
|
162 | 184 | for model_name in models:
|
163 |
| - configs = [] |
164 |
| - if is_valid_huggingface_model_id(model_name): |
165 |
| - if model_name.startswith("meta-llama/"): |
166 |
| - # LLaMA models |
167 |
| - repo_name = model_name.split("meta-llama/")[1] |
168 |
| - if "qlora" in repo_name.lower(): |
169 |
| - configs.append("llama3_qlora") |
170 |
| - elif "spinquant" in repo_name.lower(): |
171 |
| - configs.append("llama3_spinquant") |
172 |
| - else: |
173 |
| - configs.append("llama3_fb16") |
174 |
| - configs.extend( |
175 |
| - [ |
176 |
| - config |
177 |
| - for config in BENCHMARK_CONFIGS.get(target_os, []) |
178 |
| - if config.startswith("llama") |
179 |
| - ] |
| 185 | + |
| 186 | + if len(configs) > 0: |
| 187 | + supported_configs = extract_all_configs(BENCHMARK_CONFIGS, target_os) |
| 188 | + for config in configs: |
| 189 | + if config not in supported_configs: |
| 190 | + raise Exception( |
| 191 | + f"Unsupported config '{config}'. Skipping. Available configs: {extract_all_configs(BENCHMARK_CONFIGS, target_os)}" |
180 | 192 | )
|
181 |
| - else: |
182 |
| - # Non-LLaMA models |
183 |
| - configs.append("hf_xnnpack_fp32") |
184 |
| - elif model_name in MODEL_NAME_TO_MODEL: |
185 |
| - # ExecuTorch in-tree non-GenAI models |
186 |
| - configs.append("xnnpack_q8") |
187 |
| - configs.extend( |
188 |
| - [ |
189 |
| - config |
190 |
| - for config in BENCHMARK_CONFIGS.get(target_os, []) |
191 |
| - if not config.startswith("llama") |
192 |
| - ] |
193 |
| - ) |
| 193 | + print(f"Using provided configs {configs} for model '{model_name}'") |
194 | 194 | else:
|
195 |
| - # Skip unknown models with a warning |
196 |
| - logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.") |
197 |
| - continue |
| 195 | + print(f"Discover all compatible configs for model '{model_name}'") |
| 196 | + if is_valid_huggingface_model_id(model_name): |
| 197 | + if model_name.startswith("meta-llama/"): |
| 198 | + # LLaMA models |
| 199 | + repo_name = model_name.split("meta-llama/")[1] |
| 200 | + if "qlora" in repo_name.lower(): |
| 201 | + configs.append("llama3_qlora") |
| 202 | + elif "spinquant" in repo_name.lower(): |
| 203 | + configs.append("llama3_spinquant") |
| 204 | + else: |
| 205 | + configs.append("llama3_fb16") |
| 206 | + configs.extend( |
| 207 | + [ |
| 208 | + config |
| 209 | + for config in BENCHMARK_CONFIGS.get(target_os, []) |
| 210 | + if config.startswith("llama") |
| 211 | + ] |
| 212 | + ) |
| 213 | + else: |
| 214 | + # Non-LLaMA models |
| 215 | + configs.append("hf_xnnpack_fp32") |
| 216 | + elif model_name in MODEL_NAME_TO_MODEL: |
| 217 | + # ExecuTorch in-tree non-GenAI models |
| 218 | + configs.append("xnnpack_q8") |
| 219 | + configs.extend( |
| 220 | + [ |
| 221 | + config |
| 222 | + for config in BENCHMARK_CONFIGS.get(target_os, []) |
| 223 | + if not config.startswith("llama") |
| 224 | + ] |
| 225 | + ) |
| 226 | + else: |
| 227 | + # Skip unknown models with a warning |
| 228 | + logging.warning( |
| 229 | + f"Unknown or invalid model name '{model_name}'. Skipping." |
| 230 | + ) |
| 231 | + continue |
198 | 232 |
|
199 | 233 | # Add configurations for each valid device
|
200 | 234 | for device in devices:
|
|
0 commit comments