Skip to content

Commit d9f3269

Browse files
authored
Merge branch 'pytorch:main' into add-profiling-to-xnn-executor-runner-2
2 parents c886ad2 + e78ed83 commit d9f3269

File tree

379 files changed

+19741
-4823
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

379 files changed

+19741
-4823
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mpmath==1.3.0
2-
numpy==2.0.0; python_version >= '3.10'
2+
numpy>=2.0.0; python_version >= '3.10'
33
PyYAML==6.0.1
44
ruamel.yaml==0.17.32
55
sympy==1.12
@@ -8,7 +8,7 @@ tomli==2.0.1
88
torchsr==1.0.4
99
transformers==4.47.1
1010
zstd==1.5.5.1
11-
pandas==2.2.2; python_version >= '3.10'
11+
pandas>=2.2.2; python_version >= '3.10'
1212
pytest==7.2.0
1313
pytest-cov==4.1.0
1414
expecttest==0.1.6
@@ -21,7 +21,7 @@ sphinx-gallery==0.14.0
2121
breathe==4.34.0
2222
exhale==0.2.3
2323
docutils==0.16
24-
matplotlib==3.9.4
24+
matplotlib>=3.9.4
2525
# PyTorch Theme
2626
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
2727
myst-parser==0.18.1

.ci/scripts/__init__.py

Whitespace-only changes.

.ci/scripts/gather_benchmark_configs.py

Lines changed: 156 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import logging
1010
import os
1111
import re
12-
from typing import Any, Dict
12+
import sys
13+
from typing import Any, Dict, List, NamedTuple
1314

15+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
1416
from examples.models import MODEL_NAME_TO_MODEL
1517

1618

@@ -45,6 +47,131 @@
4547
}
4648

4749

50+
class DisabledConfig(NamedTuple):
51+
config_name: str
52+
github_issue: str # Link to the GitHub issue
53+
54+
55+
# Updated DISABLED_CONFIGS
56+
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
57+
"resnet50": [
58+
DisabledConfig(
59+
config_name="qnn_q8",
60+
github_issue="https://github.com/pytorch/executorch/issues/7892",
61+
),
62+
],
63+
"w2l": [
64+
DisabledConfig(
65+
config_name="qnn_q8",
66+
github_issue="https://github.com/pytorch/executorch/issues/7634",
67+
),
68+
],
69+
"mobilebert": [
70+
DisabledConfig(
71+
config_name="mps",
72+
github_issue="https://github.com/pytorch/executorch/issues/7904",
73+
),
74+
DisabledConfig(
75+
config_name="qnn_q8",
76+
github_issue="https://github.com/pytorch/executorch/issues/7946",
77+
),
78+
],
79+
"edsr": [
80+
DisabledConfig(
81+
config_name="mps",
82+
github_issue="https://github.com/pytorch/executorch/issues/7905",
83+
),
84+
],
85+
"llama": [
86+
DisabledConfig(
87+
config_name="mps",
88+
github_issue="https://github.com/pytorch/executorch/issues/7907",
89+
),
90+
],
91+
}
92+
93+
94+
def extract_all_configs(data, target_os=None):
95+
if isinstance(data, dict):
96+
# If target_os is specified, include "xplat" and the specified branch
97+
include_branches = {"xplat", target_os} if target_os else data.keys()
98+
return [
99+
v
100+
for key, value in data.items()
101+
if key in include_branches
102+
for v in extract_all_configs(value, target_os)
103+
]
104+
elif isinstance(data, list):
105+
return [v for item in data for v in extract_all_configs(item, target_os)]
106+
else:
107+
return [data]
108+
109+
110+
def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
111+
"""
112+
Generate a list of compatible benchmark configurations for a given model name and target OS.
113+
114+
Args:
115+
model_name (str): The name of the model to generate configurations for.
116+
target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
117+
118+
Returns:
119+
List[str]: A list of compatible benchmark configurations.
120+
121+
Raises:
122+
None
123+
124+
Example:
125+
generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
126+
"""
127+
configs = []
128+
if is_valid_huggingface_model_id(model_name):
129+
if model_name.startswith("meta-llama/"):
130+
# LLaMA models
131+
repo_name = model_name.split("meta-llama/")[1]
132+
if "qlora" in repo_name.lower():
133+
configs.append("llama3_qlora")
134+
elif "spinquant" in repo_name.lower():
135+
configs.append("llama3_spinquant")
136+
else:
137+
configs.append("llama3_fb16")
138+
configs.extend(
139+
[
140+
config
141+
for config in BENCHMARK_CONFIGS.get(target_os, [])
142+
if config.startswith("llama")
143+
]
144+
)
145+
else:
146+
# Non-LLaMA models
147+
configs.append("hf_xnnpack_fp32")
148+
elif model_name in MODEL_NAME_TO_MODEL:
149+
# ExecuTorch in-tree non-GenAI models
150+
configs.append("xnnpack_q8")
151+
if target_os != "xplat":
152+
# Add OS-specific configs
153+
configs.extend(
154+
[
155+
config
156+
for config in BENCHMARK_CONFIGS.get(target_os, [])
157+
if not config.startswith("llama")
158+
]
159+
)
160+
else:
161+
# Skip unknown models with a warning
162+
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
163+
164+
# Remove disabled configs for the given model
165+
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
166+
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
167+
for disabled in disabled_configs:
168+
print(
169+
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
170+
)
171+
configs = [config for config in configs if config not in disabled_config_names]
172+
return configs
173+
174+
48175
def parse_args() -> Any:
49176
"""
50177
Parse command-line arguments.
@@ -82,6 +209,11 @@ def comma_separated(value: str):
82209
type=comma_separated, # Use the custom parser for comma-separated values
83210
help=f"Comma-separated device names. Available devices: {list(DEVICE_POOLS.keys())}",
84211
)
212+
parser.add_argument(
213+
"--configs",
214+
type=comma_separated, # Use the custom parser for comma-separated values
215+
help=f"Comma-separated benchmark configs. Available configs: {extract_all_configs(BENCHMARK_CONFIGS)}",
216+
)
85217

86218
return parser.parse_args()
87219

@@ -98,11 +230,16 @@ def set_output(name: str, val: Any) -> None:
98230
set_output("benchmark_configs", {"include": [...]})
99231
"""
100232

101-
if os.getenv("GITHUB_OUTPUT"):
102-
print(f"Setting {val} to GitHub output")
103-
with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
104-
print(f"{name}={val}", file=env)
105-
else:
233+
github_output = os.getenv("GITHUB_OUTPUT")
234+
if not github_output:
235+
print(f"::set-output name={name}::{val}")
236+
return
237+
238+
try:
239+
with open(github_output, "a") as env:
240+
env.write(f"{name}={val}\n")
241+
except PermissionError:
242+
# Fall back to printing in case of permission error in unit tests
106243
print(f"::set-output name={name}::{val}")
107244

108245

@@ -123,7 +260,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123260
return bool(re.match(pattern, model_name))
124261

125262

126-
def get_benchmark_configs() -> Dict[str, Dict]:
263+
def get_benchmark_configs() -> Dict[str, Dict]: # noqa: C901
127264
"""
128265
Gather benchmark configurations for a given set of models on the target operating system and devices.
129266
@@ -153,48 +290,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153290
}
154291
"""
155292
args = parse_args()
156-
target_os = args.os
157293
devices = args.devices
158294
models = args.models
295+
target_os = args.os
296+
target_configs = args.configs
159297

160298
benchmark_configs = {"include": []}
161299

162300
for model_name in models:
163301
configs = []
164-
if is_valid_huggingface_model_id(model_name):
165-
if model_name.startswith("meta-llama/"):
166-
# LLaMA models
167-
repo_name = model_name.split("meta-llama/")[1]
168-
if "qlora" in repo_name.lower():
169-
configs.append("llama3_qlora")
170-
elif "spinquant" in repo_name.lower():
171-
configs.append("llama3_spinquant")
172-
else:
173-
configs.append("llama3_fb16")
174-
configs.extend(
175-
[
176-
config
177-
for config in BENCHMARK_CONFIGS.get(target_os, [])
178-
if config.startswith("llama")
179-
]
302+
configs.extend(generate_compatible_configs(model_name, target_os))
303+
print(f"Discovered all supported configs for model '{model_name}': {configs}")
304+
if target_configs is not None:
305+
for config in target_configs:
306+
if config not in configs:
307+
raise Exception(
308+
f"Unsupported config '{config}' for model '{model_name}' on '{target_os}'. Skipped.\n"
309+
f"Supported configs are: {configs}"
180310
)
181-
else:
182-
# Non-LLaMA models
183-
configs.append("hf_xnnpack_fp32")
184-
elif model_name in MODEL_NAME_TO_MODEL:
185-
# ExecuTorch in-tree non-GenAI models
186-
configs.append("xnnpack_q8")
187-
configs.extend(
188-
[
189-
config
190-
for config in BENCHMARK_CONFIGS.get(target_os, [])
191-
if not config.startswith("llama")
192-
]
193-
)
194-
else:
195-
# Skip unknown models with a warning
196-
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")
197-
continue
311+
configs = target_configs
312+
print(f"Using provided configs {configs} for model '{model_name}'")
198313

199314
# Add configurations for each valid device
200315
for device in devices:

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

.ci/scripts/test_model.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ test_model_with_qnn() {
169169
EXPORT_SCRIPT=inception_v3
170170
elif [[ "${MODEL_NAME}" == "vit" ]]; then
171171
EXPORT_SCRIPT=torchvision_vit
172+
elif [[ "${MODEL_NAME}" == "edsr" ]]; then
173+
EXPORT_SCRIPT=edsr
174+
# Additional deps for edsr
175+
pip install piq
176+
else
177+
echo "Unsupported model $MODEL_NAME"
178+
exit 1
172179
fi
173180

174181
# Use SM8450 for S22, SM8550 for S23, and SM8560 for S24

0 commit comments

Comments
 (0)