Skip to content

Commit 8b8987a

Browse files
committed
Update on "Use c10 version of half/bfloat16 in executorch"
Accomplished by importing relevant files from c10 into executorch/runtime/core/portable_type/c10, and then using `using` in the top-level ExecuTorch headers. This approach should keep the ExecuTorch build hermetic for embedded use cases. In the future, we should add a CI job to ensure the c10 files stay identical to the PyTorch ones. Differential Revision: [D66106969](https://our.internmc.facebook.com/intern/diff/D66106969/) [ghstack-poisoned]
2 parents 789c521 + a82f9c4 commit 8b8987a

File tree

236 files changed

+6184
-5138
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

236 files changed

+6184
-5138
lines changed

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 104 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,79 @@
4547
}
4648

4749

50+
def extract_all_configs(data, target_os=None):
51+
if isinstance(data, dict):
52+
# If target_os is specified, include "xplat" and the specified branch
53+
include_branches = {"xplat", target_os} if target_os else data.keys()
54+
return [
55+
v
56+
for key, value in data.items()
57+
if key in include_branches
58+
for v in extract_all_configs(value, target_os)
59+
]
60+
elif isinstance(data, list):
61+
return [v for item in data for v in extract_all_configs(item, target_os)]
62+
else:
63+
return [data]
64+
65+
66+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
67+
"""
68+
Generate a list of compatible benchmark configurations for a given model name and target OS.
69+
70+
Args:
71+
model_name (str): The name of the model to generate configurations for.
72+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73+
74+
Returns:
75+
List[str]: A list of compatible benchmark configurations.
76+
77+
Raises:
78+
None
79+
80+
Example:
81+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82+
"""
83+
configs = []
84+
if is_valid_huggingface_model_id(model_name):
85+
if model_name.startswith("meta-llama/"):
86+
# LLaMA models
87+
repo_name = model_name.split("meta-llama/")[1]
88+
if "qlora" in repo_name.lower():
89+
configs.append("llama3_qlora")
90+
elif "spinquant" in repo_name.lower():
91+
configs.append("llama3_spinquant")
92+
else:
93+
configs.append("llama3_fb16")
94+
configs.extend(
95+
[
96+
config
97+
for config in BENCHMARK_CONFIGS.get(target_os, [])
98+
if config.startswith("llama")
99+
]
100+
)
101+
else:
102+
# Non-LLaMA models
103+
configs.append("hf_xnnpack_fp32")
104+
elif model_name in MODEL_NAME_TO_MODEL:
105+
# ExecuTorch in-tree non-GenAI models
106+
configs.append("xnnpack_q8")
107+
if target_os != "xplat":
108+
# Add OS-specific configs
109+
configs.extend(
110+
[
111+
config
112+
for config in BENCHMARK_CONFIGS.get(target_os, [])
113+
if not config.startswith("llama")
114+
]
115+
)
116+
else:
117+
# Skip unknown models with a warning
118+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119+
120+
return configs
121+
122+
48123
def parse_args() -> Any:
49124
"""
50125
Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157
type=comma_separated, # Use the custom parser for comma-separated values
83158
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84159
)
160+
parser.add_argument(
161+
"--configs",
162+
type=comma_separated, # Use the custom parser for comma-separated values
163+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
164+
)
85165

86166
return parser.parse_args()
87167

@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98178
set_output("benchmark_configs", {"include": [...]})
99179
"""
100180

101-
if os.getenv("GITHUB_OUTPUT"):
102-
print(f"Setting {val} to GitHub output")
103-
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
104-
print(f"{name}={val}", file=env)
105-
else:
181+
github_output = os.getenv("GITHUB_OUTPUT")
182+
if not github_output:
183+
print(f"::set-output name={name}::{val}")
184+
return
185+
186+
try:
187+
with open(github_output, "a") as env:
188+
env.write(f"{name}={val}\n")
189+
except PermissionError:
190+
# Fall back to printing in case of permission error in unit tests
106191
print(f"::set-output name={name}::{val}")
107192

108193

@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123208
return bool(re.match(pattern, model_name))
124209

125210

126-
def get_benchmark_configs() -> Dict[str, Dict]:
211+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127212
"""
128213
Gather benchmark configurations for a given set of models on the target operating system and devices.
129214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153238
}
154239
"""
155240
args = parse_args()
156-
target_os = args.os
157241
devices = args.devices
158242
models = args.models
243+
target_os = args.os
244+
target_configs = args.configs
159245

160246
benchmark_configs = {"include": []}
161247

162248
for model_name in models:
163249
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
250+
configs.extend(generate_compatible_configs(model_name, target_os))
251+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
252+
if target_configs is not None:
253+
for config in target_configs:
254+
if config not in configs:
255+
raise Exception(
256+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
257+
f"Supported configs are: {configs}"
180258
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
259+
configs = target_configs
260+
print(f"Using provided configs {configs} for model '{model_name}'")
198261

199262
# Add configurations for each valid device
200263
for device in devices:
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import importlib.util
2+
import os
3+
import subprocess
4+
import sys
5+
import unittest
6+
from unittest.mock import mock_open, patch
7+
8+
import pytest
9+
10+
# Dynamically import the script
11+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12+
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13+
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14+
spec.loader.exec_module(gather_benchmark_configs)
15+
16+
17+
@pytest.mark.skipif(
18+
sys.platform != "linux", reason="The script under test runs on Linux runners only"
19+
)
20+
class TestGatehrBenchmarkConfigs(unittest.TestCase):
21+
22+
def test_extract_all_configs_android(self):
23+
android_configs = gather_benchmark_configs.extract_all_configs(
24+
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
25+
)
26+
self.assertIn("xnnpack_q8", android_configs)
27+
self.assertIn("qnn_q8", android_configs)
28+
self.assertIn("llama3_spinquant", android_configs)
29+
self.assertIn("llama3_qlora", android_configs)
30+
31+
def test_extract_all_configs_ios(self):
32+
ios_configs = gather_benchmark_configs.extract_all_configs(
33+
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
34+
)
35+
36+
self.assertIn("xnnpack_q8", ios_configs)
37+
self.assertIn("coreml_fp16", ios_configs)
38+
self.assertIn("mps", ios_configs)
39+
self.assertIn("llama3_coreml_ane", ios_configs)
40+
self.assertIn("llama3_spinquant", ios_configs)
41+
self.assertIn("llama3_qlora", ios_configs)
42+
43+
def test_generate_compatible_configs_llama_model(self):
44+
model_name = "meta-llama/Llama-3.2-1B"
45+
target_os = "ios"
46+
result = gather_benchmark_configs.generate_compatible_configs(
47+
model_name, target_os
48+
)
49+
expected = ["llama3_fb16", "llama3_coreml_ane"]
50+
self.assertEqual(result, expected)
51+
52+
target_os = "android"
53+
result = gather_benchmark_configs.generate_compatible_configs(
54+
model_name, target_os
55+
)
56+
expected = ["llama3_fb16"]
57+
self.assertEqual(result, expected)
58+
59+
def test_generate_compatible_configs_quantized_llama_model(self):
60+
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
62+
expected = ["llama3_spinquant"]
63+
self.assertEqual(result, expected)
64+
65+
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66+
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
67+
expected = ["llama3_qlora"]
68+
self.assertEqual(result, expected)
69+
70+
def test_generate_compatible_configs_non_genai_model(self):
71+
model_name = "mv2"
72+
target_os = "xplat"
73+
result = gather_benchmark_configs.generate_compatible_configs(
74+
model_name, target_os
75+
)
76+
expected = ["xnnpack_q8"]
77+
self.assertEqual(result, expected)
78+
79+
target_os = "android"
80+
result = gather_benchmark_configs.generate_compatible_configs(
81+
model_name, target_os
82+
)
83+
expected = ["xnnpack_q8", "qnn_q8"]
84+
self.assertEqual(result, expected)
85+
86+
target_os = "ios"
87+
result = gather_benchmark_configs.generate_compatible_configs(
88+
model_name, target_os
89+
)
90+
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
91+
self.assertEqual(result, expected)
92+
93+
def test_generate_compatible_configs_unknown_model(self):
94+
model_name = "unknown_model"
95+
target_os = "ios"
96+
result = gather_benchmark_configs.generate_compatible_configs(
97+
model_name, target_os
98+
)
99+
self.assertEqual(result, [])
100+
101+
def test_is_valid_huggingface_model_id_valid(self):
102+
valid_model = "meta-llama/Llama-3.2-1B"
103+
self.assertTrue(
104+
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105+
)
106+
107+
@patch("builtins.open", new_callable=mock_open)
108+
@patch("os.getenv", return_value=None)
109+
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110+
with patch("builtins.print") as mock_print:
111+
gather_benchmark_configs.set_output("test_name", "test_value")
112+
mock_print.assert_called_with("::set-output name=test_name::test_value")
113+
114+
def test_device_pools_contains_all_devices(self):
115+
expected_devices = [
116+
"apple_iphone_15",
117+
"apple_iphone_15+ios_18",
118+
"samsung_galaxy_s22",
119+
"samsung_galaxy_s24",
120+
"google_pixel_8_pro",
121+
]
122+
for device in expected_devices:
123+
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
124+
125+
def test_gather_benchmark_configs_cli(self):
126+
args = {
127+
"models": "mv2,dl3",
128+
"os": "ios",
129+
"devices": "apple_iphone_15",
130+
"configs": None,
131+
}
132+
133+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
134+
for key, value in args.items():
135+
if value is not None:
136+
cmd.append(f"--{key}")
137+
cmd.append(value)
138+
139+
result = subprocess.run(cmd, capture_output=True, text=True)
140+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
141+
self.assertIn('"model": "mv2"', result.stdout)
142+
self.assertIn('"model": "dl3"', result.stdout)
143+
self.assertIn('"config": "coreml_fp16"', result.stdout)
144+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
145+
self.assertIn('"config": "mps"', result.stdout)
146+
147+
def test_gather_benchmark_configs_cli_specified_configs(self):
148+
args = {
149+
"models": "mv2,dl3",
150+
"os": "ios",
151+
"devices": "apple_iphone_15",
152+
"configs": "coreml_fp16,xnnpack_q8",
153+
}
154+
155+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
156+
for key, value in args.items():
157+
if value is not None:
158+
cmd.append(f"--{key}")
159+
cmd.append(value)
160+
161+
result = subprocess.run(cmd, capture_output=True, text=True)
162+
self.assertEqual(result.returncode, 0, f"Error: {result.stderr}")
163+
self.assertIn('"model": "mv2"', result.stdout)
164+
self.assertIn('"model": "dl3"', result.stdout)
165+
self.assertIn('"config": "coreml_fp16"', result.stdout)
166+
self.assertIn('"config": "xnnpack_q8"', result.stdout)
167+
self.assertNotIn('"config": "mps"', result.stdout)
168+
169+
def test_gather_benchmark_configs_cli_specified_configs_raise(self):
170+
args = {
171+
"models": "mv2,dl3",
172+
"os": "ios",
173+
"devices": "apple_iphone_15",
174+
"configs": "qnn_q8",
175+
}
176+
177+
cmd = ["python", ".ci/scripts/gather_benchmark_configs.py"]
178+
for key, value in args.items():
179+
if value is not None:
180+
cmd.append(f"--{key}")
181+
cmd.append(value)
182+
183+
result = subprocess.run(cmd, capture_output=True, text=True)
184+
self.assertEqual(result.returncode, 1, f"Error: {result.stderr}")
185+
self.assertIn("Unsupported config 'qnn_q8'", result.stderr)
186+
187+
188+
if __name__ == "__main__":
189+
unittest.main()

0 commit comments

Comments
 (0)