Skip to content

Commit ecb11ed

Browse files
committed
Update on "[ET][Memory planning] Improve greedy memory planning."
This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/) cc JacobSzwejbka angelayi [ghstack-poisoned]
2 parents 3975a75 + e8ebe1a commit ecb11ed

File tree

214 files changed

+13620
-1683
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

214 files changed

+13620
-1683
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mpmath==1.3.0
2-
numpy==2.0.0; python_version >= '3.10'
2+
numpy>=2.0.0; python_version >= '3.10'
33
PyYAML==6.0.1
44
ruamel.yaml==0.17.32
55
sympy==1.12
@@ -8,7 +8,7 @@ tomli==2.0.1
88
torchsr==1.0.4
99
transformers==4.47.1
1010
zstd==1.5.5.1
11-
pandas==2.2.2; python_version >= '3.10'
11+
pandas>=2.2.2; python_version >= '3.10'
1212
pytest==7.2.0
1313
pytest-cov==4.1.0
1414
expecttest==0.1.6
@@ -21,7 +21,7 @@ sphinx-gallery==0.14.0
2121
breathe==4.34.0
2222
exhale==0.2.3
2323
docutils==0.16
24-
matplotlib==3.9.4
24+
matplotlib>=3.9.4
2525
# PyTorch Theme
2626
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
2727
myst-parser==0.18.1

.ci/scripts/gather_benchmark_configs.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import os
1111
import re
1212
import sys
13-
from typing import Any, Dict, List
13+
from typing import Any, Dict, List, NamedTuple
1414

1515
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1616
from examples.models import MODEL_NAME_TO_MODEL
@@ -47,6 +47,50 @@
4747
}
4848

4949

50+
class DisabledConfig(NamedTuple):
51+
config_name: str
52+
github_issue: str # Link to the GitHub issue
53+
54+
55+
# Updated DISABLED_CONFIGS
56+
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
57+
"resnet50": [
58+
DisabledConfig(
59+
config_name="qnn_q8",
60+
github_issue="https://github.com/pytorch/executorch/issues/7892",
61+
),
62+
],
63+
"w2l": [
64+
DisabledConfig(
65+
config_name="qnn_q8",
66+
github_issue="https://github.com/pytorch/executorch/issues/7634",
67+
),
68+
],
69+
"mobilebert": [
70+
DisabledConfig(
71+
config_name="mps",
72+
github_issue="https://github.com/pytorch/executorch/issues/7904",
73+
),
74+
DisabledConfig(
75+
config_name="qnn_q8",
76+
github_issue="https://github.com/pytorch/executorch/issues/7946",
77+
),
78+
],
79+
"edsr": [
80+
DisabledConfig(
81+
config_name="mps",
82+
github_issue="https://github.com/pytorch/executorch/issues/7905",
83+
),
84+
],
85+
"llama": [
86+
DisabledConfig(
87+
config_name="mps",
88+
github_issue="https://github.com/pytorch/executorch/issues/7907",
89+
),
90+
],
91+
}
92+
93+
5094
def extract_all_configs(data, target_os=None):
5195
if isinstance(data, dict):
5296
# If target_os is specified, include "xplat" and the specified branch
@@ -117,6 +161,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
117161
# Skip unknown models with a warning
118162
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
119163

164+
# Remove disabled configs for the given model
165+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
166+
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
167+
for disabled in disabled_configs:
168+
print(
169+
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
170+
)
171+
configs = [config for config in configs if config not in disabled_config_names]
120172
return configs
121173

122174

.ci/scripts/tests/test_gather_benchmark_configs.py

Lines changed: 89 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,41 @@
11
import importlib.util
22
import os
3+
import re
34
import subprocess
45
import sys
56
import unittest
67
from unittest.mock import mock_open, patch
78

89
import pytest
910

10-
# Dynamically import the script
11-
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
12-
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
13-
gather_benchmark_configs = importlib.util.module_from_spec(spec)
14-
spec.loader.exec_module(gather_benchmark_configs)
15-
1611

1712
@pytest.mark.skipif(
1813
sys.platform != "linux", reason="The script under test runs on Linux runners only"
1914
)
2015
class TestGatehrBenchmarkConfigs(unittest.TestCase):
2116

17+
@classmethod
18+
def setUpClass(cls):
19+
# Dynamically import the script
20+
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
21+
spec = importlib.util.spec_from_file_location(
22+
"gather_benchmark_configs", script_path
23+
)
24+
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
25+
spec.loader.exec_module(cls.gather_benchmark_configs)
26+
2227
def test_extract_all_configs_android(self):
23-
android_configs = gather_benchmark_configs.extract_all_configs(
24-
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
28+
android_configs = self.gather_benchmark_configs.extract_all_configs(
29+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
2530
)
2631
self.assertIn("xnnpack_q8", android_configs)
2732
self.assertIn("qnn_q8", android_configs)
2833
self.assertIn("llama3_spinquant", android_configs)
2934
self.assertIn("llama3_qlora", android_configs)
3035

3136
def test_extract_all_configs_ios(self):
32-
ios_configs = gather_benchmark_configs.extract_all_configs(
33-
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
37+
ios_configs = self.gather_benchmark_configs.extract_all_configs(
38+
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
3439
)
3540

3641
self.assertIn("xnnpack_q8", ios_configs)
@@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
4045
self.assertIn("llama3_spinquant", ios_configs)
4146
self.assertIn("llama3_qlora", ios_configs)
4247

48+
def test_skip_disabled_configs(self):
49+
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
50+
with patch.dict(
51+
self.gather_benchmark_configs.DISABLED_CONFIGS,
52+
{
53+
"mv3": [
54+
self.gather_benchmark_configs.DisabledConfig(
55+
config_name="disabled_config1",
56+
github_issue="https://github.com/org/repo/issues/123",
57+
),
58+
self.gather_benchmark_configs.DisabledConfig(
59+
config_name="disabled_config2",
60+
github_issue="https://github.com/org/repo/issues/124",
61+
),
62+
]
63+
},
64+
), patch.dict(
65+
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
66+
{
67+
"ios": [
68+
"disabled_config1",
69+
"disabled_config2",
70+
"enabled_config1",
71+
"enabled_config2",
72+
]
73+
},
74+
):
75+
result = self.gather_benchmark_configs.generate_compatible_configs(
76+
"mv3", target_os="ios"
77+
)
78+
79+
# Assert that disabled configs are excluded
80+
self.assertNotIn("disabled_config1", result)
81+
self.assertNotIn("disabled_config2", result)
82+
# Assert enabled configs are included
83+
self.assertIn("enabled_config1", result)
84+
self.assertIn("enabled_config2", result)
85+
86+
def test_disabled_configs_have_github_links(self):
87+
github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+")
88+
89+
for (
90+
model_name,
91+
disabled_configs,
92+
) in self.gather_benchmark_configs.DISABLED_CONFIGS.items():
93+
for disabled in disabled_configs:
94+
with self.subTest(model_name=model_name, config=disabled.config_name):
95+
# Assert that disabled is an instance of DisabledConfig
96+
self.assertIsInstance(
97+
disabled, self.gather_benchmark_configs.DisabledConfig
98+
)
99+
100+
# Assert that github_issue is provided and matches the expected pattern
101+
self.assertTrue(
102+
disabled.github_issue
103+
and github_issue_regex.match(disabled.github_issue),
104+
f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.",
105+
)
106+
43107
def test_generate_compatible_configs_llama_model(self):
44108
model_name = "meta-llama/Llama-3.2-1B"
45109
target_os = "ios"
46-
result = gather_benchmark_configs.generate_compatible_configs(
110+
result = self.gather_benchmark_configs.generate_compatible_configs(
47111
model_name, target_os
48112
)
49113
expected = ["llama3_fb16", "llama3_coreml_ane"]
50114
self.assertEqual(result, expected)
51115

52116
target_os = "android"
53-
result = gather_benchmark_configs.generate_compatible_configs(
117+
result = self.gather_benchmark_configs.generate_compatible_configs(
54118
model_name, target_os
55119
)
56120
expected = ["llama3_fb16"]
57121
self.assertEqual(result, expected)
58122

59123
def test_generate_compatible_configs_quantized_llama_model(self):
60124
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
61-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
125+
result = self.gather_benchmark_configs.generate_compatible_configs(
126+
model_name, None
127+
)
62128
expected = ["llama3_spinquant"]
63129
self.assertEqual(result, expected)
64130

65131
model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
66-
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
132+
result = self.gather_benchmark_configs.generate_compatible_configs(
133+
model_name, None
134+
)
67135
expected = ["llama3_qlora"]
68136
self.assertEqual(result, expected)
69137

70138
def test_generate_compatible_configs_non_genai_model(self):
71139
model_name = "mv2"
72140
target_os = "xplat"
73-
result = gather_benchmark_configs.generate_compatible_configs(
141+
result = self.gather_benchmark_configs.generate_compatible_configs(
74142
model_name, target_os
75143
)
76144
expected = ["xnnpack_q8"]
77145
self.assertEqual(result, expected)
78146

79147
target_os = "android"
80-
result = gather_benchmark_configs.generate_compatible_configs(
148+
result = self.gather_benchmark_configs.generate_compatible_configs(
81149
model_name, target_os
82150
)
83151
expected = ["xnnpack_q8", "qnn_q8"]
84152
self.assertEqual(result, expected)
85153

86154
target_os = "ios"
87-
result = gather_benchmark_configs.generate_compatible_configs(
155+
result = self.gather_benchmark_configs.generate_compatible_configs(
88156
model_name, target_os
89157
)
90158
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
@@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
93161
def test_generate_compatible_configs_unknown_model(self):
94162
model_name = "unknown_model"
95163
target_os = "ios"
96-
result = gather_benchmark_configs.generate_compatible_configs(
164+
result = self.gather_benchmark_configs.generate_compatible_configs(
97165
model_name, target_os
98166
)
99167
self.assertEqual(result, [])
100168

101169
def test_is_valid_huggingface_model_id_valid(self):
102170
valid_model = "meta-llama/Llama-3.2-1B"
103171
self.assertTrue(
104-
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
172+
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
105173
)
106174

107175
@patch("builtins.open", new_callable=mock_open)
108176
@patch("os.getenv", return_value=None)
109177
def test_set_output_no_github_env(self, mock_getenv, mock_file):
110178
with patch("builtins.print") as mock_print:
111-
gather_benchmark_configs.set_output("test_name", "test_value")
179+
self.gather_benchmark_configs.set_output("test_name", "test_value")
112180
mock_print.assert_called_with("::set-output name=test_name::test_value")
113181

114182
def test_device_pools_contains_all_devices(self):
@@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
120188
"google_pixel_8_pro",
121189
]
122190
for device in expected_devices:
123-
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
191+
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)
124192

125193
def test_gather_benchmark_configs_cli(self):
126194
args = {

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ else()
164164
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2")
165165
endif()
166166

167-
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")
168-
169167
option(EXECUTORCH_BUILD_ANDROID_JNI "Build Android JNI" OFF)
170168

171169
option(EXECUTORCH_BUILD_ARM_BAREMETAL

0 commit comments

Comments
 (0)