Skip to content

Commit b0b615e

Browse files
author
Github Executorch
committed
Support config filtering in ondemand benchmark flow
1 parent 57a09f4 commit b0b615e

File tree

6 files changed

+302
-44
lines changed

6 files changed

+302
-44
lines changed

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 94 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,79 @@
4547
}
4648

4749

50+
def extract_all_configs(data, target_os=None):
51+
if isinstance(data, dict):
52+
# If target_os is specified, include "xplat" and the specified branch
53+
include_branches = {"xplat", target_os} if target_os else data.keys()
54+
return [
55+
v
56+
for key, value in data.items()
57+
if key in include_branches
58+
for v in extract_all_configs(value, target_os)
59+
]
60+
elif isinstance(data, list):
61+
return [v for item in data for v in extract_all_configs(item, target_os)]
62+
else:
63+
return [data]
64+
65+
66+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
67+
"""
68+
Generate a list of compatible benchmark configurations for a given model name and target OS.
69+
70+
Args:
71+
model_name (str): The name of the model to generate configurations for.
72+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73+
74+
Returns:
75+
List[str]: A list of compatible benchmark configurations.
76+
77+
Raises:
78+
None
79+
80+
Example:
81+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82+
"""
83+
configs = []
84+
if is_valid_huggingface_model_id(model_name):
85+
if model_name.startswith("meta-llama/"):
86+
# LLaMA models
87+
repo_name = model_name.split("meta-llama/")[1]
88+
if "qlora" in repo_name.lower():
89+
configs.append("llama3_qlora")
90+
elif "spinquant" in repo_name.lower():
91+
configs.append("llama3_spinquant")
92+
else:
93+
configs.append("llama3_fb16")
94+
configs.extend(
95+
[
96+
config
97+
for config in BENCHMARK_CONFIGS.get(target_os, [])
98+
if config.startswith("llama")
99+
]
100+
)
101+
else:
102+
# Non-LLaMA models
103+
configs.append("hf_xnnpack_fp32")
104+
elif model_name in MODEL_NAME_TO_MODEL:
105+
# ExecuTorch in-tree non-GenAI models
106+
configs.append("xnnpack_q8")
107+
if target_os != "xplat":
108+
# Add OS-specific configs
109+
configs.extend(
110+
[
111+
config
112+
for config in BENCHMARK_CONFIGS.get(target_os, [])
113+
if not config.startswith("llama")
114+
]
115+
)
116+
else:
117+
# Skip unknown models with a warning
118+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119+
120+
return configs
121+
122+
48123
def parse_args() -> Any:
49124
"""
50125
Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157
type=comma_separated, # Use the custom parser for comma-separated values
83158
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84159
)
160+
parser.add_argument(
161+
"--configs",
162+
type=comma_separated, # Use the custom parser for comma-separated values
163+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
164+
)
85165

86166
return parser.parse_args()
87167

@@ -123,7 +203,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123203
return bool(re.match(pattern, model_name))
124204

125205

126-
def get_benchmark_configs() -> Dict[str, Dict]:
206+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127207
"""
128208
Gather benchmark configurations for a given set of models on the target operating system and devices.
129209
@@ -153,48 +233,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153233
}
154234
"""
155235
args = parse_args()
156-
target_os = args.os
157236
devices = args.devices
158237
models = args.models
238+
target_os = args.os
239+
target_configs = args.configs
159240

160241
benchmark_configs = {"include": []}
161242

162243
for model_name in models:
163244
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
245+
configs.extend(generate_compatible_configs(model_name, target_os))
246+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
247+
if target_configs is not None:
248+
for config in target_configs:
249+
if config not in configs:
250+
raise Exception(
251+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
252+
f"Supported configs are: {configs}"
180253
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
254+
configs = target_configs
255+
print(f"Using provided configs {configs} for model '{model_name}'")
198256

199257
# Add configurations for each valid device
200258
for device in devices:
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import importlib.util
2+
import os
3+
import subprocess
4+
import unittest
5+
from unittest.mock import mock_open, patch
6+
7+
# Dynamically import the script
8+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
9+
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
10+
gather_benchmark_configs = importlib.util.module_from_spec(spec)
11+
spec.loader.exec_module(gather_benchmark_configs)
12+
13+
14+
class TestGatehrBenchmarkConfigs(unittest.TestCase):
15+
16+
def test_extract_all_configs_android(self):
17+
android_configs = gather_benchmark_configs.extract_all_configs(
18+
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
19+
)
20+
self.assertIn("xnnpack_q8", android_configs)
21+
self.assertIn("qnn_q8", android_configs)
22+
self.assertIn("llama3_spinquant", android_configs)
23+
self.assertIn("llama3_qlora", android_configs)
24+
25+
def test_extract_all_configs_ios(self):
26+
ios_configs = gather_benchmark_configs.extract_all_configs(
27+
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
28+
)
29+
30+
self.assertIn("xnnpack_q8", ios_configs)
31+
self.assertIn("coreml_fp16", ios_configs)
32+
self.assertIn("mps", ios_configs)
33+
self.assertIn("llama3_coreml_ane", ios_configs)
34+
self.assertIn("llama3_spinquant", ios_configs)
35+
self.assertIn("llama3_qlora", ios_configs)
36+
37+
def test_generate_compatible_configs_llama_model(self):
38+
model_name = "meta-llama/Llama-3.2-1B"
39+
target_os = "ios"
40+
result = gather_benchmark_configs.generate_compatible_configs(
41+
model_name, target_os
42+
)
43+
expected = ["llama3_fb16", "llama3_coreml_ane"]
44+
self.assertEqual(result, expected)
45+
46+
target_os = "android"
47+
result = gather_benchmark_configs.generate_compatible_configs(
48+
model_name, target_os
49+
)
50+
expected = ["llama3_fb16"]
51+
self.assertEqual(result, expected)
52+
53+
def test_generate_compatible_configs_quantized_llama_model(self):
54+
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
55+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
56+
expected = ["llama3_spinquant"]
57+
self.assertEqual(result, expected)
58+
59+
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
60+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
61+
expected = ["llama3_qlora"]
62+
self.assertEqual(result, expected)
63+
64+
def test_generate_compatible_configs_non_genai_model(self):
65+
model_name = "mv2"
66+
target_os = "xplat"
67+
result = gather_benchmark_configs.generate_compatible_configs(
68+
model_name, target_os
69+
)
70+
expected = ["xnnpack_q8"]
71+
self.assertEqual(result, expected)
72+
73+
target_os = "android"
74+
result = gather_benchmark_configs.generate_compatible_configs(
75+
model_name, target_os
76+
)
77+
expected = ["xnnpack_q8", "qnn_q8"]
78+
self.assertEqual(result, expected)
79+
80+
target_os = "ios"
81+
result = gather_benchmark_configs.generate_compatible_configs(
82+
model_name, target_os
83+
)
84+
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
85+
self.assertEqual(result, expected)
86+
87+
def test_generate_compatible_configs_unknown_model(self):
88+
model_name = "unknown_model"
89+
target_os = "ios"
90+
result = gather_benchmark_configs.generate_compatible_configs(
91+
model_name, target_os
92+
)
93+
self.assertEqual(result, [])
94+
95+
def test_is_valid_huggingface_model_id_valid(self):
96+
valid_model = "meta-llama/Llama-3.2-1B"
97+
self.assertTrue(
98+
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
99+
)
100+
101+
@patch("builtins.open", new_callable=mock_open)
102+
@patch("os.getenv", return_value=None)
103+
def test_set_output_no_github_env(self, mock_getenv, mock_file):
104+
with patch("builtins.print") as mock_print:
105+
gather_benchmark_configs.set_output("test_name", "test_value")
106+
mock_print.assert_called_with("::set-output name=test_name::test_value")
107+
108+
def test_device_pools_contains_all_devices(self):
109+
expected_devices = [
110+
"apple_iphone_15",
111+
"apple_iphone_15+ios_18",
112+
"samsung_galaxy_s22",
113+
"samsung_galaxy_s24",
114+
"google_pixel_8_pro",
115+
]
116+
for device in expected_devices:
117+
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
118+
119+
def test_gather_benchmark_configs_cli(self):
120+
args = {
121+
"models": "mv2,dl3",
122+
"os": "ios",
123+
"devices": "apple_iphone_15",
124+
"configs": None,
125+
}
126+
127+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
128+
for key, value in args.items():
129+
if value is not None:
130+
cmd.append(f"--{key}")
131+
cmd.append(value)
132+
133+
result = subprocess.run(cmd, capture_output=True, text=True)
134+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
135+
self.assertIn('"model": "mv2"', result.stdout)
136+
self.assertIn('"model": "dl3"', result.stdout)
137+
self.assertIn('"config": "coreml_fp16"', result.stdout)
138+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
139+
self.assertIn('"config": "mps"', result.stdout)
140+
141+
def test_gather_benchmark_configs_cli_specified_configs(self):
142+
args = {
143+
"models": "mv2,dl3",
144+
"os": "ios",
145+
"devices": "apple_iphone_15",
146+
"configs": "coreml_fp16,xnnpack_q8",
147+
}
148+
149+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
150+
for key, value in args.items():
151+
if value is not None:
152+
cmd.append(f"--{key}")
153+
cmd.append(value)
154+
155+
result = subprocess.run(cmd, capture_output=True, text=True)
156+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
157+
self.assertIn('"model": "mv2"', result.stdout)
158+
self.assertIn('"model": "dl3"', result.stdout)
159+
self.assertIn('"config": "coreml_fp16"', result.stdout)
160+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
161+
self.assertNotIn('"config": "mps"', result.stdout)
162+
163+
def test_gather_benchmark_configs_cli_specified_configs_raise(self):
164+
args = {
165+
"models": "mv2,dl3",
166+
"os": "ios",
167+
"devices": "apple_iphone_15",
168+
"configs": "qnn_q8",
169+
}
170+
171+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
172+
for key, value in args.items():
173+
if value is not None:
174+
cmd.append(f"--{key}")
175+
cmd.append(value)
176+
177+
result = subprocess.run(cmd, capture_output=True, text=True)
178+
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
179+
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)
180+
181+
182+
if __name__ == "__main__":
183+
unittest.main()

.github/workflows/android-perf.yml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,27 @@ jobs:
7474
CRON_DEFAULT_DEVICES: samsung_galaxy_s22
7575
run: |
7676
set -eux
77+
78+
ARGS="--os android"
79+
7780
MODELS="${{ inputs.models }}"
7881
if [ -z "$MODELS" ]; then
7982
MODELS="$CRON_DEFAULT_MODELS"
8083
fi
84+
ARGS="$ARGS --models $MODELS"
85+
8186
DEVICES="${{ inputs.devices }}"
8287
if [ -z "$DEVICES" ]; then
8388
DEVICES="$CRON_DEFAULT_DEVICES"
8489
fi
90+
ARGS="$ARGS --devices $DEVICES"
91+
92+
BENCHMARK_CONFIGS="${{ inputs.benchmark_configs }}"
93+
if [ -n "$BENCHMARK_CONFIGS" ]; then
94+
ARGS="$ARGS --configs $BENCHMARK_CONFIGS"
95+
fi
8596
86-
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py \
87-
--os "android" \
88-
--models $MODELS \
89-
--devices $DEVICES
97+
PYTHONPATH="${PWD}" python .ci/scripts/gather_benchmark_configs.py $ARGS
9098
9199
prepare-test-specs:
92100
runs-on: linux.2xlarge

0 commit comments

Comments
 (0)