Skip to content

Commit b99f1d8

Browse files
committed
Update on "[ET-VK] Introduce vTensorPtr to prevent reference invalidation and remove get_val() API"
## Context Currently when writing operators developers will save a reference to a `vTensor` retrieved from a `ComputeGraph`'s list of `values_` like so: ``` vTensor& vten = graph.get_val(vref).toTensor(); ``` However, this is dangerous since if any values are added once the reference has been stored, `values_` which is a `std::vector` may have been resized and therefore have its contents moved, meaning the reference is now invalid. To protect against this, this changeset introduces the `vTensorPtr` class which is a wrapper around a `vTensor*`. When constructed, it will increment a counter in the `ComputeGraph` instance, and when destroyed it will decrement the counter. `ComputeGraph` cannot add any values while the counter is not zero. Since `Value` can be converted to other non-trivial types, this changeset also removes the `get_val` function entirely to guard against unsafe behaviour. Differential Revision: [D55984187](https://our.internmc.facebook.com/intern/diff/D55984187/) [ghstack-poisoned]
2 parents ff3213f + 650869c commit b99f1d8

File tree

68 files changed

+843
-354
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+843
-354
lines changed

CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)
144144

145145
option(EXECUTORCH_BUILD_CUSTOM "Build the custom kernels" OFF)
146146

147+
option(EXECUTORCH_BUILD_CUSTOM_OPS_AOT "Build the custom ops lib for AOT" OFF)
148+
147149
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension"
148150
OFF)
149151

@@ -185,12 +187,19 @@ cmake_dependent_option(
185187
cmake_dependent_option(EXECUTORCH_BUILD_CPUINFO "Build cpuinfo library." ON
186188
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF)
187189

190+
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
191+
set(EXECUTORCH_BUILD_CUSTOM ON)
192+
endif()
193+
188194
if(EXECUTORCH_BUILD_CUSTOM)
189195
set(EXECUTORCH_BUILD_OPTIMIZED ON)
190196
endif()
191197

192198
if(EXECUTORCH_BUILD_CPUINFO)
193199
# --- cpuinfo
200+
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
201+
${CMAKE_POSITION_INDEPENDENT_CODE})
202+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
194203
set(CPUINFO_SOURCE_DIR "backends/xnnpack/third-party/cpuinfo")
195204
set(CPUINFO_BUILD_TOOLS
196205
OFF
@@ -212,10 +221,15 @@ if(EXECUTORCH_BUILD_CPUINFO)
212221
CACHE STRING "")
213222
set(CLOG_SOURCE_DIR "${CPUINFO_SOURCE_DIR}/deps/clog")
214223
add_subdirectory("${CPUINFO_SOURCE_DIR}")
224+
set(CMAKE_POSITION_INDEPENDENT_CODE
225+
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
215226
endif()
216227

217228
if(EXECUTORCH_BUILD_PTHREADPOOL)
218229
# --- pthreadpool
230+
set(ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG
231+
${CMAKE_POSITION_INDEPENDENT_CODE})
232+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
219233
set(PTHREADPOOL_SOURCE_DIR "backends/xnnpack/third-party/pthreadpool")
220234
set(PTHREADPOOL_BUILD_TESTS
221235
OFF
@@ -235,6 +249,8 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
235249
CACHE STRING "")
236250
endif()
237251
add_subdirectory("${PTHREADPOOL_SOURCE_DIR}")
252+
set(CMAKE_POSITION_INDEPENDENT_CODE
253+
${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG})
238254
endif()
239255

240256
if(NOT PYTHON_EXECUTABLE)
@@ -546,6 +562,9 @@ if(EXECUTORCH_BUILD_PYBIND)
546562
list(APPEND _dep_libs custom_ops)
547563
endif()
548564

565+
if(EXECUTORCH_BUILD_CUSTOM_OPS_AOT)
566+
list(APPEND _dep_libs custom_ops_aot_lib)
567+
endif()
549568
# compile options for pybind
550569

551570
set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti

backends/qualcomm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ target_link_libraries(qnn_executorch_backend
253253
qnn_executorch_header
254254
qnn_schema
255255
qnn_manager
256-
executorch
256+
executorch_no_prim_ops
257257
qcir_utils
258258
)
259259
target_link_libraries(utils

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace vkcompute {
3737
}
3838

3939
VALUE_PTR_CLASS_IMPL(vTensorPtr, vTensor, Tensor)
40+
VALUE_PTR_CLASS_IMPL(TensorRefPtr, TensorRef, TensorRef)
4041
VALUE_PTR_CLASS_IMPL(StagingPtr, api::StorageBuffer, Staging)
4142
VALUE_PTR_CLASS_IMPL(IntListPtr, std::vector<int64_t>, IntList)
4243
VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
@@ -195,18 +196,17 @@ ValueRef ComputeGraph::add_tensor(
195196
}
196197

197198
ValueRef ComputeGraph::add_tensor_like(
198-
const ValueRef vref,
199+
const ValueRef idx,
199200
const api::StorageType storage_type,
200201
const api::GPUMemoryLayout memory_layout) {
201-
TensorRef tref = get_tref(vref);
202-
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
202+
return add_tensor(
203+
get_sizes_of(idx), get_dtype_of(idx), storage_type, memory_layout);
203204
}
204205

205206
ValueRef ComputeGraph::add_tensor_like(
206-
const ValueRef vref,
207+
const ValueRef idx,
207208
const api::GPUMemoryLayout memory_layout) {
208-
TensorRef tref = get_tref(vref);
209-
return add_tensor(tref.sizes, tref.dtype, memory_layout);
209+
return add_tensor(get_sizes_of(idx), get_dtype_of(idx), memory_layout);
210210
}
211211

212212
ValueRef ComputeGraph::add_tensor(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ComputeGraph;
5555
};
5656

5757
DECL_VALUE_PTR_CLASS(vTensorPtr, vTensor)
58+
DECL_VALUE_PTR_CLASS(TensorRefPtr, TensorRef)
5859
DECL_VALUE_PTR_CLASS(StagingPtr, api::StorageBuffer)
5960
DECL_VALUE_PTR_CLASS(IntListPtr, std::vector<int64_t>)
6061
DECL_VALUE_PTR_CLASS(DoubleListPtr, std::vector<double>)
@@ -132,6 +133,7 @@ class ComputeGraph final {
132133
}
133134

134135
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(vTensorPtr, tensor, Tensor)
136+
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(TensorRefPtr, tref, TensorRef)
135137
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(StagingPtr, staging, Staging)
136138
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(IntListPtr, int_list, IntList)
137139
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(DoubleListPtr, double_list, DoubleList)
@@ -148,7 +150,6 @@ class ComputeGraph final {
148150
return values_.at(idx).is##type_name(); \
149151
}
150152

151-
GET_AND_CHECK_VAL_AS_TYPE_FNS(TensorRef, tref, TensorRef)
152153
GET_AND_CHECK_VAL_AS_TYPE_FNS(int64_t, int, Int)
153154
GET_AND_CHECK_VAL_AS_TYPE_FNS(double, double, Double)
154155
GET_AND_CHECK_VAL_AS_TYPE_FNS(bool, bool, Bool)
@@ -392,6 +393,7 @@ class ComputeGraph final {
392393
//
393394

394395
friend class vTensorPtr;
396+
friend class TensorRefPtr;
395397
friend class StagingPtr;
396398
friend class IntListPtr;
397399
friend class DoubleListPtr;

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ api::StorageBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
5656
return staging;
5757
}
5858

59-
TensorRef tref = graph->get_tref(tref_);
60-
size_t numel = api::utils::multiply_integers(tref.sizes);
61-
api::StorageBuffer staging(graph->context(), tref.dtype, numel);
62-
size_t nbytes = numel * api::element_size(tref.dtype);
63-
copy_ptr_to_staging(tref.data, staging, nbytes);
59+
TensorRefPtr tref = graph->get_tref(tref_);
60+
size_t numel = api::utils::multiply_integers(tref->sizes);
61+
api::StorageBuffer staging(graph->context(), tref->dtype, numel);
62+
size_t nbytes = numel * api::element_size(tref->dtype);
63+
copy_ptr_to_staging(tref->data, staging, nbytes);
6464
return staging;
6565
}
6666

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1919

20+
#include <iostream>
21+
2022
namespace vkcompute {
2123

2224
void resize_conv2d_node(
@@ -35,8 +37,8 @@ void resize_conv2d_node(
3537
new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4);
3638
}
3739

38-
TensorRef weight_ref = graph->get_tref(extra_args[0]);
39-
const auto& weight_sizes = weight_ref.sizes;
40+
TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
41+
const auto& weight_sizes = weight_ref->sizes;
4042
new_out_sizes.at(ndim - 3) =
4143
transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
4244

@@ -59,11 +61,14 @@ ValueRef prepack_biases(
5961
const ValueRef vref,
6062
const ValueRef weight,
6163
const bool transposed) {
62-
TensorRef tref = graph.get_tref(weight);
63-
const int64_t out_channels = transposed ? tref.sizes.at(1) : tref.sizes.at(0);
64+
auto sizes = graph.get_sizes_of(weight);
65+
const int64_t out_channels = transposed ? sizes.at(1) : sizes.at(0);
6466

6567
ValueRef v = graph.add_tensor(
66-
{out_channels}, tref.dtype, api::kTexture2D, api::kWidthPacked);
68+
{out_channels},
69+
graph.get_dtype_of(weight),
70+
api::kTexture2D,
71+
api::kWidthPacked);
6772
vTensorPtr t = graph.get_tensor(v);
6873

6974
api::ShaderInfo shader = get_nchw_to_image_shader(*t);
@@ -102,7 +107,7 @@ api::ShaderInfo get_conv2d_shader(
102107
case Conv2dMethod::Depthwise:
103108
kernel_name = "conv2d_dw";
104109
if (!prepack_weights) {
105-
const auto weight_sizes = graph.get_tref(weight).sizes;
110+
const auto& weight_sizes = graph.get_tref(weight)->sizes;
106111
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
107112
kernel_name += "_output_tile_3x3";
108113
}
@@ -180,12 +185,12 @@ ValueRef prepack_weights(
180185
ComputeGraph& graph,
181186
const ValueRef vref,
182187
const Conv2dMethod method) {
183-
const auto original_sizes = graph.get_tref(vref).sizes;
184-
const auto final_sizes = get_final_sizes(graph.get_tref(vref).sizes, method);
188+
const auto original_sizes = graph.get_sizes_of(vref);
189+
const auto final_sizes = get_final_sizes(original_sizes, method);
185190

186191
ValueRef v = graph.add_tensor(
187192
final_sizes,
188-
graph.get_tref(vref).dtype,
193+
graph.get_dtype_of(vref),
189194
api::kTexture2D,
190195
api::kChannelsPacked);
191196
vTensorPtr t = graph.get_tensor(v);
@@ -239,7 +244,7 @@ Conv2dParams create_conv2d_params(
239244
p.kernel_size.data[1] +
240245
(p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1),
241246
});
242-
const auto weight_sizes = graph.get_tref(weight).sizes;
247+
const auto weight_sizes = graph.get_sizes_of(weight);
243248
const int32_t in_group_size =
244249
api::utils::safe_downcast<int32_t>(api::utils::align_up(
245250
transposed ? weight_sizes.at(0) : weight_sizes.at(1), INT64_C(4)));
@@ -267,7 +272,7 @@ Conv2dMethod get_conv2d_method(
267272
const ValueRef weight,
268273
const int64_t groups,
269274
const bool transposed) {
270-
const auto weight_sizes = graph.get_tref(weight).sizes;
275+
const auto weight_sizes = graph.get_sizes_of(weight);
271276
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
272277
return Conv2dMethod::Depthwise;
273278
}

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ ValueRef prepack_if_tensor_ref(
9797
ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) {
9898
if (graph.val_is_tref(v)) {
9999
api::GPUMemoryLayout layout =
100-
graph.suggested_memory_layout(graph.get_tref(v).sizes);
100+
graph.suggested_memory_layout(graph.get_tref(v)->sizes);
101101
return prepack(graph, v, layout);
102102
} else {
103103
return v;

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ api::utils::ivec2 make_ivec2_kernel_size(
2121
if (kernel_size_only) {
2222
return make_ivec2_from_list(graph, weight);
2323
} else {
24-
const auto weight_sizes = graph.get_tref(weight).sizes;
24+
const auto weight_sizes = graph.get_tref(weight)->sizes;
2525
return api::utils::make_ivec2({weight_sizes.at(3), weight_sizes.at(2)});
2626
}
2727
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def virtual_resize(self, ref: ValueRefList) -> str:
248248
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
249249
if self.prepack_ref(ref):
250250
return ""
251-
ret_str = f"{self.graph}{self.dot}get_val({ref.name}.value).toTensor()"
252-
ret_str += f".virtual_resize({ref.src_cpp_name}.sizes().vec());\n"
251+
ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)"
252+
ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n"
253253
return ret_str
254254

255255
def copy_into_staging(self, ref: ValueRefList) -> str:

backends/xnnpack/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ add_library(xnnpack_backend STATIC ${_xnnpack_backend__srcs})
8181
target_link_libraries(xnnpack_backend
8282
PRIVATE
8383
${xnnpack_third_party}
84-
executorch
84+
executorch_no_prim_ops
8585
xnnpack_schema)
8686

8787
target_include_directories(xnnpack_backend

docs/source/build-run-qualcomm-ai-engine-direct-backend.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ Python APIs on x64 are required to compile models to Qualcomm AI Engine Direct b
115115

116116
```bash
117117
cd $EXECUTORCH_ROOT
118+
# Workaround for fbs files in exir/_serialize
119+
cp schema/program.fbs exir/_serialize/program.fbs
120+
cp schema/scalar_type.fbs exir/_serialize/scalar_type.fbs
121+
118122
mkdir build_x86_64
119123
cd build_x86_64
120124
cmake .. -DEXECUTORCH_BUILD_QNN=ON -DQNN_SDK_ROOT=${QNN_SDK_ROOT}
@@ -138,8 +142,8 @@ mkdir build_android
138142
cd build_android
139143
# build executorch & qnn_executorch_backend
140144
cmake .. \
141-
-DBUCK2=buck2 \
142145
-DCMAKE_INSTALL_PREFIX=$PWD \
146+
-DEXECUTORCH_BUILD_SDK=ON \
143147
-DEXECUTORCH_BUILD_QNN=ON \
144148
-DQNN_SDK_ROOT=$QNN_SDK_ROOT \
145149
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
@@ -220,6 +224,7 @@ So, we can run `qnn_executor_runner` like
220224
```bash
221225
adb push ./deeplab_v3/dlv3_qnn.pte ${DEVICE_DIR}
222226
adb push ${EXECUTORCH_ROOT}/build_android/examples/qualcomm/qnn_executor_runner ${DEVICE_DIR}
227+
adb push ${EXECUTORCH_ROOT}/build_android/lib/libqnn_executorch_backend.so ${DEVICE_DIR}
223228
adb shell "cd ${DEVICE_DIR} \
224229
&& export LD_LIBRARY_PATH=${DEVICE_DIR} \
225230
&& export ADSP_LIBRARY_PATH=${DEVICE_DIR} \

docs/source/build-run-xtensa.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Step 2. Make sure you have completed the ExecuTorch setup tutorials linked to at
6464
The working tree is:
6565

6666
```
67-
examples/xtensa/
67+
examples/cadence/
6868
├── aot
6969
├── kernels
7070
├── ops
@@ -75,7 +75,7 @@ examples/xtensa/
7575

7676
***AoT (Ahead-of-Time) Components***:
7777

78-
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
78+
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/cadence/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
7979

8080
***Operators***:
8181

@@ -101,14 +101,14 @@ python3 -m examples.portable.scripts.export --model_name="add"
101101
***Quantized Operators***:
102102

103103
The other, more complex model are custom operators, including:
104-
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
105-
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.
104+
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
105+
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/cadence/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.
106106

107107
In both cases the generated file is called `XtensaDemoModel.pte`.
108108

109109
```bash
110110
cd executorch
111-
python3 -m examples.xtensa.tests.quantized_<linear,conv1d>_example
111+
python3 -m examples.cadence.tests.quantized_<linear,conv1d>_example
112112
```
113113

114114
***Small Model: RNNT predictor***:
@@ -118,7 +118,7 @@ The predictor is a sequence of basic ops (embedding, ReLU, linear, layer norm) a
118118

119119
```bash
120120
cd executorch
121-
python3 -m examples.xtensa.tests.rnnt_predictor_quantized_example
121+
python3 -m examples.cadence.tests.rnnt_predictor_quantized_example
122122
```
123123

124124
The generated file is called `XtensaDemoModel.pte`.
@@ -131,7 +131,7 @@ In this step, you'll be building the DSP firmware image that consists of the sam
131131
***Step 1***. Configure the environment variables needed to point to the Xtensa toolchain that you have installed in the previous step. The three environment variables that need to be set include:
132132
```bash
133133
# Directory in which the Xtensa toolchain was installed
134-
export XTENSA_TOOLCHAIN=/home/user_name/xtensa/XtDevTools/install/tools
134+
export XTENSA_TOOLCHAIN=/home/user_name/cadence/XtDevTools/install/tools
135135
# The version of the toolchain that was installed. This is essentially the name of the directory
136136
# that is present in the XTENSA_TOOLCHAIN directory from above.
137137
export TOOLCHAIN_VER=RI-2021.8-linux
@@ -151,7 +151,7 @@ cd executorch
151151
rm -rf cmake-out
152152
# prebuild and install executorch library
153153
cmake -DBUCK2=buck2 \
154-
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/xtensa/xtensa.cmake \
154+
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/cadence/cadence.cmake \
155155
-DCMAKE_INSTALL_PREFIX=cmake-out \
156156
-DCMAKE_BUILD_TYPE=Debug \
157157
-DPYTHON_EXECUTABLE=python3 \
@@ -165,18 +165,18 @@ cmake -DBUCK2=buck2 \
165165
-Bcmake-out .
166166

167167
cmake --build cmake-out -j8 --target install --config Debug
168-
# build xtensa runner
168+
# build cadence runner
169169
cmake -DCMAKE_BUILD_TYPE=Debug \
170-
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/xtensa/xtensa.cmake \
170+
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/cadence/cadence.cmake \
171171
-DCMAKE_PREFIX_PATH=<path_to_executorch>/cmake-out \
172172
-DMODEL_PATH=<path_to_program_file_generated_in_previous_step> \
173173
-DNXP_SDK_ROOT_DIR=<path_to_nxp_sdk_root> -DEXECUTORCH_BUILD_FLATC=0 \
174174
-DFLATC_EXECUTABLE="$(which flatc)" \
175175
-DNN_LIB_BASE_DIR=<path_to_nnlib_cloned_in_step_2> \
176-
-Bcmake-out/examples/xtensa \
177-
examples/xtensa
176+
-Bcmake-out/examples/cadence \
177+
examples/cadence
178178

179-
cmake --build cmake-out/examples/xtensa -j8 -t xtensa_executorch_example
179+
cmake --build cmake-out/examples/cadence -j8 -t cadence_executorch_example
180180
```
181181

182182
After having succesfully run the above step you should see two binary files in their CMake output directory.
@@ -213,6 +213,6 @@ First 20 elements of output 0
213213

214214
In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip.
215215

216-
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
216+
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/cadence/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/cadence/kernels).
217217

218218
Other models can be created following the same structure, always assuming that operators and kernels are available.

0 commit comments

Comments
 (0)