Skip to content

Commit 1a6ba2e

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Reduce build size of op_cumsum (#6021)
Summary: Pull Request resolved: #6021 98 K -> 9 K ghstack-source-id: 246985127 exported-using-ghexport Reviewed By: malfet, swolchok Differential Revision: D63997235 fbshipit-source-id: c78379335ea51d84891a0ed414944cdaca3695c7
1 parent d989680 commit 1a6ba2e

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

kernels/portable/cpu/op_cumsum.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
910
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112
#include <executorch/runtime/platform/assert.h>
@@ -34,17 +35,22 @@ namespace {
3435
* the memory level, thereby increasing the speed of memory IO as
3536
* well as reducing the number of cache misses.
3637
*/
37-
template <typename CTYPE_IN, typename CTYPE_OUT>
38-
void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) {
38+
template <typename CTYPE_OUT, typename LoadFn = CTYPE_OUT (*)(const void*)>
39+
void cumsum_tensors(
40+
const Tensor& self,
41+
LoadFn load_self,
42+
int64_t dim,
43+
Tensor& out) {
3944
if (self.numel() == 0) {
4045
return;
4146
}
4247

43-
const CTYPE_IN* input_data_base = self.const_data_ptr<CTYPE_IN>();
48+
const char* const input_data_base =
49+
reinterpret_cast<const char*>(self.const_data_ptr());
4450
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
4551

4652
if (self.dim() == 0) {
47-
output_data_base[0] = input_data_base[0];
53+
output_data_base[0] = load_self(&input_data_base[0]);
4854
return;
4955
}
5056

@@ -57,15 +63,16 @@ void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) {
5763

5864
for (size_t idx = 0; idx < trailing_dims; idx++) {
5965
output_data_base[start_loc + idx] =
60-
static_cast<CTYPE_OUT>(input_data_base[start_loc + idx]);
66+
load_self(&input_data_base[(start_loc + idx) * self.element_size()]);
6167
}
6268

6369
for (size_t j = 1; j < dim_size; j++) {
6470
size_t cur_round_base = start_loc + j * trailing_dims;
6571
size_t prev_round_base = start_loc + (j - 1) * trailing_dims;
6672
for (size_t idx = 0; idx < trailing_dims; idx++) {
6773
output_data_base[cur_round_base + idx] =
68-
static_cast<CTYPE_OUT>(input_data_base[cur_round_base + idx]) +
74+
load_self(&input_data_base
75+
[(cur_round_base + idx) * self.element_size()]) +
6976
output_data_base[prev_round_base + idx];
7077
}
7178
}
@@ -101,13 +108,15 @@ Tensor& cumsum_out(
101108

102109
dim = (self.dim() == 0) ? 0 : dim < 0 ? dim + self.dim() : dim;
103110

104-
ET_SWITCH_REAL_TYPES_AND(
105-
Bool, self.scalar_type(), ctx, "cumsum", CTYPE_SELF, [&] {
106-
ET_SWITCH_REAL_TYPES_AND(
107-
Bool, out.scalar_type(), ctx, "cumsum", CTYPE_OUT, [&] {
108-
cumsum_tensors<CTYPE_SELF, CTYPE_OUT>(self, dim, out);
109-
});
110-
});
111+
// @lint-ignore CLANGTIDY facebook-hte-CArray
112+
static constexpr const char op_name[] = "cumsum.out";
113+
114+
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
115+
const auto load_self =
116+
utils::internal::get_load_to_common_fn<CTYPE_OUT, op_name>(
117+
self, utils::SupportedTensorDtypes::REALHBBF16);
118+
cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119+
});
111120

112121
return out;
113122
}

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ ATEN_OPS = (
447447
op_target(
448448
name = "op_cumsum",
449449
deps = [
450+
"//executorch/kernels/portable/cpu/util:dtype_util",
450451
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
451452
"//executorch/runtime/core/exec_aten/util:tensor_util",
452453
"//executorch/kernels/portable/cpu/util:kernel_ops_util",

0 commit comments

Comments
 (0)