Skip to content

Commit 317dc11

Browse files
author
Github Executorch
committed
Support config filtering in ondemand benchmark flow
1 parent 4bc2029 commit 317dc11

File tree

6 files changed

+312
-49
lines changed

6 files changed

+312
-49
lines changed

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 104 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,79 @@
4547
}
4648

4749

50+
def extract_all_configs(data, target_os=None):
51+
if isinstance(data, dict):
52+
# If target_os is specified, include "xplat" and the specified branch
53+
include_branches = {"xplat", target_os} if target_os else data.keys()
54+
return [
55+
v
56+
for key, value in data.items()
57+
if key in include_branches
58+
for v in extract_all_configs(value, target_os)
59+
]
60+
elif isinstance(data, list):
61+
return [v for item in data for v in extract_all_configs(item, target_os)]
62+
else:
63+
return [data]
64+
65+
66+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
67+
"""
68+
Generate a list of compatible benchmark configurations for a given model name and target OS.
69+
70+
Args:
71+
model_name (str): The name of the model to generate configurations for.
72+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73+
74+
Returns:
75+
List[str]: A list of compatible benchmark configurations.
76+
77+
Raises:
78+
None
79+
80+
Example:
81+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82+
"""
83+
configs = []
84+
if is_valid_huggingface_model_id(model_name):
85+
if model_name.startswith("meta-llama/"):
86+
# LLaMA models
87+
repo_name = model_name.split("meta-llama/")[1]
88+
if "qlora" in repo_name.lower():
89+
configs.append("llama3_qlora")
90+
elif "spinquant" in repo_name.lower():
91+
configs.append("llama3_spinquant")
92+
else:
93+
configs.append("llama3_fb16")
94+
configs.extend(
95+
[
96+
config
97+
for config in BENCHMARK_CONFIGS.get(target_os, [])
98+
if config.startswith("llama")
99+
]
100+
)
101+
else:
102+
# Non-LLaMA models
103+
configs.append("hf_xnnpack_fp32")
104+
elif model_name in MODEL_NAME_TO_MODEL:
105+
# ExecuTorch in-tree non-GenAI models
106+
configs.append("xnnpack_q8")
107+
if target_os != "xplat":
108+
# Add OS-specific configs
109+
configs.extend(
110+
[
111+
config
112+
for config in BENCHMARK_CONFIGS.get(target_os, [])
113+
if not config.startswith("llama")
114+
]
115+
)
116+
else:
117+
# Skip unknown models with a warning
118+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119+
120+
return configs
121+
122+
48123
def parse_args() -> Any:
49124
"""
50125
Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157
type=comma_separated, # Use the custom parser for comma-separated values
83158
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84159
)
160+
parser.add_argument(
161+
"--configs",
162+
type=comma_separated, # Use the custom parser for comma-separated values
163+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
164+
)
85165

86166
return parser.parse_args()
87167

@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98178
set_output("benchmark_configs", {"include": [...]})
99179
"""
100180

101-
if os.getenv("GITHUB_OUTPUT"):
102-
print(f"Setting {val} to GitHub output")
103-
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
104-
print(f"{name}={val}", file=env)
105-
else:
181+
github_output = os.getenv("GITHUB_OUTPUT")
182+
if not github_output:
183+
print(f"::set-output name={name}::{val}")
184+
return
185+
186+
try:
187+
with open(github_output, "a") as env:
188+
env.write(f"{name}={val}\n")
189+
except PermissionError:
190+
# Fall back to printing in case of permission error in unit tests
106191
print(f"::set-output name={name}::{val}")
107192

108193

@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123208
return bool(re.match(pattern, model_name))
124209

125210

126-
def get_benchmark_configs() -> Dict[str, Dict]:
211+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127212
"""
128213
Gather benchmark configurations for a given set of models on the target operating system and devices.
129214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153238
}
154239
"""
155240
args = parse_args()
156-
target_os = args.os
157241
devices = args.devices
158242
models = args.models
243+
target_os = args.os
244+
target_configs = args.configs
159245

160246
benchmark_configs = {"include": []}
161247

162248
for model_name in models:
163249
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
250+
configs.extend(generate_compatible_configs(model_name, target_os))
251+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
252+
if target_configs is not None:
253+
for config in target_configs:
254+
if config not in configs:
255+
raise Exception(
256+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
257+
f"Supported configs are: {configs}"
180258
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
259+
configs = target_configs
260+
print(f"Using provided configs {configs} for model '{model_name}'")
198261

199262
# Add configurations for each valid device
200263
for device in devices:
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import importlib.util
2+
import os
3+
import subprocess
4+
import unittest
5+
from unittest.mock import mock_open, patch
6+
7+
# Dynamically import the script
8+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
9+
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
10+
gather_benchmark_configs = importlib.util.module_from_spec(spec)
11+
spec.loader.exec_module(gather_benchmark_configs)
12+
13+
14+
class TestGatehrBenchmarkConfigs(unittest.TestCase):
15+
16+
def test_extract_all_configs_android(self):
17+
android_configs = gather_benchmark_configs.extract_all_configs(
18+
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
19+
)
20+
self.assertIn("xnnpack_q8", android_configs)
21+
self.assertIn("qnn_q8", android_configs)
22+
self.assertIn("llama3_spinquant", android_configs)
23+
self.assertIn("llama3_qlora", android_configs)
24+
25+
def test_extract_all_configs_ios(self):
26+
ios_configs = gather_benchmark_configs.extract_all_configs(
27+
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
28+
)
29+
30+
self.assertIn("xnnpack_q8", ios_configs)
31+
self.assertIn("coreml_fp16", ios_configs)
32+
self.assertIn("mps", ios_configs)
33+
self.assertIn("llama3_coreml_ane", ios_configs)
34+
self.assertIn("llama3_spinquant", ios_configs)
35+
self.assertIn("llama3_qlora", ios_configs)
36+
37+
def test_generate_compatible_configs_llama_model(self):
38+
model_name = "meta-llama/Llama-3.2-1B"
39+
target_os = "ios"
40+
result = gather_benchmark_configs.generate_compatible_configs(
41+
model_name, target_os
42+
)
43+
expected = ["llama3_fb16", "llama3_coreml_ane"]
44+
self.assertEqual(result, expected)
45+
46+
target_os = "android"
47+
result = gather_benchmark_configs.generate_compatible_configs(
48+
model_name, target_os
49+
)
50+
expected = ["llama3_fb16"]
51+
self.assertEqual(result, expected)
52+
53+
def test_generate_compatible_configs_quantized_llama_model(self):
54+
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
55+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
56+
expected = ["llama3_spinquant"]
57+
self.assertEqual(result, expected)
58+
59+
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
60+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
61+
expected = ["llama3_qlora"]
62+
self.assertEqual(result, expected)
63+
64+
def test_generate_compatible_configs_non_genai_model(self):
65+
model_name = "mv2"
66+
target_os = "xplat"
67+
result = gather_benchmark_configs.generate_compatible_configs(
68+
model_name, target_os
69+
)
70+
expected = ["xnnpack_q8"]
71+
self.assertEqual(result, expected)
72+
73+
target_os = "android"
74+
result = gather_benchmark_configs.generate_compatible_configs(
75+
model_name, target_os
76+
)
77+
expected = ["xnnpack_q8", "qnn_q8"]
78+
self.assertEqual(result, expected)
79+
80+
target_os = "ios"
81+
result = gather_benchmark_configs.generate_compatible_configs(
82+
model_name, target_os
83+
)
84+
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
85+
self.assertEqual(result, expected)
86+
87+
def test_generate_compatible_configs_unknown_model(self):
88+
model_name = "unknown_model"
89+
target_os = "ios"
90+
result = gather_benchmark_configs.generate_compatible_configs(
91+
model_name, target_os
92+
)
93+
self.assertEqual(result, [])
94+
95+
def test_is_valid_huggingface_model_id_valid(self):
96+
valid_model = "meta-llama/Llama-3.2-1B"
97+
self.assertTrue(
98+
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
99+
)
100+
101+
@patch("builtins.open", new_callable=mock_open)
102+
@patch("os.getenv", return_value=None)
103+
def test_set_output_no_github_env(self, mock_getenv, mock_file):
104+
with patch("builtins.print") as mock_print:
105+
gather_benchmark_configs.set_output("test_name", "test_value")
106+
mock_print.assert_called_with("::set-output name=test_name::test_value")
107+
108+
def test_device_pools_contains_all_devices(self):
109+
expected_devices = [
110+
"apple_iphone_15",
111+
"apple_iphone_15+ios_18",
112+
"samsung_galaxy_s22",
113+
"samsung_galaxy_s24",
114+
"google_pixel_8_pro",
115+
]
116+
for device in expected_devices:
117+
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
118+
119+
def test_gather_benchmark_configs_cli(self):
120+
args = {
121+
"models": "mv2,dl3",
122+
"os": "ios",
123+
"devices": "apple_iphone_15",
124+
"configs": None,
125+
}
126+
127+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
128+
for key, value in args.items():
129+
if value is not None:
130+
cmd.append(f"--{key}")
131+
cmd.append(value)
132+
133+
result = subprocess.run(cmd, capture_output=True, text=True)
134+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
135+
self.assertIn('"model": "mv2"', result.stdout)
136+
self.assertIn('"model": "dl3"', result.stdout)
137+
self.assertIn('"config": "coreml_fp16"', result.stdout)
138+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
139+
self.assertIn('"config": "mps"', result.stdout)
140+
141+
def test_gather_benchmark_configs_cli_specified_configs(self):
142+
args = {
143+
"models": "mv2,dl3",
144+
"os": "ios",
145+
"devices": "apple_iphone_15",
146+
"configs": "coreml_fp16,xnnpack_q8",
147+
}
148+
149+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
150+
for key, value in args.items():
151+
if value is not None:
152+
cmd.append(f"--{key}")
153+
cmd.append(value)
154+
155+
result = subprocess.run(cmd, capture_output=True, text=True)
156+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
157+
self.assertIn('"model": "mv2"', result.stdout)
158+
self.assertIn('"model": "dl3"', result.stdout)
159+
self.assertIn('"config": "coreml_fp16"', result.stdout)
160+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
161+
self.assertNotIn('"config": "mps"', result.stdout)
162+
163+
def test_gather_benchmark_configs_cli_specified_configs_raise(self):
164+
args = {
165+
"models": "mv2,dl3",
166+
"os": "ios",
167+
"devices": "apple_iphone_15",
168+
"configs": "qnn_q8",
169+
}
170+
171+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
172+
for key, value in args.items():
173+
if value is not None:
174+
cmd.append(f"--{key}")
175+
cmd.append(value)
176+
177+
result = subprocess.run(cmd, capture_output=True, text=True)
178+
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
179+
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)
180+
181+
182+
if __name__ == "__main__":
183+
unittest.main()

0 commit comments

Comments
 (0)