6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < executorch/kernels/portable/cpu/util/dtype_util.h>
9
10
#include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
10
11
#include < executorch/runtime/kernel/kernel_includes.h>
11
12
#include < executorch/runtime/platform/assert.h>
@@ -34,17 +35,22 @@ namespace {
34
35
* the memory level, thereby increasing the speed of memory IO as
35
36
* well as reducing the number of cache misses.
36
37
*/
37
- template <typename CTYPE_IN, typename CTYPE_OUT>
38
- void cumsum_tensors (const Tensor& self, int64_t dim, Tensor& out) {
38
+ template <typename CTYPE_OUT, typename LoadFn = CTYPE_OUT (*)(const void *)>
39
+ void cumsum_tensors (
40
+ const Tensor& self,
41
+ LoadFn load_self,
42
+ int64_t dim,
43
+ Tensor& out) {
39
44
if (self.numel () == 0 ) {
40
45
return ;
41
46
}
42
47
43
- const CTYPE_IN* input_data_base = self.const_data_ptr <CTYPE_IN>();
48
+ const char * const input_data_base =
49
+ reinterpret_cast <const char *>(self.const_data_ptr ());
44
50
CTYPE_OUT* output_data_base = out.mutable_data_ptr <CTYPE_OUT>();
45
51
46
52
if (self.dim () == 0 ) {
47
- output_data_base[0 ] = input_data_base[0 ];
53
+ output_data_base[0 ] = load_self (& input_data_base[0 ]) ;
48
54
return ;
49
55
}
50
56
@@ -57,15 +63,16 @@ void cumsum_tensors(const Tensor& self, int64_t dim, Tensor& out) {
57
63
58
64
for (size_t idx = 0 ; idx < trailing_dims; idx++) {
59
65
output_data_base[start_loc + idx] =
60
- static_cast <CTYPE_OUT>( input_data_base[start_loc + idx]);
66
+ load_self (& input_data_base[( start_loc + idx) * self. element_size () ]);
61
67
}
62
68
63
69
for (size_t j = 1 ; j < dim_size; j++) {
64
70
size_t cur_round_base = start_loc + j * trailing_dims;
65
71
size_t prev_round_base = start_loc + (j - 1 ) * trailing_dims;
66
72
for (size_t idx = 0 ; idx < trailing_dims; idx++) {
67
73
output_data_base[cur_round_base + idx] =
68
- static_cast <CTYPE_OUT>(input_data_base[cur_round_base + idx]) +
74
+ load_self (&input_data_base
75
+ [(cur_round_base + idx) * self.element_size ()]) +
69
76
output_data_base[prev_round_base + idx];
70
77
}
71
78
}
@@ -101,13 +108,15 @@ Tensor& cumsum_out(
101
108
102
109
dim = (self.dim () == 0 ) ? 0 : dim < 0 ? dim + self.dim () : dim;
103
110
104
- ET_SWITCH_REAL_TYPES_AND (
105
- Bool, self.scalar_type (), ctx, " cumsum" , CTYPE_SELF, [&] {
106
- ET_SWITCH_REAL_TYPES_AND (
107
- Bool, out.scalar_type (), ctx, " cumsum" , CTYPE_OUT, [&] {
108
- cumsum_tensors<CTYPE_SELF, CTYPE_OUT>(self, dim, out);
109
- });
110
- });
111
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
112
+ static constexpr const char op_name[] = " cumsum.out" ;
113
+
114
+ ET_SWITCH_REALHBBF16_TYPES (out.scalar_type (), ctx, op_name, CTYPE_OUT, [&] {
115
+ const auto load_self =
116
+ utils::internal::get_load_to_common_fn<CTYPE_OUT, op_name>(
117
+ self, utils::SupportedTensorDtypes::REALHBBF16);
118
+ cumsum_tensors<CTYPE_OUT>(self, load_self, dim, out);
119
+ });
111
120
112
121
return out;
113
122
}
0 commit comments