Skip to content

Commit cee5e89

Browse files
committed
Update on "[ET-VK] Store unique ptr to Tensor in Value instead of inlined tensor object, to reduce Value struct size from 448 to 80 bytes."
This diff aims to reduce the size of the Value struct in the Executorch Vulkan runtime by storing a unique pointer to the Tensor object instead of an inlined tensor object. This change reduces the size of the Value struct from 448 bytes to 80 bytes, which can improve performance and reduce memory usage. Differential Revision: [D66655991](https://our.internmc.facebook.com/intern/diff/D66655991/) [ghstack-poisoned]
2 parents dc09af3 + 5860506 commit cee5e89

File tree

69 files changed

+705
-407
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+705
-407
lines changed

CMakeLists.txt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ option(EXECUTORCH_BUILD_EXTENSION_TENSOR "Build the Tensor extension" OFF)
200200

201201
option(EXECUTORCH_BUILD_EXTENSION_TRAINING "Build the training extension" OFF)
202202

203-
option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)
204-
205203
option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)
206204

207205
option(EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" OFF)
@@ -216,6 +214,8 @@ option(EXECUTORCH_BUILD_KERNELS_QUANTIZED "Build the quantized kernels" OFF)
216214

217215
option(EXECUTORCH_BUILD_DEVTOOLS "Build the ExecuTorch Developer Tools")
218216

217+
option(EXECUTORCH_BUILD_TESTS "Build CMake-based unit tests" OFF)
218+
219219
option(EXECUTORCH_NNLIB_OPT "Build Cadence backend Hifi nnlib kernel" OFF)
220220

221221
option(EXECUTORCH_CADENCE_CPU_RUNNER "Build Cadence backend CPU runner" OFF)
@@ -330,6 +330,10 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
330330
)
331331
endif()
332332

333+
if(EXECUTORCH_BUILD_TESTS)
334+
include(CTest)
335+
endif()
336+
333337
if(NOT PYTHON_EXECUTABLE)
334338
resolve_python_executable()
335339
endif()
@@ -625,7 +629,7 @@ cmake_dependent_option(
625629
)
626630

627631
# Add googletest if any test targets should be built
628-
if(EXECUTORCH_BUILD_GTESTS)
632+
if(BUILD_TESTING)
629633
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/googletest)
630634
endif()
631635

@@ -829,5 +833,7 @@ if(EXECUTORCH_BUILD_VULKAN)
829833
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
830834
endif()
831835

836+
include(Test.cmake)
837+
832838
# Print all summary
833839
executorch_print_configuration_summary()

Test.cmake

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# A helper CMake file to trigger C++ unit tests.
9+
#
10+
11+
if(BUILD_TESTING)
12+
# This contains the list of tests which are always built
13+
add_subdirectory(extension/evalue_util/test)
14+
add_subdirectory(extension/kernel_util/test)
15+
add_subdirectory(extension/memory_allocator/test)
16+
add_subdirectory(extension/parallel/test)
17+
add_subdirectory(extension/pytree/test)
18+
add_subdirectory(kernels/portable/cpu/util/test)
19+
add_subdirectory(kernels/prim_ops/test)
20+
add_subdirectory(kernels/test)
21+
add_subdirectory(runtime/core/exec_aten/testing_util/test)
22+
add_subdirectory(runtime/core/exec_aten/util/test)
23+
add_subdirectory(runtime/core/portable_type/test)
24+
add_subdirectory(runtime/core/test)
25+
add_subdirectory(runtime/executor/test)
26+
add_subdirectory(runtime/kernel/test)
27+
add_subdirectory(runtime/platform/test)
28+
add_subdirectory(test/utils)
29+
endif()

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module):
260260
) # Works
261261

262262
@parameterized.expand(testsuite_conv2d, skip_on_empty=True)
263+
@unittest.expectedFailure
263264
def test_dw_conv2d_u55_BI(
264265
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
265266
):

backends/cadence/aot/ops_registrations.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@
146146
"quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
147147
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
148148
)
149-
149+
lib.define(
150+
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
151+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
152+
)
150153

151154
# ------------------------------------ #
152155
# Migrated from custom_ops.ymal #
@@ -192,6 +195,10 @@
192195
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
193196
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
194197
)
198+
lib.define(
199+
"quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
200+
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
201+
)
195202
lib.define(
196203
"quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, "
197204
"Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -595,6 +602,28 @@ def quantized_fully_connected_meta(
595602
bias: torch.Tensor,
596603
in_zero_point: int,
597604
weight_zero_point: torch.Tensor,
605+
out_multiplier: torch.Tensor,
606+
out_shift: torch.Tensor,
607+
out_zero_point: int,
608+
offset: Optional[torch.Tensor],
609+
) -> torch.Tensor:
610+
# src comes in shape [leading_dims, in_dim]
611+
# weight comes in shape [out_dim, in_dim]
612+
# output comes in empty with shape [leading_dims, out_dim]
613+
out_size = list(src.size())
614+
weight_size = list(weight.size())
615+
assert len(weight_size) == 2
616+
out_size[-1] = weight_size[0]
617+
return src.new_empty(out_size, dtype=src.dtype)
618+
619+
620+
@register_fake("cadence::quantized_fully_connected.per_tensor")
621+
def quantized_fully_connected_per_tensor_meta(
622+
src: torch.Tensor,
623+
weight: torch.Tensor,
624+
bias: torch.Tensor,
625+
in_zero_point: int,
626+
weight_zero_point: int,
598627
out_multiplier: int,
599628
out_shift: int,
600629
out_zero_point: int,
@@ -607,7 +636,7 @@ def quantized_fully_connected_meta(
607636
weight_size = list(weight.size())
608637
assert len(weight_size) == 2
609638
out_size[-1] = weight_size[0]
610-
return src.new_empty(out_size, dtype=torch.uint8)
639+
return src.new_empty(out_size, dtype=src.dtype)
611640

612641

613642
@register_fake("cadence::convolution")

backends/cadence/aot/replace_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# 3. functions that replace an ATen op with another semantically equivalent ATen op.
1010
# 4. functions that concretize optional args.
1111

12+
# pyre-unsafe
13+
1214
import math
1315
from operator import neg
1416
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
@@ -1698,12 +1700,6 @@ def call_operator(self, op, args, kwargs, meta):
16981700
if leading_dims != 1:
16991701
return super().call_operator(op, args, kwargs, meta)
17001702

1701-
# If the op is quantized::linear, but per-channel quantized, bail.
1702-
if op == exir_ops.edge.cadence.quantized_linear.default:
1703-
weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1]
1704-
if weight.shape != [1]:
1705-
return super().call_operator(op, args, kwargs, meta)
1706-
17071703
# Replace the linear with fully connected op
17081704
return super().call_operator(
17091705
self.linear_to_fc_op[op],
@@ -1893,6 +1889,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18931889
exir_ops.edge.cadence.quantized_conv.per_tensor,
18941890
[8, 9, 12, 13],
18951891
),
1892+
exir_ops.edge.cadence.quantized_fully_connected: (
1893+
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
1894+
[4, 5, 6],
1895+
),
18961896
exir_ops.edge.cadence.quantized_layer_norm: (
18971897
exir_ops.edge.cadence.quantized_layer_norm.per_tensor,
18981898
[1, 2],
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -euo pipefail
9+
10+
unset CMAKE_PREFIX_PATH
11+
unset XTENSA_CORE
12+
export XTENSA_CORE=FCV_FG3GP
13+
git submodule sync
14+
git submodule update --init
15+
./install_requirements.sh
16+
17+
rm -rf cmake-out
18+
19+
STEPWISE_BUILD=false
20+
21+
if $STEPWISE_BUILD; then
22+
echo "Building ExecuTorch"
23+
cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
24+
-DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \
25+
-DCMAKE_BUILD_TYPE=Release \
26+
-DEXECUTORCH_ENABLE_EVENT_TRACER=OFF \
27+
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
28+
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \
29+
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
30+
-DEXECUTORCH_BUILD_CPUINFO=OFF \
31+
-DEXECUTORCH_ENABLE_LOGGING=ON \
32+
-DEXECUTORCH_USE_DL=OFF \
33+
-DEXECUTORCH_BUILD_CADENCE=OFF \
34+
-DFLATC_EXECUTABLE="$(which flatc)" \
35+
-DHAVE_FNMATCH_H=OFF \
36+
-Bcmake-out .
37+
38+
echo "Building any Cadence-specific binaries on top"
39+
cmake -DBUCK2="$BUCK" \
40+
-DCMAKE_TOOLCHAIN_FILE=/home/zonglinpeng/ws/zonglinpeng/executorch/backends/cadence/cadence.cmake \
41+
-DCMAKE_INSTALL_PREFIX=cmake-out \
42+
-DCMAKE_BUILD_TYPE=Release \
43+
-DEXECUTORCH_BUILD_HOST_TARGETS=ON \
44+
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \
45+
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
46+
-DEXECUTORCH_BUILD_CADENCE=ON \
47+
-DFLATC_EXECUTABLE="$(which flatc)" \
48+
-DEXECUTORCH_ENABLE_LOGGING=ON \
49+
-DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \
50+
-DEXECUTORCH_USE_DL=OFF \
51+
-DBUILD_EXECUTORCH_PORTABLE_OPS=ON \
52+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF \
53+
-DPYTHON_EXECUTABLE=python3 \
54+
-DEXECUTORCH_FUSION_G3_OPT=ON \
55+
-DEXECUTORCH_BUILD_GFLAGS=ON \
56+
-DHAVE_FNMATCH_H=OFF \
57+
-Bcmake-out/backends/cadence \
58+
backends/cadence
59+
cmake --build cmake-out/backends/cadence -j8
60+
else
61+
echo "Building Cadence toolchain with ExecuTorch packages"
62+
cmake_prefix_path="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags"
63+
cmake -DBUCK2="$BUCK" \
64+
-DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \
65+
-DHAVE_SYS_STAT_H=ON \
66+
-DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \
67+
-DCMAKE_INSTALL_PREFIX=cmake-out \
68+
-DCMAKE_BUILD_TYPE=Release \
69+
-DEXECUTORCH_BUILD_HOST_TARGETS=ON \
70+
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \
71+
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
72+
-DEXECUTORCH_BUILD_CPUINFO=OFF \
73+
-DEXECUTORCH_BUILD_FLATC=OFF \
74+
-DEXECUTORCH_BUILD_CADENCE=ON \
75+
-DFLATC_EXECUTABLE="$(which flatc)" \
76+
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
77+
-DEXECUTORCH_ENABLE_LOGGING=ON \
78+
-DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \
79+
-DEXECUTORCH_USE_DL=OFF \
80+
-DBUILD_EXECUTORCH_PORTABLE_OPS=ON \
81+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF \
82+
-DPYTHON_EXECUTABLE=python3 \
83+
-DEXECUTORCH_FUSION_G3_OPT=ON \
84+
-DHAVE_FNMATCH_H=OFF \
85+
-Bcmake-out
86+
cmake --build cmake-out --target install --config Release -j8
87+
fi
88+
89+
echo "Run simple model to verify cmake build"
90+
python3 -m examples.portable.scripts.export --model_name="add"
91+
xt-run --turbo cmake-out/executor_runner --model_path=add.pte

backends/cadence/build_cadence_xtensa.sh renamed to backends/cadence/build_cadence_hifi4.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
set -euo pipefail
99

1010
unset CMAKE_PREFIX_PATH
11+
unset XTENSA_CORE
12+
export XTENSA_CORE=nxp_rt600_RI23_11_newlib
1113
git submodule sync
1214
git submodule update --init
1315
./install_requirements.sh
@@ -53,7 +55,7 @@ if $STEPWISE_BUILD; then
5355
-DHAVE_FNMATCH_H=OFF \
5456
-Bcmake-out/backends/cadence \
5557
backends/cadence
56-
cmake --build cmake-out/backends/cadence -j16
58+
cmake --build cmake-out/backends/cadence -j8
5759
else
5860
echo "Building Cadence toolchain with ExecuTorch packages"
5961
cmake_prefix_path="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags"
@@ -79,7 +81,7 @@ else
7981
-DEXECUTORCH_NNLIB_OPT=ON \
8082
-DHAVE_FNMATCH_H=OFF \
8183
-Bcmake-out
82-
cmake --build cmake-out --target install --config Release -j16
84+
cmake --build cmake-out --target install --config Release -j8
8385
fi
8486

8587
echo "Run simple model to verify cmake build"

backends/cadence/hifi/operators/op_mean.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ Tensor& mean_dim_out(
145145
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
146146
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
147147
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
148-
const size_t num =
149-
torch::executor::exeget_reduced_dim_product(in, dim_list);
148+
const size_t num = torch::executor::get_reduced_dim_product(in, dim_list);
150149
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
151150
CTYPE_OUT sum = 0;
152151
if (in.numel() > 0) {

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class RemoveRedundancy(ExportPass):
1313
"""
14-
Trim the 'identity' operators to reduce the unnecessary copy overhead.
14+
Trim certain operators to reduce unnecessary overhead.
1515
"""
1616

1717
redundant_ops = {
@@ -21,6 +21,10 @@ class RemoveRedundancy(ExportPass):
2121
torch.ops.aten.alias.default,
2222
exir_ops.edge.aten.alias.default,
2323
exir_ops.edge.aten.lift_fresh_copy.default,
24+
# remove this target if '_skip_dim_order' is set to False
25+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
26+
# remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True
27+
exir_ops.edge.aten._to_copy.default,
2428
}
2529

2630
def __init__(self):
@@ -31,6 +35,13 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
3135
if n.target not in self.redundant_ops:
3236
continue
3337

38+
# do not remove cast operator
39+
if (
40+
n.target == exir_ops.edge.aten._to_copy.default
41+
and "memory_format" not in n.kwargs
42+
):
43+
continue
44+
3445
to_be_remove = n
3546
for user_n in list(n.users.keys()):
3647
user_n.replace_input_with(n, n.args[0])

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
op_ceil,
1515
op_clamp,
1616
op_conv2d,
17+
op_cos,
1718
op_depth_to_space,
1819
op_dequantize,
1920
op_div,
@@ -43,6 +44,7 @@
4344
op_rsqrt,
4445
op_select_copy,
4546
op_sigmoid,
47+
op_sin,
4648
op_skip_ops,
4749
op_slice_copy,
4850
op_softmax,
@@ -71,6 +73,7 @@
7173
op_ceil,
7274
op_clamp,
7375
op_conv2d,
76+
op_cos,
7477
op_depth_to_space,
7578
op_dequantize,
7679
op_div,
@@ -100,6 +103,7 @@
100103
op_rsqrt,
101104
op_select_copy,
102105
op_sigmoid,
106+
op_sin,
103107
op_skip_ops,
104108
op_slice_copy,
105109
op_softmax,

0 commit comments

Comments
 (0)