Skip to content

Commit b211c00

Browse files
committed
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
Support creating bf16 PTEs. Differential Revision: [D61981363](https://our.internmc.facebook.com/intern/diff/D61981363/) [ghstack-poisoned]
2 parents b8003af + 862cd85 commit b211c00

File tree

252 files changed

+3304
-1082
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

252 files changed

+3304
-1082
lines changed

.ci/scripts/build-qnn-sdk.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ set_up_aot() {
2929
-DQNN_SDK_ROOT=${QNN_SDK_ROOT} \
3030
-DEXECUTORCH_BUILD_SDK=ON \
3131
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
32+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
3233
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
3334
-DPYTHON_EXECUTABLE=python3 \
3435
-DEXECUTORCH_SEPARATE_FLATCC_HOST_PROJECT=OFF

.ci/scripts/build_llama_android.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ install_executorch_and_backend_lib() {
2222
-DANDROID_PLATFORM=android-23 \
2323
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
2424
-DCMAKE_BUILD_TYPE=Release \
25-
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
2625
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
26+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
27+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
2728
-DEXECUTORCH_BUILD_XNNPACK=ON \
2829
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
2930
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \

.ci/scripts/test.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ test_model_with_qnn() {
175175
EXPORTED_MODEL_NAME=vit_qnn.pte
176176
fi
177177

178-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m SM8550 --compile_only
178+
# Use SM8450 for S22, SM8550 for S23, and SM8560 for S24
179+
QNN_CHIPSET=SM8450
180+
181+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only
179182
EXPORTED_MODEL=./${EXPORT_SCRIPT}/${EXPORTED_MODEL_NAME}
180183
}
181184

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,9 @@ cmake_install_executorch_libraries() {
107107
retry cmake \
108108
-DCMAKE_INSTALL_PREFIX=cmake-out \
109109
-DCMAKE_BUILD_TYPE=Debug \
110-
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
111110
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
111+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
112+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
112113
-DEXECUTORCH_BUILD_KERNELS_CUSTOM="$CUSTOM" \
113114
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
114115
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \

.ci/scripts/test_llava.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ cmake_install_executorch_libraries() {
2020
cmake \
2121
-DCMAKE_INSTALL_PREFIX=cmake-out \
2222
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
23-
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
2423
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
24+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
25+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
2526
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
2627
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
2728
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
@@ -61,7 +62,7 @@ export_llava() {
6162
# Download a new image with different size, to test if the model can handle different image sizes
6263
prepare_image_tensor() {
6364
echo "Downloading image"
64-
curl -o basketball.jpg https://upload.wikimedia.org/wikipedia/commons/7/73/Chicago_Bulls_and_New_Jersey_Nets%2C_March_28%2C_1991.jpg
65+
curl -o basketball.jpg https://upload.wikimedia.org/wikipedia/commons/7/73/Chicago_Bulls_and_New_Jersey_Nets%2C_March_28%2C_1991.jpg
6566
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.image_util --image-path basketball.jpg --output-path image.pt
6667
}
6768

.github/workflows/android-perf.yml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ on:
1515
description: Target devices to run benchmark
1616
required: false
1717
type: string
18-
default: samsung_galaxy_s2x
18+
default: samsung_galaxy_s22
1919
delegates:
2020
description: Backend delegates
2121
required: false
@@ -45,7 +45,7 @@ on:
4545
description: Target devices to run benchmark
4646
required: false
4747
type: string
48-
default: samsung_galaxy_s2x
48+
default: samsung_galaxy_s22
4949
delegates:
5050
description: Backend delegates
5151
required: false
@@ -85,7 +85,7 @@ jobs:
8585
# during scheduled runs and to provide flexibility for different defaults between
8686
# on-demand and periodic benchmarking.
8787
CRON_DEFAULT_MODELS: "stories110M,dl3,mv3,mv2,ic4,ic3,vit"
88-
CRON_DEFAULT_DEVICES: "samsung_galaxy_s2x"
88+
CRON_DEFAULT_DEVICES: "samsung_galaxy_s22"
8989
CRON_DEFAULT_DELEGATES: "xnnpack,qnn"
9090
run: |
9191
set -ex
@@ -104,7 +104,7 @@ jobs:
104104
105105
# Mapping devices to their corresponding device-pool-arn
106106
declare -A DEVICE_POOL_ARNS
107-
DEVICE_POOL_ARNS[samsung_galaxy_s2x]="arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa"
107+
DEVICE_POOL_ARNS[samsung_galaxy_s22]="arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa"
108108
109109
# Resolve device names with their corresponding ARNs
110110
if [[ ! $(echo "$DEVICES" | jq empty 2>/dev/null) ]]; then
@@ -206,6 +206,10 @@ jobs:
206206
name: build-llm-demo
207207
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
208208
needs: set-parameters
209+
strategy:
210+
matrix:
211+
delegate: ${{ fromJson(needs.set-parameters.outputs.delegates) }}
212+
fail-fast: false
209213
with:
210214
runner: linux.2xlarge
211215
docker-image: executorch-ubuntu-22.04-clang12-android
@@ -222,8 +226,14 @@ jobs:
222226
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake
223227
export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded
224228
229+
if [[ ${{ matrix.delegate }} == "qnn" ]]; then
230+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
231+
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
232+
fi
233+
225234
# TODO: This needs to be replaced with a generic loader .apk
226235
# Build LLM Demo for Android
236+
export ANDROID_ABIS="arm64-v8a"
227237
bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
228238
229239
# Upload artifacts to S3. The artifacts are needed not only by the device farm but also TorchChat

.github/workflows/lint.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ jobs:
6565
script: |
6666
FILES_NEEDS_FORMAT=$(/opt/google-java-format -n extension/android/src/main/java/org/pytorch/executorch/*.java \
6767
examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/*.java \
68-
examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/*.java)
68+
examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/*.java \
69+
extension/android/benchmark/app/src/main/java/org/pytorch/minibench/*.java)
6970
if [ -n "$FILES_NEEDS_FORMAT" ]; then
7071
echo "Warning: The following files need formatting. Please use google-java-format."
7172
echo "$FILES_NEEDS_FORMAT"

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ cmake_dependent_option(
228228
)
229229

230230
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
231+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
231232
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)
232233
endif()
233234

backends/apple/coreml/runtime/include/coreml_backend/delegate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class BackendDelegate;
2020
namespace torch {
2121
namespace executor {
2222

23-
class CoreMLBackendDelegate final : public PyTorchBackendInterface {
23+
class CoreMLBackendDelegate final : public ::executorch::runtime::BackendInterface {
2424
public:
2525
CoreMLBackendDelegate() noexcept;
2626
~CoreMLBackendDelegate() = default;

backends/apple/mps/runtime/MPSBackend.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
namespace torch {
2020
namespace executor {
2121

22-
class MPSBackend final : public PyTorchBackendInterface {
22+
class MPSBackend final : public ::executorch::runtime::BackendInterface {
2323
public:
2424
~MPSBackend() = default;
2525

backends/apple/mps/test/test_mps_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,7 @@ def lower_module_and_test_output(
239239
)
240240

241241
executorch_program = delegated_program.to_executorch(
242-
config=ExecutorchBackendConfig(
243-
extract_delegate_segments=False, extract_constant_segment=False
244-
)
242+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
245243
)
246244
else:
247245
delegated_program = to_backend(
@@ -258,9 +256,7 @@ def lower_module_and_test_output(
258256
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
259257
),
260258
).to_executorch(
261-
config=ExecutorchBackendConfig(
262-
extract_delegate_segments=False, extract_constant_segment=False
263-
)
259+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
264260
)
265261

266262
if bundled_program:

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Quantization:
3333
- `arm_quantizer_utils.py` - Utilities for quantization
3434

3535
Runtime:
36-
- `runtime/ArmBackendEthosU.cpp` - The Arm backend implementation of the ExecuTorch runtime backend (PyTorchBackendInterface) for Ethos-U
36+
- `runtime/ArmBackendEthosU.cpp` - The Arm backend implementation of the ExecuTorch runtime backend (BackendInterface) for Ethos-U
3737

3838
Other:
3939
- `third-party/` - Dependencies on other code - in particular the TOSA serialization_lib for compiling to TOSA and the ethos-u-core-driver for the bare-metal backend supporting Ethos-U

backends/arm/TARGETS

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "arm_partitioner",
5+
srcs = [
6+
"arm_partitioner.py",
7+
],
8+
typing = True,
9+
deps = [
10+
":arm_backend",
11+
"//executorch/backends/arm/passes:passes",
12+
"//executorch/exir:lib",
13+
],
14+
)
15+
16+
python_library(
17+
name = "arm_backend",
18+
srcs = [
19+
"arm_backend.py",
20+
],
21+
typing = True,
22+
deps = [
23+
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
24+
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
25+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
26+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
27+
":arm_vela",
28+
"//executorch/backends/arm/operators:lib",
29+
"//executorch/backends/arm/operators:node_visitor",
30+
"//executorch/backends/arm/passes:passes",
31+
],
32+
)
33+
34+
python_library(
35+
name = "arm_vela",
36+
srcs = [
37+
"arm_vela.py",
38+
],
39+
typing = True,
40+
deps = [
41+
"fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela",
42+
],
43+
)
44+
45+
python_library(
46+
name = "tosa_mapping",
47+
srcs = [
48+
"tosa_mapping.py",
49+
],
50+
typing = True,
51+
deps = [
52+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
53+
"//caffe2:torch",
54+
],
55+
)
56+
57+
python_library(
58+
name = "tosa_quant_utils",
59+
srcs = [
60+
"tosa_quant_utils.py",
61+
],
62+
typing = True,
63+
deps = [
64+
"fbsource//third-party/pypi/numpy:numpy",
65+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
66+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
67+
":tosa_mapping",
68+
"//executorch/exir/dialects:lib",
69+
],
70+
)
71+
72+
python_library(
73+
name = "tosa_utils",
74+
srcs = [
75+
"tosa_utils.py",
76+
],
77+
typing = True,
78+
deps = [
79+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
80+
":tosa_quant_utils",
81+
"//executorch/backends/arm/operators:node_visitor",
82+
],
83+
)

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
159159
return False
160160

161161

162-
def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
162+
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
163163
for spec in compile_spec:
164164
if spec.key == "debug_artifact_path":
165165
return spec.value.decode()

backends/arm/arm_vela.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
import os
77
import struct
8-
import subprocess
98
import tempfile
109

1110
from typing import List
1211

1312
import numpy as np
13+
from ethosu.vela import vela
1414

1515

1616
# Pack either input or output tensor block, compose the related arrays into
@@ -38,21 +38,17 @@ def vela_compile(tosa_graph, args: List[str]):
3838
with tempfile.TemporaryDirectory() as tmpdir:
3939
tosaname = "out.tosa"
4040
flatbuffer = tosa_graph.serialize()
41-
with open(os.path.join(tmpdir, tosaname), "wb") as f:
41+
tosa_path = os.path.join(tmpdir, tosaname)
42+
with open(tosa_path, "wb") as f:
4243
f.write(flatbuffer)
4344

4445
# invoke vela
45-
vela_command = f"cd {tmpdir}; vela {' '.join(args)} {tosaname}"
46-
try:
47-
subprocess.run([vela_command], shell=True, check=True, capture_output=True)
48-
except subprocess.CalledProcessError as process_error:
49-
raise RuntimeError(
50-
f"Vela compiler ('{vela_command}') failed with error:\n \
51-
{process_error.stderr.decode()}\n \
52-
Stdout:\n{process_error.stdout.decode()}"
53-
)
54-
55-
np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
46+
output_dir = os.path.join(tmpdir, "output")
47+
args.append(f"--output-dir={output_dir}")
48+
args.append(tosa_path)
49+
vela.main(" ".join(args).split(" "))
50+
51+
np_path = os.path.join(output_dir, "out_sg0_vela.npz")
5652
blocks = b""
5753

5854
with np.load(np_path, allow_pickle=False) as data:

backends/arm/operators/TARGETS

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "node_visitor",
5+
srcs = ["node_visitor.py"],
6+
typing = True,
7+
deps = [
8+
"//executorch/backends/arm:tosa_mapping",
9+
],
10+
)
11+
12+
python_library(
13+
name = "ops",
14+
srcs = glob(["op_*.py"]),
15+
typing = True,
16+
deps = [
17+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
18+
":node_visitor",
19+
"//executorch/backends/arm:tosa_mapping",
20+
"//executorch/backends/arm:tosa_quant_utils",
21+
"//executorch/backends/arm:tosa_utils",
22+
"//executorch/exir:lib",
23+
],
24+
)
25+
26+
python_library(
27+
name = "lib",
28+
srcs = ["__init__.py"],
29+
typing = True,
30+
deps = [
31+
":node_visitor",
32+
":ops",
33+
],
34+
)

backends/arm/operators/op_bmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def define_node(
7272
build_rescale(
7373
tosa_fb=tosa_graph,
7474
scale=final_output_scale,
75+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
7576
input_node=bmm_result,
7677
output_name=output.name,
7778
output_type=ts.DType.INT8,

backends/arm/operators/op_conv2d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import List
5+
from typing import cast, List
66

77
import serializer.tosa_serializer as ts
88
import torch
@@ -156,11 +156,12 @@ def define_node(
156156
# integer value domain of the next op. Otherwise return float32 output.
157157
if is_quant_node:
158158
# Get scale_factor from input, weight, and output.
159-
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
160-
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
159+
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
160+
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
161161
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
162162
build_rescale_conv_output(
163163
tosa_graph,
164+
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
164165
conv2d_res,
165166
output.name,
166167
actual_out_type,

0 commit comments

Comments
 (0)