Skip to content

Commit ed7f187

Browse files
author
Github Executorch
committed
Support config filtering in ondemand benchmark flow
1 parent 44d223d commit ed7f187

File tree

3 files changed

+74
-36
lines changed

3 files changed

+74
-36
lines changed

.ci/scripts/gather_benchmark_configs.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@
4545
}
4646

4747

48+
def extract_all_configs(data, os_name=None):
49+
if isinstance(data, dict):
50+
# If os_name is specified, include "xplat" and the specified branch
51+
include_branches = {"xplat", os_name} if os_name else data.keys()
52+
return [
53+
v
54+
for key, value in data.items()
55+
if key in include_branches
56+
for v in extract_all_configs(value, os_name)
57+
]
58+
elif isinstance(data, list):
59+
return [v for item in data for v in extract_all_configs(item, os_name)]
60+
else:
61+
return [data]
62+
63+
4864
def parse_args() -> Any:
4965
"""
5066
Parse command-line arguments.
@@ -82,6 +98,11 @@ def comma_separated(value: str):
8298
type=comma_separated, # Use the custom parser for comma-separated values
8399
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84100
)
101+
parser.add_argument(
102+
"--configs",
103+
type=comma_separated, # Use the custom parser for comma-separated values
104+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
105+
)
85106

86107
return parser.parse_args()
87108

@@ -153,48 +174,61 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153174
}
154175
"""
155176
args = parse_args()
156-
target_os = args.os
157177
devices = args.devices
158178
models = args.models
179+
configs = args.configs
180+
target_os = args.os
159181

160182
benchmark_configs = {"include": []}
161183

162184
for model_name in models:
163-
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
185+
186+
if len(configs) > 0:
187+
supported_configs = extract_all_configs(BENCHMARK_CONFIGS, target_os)
188+
for config in configs:
189+
if config not in supported_configs:
190+
raise Exception(
191+
f"Unsupported config '{config}'. Skipping. Available configs: {extract_all_configs(BENCHMARK_CONFIGS, target_os)}"
180192
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
193+
print(f"Using provided configs {configs} for model '{model_name}'")
194194
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
195+
print(f"Discover all compatible configs for model '{model_name}'")
196+
if is_valid_huggingface_model_id(model_name):
197+
if model_name.startswith("meta-llama/"):
198+
# LLaMA models
199+
repo_name = model_name.split("meta-llama/")[1]
200+
if "qlora" in repo_name.lower():
201+
configs.append("llama3_qlora")
202+
elif "spinquant" in repo_name.lower():
203+
configs.append("llama3_spinquant")
204+
else:
205+
configs.append("llama3_fb16")
206+
configs.extend(
207+
[
208+
config
209+
for config in BENCHMARK_CONFIGS.get(target_os, [])
210+
if config.startswith("llama")
211+
]
212+
)
213+
else:
214+
# Non-LLaMA models
215+
configs.append("hf_xnnpack_fp32")
216+
elif model_name in MODEL_NAME_TO_MODEL:
217+
# ExecuTorch in-tree non-GenAI models
218+
configs.append("xnnpack_q8")
219+
configs.extend(
220+
[
221+
config
222+
for config in BENCHMARK_CONFIGS.get(target_os, [])
223+
if not config.startswith("llama")
224+
]
225+
)
226+
else:
227+
# Skip unknown models with a warning
228+
logging.warning(
229+
f"Unknown or invalid model name '{model_name}'. Skipping."
230+
)
231+
continue
198232

199233
# Add configurations for each valid device
200234
for device in devices:

.github/workflows/android-perf.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ jobs:
8282
if [ -z "$DEVICES" ]; then
8383
DEVICES="$CRON_DEFAULT_DEVICES"
8484
fi
85+
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
8586
8687
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
8788
--os "android" \
8889
--models $MODELS \
89-
--devices $DEVICES
90+
--devices $DEVICES \
91+
--configs $BENCHMARK_CONFIGS
9092
9193
prepare-test-specs:
9294
runs-on: linux.2xlarge

.github/workflows/apple-perf.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ jobs:
8282
if [ -z "$DEVICES" ]; then
8383
DEVICES="$CRON_DEFAULT_DEVICES"
8484
fi
85+
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
8586
8687
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
8788
--os "ios" \
8889
--models $MODELS \
89-
--devices $DEVICES
90+
--devices $DEVICES \
91+
--configs $BENCHMARK_CONFIGS
9092
9193
echo "benchmark_configs is: ${{ steps.set-parameters.outputs.benchmark_configs }}"
9294

0 commit comments

Comments
 (0)