Skip to content

Commit 4940f44

Browse files
author
Github Executorch
committed
Enable composable benchmark configs for flexible model+device+optimization scheduling
1 parent 6ab4399 commit 4940f44

File tree

5 files changed

+354
-196
lines changed

5 files changed

+354
-196
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import json
9+
import logging
10+
import os
11+
import re
12+
from typing import Any, Dict
13+
14+
from examples.models import MODEL_NAME_TO_MODEL
15+
16+
17+
# Device pools for AWS Device Farm
18+
DEVICE_POOLS = {
19+
"apple_iphone_15": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/3b5acd2e-92e2-4778-b651-7726bafe129d",
20+
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
21+
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
22+
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
23+
}
24+
25+
# Predefined benchmark configurations
26+
BENCHMARK_CONFIGS = {
27+
"xplat": [
28+
"xnnpack_q8",
29+
"hf_xnnpack_fp32",
30+
"llama3_fb16",
31+
"llama3_spinquant",
32+
"llama3_qlora",
33+
],
34+
"android": [
35+
"qnn_q8",
36+
"llama3_qnn_htp",
37+
],
38+
"ios": [
39+
"coreml_fp16",
40+
"mps",
41+
"llama3_coreml_ane",
42+
],
43+
}
44+
45+
46+
def parse_args() -> Any:
47+
"""
48+
Parse command-line arguments.
49+
50+
Returns:
51+
argparse.Namespace: Parsed command-line arguments.
52+
53+
Example:
54+
parse_args() -> Namespace(models=['mv3', 'meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8'],
55+
os='android',
56+
devices=['samsung_galaxy_s22'])
57+
"""
58+
from argparse import ArgumentParser
59+
60+
def comma_separated(value: str):
61+
"""
62+
Parse a comma-separated string into a list.
63+
"""
64+
return value.split(",")
65+
66+
parser = ArgumentParser("Gather all benchmark configs.")
67+
parser.add_argument(
68+
"--os",
69+
type=str,
70+
choices=["android", "ios"],
71+
help="The target OS.",
72+
)
73+
parser.add_argument(
74+
"--models",
75+
type=comma_separated, # Use the custom parser for comma-separated values
76+
help=f"Comma-separated model IDs or names. Valid values include {MODEL_NAME_TO_MODEL}.",
77+
)
78+
parser.add_argument(
79+
"--devices",
80+
type=comma_separated, # Use the custom parser for comma-separated values
81+
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
82+
)
83+
84+
return parser.parse_args()
85+
86+
87+
def set_output(name: str, val: Any) -> None:
88+
"""
89+
Set the output value to be used by other GitHub jobs.
90+
91+
Args:
92+
name (str): The name of the output variable.
93+
val (Any): The value to set for the output variable.
94+
95+
Example:
96+
set_output("benchmark_configs", {"include": [...]})
97+
"""
98+
99+
if os.getenv("GITHUB_OUTPUT"):
100+
print(f"Setting {val} to GitHub output")
101+
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
102+
print(f"{name}={val}", file=env)
103+
else:
104+
print(f"::set-output name={name}::{val}")
105+
106+
107+
def is_valid_huggingface_model_id(model_name: str) -> bool:
108+
"""
109+
Validate if the model name matches the pattern for HuggingFace model IDs.
110+
111+
Args:
112+
model_name (str): The model name to validate.
113+
114+
Returns:
115+
bool: True if the model name matches the valid pattern, False otherwise.
116+
117+
Example:
118+
is_valid_huggingface_model_id('meta-llama/Llama-3.2-1B') -> True
119+
"""
120+
pattern = r"^[a-zA-Z0-9-_]+/[a-zA-Z0-9-_.]+$"
121+
return bool(re.match(pattern, model_name))
122+
123+
124+
def get_benchmark_configs() -> Dict[str, Dict]:
125+
"""
126+
Gather benchmark configurations for a given set of models on the target operating system and devices.
127+
128+
Args:
129+
None
130+
131+
Returns:
132+
Dict[str, Dict]: A dictionary containing the benchmark configurations.
133+
134+
Example:
135+
get_benchmark_configs() -> {
136+
"include": [
137+
{"model": "meta-llama/Llama-3.2-1B", "benchmark_config": "hf_xnnpack_fp32", "device": "arn:aws:..."},
138+
{"model": "mv3", "benchmark_config": "xnnpack_q8", "device": "arn:aws:..."},
139+
...
140+
]
141+
}
142+
"""
143+
args = parse_args()
144+
target_os = args.os
145+
devices = args.devices
146+
models = args.models
147+
148+
benchmark_configs = {"include": []}
149+
150+
for model_name in models:
151+
configs = []
152+
if is_valid_huggingface_model_id(model_name):
153+
if model_name.startswith("meta-llama/"):
154+
# LLaMA models
155+
repo_name = model_name.split("meta-llama/")[1]
156+
if "qlora" in repo_name.lower():
157+
configs.append("llama3_qlora")
158+
elif "spinquant" in repo_name.lower():
159+
configs.append("llama3_spinquant")
160+
else:
161+
configs.append("llama3_fb16")
162+
configs.extend(
163+
[
164+
config
165+
for config in BENCHMARK_CONFIGS.get(target_os, [])
166+
if config.startswith("llama")
167+
]
168+
)
169+
else:
170+
# Non-LLaMA models
171+
configs.append("hf_xnnpack_fp32")
172+
elif model_name in MODEL_NAME_TO_MODEL:
173+
# ExecuTorch in-tree non-GenAI models
174+
configs.append("xnnpack_q8")
175+
configs.extend(
176+
[
177+
config
178+
for config in BENCHMARK_CONFIGS.get(target_os, [])
179+
if not config.startswith("llama")
180+
]
181+
)
182+
else:
183+
# Skip unknown models with a warning
184+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
185+
continue
186+
187+
# Add configurations for each valid device
188+
for device in devices:
189+
if device not in DEVICE_POOLS:
190+
logging.warning(f"Unsupported device '{device}'. Skipping.")
191+
continue
192+
for config in configs:
193+
record = {
194+
"model": model_name,
195+
"config": config,
196+
"device": DEVICE_POOLS[device],
197+
}
198+
benchmark_configs["include"].append(record)
199+
200+
set_output("benchmark_configs", json.dumps(benchmark_configs))
201+
202+
203+
if __name__ == "__main__":
204+
get_benchmark_configs()

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ prepare_artifacts_upload() {
208208
PARAMS="params.json"
209209
CHECKPOINT_FILE_NAME=""
210210
touch "${PARAMS}"
211-
if [[ "${MODEL_NAME}" == "stories110M" ]]; then
211+
if [[ "${MODEL_NAME}" == "llama" ]] || [[ "${MODEL_NAME}" == "stories"* ]] || [[ "${MODEL_NAME}" == "tinyllama" ]]; then
212212
CHECKPOINT_FILE_NAME="stories110M.pt"
213213
download_stories_model_artifacts
214214
else

.ci/scripts/test_model.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,25 +209,25 @@ test_model_with_mps() {
209209
if [[ "${BACKEND}" == "portable" ]]; then
210210
echo "Testing ${MODEL_NAME} with portable kernels..."
211211
test_model
212-
elif [[ "${BACKEND}" == "qnn" ]]; then
212+
elif [[ "${BACKEND}" == *"qnn"* ]]; then
213213
echo "Testing ${MODEL_NAME} with qnn..."
214214
test_model_with_qnn
215215
if [[ $? -eq 0 ]]; then
216216
prepare_artifacts_upload
217217
fi
218-
elif [[ "${BACKEND}" == "coreml" ]]; then
218+
elif [[ "${BACKEND}" == *"coreml"* ]]; then
219219
echo "Testing ${MODEL_NAME} with coreml..."
220220
test_model_with_coreml
221221
if [[ $? -eq 0 ]]; then
222222
prepare_artifacts_upload
223223
fi
224-
elif [[ "${BACKEND}" == "mps" ]]; then
224+
elif [[ "${BACKEND}" == *"mps"* ]]; then
225225
echo "Testing ${MODEL_NAME} with mps..."
226226
test_model_with_mps
227227
if [[ $? -eq 0 ]]; then
228228
prepare_artifacts_upload
229229
fi
230-
elif [[ "${BACKEND}" == "xnnpack" ]]; then
230+
elif [[ "${BACKEND}" == *"xnnpack"* ]]; then
231231
echo "Testing ${MODEL_NAME} with xnnpack..."
232232
WITH_QUANTIZATION=true
233233
WITH_DELEGATION=true

0 commit comments

Comments
 (0)