Skip to content

Commit 003e0ff

Browse files
Github ExecutorchGuang Yang
authored andcommitted
Support config filtering in ondemand benchmark flow
1 parent 09279e6 commit 003e0ff

File tree

3 files changed

+92
-42
lines changed

3 files changed

+92
-42
lines changed

.ci/scripts/gather_benchmark_configs.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@
4545
}
4646

4747

48+
def extract_all_configs(data, os_name=None):
49+
if isinstance(data, dict):
50+
# If os_name is specified, include "xplat" and the specified branch
51+
include_branches = {"xplat", os_name} if os_name else data.keys()
52+
return [
53+
v
54+
for key, value in data.items()
55+
if key in include_branches
56+
for v in extract_all_configs(value, os_name)
57+
]
58+
elif isinstance(data, list):
59+
return [v for item in data for v in extract_all_configs(item, os_name)]
60+
else:
61+
return [data]
62+
63+
4864
def parse_args() -> Any:
4965
"""
5066
Parse command-line arguments.
@@ -82,6 +98,11 @@ def comma_separated(value: str):
8298
type=comma_separated, # Use the custom parser for comma-separated values
8399
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84100
)
101+
parser.add_argument(
102+
"--configs",
103+
type=comma_separated, # Use the custom parser for comma-separated values
104+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
105+
)
85106

86107
return parser.parse_args()
87108

@@ -123,7 +144,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123144
return bool(re.match(pattern, model_name))
124145

125146

126-
def get_benchmark_configs() -> Dict[str, Dict]:
147+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127148
"""
128149
Gather benchmark configurations for a given set of models on the target operating system and devices.
129150
@@ -153,48 +174,61 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153174
}
154175
"""
155176
args = parse_args()
156-
target_os = args.os
157177
devices = args.devices
158178
models = args.models
179+
target_os = args.os
159180

160181
benchmark_configs = {"include": []}
161182

162183
for model_name in models:
163184
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
185+
if args.configs is not None:
186+
configs = args.configs
187+
supported_configs = extract_all_configs(BENCHMARK_CONFIGS, target_os)
188+
for config in configs:
189+
if config not in supported_configs:
190+
raise Exception(
191+
f"Unsupported config '{config}'. Skipping. Available configs: {extract_all_configs(BENCHMARK_CONFIGS, target_os)}"
180192
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
193+
print(f"Using provided configs {configs} for model '{model_name}'")
194194
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
195+
print(f"Discover all compatible configs for model '{model_name}'")
196+
if is_valid_huggingface_model_id(model_name):
197+
if model_name.startswith("meta-llama/"):
198+
# LLaMA models
199+
repo_name = model_name.split("meta-llama/")[1]
200+
if "qlora" in repo_name.lower():
201+
configs.append("llama3_qlora")
202+
elif "spinquant" in repo_name.lower():
203+
configs.append("llama3_spinquant")
204+
else:
205+
configs.append("llama3_fb16")
206+
configs.extend(
207+
[
208+
config
209+
for config in BENCHMARK_CONFIGS.get(target_os, [])
210+
if config.startswith("llama")
211+
]
212+
)
213+
else:
214+
# Non-LLaMA models
215+
configs.append("hf_xnnpack_fp32")
216+
elif model_name in MODEL_NAME_TO_MODEL:
217+
# ExecuTorch in-tree non-GenAI models
218+
configs.append("xnnpack_q8")
219+
configs.extend(
220+
[
221+
config
222+
for config in BENCHMARK_CONFIGS.get(target_os, [])
223+
if not config.startswith("llama")
224+
]
225+
)
226+
else:
227+
# Skip unknown models with a warning
228+
logging.warning(
229+
f"Unknown or invalid model name '{model_name}'. Skipping."
230+
)
231+
continue
198232

199233
# Add configurations for each valid device
200234
for device in devices:

.github/workflows/android-perf.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,27 @@ jobs:
7474
CRON_DEFAULT_DEVICES: samsung_galaxy_s22
7575
run: |
7676
set -eux
77+
78+
ARGS="--os android"
79+
7780
MODELS="${{ inputs.models }}"
7881
if [ -z "$MODELS" ]; then
7982
MODELS="$CRON_DEFAULT_MODELS"
8083
fi
84+
ARGS="$ARGS --models $MODELS"
85+
8186
DEVICES="${{ inputs.devices }}"
8287
if [ -z "$DEVICES" ]; then
8388
DEVICES="$CRON_DEFAULT_DEVICES"
8489
fi
90+
ARGS="$ARGS --devices $DEVICES"
91+
92+
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
93+
if [ -n "$BENCHMARK_CONFIGS" ]; then
94+
ARGS="$ARGS --configs $BENCHMARK_CONFIGS"
95+
fi
8596
86-
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
87-
--os "android" \
88-
--models $MODELS \
89-
--devices $DEVICES
97+
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS
9098
9199
prepare-test-specs:
92100
runs-on: linux.2xlarge

.github/workflows/apple-perf.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,27 @@ jobs:
7474
CRON_DEFAULT_DEVICES: apple_iphone_15
7575
run: |
7676
set -eux
77+
78+
ARGS="--os ios"
79+
7780
MODELS="${{ inputs.models }}"
7881
if [ -z "$MODELS" ]; then
7982
MODELS="$CRON_DEFAULT_MODELS"
8083
fi
84+
ARGS="$ARGS --models $MODELS"
85+
8186
DEVICES="${{ inputs.devices }}"
8287
if [ -z "$DEVICES" ]; then
8388
DEVICES="$CRON_DEFAULT_DEVICES"
8489
fi
90+
ARGS="$ARGS --devices $DEVICES"
91+
92+
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
93+
if [ -n "$BENCHMARK_CONFIGS" ]; then
94+
ARGS="$ARGS --configs $BENCHMARK_CONFIGS"
95+
fi
8596
86-
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
87-
--os "ios" \
88-
--models $MODELS \
89-
--devices $DEVICES
97+
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS
9098
9199
echo "benchmark_configs is: ${{ steps.set-parameters.outputs.benchmark_configs }}"
92100

0 commit comments

Comments
 (0)