Skip to content

Commit ff1d6af

Browse files
authored
Optimized atan2, _softmax, cat, clamp, full, relu, remainder, permute_copy_out ops and updates to use memory_allocator
Differential Revision: D68446171 Pull Request resolved: #7567
1 parent 08b192a commit ff1d6af

23 files changed

+4501
-219
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
- op: _softmax.out
2121
kernels:
2222
- arg_meta: null
23-
kernel_name: torch::executor::softmax_out
23+
kernel_name: cadence::impl::HiFi::softmax_out
24+
25+
- op: atan2.out
26+
kernels:
27+
- arg_meta: null
28+
kernel_name: cadence::impl::HiFi::atan2_out
2429

2530
- op: add.out
2631
kernels:
@@ -35,7 +40,12 @@
3540
- op: cat.out
3641
kernels:
3742
- arg_meta: null
38-
kernel_name: torch::executor::cat_out
43+
kernel_name: cadence::impl::HiFi::cat_out
44+
45+
- op: clamp.Tensor_out
46+
kernels:
47+
- arg_meta: null
48+
kernel_name: cadence::impl::HiFi::clamp_tensor_out
3949

4050
- op: clone.out
4151
kernels:
@@ -60,7 +70,12 @@
6070
- op: full.out
6171
kernels:
6272
- arg_meta: null
63-
kernel_name: torch::executor::full_out
73+
kernel_name: cadence::impl::HiFi::full_out
74+
75+
- op: gt.Scalar_out
76+
kernels:
77+
- arg_meta: null
78+
kernel_name: torch::executor::gt_scalar_out
6479

6580
- op: gelu.out
6681
kernels:
@@ -100,7 +115,7 @@
100115
- op: permute_copy.out
101116
kernels:
102117
- arg_meta: null
103-
kernel_name: torch::executor::permute_copy_out
118+
kernel_name: cadence::impl::HiFi::permute_copy_out
104119

105120
- op: pow.Scalar_out
106121
kernels:
@@ -117,6 +132,11 @@
117132
- arg_meta: null
118133
kernel_name: cadence::impl::HiFi::pow_Tensor_Tensor_out
119134

135+
- op: remainder.Tensor_out
136+
kernels:
137+
- arg_meta: null
138+
kernel_name: cadence::impl::HiFi::remainder_Tensor_out
139+
120140
- op: rsqrt.out
121141
kernels:
122142
- arg_meta: null
@@ -170,7 +190,6 @@
170190
- arg_meta: null
171191
kernel_name: cadence::impl::HiFi::dequantize_per_tensor_out
172192

173-
174193
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
175194
kernels:
176195
- arg_meta: null
@@ -184,6 +203,12 @@
184203
kernels:
185204
- arg_meta: null
186205
kernel_name: cadence::impl::HiFi::quantized_linear_out
206+
207+
- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
208+
kernels:
209+
- arg_meta: null
210+
kernel_name: cadence::impl::HiFi::quantized_relu_out
211+
187212
- func: cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
188213
kernels:
189214
- arg_meta: null

backends/cadence/hifi/kernels/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@ add_library(
1010
kernels.cpp
1111
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp
1212
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_broadcast_32.c
13+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_concat_32.c
1314
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c
15+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c
16+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_clamp_f32_broadcast.c
1417
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c
1518
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c
1619
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c
1720
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c
1821
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c
22+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_remainder_broadcast_f32.c
1923
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c
2024
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c
25+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c
2126
)
2227
# Let files say "include <executorch/path/to/header.h>".
2328
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

backends/cadence/hifi/kernels/kernels.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ memcpy(void* dst, const void* src, size_t num_bytes) {
2020
MEMCPY_8b(dst, src, num_bytes);
2121
}
2222

23+
void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) {
24+
Result<void*> temp_mem_res = ctx.allocate_temp(size);
25+
return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
26+
}
27+
2328
// Quantize a fp32 value to an int8_t/uint8_t value
2429
template <typename T>
2530
__attribute__((always_inline)) T

backends/cadence/hifi/kernels/kernels.h

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
*/
88

99
#pragma once
10-
10+
#include <executorch/runtime/kernel/kernel_includes.h>
1111
#include <inttypes.h>
1212
#include <stddef.h>
1313
#include <xa_type_def.h>
1414
/* For NNLIB APIs */
1515
#include "xa_nnlib_kernels_api.h"
1616

17+
using executorch::runtime::KernelRuntimeContext;
18+
using executorch::runtime::Result;
19+
1720
/* Potential NNLIB function/APIs */
1821

1922
extern "C" WORD32 xa_nn_broadcast_32_32(
@@ -23,6 +26,16 @@ extern "C" WORD32 xa_nn_broadcast_32_32(
2326
const int* const in_shape,
2427
int num_dims);
2528

29+
extern "C" WORD32 xa_nn_concat_32_32(
30+
WORD32* __restrict__ p_out,
31+
const WORD32* const p_out_shape,
32+
const WORD32** pp_inps,
33+
const WORD32* const* pp_inps_shape,
34+
WORD32 num_out_dims,
35+
WORD32 num_inp,
36+
WORD32 num_inp_dims,
37+
WORD32 axis);
38+
2639
extern "C" WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(
2740
FLOAT32* __restrict__ p_out,
2841
const WORD32* const p_out_shape,
@@ -31,6 +44,26 @@ extern "C" WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(
3144
const FLOAT32* __restrict__ p_inp2,
3245
const WORD32* const p_inp2_shape);
3346

47+
extern "C" void
48+
xa_nn_elm_atan2_f32(FLOAT32* z, const FLOAT32* y, const FLOAT32* x, WORD32 N);
49+
50+
extern "C" WORD32 xa_nn_elm_clamp_f32xf32xf32_f32(
51+
FLOAT32* __restrict__ p_out,
52+
const FLOAT32* __restrict__ p_inp,
53+
const FLOAT32* __restrict__ p_min,
54+
const FLOAT32* __restrict__ p_max,
55+
WORD32 num_elm);
56+
57+
extern "C" WORD32 xa_nn_elm_clamp_broadcast_4D_f32Xf32xf32_f32(
58+
FLOAT32* __restrict__ p_out,
59+
const WORD32* const p_out_shape,
60+
const FLOAT32* __restrict__ p_inp,
61+
const WORD32* const p_inp_shape,
62+
const FLOAT32* __restrict__ p_min,
63+
const WORD32* const p_min_shape,
64+
const FLOAT32* __restrict__ p_max,
65+
const WORD32* const p_max_shape);
66+
3467
extern "C" WORD32 xa_nn_elm_div_broadcast_4D_f32xf32_f32(
3568
FLOAT32* __restrict__ p_out,
3669
const WORD32* const p_out_shape,
@@ -97,6 +130,20 @@ extern "C" void xa_nn_elm_pow_f32(
97130
const FLOAT32* __restrict__ y,
98131
WORD32 N);
99132

133+
extern "C" WORD32 xa_nn_elm_remainder_f32xf32_f32(
134+
FLOAT32* __restrict__ p_out,
135+
const FLOAT32* __restrict__ p_inp1,
136+
const FLOAT32* __restrict__ p_inp2,
137+
WORD32 num_elm);
138+
139+
extern "C" WORD32 xa_nn_elm_remainder_broadcast_4D_f32xf32_f32(
140+
FLOAT32* __restrict__ p_out,
141+
const WORD32* const p_out_shape,
142+
const FLOAT32* __restrict__ p_inp1,
143+
const WORD32* const p_inp1_shape,
144+
const FLOAT32* __restrict__ p_inp2,
145+
const WORD32* const p_inp2_shape);
146+
100147
extern "C" WORD32 xa_nn_elm_where_f32xf32_f32(
101148
FLOAT32* __restrict__ p_out,
102149
const FLOAT32* __restrict__ p_inp1,
@@ -125,11 +172,22 @@ extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32(
125172
WORD32 num_axis_dims,
126173
void* __restrict__ p_scratch_in);
127174

175+
extern "C" WORD32 xa_nn_transpose_32_32(
176+
WORD32* __restrict__ p_out,
177+
const WORD32* const p_out_shape,
178+
const WORD32* __restrict__ p_inp,
179+
const WORD32* const p_inp_shape,
180+
const WORD32* __restrict__ p_permute_vec,
181+
WORD32 num_out_dims,
182+
WORD32 num_inp_dims);
183+
128184
namespace cadence {
129185
namespace impl {
130186
namespace HiFi {
131187
namespace kernels {
132188

189+
void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size);
190+
133191
void memcpy(void* dst, const void* src, size_t num_bytes);
134192

135193
WORD32 matmul_asym8uxasym8u_asym8u(

backends/cadence/hifi/operators/CMakeLists.txt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,35 @@ endif()
2121
# ATen compliant ops that are needed to run this model.
2222
set(_aten_ops__srcs
2323
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp"
24+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_atan2.cpp"
25+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_cat.cpp"
26+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_clamp.cpp"
2427
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
28+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_full.cpp"
2529
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_maximum.cpp"
2630
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp"
2731
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_minimum.cpp"
2832
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp"
33+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_permute_copy.cpp"
2934
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_pow.cpp"
35+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_remainder.cpp"
3036
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_rsqrt.cpp"
37+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_softmax.cpp"
3138
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp"
3239
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp"
3340
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp"
3441
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp"
3542
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
36-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
3743
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
3844
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp"
39-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_full.cpp"
45+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gt.cpp"
4046
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gelu.cpp"
4147
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_hardtanh.cpp"
4248
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_max_pool2d_with_indices.cpp"
43-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp"
4449
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
45-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
4650
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp"
4751
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp"
4852
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_view_copy.cpp"
49-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp"
5053
"${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp"
5154
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp"
5255
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp"
@@ -74,7 +77,7 @@ target_include_directories(
7477
# Custom ops that are needed to run the test model.
7578
add_library(
7679
custom_ops "quantized_linear_out.cpp" "quantized_layer_norm.cpp"
77-
"quantize_per_tensor.cpp" "dequantize_per_tensor.cpp"
80+
"quantize_per_tensor.cpp" "quantized_relu_out.cpp" "dequantize_per_tensor.cpp"
7881
)
7982
target_include_directories(
8083
custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR}

0 commit comments

Comments
 (0)