Skip to content

Commit 38346fd

Browse files
cad-audiodijopaul
andauthored
Added HiFi optimized mean and where ops. (#6483)
Adding mean and where ops optimized on HiFi Co-authored-by: dijopaul <[email protected]>
1 parent 4bbe994 commit 38346fd

File tree

8 files changed

+1870
-10
lines changed

8 files changed

+1870
-10
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@
6262
- arg_meta: null
6363
kernel_name: torch::executor::full_out
6464

65+
- op: mean.out
66+
kernels:
67+
- arg_meta: null
68+
kernel_name: cadence::impl::HiFi::mean_dim_out
69+
6570
- op: mul.out
6671
kernels:
6772
- arg_meta: null
@@ -105,7 +110,7 @@
105110
- op: where.self_out
106111
kernels:
107112
- arg_meta: null
108-
kernel_name: torch::executor::where_out
113+
kernel_name: cadence::impl::HiFi::where_out
109114

110115
# custom ops
111116
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)

backends/cadence/hifi/kernels/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ add_library(
1313
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c
1414
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c
1515
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c
16+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c
17+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c
1618
)
1719
# Let files say "include <executorch/path/to/header.h>".
1820
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

backends/cadence/hifi/kernels/kernels.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,34 @@ extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(
5555
const FLOAT32* __restrict__ p_inp2,
5656
const WORD32* const p_inp2_shape);
5757

58+
extern "C" WORD32 xa_nn_elm_where_f32xf32_f32(
59+
FLOAT32* __restrict__ p_out,
60+
const FLOAT32* __restrict__ p_inp1,
61+
const FLOAT32* __restrict__ p_inp2,
62+
const unsigned char* __restrict__ p_condition,
63+
WORD32 num_elm);
64+
65+
extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32(
66+
FLOAT32* __restrict__ p_out,
67+
const WORD32* const p_out_shape,
68+
const FLOAT32* __restrict__ p_inp1,
69+
const WORD32* const p_inp1_shape,
70+
const FLOAT32* __restrict__ p_inp2,
71+
const WORD32* const p_inp2_shape,
72+
const unsigned char* __restrict__ p_condition,
73+
const WORD32* const p_condition_shape);
74+
75+
extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32(
76+
FLOAT32* __restrict__ p_out,
77+
const WORD32* const p_out_shape,
78+
const FLOAT32* __restrict__ p_inp,
79+
const WORD32* const p_inp_shape,
80+
const WORD32* __restrict__ p_axis,
81+
WORD32 num_out_dims,
82+
WORD32 num_inp_dims,
83+
WORD32 num_axis_dims,
84+
void* __restrict__ p_scratch_in);
85+
5886
namespace cadence {
5987
namespace impl {
6088
namespace HiFi {

backends/cadence/hifi/operators/CMakeLists.txt

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,12 @@ endif()
2222
set(_aten_ops__srcs
2323
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp"
2424
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
25+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp"
2526
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp"
2627
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp"
2728
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp"
2829
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp"
29-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp"
30-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp"
31-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp"
32-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp"
33-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp"
34-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp"
35-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp"
36-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
37-
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp"
30+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp"
3831
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
3932
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
4033
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
@@ -57,6 +50,7 @@ set(_aten_ops__srcs
5750
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp"
5851
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp"
5952
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp"
53+
"${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp"
6054
)
6155
add_library(aten_ops_cadence ${_aten_ops__srcs})
6256
target_link_libraries(aten_ops_cadence PUBLIC executorch)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
10+
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
15+
16+
using exec_aten::ScalarType;
17+
using exec_aten::Tensor;
18+
using executorch::aten::RuntimeContext;
19+
using executorch::runtime::ArrayRef;
20+
using torch::executor::Error;
21+
using torch::executor::optional;
22+
23+
namespace cadence {
24+
namespace impl {
25+
namespace HiFi {
26+
namespace native {
27+
28+
int prepare_data(
29+
const Tensor& in,
30+
Tensor& out,
31+
optional<ArrayRef<int64_t>> dim_list,
32+
int* inp_shape,
33+
int* out_shape,
34+
int* p_axis,
35+
int num_inp_dims,
36+
int num_out_dims) {
37+
for (int i = 0; i < num_inp_dims; i++) {
38+
inp_shape[i] = in.size(i);
39+
}
40+
41+
for (int i = 0; i < num_out_dims; i++) {
42+
out_shape[i] = out.size(i);
43+
}
44+
45+
int num_axis_dims = 0;
46+
for (const auto& d : dim_list.value()) {
47+
if (d < 0) {
48+
p_axis[num_axis_dims] = num_inp_dims + d;
49+
num_axis_dims++;
50+
} else {
51+
p_axis[num_axis_dims] = d;
52+
num_axis_dims++;
53+
}
54+
}
55+
56+
return num_axis_dims;
57+
}
58+
59+
Tensor& mean_dim_out(
60+
RuntimeContext& ctx,
61+
const Tensor& in,
62+
optional<ArrayRef<int64_t>> dim_list,
63+
bool keepdim,
64+
optional<ScalarType> dtype,
65+
Tensor& out) {
66+
ET_KERNEL_CHECK(
67+
ctx,
68+
torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out),
69+
InvalidArgument,
70+
out);
71+
72+
ET_KERNEL_CHECK(
73+
ctx,
74+
torch::executor::resize_reduction_out(in, dim_list, keepdim, out) ==
75+
Error::Ok,
76+
InvalidArgument,
77+
out);
78+
79+
constexpr auto name = "mean.out";
80+
constexpr int kNnlibMaxDim = 4;
81+
82+
bool optimized = 1;
83+
84+
if (out.scalar_type() != ScalarType::Float)
85+
optimized = 0;
86+
87+
if (in.dim() > kNnlibMaxDim)
88+
optimized = 0;
89+
90+
if (optimized) {
91+
float* __restrict__ p_out = out.mutable_data_ptr<float>();
92+
const float* __restrict__ p_inp =
93+
(const float* __restrict__)in.const_data_ptr<float>();
94+
95+
int num_elm = in.numel();
96+
97+
int num_inp_dims = in.dim();
98+
int num_out_dims = out.dim();
99+
100+
int inp_shape[kNnlibMaxDim];
101+
int out_shape[kNnlibMaxDim];
102+
int p_axis[kNnlibMaxDim];
103+
104+
for (int i = 0; i < kNnlibMaxDim; i++) {
105+
out_shape[i] = 1;
106+
inp_shape[i] = 1;
107+
p_axis[i] = 1;
108+
}
109+
110+
int num_axis_dims = prepare_data(
111+
in,
112+
out,
113+
dim_list,
114+
inp_shape,
115+
out_shape,
116+
p_axis,
117+
num_inp_dims,
118+
num_out_dims);
119+
120+
if (num_axis_dims == num_inp_dims) {
121+
num_out_dims = 1;
122+
out_shape[0] = 1;
123+
}
124+
125+
int scratch_size = xa_nn_reduce_getsize_nhwc(
126+
-3, inp_shape, num_inp_dims, p_axis, num_axis_dims, 1);
127+
128+
void* __restrict__ p_scratch_in = (void* __restrict__)malloc(scratch_size);
129+
130+
xa_nn_reduce_mean_4D_f32_f32(
131+
p_out,
132+
out_shape,
133+
p_inp,
134+
inp_shape,
135+
p_axis,
136+
num_out_dims,
137+
num_inp_dims,
138+
num_axis_dims,
139+
p_scratch_in);
140+
141+
return out;
142+
}
143+
144+
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
145+
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
146+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
147+
const size_t num = torch::executor::get_reduced_dim_product(in, dim_list);
148+
149+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
150+
CTYPE_OUT sum = 0;
151+
if (in.numel() > 0) {
152+
sum = torch::executor::map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
153+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
154+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
155+
in,
156+
dim_list,
157+
out_ix);
158+
}
159+
out_data[out_ix] = sum / static_cast<float>(num);
160+
}
161+
});
162+
});
163+
164+
return out;
165+
}
166+
167+
} // namespace native
168+
} // namespace HiFi
169+
} // namespace impl
170+
} // namespace cadence

0 commit comments

Comments
 (0)