Skip to content

Commit c737322

Browse files
committed
Update
[ghstack-poisoned]
2 parents cdd702a + 58bda89 commit c737322

File tree

157 files changed

+10431
-1074
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

157 files changed

+10431
-1074
lines changed

.ci/scripts/gather_benchmark_configs.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import re
1212
import sys
13-
from typing import Any, Dict, List
13+
from typing import Any, Dict, List, NamedTuple
1414

1515
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1616
from examples.models import MODEL_NAME_TO_MODEL
@@ -47,6 +47,50 @@
4747
}
4848

4949

50+
class DisabledConfig(NamedTuple):
51+
config_name: str
52+
github_issue: str # Link to the GitHub issue
53+
54+
55+
# Updated DISABLED_CONFIGS
56+
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
57+
"resnet50": [
58+
DisabledConfig(
59+
config_name="qnn_q8",
60+
github_issue="https://github.com/pytorch/executorch/issues/7892",
61+
),
62+
],
63+
"w2l": [
64+
DisabledConfig(
65+
config_name="qnn_q8",
66+
github_issue="https://github.com/pytorch/executorch/issues/7634",
67+
),
68+
],
69+
"mobilebert": [
70+
DisabledConfig(
71+
config_name="mps",
72+
github_issue="https://github.com/pytorch/executorch/issues/7904",
73+
),
74+
DisabledConfig(
75+
config_name="qnn_q8",
76+
github_issue="https://github.com/pytorch/executorch/issues/7946",
77+
),
78+
],
79+
"edsr": [
80+
DisabledConfig(
81+
config_name="mps",
82+
github_issue="https://github.com/pytorch/executorch/issues/7905",
83+
),
84+
],
85+
"llama": [
86+
DisabledConfig(
87+
config_name="mps",
88+
github_issue="https://github.com/pytorch/executorch/issues/7907",
89+
),
90+
],
91+
}
92+
93+
5094
def extract_all_configs(data, target_os=None):
5195
if isinstance(data, dict):
5296
# If target_os is specified, include "xplat" and the specified branch
@@ -117,6 +161,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
117161
# Skip unknown models with a warning
118162
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119163

164+
# Remove disabled configs for the given model
165+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
166+
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
167+
for disabled in disabled_configs:
168+
print(
169+
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
170+
)
171+
configs = [config for config in configs if config not in disabled_config_names]
120172
return configs
121173

122174

.ci/scripts/tests/test_gather_benchmark_configs.py

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,41 @@
11
import importlib.util
22
import os
3+
import re
34
import subprocess
45
import sys
56
import unittest
67
from unittest.mock import mock_open, patch
78

89
import pytest
910

10-
# Dynamically import the script
11-
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12-
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13-
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14-
spec.loader.exec_module(gather_benchmark_configs)
15-
1611

1712
@pytest.mark.skipif(
1813
sys.platform != "linux", reason="The script under test runs on Linux runners only"
1914
)
2015
class TestGatehrBenchmarkConfigs(unittest.TestCase):
2116

17+
@classmethod
18+
def setUpClass(cls):
19+
# Dynamically import the script
20+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
21+
spec = importlib.util.spec_from_file_location(
22+
"gather_benchmark_configs", script_path
23+
)
24+
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
25+
spec.loader.exec_module(cls.gather_benchmark_configs)
26+
2227
def test_extract_all_configs_android(self):
23-
android_configs = gather_benchmark_configs.extract_all_configs(
24-
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
28+
android_configs = self.gather_benchmark_configs.extract_all_configs(
29+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
2530
)
2631
self.assertIn("xnnpack_q8", android_configs)
2732
self.assertIn("qnn_q8", android_configs)
2833
self.assertIn("llama3_spinquant", android_configs)
2934
self.assertIn("llama3_qlora", android_configs)
3035

3136
def test_extract_all_configs_ios(self):
32-
ios_configs = gather_benchmark_configs.extract_all_configs(
33-
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
37+
ios_configs = self.gather_benchmark_configs.extract_all_configs(
38+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
3439
)
3540

3641
self.assertIn("xnnpack_q8", ios_configs)
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
4045
self.assertIn("llama3_spinquant", ios_configs)
4146
self.assertIn("llama3_qlora", ios_configs)
4247

48+
def test_skip_disabled_configs(self):
49+
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
50+
with patch.dict(
51+
self.gather_benchmark_configs.DISABLED_CONFIGS,
52+
{
53+
"mv3": [
54+
self.gather_benchmark_configs.DisabledConfig(
55+
config_name="disabled_config1",
56+
github_issue="https://github.com/org/repo/issues/123",
57+
),
58+
self.gather_benchmark_configs.DisabledConfig(
59+
config_name="disabled_config2",
60+
github_issue="https://github.com/org/repo/issues/124",
61+
),
62+
]
63+
},
64+
), patch.dict(
65+
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
66+
{
67+
"ios": [
68+
"disabled_config1",
69+
"disabled_config2",
70+
"enabled_config1",
71+
"enabled_config2",
72+
]
73+
},
74+
):
75+
result = self.gather_benchmark_configs.generate_compatible_configs(
76+
"mv3", target_os="ios"
77+
)
78+
79+
# Assert that disabled configs are excluded
80+
self.assertNotIn("disabled_config1", result)
81+
self.assertNotIn("disabled_config2", result)
82+
# Assert enabled configs are included
83+
self.assertIn("enabled_config1", result)
84+
self.assertIn("enabled_config2", result)
85+
86+
def test_disabled_configs_have_github_links(self):
87+
github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+")
88+
89+
for (
90+
model_name,
91+
disabled_configs,
92+
) in self.gather_benchmark_configs.DISABLED_CONFIGS.items():
93+
for disabled in disabled_configs:
94+
with self.subTest(model_name=model_name, config=disabled.config_name):
95+
# Assert that disabled is an instance of DisabledConfig
96+
self.assertIsInstance(
97+
disabled, self.gather_benchmark_configs.DisabledConfig
98+
)
99+
100+
# Assert that github_issue is provided and matches the expected pattern
101+
self.assertTrue(
102+
disabled.github_issue
103+
and github_issue_regex.match(disabled.github_issue),
104+
f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.",
105+
)
106+
43107
def test_generate_compatible_configs_llama_model(self):
44108
model_name = "meta-llama/Llama-3.2-1B"
45109
target_os = "ios"
46-
result = gather_benchmark_configs.generate_compatible_configs(
110+
result = self.gather_benchmark_configs.generate_compatible_configs(
47111
model_name, target_os
48112
)
49113
expected = ["llama3_fb16", "llama3_coreml_ane"]
50114
self.assertEqual(result, expected)
51115

52116
target_os = "android"
53-
result = gather_benchmark_configs.generate_compatible_configs(
117+
result = self.gather_benchmark_configs.generate_compatible_configs(
54118
model_name, target_os
55119
)
56120
expected = ["llama3_fb16"]
57121
self.assertEqual(result, expected)
58122

59123
def test_generate_compatible_configs_quantized_llama_model(self):
60124
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
125+
result = self.gather_benchmark_configs.generate_compatible_configs(
126+
model_name, None
127+
)
62128
expected = ["llama3_spinquant"]
63129
self.assertEqual(result, expected)
64130

65131
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
132+
result = self.gather_benchmark_configs.generate_compatible_configs(
133+
model_name, None
134+
)
67135
expected = ["llama3_qlora"]
68136
self.assertEqual(result, expected)
69137

70138
def test_generate_compatible_configs_non_genai_model(self):
71139
model_name = "mv2"
72140
target_os = "xplat"
73-
result = gather_benchmark_configs.generate_compatible_configs(
141+
result = self.gather_benchmark_configs.generate_compatible_configs(
74142
model_name, target_os
75143
)
76144
expected = ["xnnpack_q8"]
77145
self.assertEqual(result, expected)
78146

79147
target_os = "android"
80-
result = gather_benchmark_configs.generate_compatible_configs(
148+
result = self.gather_benchmark_configs.generate_compatible_configs(
81149
model_name, target_os
82150
)
83151
expected = ["xnnpack_q8", "qnn_q8"]
84152
self.assertEqual(result, expected)
85153

86154
target_os = "ios"
87-
result = gather_benchmark_configs.generate_compatible_configs(
155+
result = self.gather_benchmark_configs.generate_compatible_configs(
88156
model_name, target_os
89157
)
90158
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93161
def test_generate_compatible_configs_unknown_model(self):
94162
model_name = "unknown_model"
95163
target_os = "ios"
96-
result = gather_benchmark_configs.generate_compatible_configs(
164+
result = self.gather_benchmark_configs.generate_compatible_configs(
97165
model_name, target_os
98166
)
99167
self.assertEqual(result, [])
100168

101169
def test_is_valid_huggingface_model_id_valid(self):
102170
valid_model = "meta-llama/Llama-3.2-1B"
103171
self.assertTrue(
104-
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
172+
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105173
)
106174

107175
@patch("builtins.open", new_callable=mock_open)
108176
@patch("os.getenv", return_value=None)
109177
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110178
with patch("builtins.print") as mock_print:
111-
gather_benchmark_configs.set_output("test_name", "test_value")
179+
self.gather_benchmark_configs.set_output("test_name", "test_value")
112180
mock_print.assert_called_with("::set-output name=test_name::test_value")
113181

114182
def test_device_pools_contains_all_devices(self):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120188
"google_pixel_8_pro",
121189
]
122190
for device in expected_devices:
123-
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
191+
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)
124192

125193
def test_gather_benchmark_configs_cli(self):
126194
args = {

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ else()
164164
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2")
165165
endif()
166166

167-
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")
168-
169167
option(EXECUTORCH_BUILD_ANDROID_JNI "Build Android JNI" OFF)
170168

171169
option(EXECUTORCH_BUILD_ARM_BAREMETAL

backends/arm/operator_support/to_copy_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
125125
# Check dim_order (to_dim_order_copy)
126126
if "dim_order" in node.kwargs:
127127
dim_order = node.kwargs["dim_order"]
128+
# pyre-ignore[6]
128129
if dim_order != list(range(len(dim_order))):
129130
logger.info(
130131
f"Argument {dim_order=} is not supported for "

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ python_library(
6565
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
6666
"//executorch/backends/cadence/runtime:runtime",
6767
"//executorch/backends/cadence/aot/quantizer:quantizer",
68+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
6869
"//executorch/backends/transforms:decompose_sdpa",
6970
"//executorch/backends/transforms:remove_clone_ops",
7071
"//executorch/exir:lib",

0 commit comments

Comments
 (0)