6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < cinttypes>
10
- #include < cstdint>
11
9
#include < cstring>
12
10
13
- #include < executorch/kernels/portable/cpu/util/index_util.h>
11
+ #include < executorch/kernels/portable/cpu/util/advanced_index_util.h>
12
+ #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
14
13
#include < executorch/runtime/kernel/kernel_includes.h>
15
14
16
15
namespace torch {
@@ -19,133 +18,132 @@ namespace native {
19
18
20
19
using Tensor = exec_aten::Tensor;
21
20
22
- namespace {
23
-
24
- template <typename CTYPE>
25
- void index_put_out_impl_mask (
26
- const Tensor& in,
27
- exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
28
- const Tensor& values,
29
- const bool accum,
30
- Tensor& out) {
31
- // Data pointers
32
- const CTYPE* const in_data = in.const_data_ptr <CTYPE>();
33
- CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
34
-
35
- const CTYPE* val_data = values.const_data_ptr <CTYPE>();
36
-
37
- // To start, copy the in into the output
38
- memcpy (out_data, in_data, in.nbytes ());
39
-
40
- const Tensor& mask = indices[0 ].value ();
41
- const bool * const mask_ptr = mask.const_data_ptr <bool >();
42
- size_t count = 0 ;
43
- for (int i = 0 ; i < mask.numel (); ++i) {
44
- if (mask_ptr[i]) {
45
- if (accum) {
46
- out_data[i] += val_data[count];
47
- } else {
48
- out_data[i] = val_data[count];
49
- }
50
- if (values.numel () > 1 ) {
51
- count++;
52
- }
53
- }
54
- }
55
- }
56
-
57
- template <typename CTYPE>
58
- void index_put_out_impl_list (
59
- const Tensor& in,
60
- exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
61
- const Tensor& values,
62
- const bool accum,
63
- Tensor& out) {
64
- // Data pointers
65
- const CTYPE* const in_data = in.const_data_ptr <CTYPE>();
66
- CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
67
-
68
- const CTYPE* val = values.const_data_ptr <CTYPE>();
69
-
70
- // To start, copy the in into the output
71
- memcpy (out_data, in_data, in.nbytes ());
72
-
73
- size_t num_idx_queries = get_indices_broadcast_len (indices);
74
- for (size_t idx = 0 ; idx < num_idx_queries; idx++) {
75
- const CTYPE* src = in_data;
76
- CTYPE* dst = out_data;
77
-
78
- // For each index query, align the src and dst pointers to the position
79
- // described by the query.
80
- size_t offset = get_index_query_pos_offset (idx, in, indices);
81
- src += offset;
82
- dst += offset;
83
-
84
- // Calculate the region of data to copy for this query.
85
- // For example, a 2x4x3x5 tensor indexing at [1, 1, :, :] should copy 15
86
- // elements.
87
- size_t copy_len = getTrailingDims (in, indices.size () - 1 );
88
-
89
- // If values only contains 1 element, it needs to be broadcasted.
90
- if (values.numel () == 1 ) {
91
- CTYPE value = *val;
92
-
93
- for (size_t i = 0 ; i < copy_len; ++i) {
94
- if (accum) {
95
- dst[i] += value;
96
- } else {
97
- dst[i] = value;
98
- }
99
- }
100
- }
101
- // General case.
102
- else {
103
- if (accum) {
104
- for (size_t i = 0 ; i < copy_len; ++i) {
105
- dst[i] = src[i] + val[i];
106
- }
107
- val += copy_len;
108
- } else {
109
- size_t copy_size = copy_len * sizeof (CTYPE);
110
- memcpy (dst, val, copy_size);
111
- val += copy_len;
112
- }
113
- }
114
- }
115
- }
116
-
117
- } // namespace
118
-
119
21
Tensor& index_put_out (
120
22
RuntimeContext& ctx,
121
23
const Tensor& in,
122
24
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
123
25
const Tensor& values,
124
26
const bool accumulate,
125
27
Tensor& out) {
28
+ (void )ctx;
29
+
126
30
ET_KERNEL_CHECK (
127
- ctx,
128
- check_index_put_args (in, indices, values, out),
129
- InvalidArgument,
130
- out);
31
+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
32
+
33
+ ET_KERNEL_CHECK (
34
+ ctx, tensors_have_same_dtype (in, values), InvalidArgument, out);
131
35
132
- if (indices.empty () || in.numel () == 0 ) {
36
+ ScalarType in_type = in.scalar_type ();
37
+ size_t block_count = count_index_blocks (indices);
38
+
39
+ // If indices list is empty or all indices are null, then the operation is
40
+ // performed over then entire input tensor. So, this is equivalent to
41
+ // out = values when accumulate is false. Otherwise, the operation is
42
+ // out = in + values where accumulate is true.
43
+ if (block_count == 0 ) {
133
44
ET_KERNEL_CHECK (
134
45
ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
135
- memcpy (
136
- out.mutable_data_ptr <char >(), in.const_data_ptr <char >(), in.nbytes ());
46
+
47
+ // Check that values tensors can be broadcasted to out
48
+ ET_KERNEL_CHECK (
49
+ ctx, tensor_is_broadcastable_to (values, out), InvalidArgument, out);
50
+
51
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]() {
52
+ apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
53
+ [accumulate](const CTYPE val_in, const CTYPE val) {
54
+ return accumulate ? val_in + val : val;
55
+ },
56
+ in,
57
+ values,
58
+ out);
59
+ });
137
60
return out;
138
61
}
139
62
63
+ // The index output shape depends on whether all the non-null indices are
64
+ // adjacent or not.
65
+ bool adjacent = (block_count == 1 );
66
+
67
+ // Compute the expected index output shape.
68
+ Tensor::SizesType x_sizes[kTensorDimensionLimit ];
69
+ size_t x_dim = 0 ;
70
+ ET_KERNEL_CHECK (
71
+ ctx,
72
+ get_index_out_target_size (in, indices, adjacent, x_sizes, &x_dim),
73
+ InvalidArgument,
74
+ out);
75
+
76
+ // Check that values tensors can be broadcasted to indexing result
77
+ ET_KERNEL_CHECK (
78
+ ctx,
79
+ tensor_is_broadcastable_to (values.sizes (), {x_sizes, x_dim}),
80
+ InvalidArgument,
81
+ out);
82
+
140
83
ET_KERNEL_CHECK (
141
84
ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
142
85
143
- ScalarType dtype = in.scalar_type ();
144
- ET_SWITCH_REAL_TYPES_AND (Bool, dtype, ctx, " index_put" , CTYPE, [&]() {
145
- if (is_index_mask (in, indices)) {
146
- index_put_out_impl_mask<CTYPE>(in, indices, values, accumulate, out);
147
- } else {
148
- index_put_out_impl_list<CTYPE>(in, indices, values, accumulate, out);
86
+ // No further action if the input is empty
87
+ if (in.numel () == 0 ) {
88
+ return out;
89
+ }
90
+
91
+ // To start, copy the input data into the out tensor
92
+ memcpy (out.mutable_data_ptr <char >(), in.const_data_ptr <char >(), in.nbytes ());
93
+
94
+ // In what follows, `x = in[indices]`. This tensor is implicit, and it would
95
+ // be much easier to be able to allocate memory, and then call index.Tensor
96
+ // to compute `x`. But since we can't do that, we have to keep track of its
97
+ // shape, number of dimensions, number of elements, and use it to translate
98
+ // coordinates from `x` to `in`.
99
+
100
+ // Compute the dim_map and ix_map needed for `x -> in` coordinate translation
101
+ int32_t dim_map[kTensorDimensionLimit ];
102
+ int32_t ix_map[kTensorDimensionLimit ];
103
+ size_t start = 0 ;
104
+
105
+ if (adjacent) {
106
+ start = get_num_leading_null_indices (indices);
107
+ }
108
+ size_t bc_ndim = get_indices_broadcast_ndim (indices);
109
+ compute_dim_map (in, indices, dim_map, block_count == 1 );
110
+ compute_index_map (in, indices, ix_map);
111
+
112
+ // Compute the number of elements in the indexed space
113
+ size_t x_numel = 1 ;
114
+ for (size_t i = 0 ; i < x_dim; i++) {
115
+ x_numel *= x_sizes[i];
116
+ }
117
+
118
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]() {
119
+ const CTYPE* const in_data = in.const_data_ptr <CTYPE>();
120
+ const CTYPE* const values_data = values.const_data_ptr <CTYPE>();
121
+ CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
122
+
123
+ for (auto x_ix = 0 ; x_ix < x_numel; x_ix++) {
124
+ size_t in_ix = 0 ;
125
+
126
+ size_t x_coord[kTensorDimensionLimit ];
127
+ delinearize_index (x_ix, {x_sizes, x_dim}, x_coord, kTensorDimensionLimit );
128
+
129
+ size_t in_coord[kTensorDimensionLimit ];
130
+
131
+ ET_KERNEL_CHECK (
132
+ ctx,
133
+ get_in_coord (
134
+ in, indices, start, bc_ndim, dim_map, ix_map, x_coord, in_coord),
135
+ InvalidArgument,
136
+ out);
137
+
138
+ in_ix = coordinateToIndex (in, in_coord);
139
+
140
+ // Braodcast values
141
+ size_t val_ix = linearize_access_indexes (x_coord, x_dim, values);
142
+ if (accumulate) {
143
+ out_data[in_ix] += values_data[val_ix];
144
+ } else {
145
+ out_data[in_ix] = values_data[val_ix];
146
+ }
149
147
}
150
148
});
151
149
0 commit comments