Skip to content

Commit f8bc774

Browse files
cad-audiodijopaulnishpooniamcremon-meta
authored
HiFi optimizations for mean, where, min, max, pow, rem and quantized_linear operators. (#6867)
* Adding mean and where ops optimized on HiFi * Adding quantized linear optimized versions for int8 and uint8 * adding pow, remainder, minimum, maximum operators (#33) * adding pow, remainder, minimum, maximum operators * adding pow, remainder, minimum, maximum operators * Fix for build issue faced in div_mod on old tools * Fix build failure due to merge issue * Fixing review comments on PR 6867 --------- Co-authored-by: dijopaul <[email protected]> Co-authored-by: nishpoonia <[email protected]> Co-authored-by: mcremon-meta <[email protected]>
1 parent a8fa857 commit f8bc774

File tree

13 files changed

+3478
-22
lines changed

13 files changed

+3478
-22
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,20 @@
7777
- arg_meta: null
7878
kernel_name: torch::executor::max_pool2d_with_indices_out
7979

80+
- op: maximum.out
81+
kernels:
82+
- arg_meta: null
83+
kernel_name: cadence::impl::HiFi::maximum_out
84+
8085
- op: mean.out
8186
kernels:
8287
- arg_meta: null
83-
kernel_name: cadence::impl::HiFi::mean_dim_out
88+
kernel_name: cadence::impl::HiFi::mean_dim_out
89+
90+
- op: minimum.out
91+
kernels:
92+
- arg_meta: null
93+
kernel_name: cadence::impl::HiFi::minimum_out
8494

8595
- op: mul.out
8696
kernels:
@@ -92,6 +102,26 @@
92102
- arg_meta: null
93103
kernel_name: torch::executor::permute_copy_out
94104

105+
- op: pow.Scalar_out
106+
kernels:
107+
- arg_meta: null
108+
kernel_name: cadence::impl::HiFi::pow_Scalar_out
109+
110+
- op: pow.Tensor_Scalar_out
111+
kernels:
112+
- arg_meta: null
113+
kernel_name: cadence::impl::HiFi::pow_Tensor_Scalar_out
114+
115+
- op: pow.Tensor_Tensor_out
116+
kernels:
117+
- arg_meta: null
118+
kernel_name: cadence::impl::HiFi::pow_Tensor_Tensor_out
119+
120+
- op: rsqrt.out
121+
kernels:
122+
- arg_meta: null
123+
kernel_name: cadence::impl::HiFi::rsqrt_out
124+
95125
- op: sigmoid.out
96126
kernels:
97127
- arg_meta: null

backends/cadence/hifi/kernels/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ add_library(
99
cadence_kernels
1010
kernels.cpp
1111
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp
12+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_broadcast_32.c
1213
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c
1314
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_f32_broadcast.c
1415
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_mode_f32_broadcast.c
16+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c
1517
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c
18+
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_pow_f32.c
1619
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c
1720
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c
1821
)

backends/cadence/hifi/kernels/kernels.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
#include "xa_nnlib_kernels_api.h"
1616

1717
/* Potential NNLIB function/APIs */
18+
19+
extern "C" WORD32 xa_nn_broadcast_32_32(
20+
WORD32* __restrict__ p_out,
21+
const int* const out_shape,
22+
WORD32* __restrict__ p_in,
23+
const int* const in_shape,
24+
int num_dims);
25+
1826
extern "C" WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(
1927
FLOAT32* __restrict__ p_out,
2028
const WORD32* const p_out_shape,
@@ -47,6 +55,34 @@ extern "C" WORD32 xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32(
4755
const WORD32* const p_inp2_shape,
4856
WORD32 mode);
4957

58+
extern "C" WORD32 xa_nn_elm_maximum_f32xf32_f32(
59+
FLOAT32* __restrict__ p_out,
60+
const FLOAT32* __restrict__ p_inp1,
61+
const FLOAT32* __restrict__ p_inp2,
62+
WORD32 num_elm);
63+
64+
extern "C" WORD32 xa_nn_elm_maximum_broadcast_4D_f32xf32_f32(
65+
FLOAT32* __restrict__ p_out,
66+
const WORD32* const p_out_shape,
67+
const FLOAT32* __restrict__ p_inp1,
68+
const WORD32* const p_inp1_shape,
69+
const FLOAT32* __restrict__ p_inp2,
70+
const WORD32* const p_inp2_shape);
71+
72+
extern "C" WORD32 xa_nn_elm_minimum_f32xf32_f32(
73+
FLOAT32* __restrict__ p_out,
74+
const FLOAT32* __restrict__ p_inp1,
75+
const FLOAT32* __restrict__ p_inp2,
76+
WORD32 num_elm);
77+
78+
extern "C" WORD32 xa_nn_elm_minimum_broadcast_4D_f32xf32_f32(
79+
FLOAT32* __restrict__ p_out,
80+
const WORD32* const p_out_shape,
81+
const FLOAT32* __restrict__ p_inp1,
82+
const WORD32* const p_inp1_shape,
83+
const FLOAT32* __restrict__ p_inp2,
84+
const WORD32* const p_inp2_shape);
85+
5086
extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(
5187
FLOAT32* __restrict__ p_out,
5288
const WORD32* const p_out_shape,
@@ -55,6 +91,12 @@ extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(
5591
const FLOAT32* __restrict__ p_inp2,
5692
const WORD32* const p_inp2_shape);
5793

94+
extern "C" void xa_nn_elm_pow_f32(
95+
FLOAT32* restrict z,
96+
const FLOAT32* restrict x,
97+
const FLOAT32* restrict y,
98+
WORD32 N);
99+
58100
extern "C" WORD32 xa_nn_elm_where_f32xf32_f32(
59101
FLOAT32* __restrict__ p_out,
60102
const FLOAT32* __restrict__ p_inp1,

backends/cadence/hifi/operators/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@ endif()
2222
set(_aten_ops__srcs
2323
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp"
2424
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
25+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_maximum.cpp"
2526
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp"
27+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_minimum.cpp"
2628
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp"
29+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_pow.cpp"
30+
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_rsqrt.cpp"
2731
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp"
2832
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp"
2933
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp"
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
11+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/kernels/portable/cpu/util/math_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
using exec_aten::ScalarType;
16+
using exec_aten::Tensor;
17+
using executorch::aten::RuntimeContext;
18+
using executorch::runtime::can_cast;
19+
using executorch::runtime::canCast;
20+
using executorch::runtime::CppTypeToScalarType;
21+
using executorch::runtime::promoteTypes;
22+
using torch::executor::apply_binary_elementwise_fn;
23+
using torch::executor::Error;
24+
using torch::executor::resize_to_broadcast_target_size;
25+
26+
27+
namespace cadence {
28+
namespace impl {
29+
namespace HiFi {
30+
namespace native {
31+
namespace {
32+
33+
template <
34+
bool can_cast,
35+
typename CTYPE_A,
36+
typename CTYPE_B,
37+
typename CTYPE_IN,
38+
typename CTYPE_OUT>
39+
struct MaximumInner;
40+
41+
template <
42+
typename CTYPE_A,
43+
typename CTYPE_B,
44+
typename CTYPE_IN,
45+
typename CTYPE_OUT>
46+
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
47+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
48+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
49+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
50+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
51+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
52+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
53+
CTYPE_IN value =
54+
torch::executor::native::utils::max_override(a_casted, b_casted);
55+
56+
return static_cast<CTYPE_OUT>(value);
57+
},
58+
a,
59+
b,
60+
out);
61+
}
62+
};
63+
64+
struct ReportCanCastBug {
65+
static void run(const Tensor&, const Tensor&, Tensor&) {
66+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
67+
}
68+
};
69+
70+
template <
71+
typename CTYPE_A,
72+
typename CTYPE_B,
73+
typename CTYPE_IN,
74+
typename CTYPE_OUT>
75+
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
76+
: public ReportCanCastBug {};
77+
78+
} // namespace
79+
80+
Tensor& maximum_out(
81+
RuntimeContext& ctx,
82+
const Tensor& a,
83+
const Tensor& b,
84+
Tensor& out) {
85+
(void)ctx;
86+
87+
ET_KERNEL_CHECK(
88+
ctx,
89+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
90+
InvalidArgument,
91+
out);
92+
93+
constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
94+
95+
ScalarType a_type = a.scalar_type();
96+
ScalarType b_type = b.scalar_type();
97+
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
98+
ScalarType out_type = out.scalar_type();
99+
100+
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
101+
102+
bool optimized = true;
103+
/*find broadcast*/
104+
bool a_is_broadcasted = !out.sizes().equals(a.sizes());
105+
bool b_is_broadcasted = !out.sizes().equals(b.sizes());
106+
bool broadcast = (a_is_broadcasted || b_is_broadcasted);
107+
108+
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
109+
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
110+
111+
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
112+
optimized = false;
113+
if ((broadcast == true) && (max_dim > kNnlibMaxDim))
114+
optimized = false;
115+
116+
if (optimized) {
117+
float* a_data = a.mutable_data_ptr<float>();
118+
float* b_data = b.mutable_data_ptr<float>();
119+
float* out_data = out.mutable_data_ptr<float>();
120+
121+
if (broadcast == true) {
122+
int out_shape[kNnlibMaxDim];
123+
int inp1_shape[kNnlibMaxDim];
124+
int inp2_shape[kNnlibMaxDim];
125+
126+
for (int i = 0; i < kNnlibMaxDim; i++) {
127+
out_shape[i] = 1;
128+
inp1_shape[i] = 1;
129+
inp2_shape[i] = 1;
130+
}
131+
132+
int off_o = kNnlibMaxDim - out.dim();
133+
int off_a = kNnlibMaxDim - a.dim();
134+
int off_b = kNnlibMaxDim - b.dim();
135+
136+
for (int i = 0; i < out.dim(); i++) {
137+
out_shape[i + off_o] = out.size(i);
138+
}
139+
140+
for (int i = 0; i < a.dim(); i++)
141+
inp1_shape[i + off_a] = a.size(i);
142+
143+
for (int i = 0; i < b.dim(); i++)
144+
inp2_shape[i + off_b] = b.size(i);
145+
146+
xa_nn_elm_maximum_broadcast_4D_f32xf32_f32(
147+
out_data, out_shape, a_data, inp1_shape, b_data, inp2_shape);
148+
} else {
149+
xa_nn_elm_maximum_f32xf32_f32(out_data, a_data, b_data, out.numel());
150+
}
151+
return out;
152+
}
153+
ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
154+
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
155+
using CTYPE_IN = typename torch::executor::
156+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
157+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
158+
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
159+
MaximumInner<
160+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
161+
CTYPE_A,
162+
CTYPE_B,
163+
CTYPE_IN,
164+
CTYPE_OUT>::run(a, b, out);
165+
});
166+
});
167+
});
168+
169+
return out;
170+
}
171+
172+
} // namespace native
173+
} // namespace HiFi
174+
} // namespace impl
175+
} // namespace cadence

0 commit comments

Comments
 (0)