Skip to content

Commit eb6d388

Browse files
committed
Update on "[ET-VK] Integrate axis mapping into staging <-> buffer transfer shaders"
## Context Building on the previous diff, this diff integrates axis mapping into staging <-> buffer transfer shaders. Alternative versions of indexing utility functions are introduced to account for axis mapping. The impact of shader latency of using axis mapping on transfer shaders is examined in the next diff. Differential Revision: [D62210117](https://our.internmc.facebook.com/intern/diff/D62210117/) [ghstack-poisoned]
2 parents 7c1ff3b + b2ab1d7 commit eb6d388

File tree

176 files changed

+2608
-773
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

176 files changed

+2608
-773
lines changed

.ci/scripts/test.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ test_model_with_qnn() {
175175
EXPORTED_MODEL_NAME=vit_qnn.pte
176176
fi
177177

178-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m SM8550 --compile_only
178+
# Use SM8450 for S22, SM8550 for S23, and SM8560 for S24
179+
QNN_CHIPSET=SM8450
180+
181+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --compile_only
179182
EXPORTED_MODEL=./${EXPORT_SCRIPT}/${EXPORTED_MODEL_NAME}
180183
}
181184

.github/workflows/android-perf.yml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ on:
1515
description: Target devices to run benchmark
1616
required: false
1717
type: string
18-
default: samsung_galaxy_s2x
18+
default: samsung_galaxy_s22
1919
delegates:
2020
description: Backend delegates
2121
required: false
@@ -45,7 +45,7 @@ on:
4545
description: Target devices to run benchmark
4646
required: false
4747
type: string
48-
default: samsung_galaxy_s2x
48+
default: samsung_galaxy_s22
4949
delegates:
5050
description: Backend delegates
5151
required: false
@@ -85,7 +85,7 @@ jobs:
8585
# during scheduled runs and to provide flexibility for different defaults between
8686
# on-demand and periodic benchmarking.
8787
CRON_DEFAULT_MODELS: "stories110M,dl3,mv3,mv2,ic4,ic3,vit"
88-
CRON_DEFAULT_DEVICES: "samsung_galaxy_s2x"
88+
CRON_DEFAULT_DEVICES: "samsung_galaxy_s22"
8989
CRON_DEFAULT_DELEGATES: "xnnpack,qnn"
9090
run: |
9191
set -ex
@@ -104,7 +104,7 @@ jobs:
104104
105105
# Mapping devices to their corresponding device-pool-arn
106106
declare -A DEVICE_POOL_ARNS
107-
DEVICE_POOL_ARNS[samsung_galaxy_s2x]="arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa"
107+
DEVICE_POOL_ARNS[samsung_galaxy_s22]="arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa"
108108
109109
# Resolve device names with their corresponding ARNs
110110
if [[ ! $(echo "$DEVICES" | jq empty 2>/dev/null) ]]; then
@@ -206,6 +206,10 @@ jobs:
206206
name: build-llm-demo
207207
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
208208
needs: set-parameters
209+
strategy:
210+
matrix:
211+
delegate: ${{ fromJson(needs.set-parameters.outputs.delegates) }}
212+
fail-fast: false
209213
with:
210214
runner: linux.2xlarge
211215
docker-image: executorch-ubuntu-22.04-clang12-android
@@ -222,8 +226,14 @@ jobs:
222226
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake
223227
export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded
224228
229+
if [[ ${{ matrix.delegate }} == "qnn" ]]; then
230+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
231+
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
232+
fi
233+
225234
# TODO: This needs to be replaced with a generic loader .apk
226235
# Build LLM Demo for Android
236+
export ANDROID_ABIS="arm64-v8a"
227237
bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
228238
229239
# Upload artifacts to S3. The artifacts are needed not only by the device farm but also TorchChat

.github/workflows/lint.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ jobs:
6565
script: |
6666
FILES_NEEDS_FORMAT=$(/opt/google-java-format -n extension/android/src/main/java/org/pytorch/executorch/*.java \
6767
examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/*.java \
68-
examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/*.java)
68+
examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/*.java \
69+
extension/android/benchmark/app/src/main/java/org/pytorch/minibench/*.java)
6970
if [ -n "$FILES_NEEDS_FORMAT" ]; then
7071
echo "Warning: The following files need formatting. Please use google-java-format."
7172
echo "$FILES_NEEDS_FORMAT"

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ cmake_dependent_option(
228228
)
229229

230230
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
231+
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
231232
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)
232233
endif()
233234

backends/apple/coreml/runtime/include/coreml_backend/delegate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class BackendDelegate;
2020
namespace torch {
2121
namespace executor {
2222

23-
class CoreMLBackendDelegate final : public PyTorchBackendInterface {
23+
class CoreMLBackendDelegate final : public ::executorch::runtime::BackendInterface {
2424
public:
2525
CoreMLBackendDelegate() noexcept;
2626
~CoreMLBackendDelegate() = default;

backends/apple/mps/runtime/MPSBackend.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
namespace torch {
2020
namespace executor {
2121

22-
class MPSBackend final : public PyTorchBackendInterface {
22+
class MPSBackend final : public ::executorch::runtime::BackendInterface {
2323
public:
2424
~MPSBackend() = default;
2525

backends/apple/mps/test/test_mps_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,7 @@ def lower_module_and_test_output(
239239
)
240240

241241
executorch_program = delegated_program.to_executorch(
242-
config=ExecutorchBackendConfig(
243-
extract_delegate_segments=False, extract_constant_segment=False
244-
)
242+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
245243
)
246244
else:
247245
delegated_program = to_backend(
@@ -258,9 +256,7 @@ def lower_module_and_test_output(
258256
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
259257
),
260258
).to_executorch(
261-
config=ExecutorchBackendConfig(
262-
extract_delegate_segments=False, extract_constant_segment=False
263-
)
259+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
264260
)
265261

266262
if bundled_program:

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Quantization:
3333
- `arm_quantizer_utils.py` - Utilities for quantization
3434

3535
Runtime:
36-
- `runtime/ArmBackendEthosU.cpp` - The Arm backend implementation of the ExecuTorch runtime backend (PyTorchBackendInterface) for Ethos-U
36+
- `runtime/ArmBackendEthosU.cpp` - The Arm backend implementation of the ExecuTorch runtime backend (BackendInterface) for Ethos-U
3737

3838
Other:
3939
- `third-party/` - Dependencies on other code - in particular the TOSA serialization_lib for compiling to TOSA and the ethos-u-core-driver for the bare-metal backend supporting Ethos-U

backends/arm/TARGETS

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "arm_partitioner",
5+
srcs = [
6+
"arm_partitioner.py",
7+
],
8+
typing = True,
9+
deps = [
10+
":arm_backend",
11+
"//executorch/backends/arm/passes:passes",
12+
"//executorch/exir:lib",
13+
],
14+
)
15+
16+
python_library(
17+
name = "arm_backend",
18+
srcs = [
19+
"arm_backend.py",
20+
],
21+
typing = True,
22+
deps = [
23+
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
24+
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
25+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
26+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
27+
":arm_vela",
28+
"//executorch/backends/arm/operators:lib",
29+
"//executorch/backends/arm/operators:node_visitor",
30+
"//executorch/backends/arm/passes:passes",
31+
],
32+
)
33+
34+
python_library(
35+
name = "arm_vela",
36+
srcs = [
37+
"arm_vela.py",
38+
],
39+
typing = True,
40+
deps = [
41+
"fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela",
42+
],
43+
)
44+
45+
python_library(
46+
name = "tosa_mapping",
47+
srcs = [
48+
"tosa_mapping.py",
49+
],
50+
typing = True,
51+
deps = [
52+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
53+
"//caffe2:torch",
54+
],
55+
)
56+
57+
python_library(
58+
name = "tosa_quant_utils",
59+
srcs = [
60+
"tosa_quant_utils.py",
61+
],
62+
typing = True,
63+
deps = [
64+
"fbsource//third-party/pypi/numpy:numpy",
65+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
66+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
67+
":tosa_mapping",
68+
"//executorch/exir/dialects:lib",
69+
],
70+
)
71+
72+
python_library(
73+
name = "tosa_utils",
74+
srcs = [
75+
"tosa_utils.py",
76+
],
77+
typing = True,
78+
deps = [
79+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
80+
":tosa_quant_utils",
81+
"//executorch/backends/arm/operators:node_visitor",
82+
],
83+
)

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
159159
return False
160160

161161

162-
def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
162+
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
163163
for spec in compile_spec:
164164
if spec.key == "debug_artifact_path":
165165
return spec.value.decode()

backends/arm/arm_vela.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
import os
77
import struct
8-
import subprocess
98
import tempfile
109

1110
from typing import List
1211

1312
import numpy as np
13+
from ethosu.vela import vela
1414

1515

1616
# Pack either input or output tensor block, compose the related arrays into
@@ -38,21 +38,17 @@ def vela_compile(tosa_graph, args: List[str]):
3838
with tempfile.TemporaryDirectory() as tmpdir:
3939
tosaname = "out.tosa"
4040
flatbuffer = tosa_graph.serialize()
41-
with open(os.path.join(tmpdir, tosaname), "wb") as f:
41+
tosa_path = os.path.join(tmpdir, tosaname)
42+
with open(tosa_path, "wb") as f:
4243
f.write(flatbuffer)
4344

4445
# invoke vela
45-
vela_command = f"cd {tmpdir}; vela {' '.join(args)} {tosaname}"
46-
try:
47-
subprocess.run([vela_command], shell=True, check=True, capture_output=True)
48-
except subprocess.CalledProcessError as process_error:
49-
raise RuntimeError(
50-
f"Vela compiler ('{vela_command}') failed with error:\n \
51-
{process_error.stderr.decode()}\n \
52-
Stdout:\n{process_error.stdout.decode()}"
53-
)
54-
55-
np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
46+
output_dir = os.path.join(tmpdir, "output")
47+
args.append(f"--output-dir={output_dir}")
48+
args.append(tosa_path)
49+
vela.main(" ".join(args).split(" "))
50+
51+
np_path = os.path.join(output_dir, "out_sg0_vela.npz")
5652
blocks = b""
5753

5854
with np.load(np_path, allow_pickle=False) as data:

backends/arm/operators/TARGETS

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "node_visitor",
5+
srcs = ["node_visitor.py"],
6+
typing = True,
7+
deps = [
8+
"//executorch/backends/arm:tosa_mapping",
9+
],
10+
)
11+
12+
python_library(
13+
name = "ops",
14+
srcs = glob(["op_*.py"]),
15+
typing = True,
16+
deps = [
17+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
18+
":node_visitor",
19+
"//executorch/backends/arm:tosa_mapping",
20+
"//executorch/backends/arm:tosa_quant_utils",
21+
"//executorch/backends/arm:tosa_utils",
22+
"//executorch/exir:lib",
23+
],
24+
)
25+
26+
python_library(
27+
name = "lib",
28+
srcs = ["__init__.py"],
29+
typing = True,
30+
deps = [
31+
":node_visitor",
32+
":ops",
33+
],
34+
)

backends/arm/operators/op_bmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def define_node(
7272
build_rescale(
7373
tosa_fb=tosa_graph,
7474
scale=final_output_scale,
75+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
7576
input_node=bmm_result,
7677
output_name=output.name,
7778
output_type=ts.DType.INT8,

backends/arm/operators/op_conv2d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import List
5+
from typing import cast, List
66

77
import serializer.tosa_serializer as ts
88
import torch
@@ -156,11 +156,12 @@ def define_node(
156156
# integer value domain of the next op. Otherwise return float32 output.
157157
if is_quant_node:
158158
# Get scale_factor from input, weight, and output.
159-
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
160-
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
159+
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
160+
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
161161
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
162162
build_rescale_conv_output(
163163
tosa_graph,
164+
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
164165
conv2d_res,
165166
output.name,
166167
actual_out_type,

backends/arm/operators/op_mm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def define_node(
9696
build_rescale(
9797
tosa_fb=tosa_graph,
9898
scale=final_output_scale,
99+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined.
99100
input_node=reshape_intermediate,
100101
output_name=output.name,
101102
output_type=ts.DType.INT8,

backends/arm/operators/op_mul.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import List
6+
from typing import cast, List
77

88
import executorch.backends.arm.tosa_quant_utils as tqutils
99
import executorch.backends.arm.tosa_utils as tutils
@@ -35,8 +35,12 @@ def define_node(
3535
if is_quant_node:
3636
input_A = inputs[0]
3737
input_B = inputs[1]
38-
input_A_qargs = tqutils.get_quant_node_args(node.args[0])
39-
input_B_qargs = tqutils.get_quant_node_args(node.args[1])
38+
input_A_qargs = tqutils.get_quant_node_args(
39+
cast(torch.fx.Node, node.args[0])
40+
)
41+
input_B_qargs = tqutils.get_quant_node_args(
42+
cast(torch.fx.Node, node.args[1])
43+
)
4044

4145
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
4246
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)

backends/arm/operators/op_output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import cast
7+
68
import serializer.tosa_serializer as ts
79
import torch
810

@@ -11,7 +13,7 @@ def process_output(
1113
node: torch.fx.Node,
1214
tosa_graph: ts.TosaSerializer,
1315
):
14-
for output in node.args[0]:
16+
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
1517
tosa_graph.addOutputTensor(
1618
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
1719
)

0 commit comments

Comments
 (0)