Skip to content

Commit 017a718

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Rewrite index_put (#724)
Summary: Pull Request resolved: #724 Rewrite `index_put` to handle all the advanced indexing functionality. ghstack-source-id: 203418806 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D50031048 fbshipit-source-id: 3cfac559ccec2d5633e804c26b237ae9ae3d831e
1 parent b6c2f01 commit 017a718

File tree

5 files changed

+148
-125
lines changed

5 files changed

+148
-125
lines changed

kernels/portable/cpu/op_index_put.cpp

Lines changed: 110 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cinttypes>
10-
#include <cstdint>
119
#include <cstring>
1210

13-
#include <executorch/kernels/portable/cpu/util/index_util.h>
11+
#include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
12+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1413
#include <executorch/runtime/kernel/kernel_includes.h>
1514

1615
namespace torch {
@@ -19,133 +18,131 @@ namespace native {
1918

2019
using Tensor = exec_aten::Tensor;
2120

22-
namespace {
23-
24-
template <typename CTYPE>
25-
void index_put_out_impl_mask(
26-
const Tensor& in,
27-
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
28-
const Tensor& values,
29-
const bool accum,
30-
Tensor& out) {
31-
// Data pointers
32-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
33-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
34-
35-
const CTYPE* val_data = values.const_data_ptr<CTYPE>();
36-
37-
// To start, copy the in into the output
38-
memcpy(out_data, in_data, in.nbytes());
39-
40-
const Tensor& mask = indices[0].value();
41-
const bool* const mask_ptr = mask.const_data_ptr<bool>();
42-
size_t count = 0;
43-
for (int i = 0; i < mask.numel(); ++i) {
44-
if (mask_ptr[i]) {
45-
if (accum) {
46-
out_data[i] += val_data[count];
47-
} else {
48-
out_data[i] = val_data[count];
49-
}
50-
if (values.numel() > 1) {
51-
count++;
52-
}
53-
}
54-
}
55-
}
56-
57-
template <typename CTYPE>
58-
void index_put_out_impl_list(
59-
const Tensor& in,
60-
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
61-
const Tensor& values,
62-
const bool accum,
63-
Tensor& out) {
64-
// Data pointers
65-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
66-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
67-
68-
const CTYPE* val = values.const_data_ptr<CTYPE>();
69-
70-
// To start, copy the in into the output
71-
memcpy(out_data, in_data, in.nbytes());
72-
73-
size_t num_idx_queries = get_indices_broadcast_len(indices);
74-
for (size_t idx = 0; idx < num_idx_queries; idx++) {
75-
const CTYPE* src = in_data;
76-
CTYPE* dst = out_data;
77-
78-
// For each index query, align the src and dst pointers to the position
79-
// described by the query.
80-
size_t offset = get_index_query_pos_offset(idx, in, indices);
81-
src += offset;
82-
dst += offset;
83-
84-
// Calculate the region of data to copy for this query.
85-
// For example, a 2x4x3x5 tensor indexing at [1, 1, :, :] should copy 15
86-
// elements.
87-
size_t copy_len = getTrailingDims(in, indices.size() - 1);
88-
89-
// If values only contains 1 element, it needs to be broadcasted.
90-
if (values.numel() == 1) {
91-
CTYPE value = *val;
92-
93-
for (size_t i = 0; i < copy_len; ++i) {
94-
if (accum) {
95-
dst[i] += value;
96-
} else {
97-
dst[i] = value;
98-
}
99-
}
100-
}
101-
// General case.
102-
else {
103-
if (accum) {
104-
for (size_t i = 0; i < copy_len; ++i) {
105-
dst[i] = src[i] + val[i];
106-
}
107-
val += copy_len;
108-
} else {
109-
size_t copy_size = copy_len * sizeof(CTYPE);
110-
memcpy(dst, val, copy_size);
111-
val += copy_len;
112-
}
113-
}
114-
}
115-
}
116-
117-
} // namespace
118-
11921
Tensor& index_put_out(
12022
RuntimeContext& ctx,
12123
const Tensor& in,
12224
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
12325
const Tensor& values,
12426
const bool accumulate,
12527
Tensor& out) {
28+
(void)ctx;
29+
12630
ET_KERNEL_CHECK(
127-
ctx,
128-
check_index_put_args(in, indices, values, out),
129-
InvalidArgument,
130-
out);
31+
ctx, check_index_args(in, indices, out), InvalidArgument, out);
32+
33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dtype(in, values), InvalidArgument, out);
13135

132-
if (indices.empty() || in.numel() == 0) {
36+
ScalarType in_type = in.scalar_type();
37+
size_t block_count = count_index_blocks(indices);
38+
39+
// If indices list is empty or all indices are null, then the operation is
40+
// performed over then entire input tensor. So, this is equivalent to
41+
// out = values when accumulate is false. Otherwise, the operation is
42+
// out = in + values where accumulate is true.
43+
if (block_count == 0) {
13344
ET_KERNEL_CHECK(
13445
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
135-
memcpy(
136-
out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());
46+
47+
// Check that values tensors can be broadcasted to out
48+
ET_KERNEL_CHECK(
49+
ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out);
50+
51+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
52+
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
53+
[accumulate](const CTYPE val_in, const CTYPE val) {
54+
return accumulate ? val_in + val : val;
55+
},
56+
in,
57+
values,
58+
out);
59+
});
13760
return out;
13861
}
13962

63+
// The index output shape depends on whether all the non-null indices are
64+
// adjacent or not.
65+
bool adjacent = (block_count == 1);
66+
67+
// Compute the expected index output shape.
68+
Tensor::SizesType x_sizes[kTensorDimensionLimit];
69+
size_t x_dim = 0;
70+
ET_KERNEL_CHECK(
71+
ctx,
72+
get_index_out_target_size(in, indices, adjacent, x_sizes, &x_dim),
73+
InvalidArgument,
74+
out);
75+
76+
// Check that values tensors can be broadcasted to indexing result
77+
ET_KERNEL_CHECK(
78+
ctx,
79+
tensor_is_broadcastable_to(values.sizes(), {x_sizes, x_dim}),
80+
InvalidArgument,
81+
out);
82+
14083
ET_KERNEL_CHECK(
14184
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
14285

143-
ScalarType dtype = in.scalar_type();
144-
ET_SWITCH_REAL_TYPES_AND(Bool, dtype, ctx, "index_put", CTYPE, [&]() {
145-
if (is_index_mask(in, indices)) {
146-
index_put_out_impl_mask<CTYPE>(in, indices, values, accumulate, out);
147-
} else {
148-
index_put_out_impl_list<CTYPE>(in, indices, values, accumulate, out);
86+
// No further action if the input is empty
87+
if (in.numel() == 0) {
88+
return out;
89+
}
90+
91+
// To start, copy the input data into the out tensor
92+
memcpy(out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());
93+
94+
// In what follows, `x = in[indices]`. This tensor is implicit, and it would
95+
// be much easier to be able to allocate memory, and then call index.Tensor
96+
// to compute `x`. But since we can't do that, we have to keep track of its
97+
// shape, number of dimensions, number of elements, and use it to translate
98+
// coordinates from `x` to `in`.
99+
100+
// Compute the dim_map and ix_map needed for `x -> in` coordinate translation
101+
int32_t dim_map[kTensorDimensionLimit];
102+
int32_t ix_map[kTensorDimensionLimit];
103+
size_t start = 0;
104+
105+
if (adjacent) {
106+
start = get_num_leading_null_indices(indices);
107+
}
108+
size_t bc_ndim = get_indices_broadcast_ndim(indices);
109+
compute_dim_map(in, indices, dim_map, block_count == 1);
110+
compute_index_map(in, indices, ix_map);
111+
112+
// Compute the number of elements in the indexed space
113+
size_t x_numel = 1;
114+
for (size_t i = 0; i < x_dim; i++) {
115+
x_numel *= x_sizes[i];
116+
}
117+
118+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
119+
const CTYPE* const values_data = values.const_data_ptr<CTYPE>();
120+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
121+
122+
for (auto x_ix = 0; x_ix < x_numel; x_ix++) {
123+
size_t in_ix = 0;
124+
125+
size_t x_coord[kTensorDimensionLimit];
126+
delinearize_index(x_ix, {x_sizes, x_dim}, x_coord, kTensorDimensionLimit);
127+
128+
size_t in_coord[kTensorDimensionLimit];
129+
130+
ET_KERNEL_CHECK(
131+
ctx,
132+
get_in_coord(
133+
in, indices, start, bc_ndim, dim_map, ix_map, x_coord, in_coord),
134+
InvalidArgument,
135+
out);
136+
137+
in_ix = coordinateToIndex(in, in_coord);
138+
139+
// Braodcast values
140+
size_t val_ix = linearize_access_indexes(x_coord, x_dim, values);
141+
if (accumulate) {
142+
out_data[in_ix] += values_data[val_ix];
143+
} else {
144+
out_data[in_ix] = values_data[val_ix];
145+
}
149146
}
150147
});
151148

kernels/portable/cpu/targets.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ _ATEN_OPS = (
381381
op_target(
382382
name = "op_index_put",
383383
deps = [
384-
"//executorch/kernels/portable/cpu/util:index_util",
384+
"//executorch/kernels/portable/cpu/util:advanced_index_util",
385+
"//executorch/kernels/portable/cpu/util:broadcast_util",
385386
],
386387
),
387388
op_target(

kernels/portable/cpu/util/broadcast_util.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,26 @@ void get_broadcast_target_size(
254254

255255
void delinearize_index(
256256
size_t linear_index,
257-
const Tensor& t,
257+
exec_aten::ArrayRef<Tensor::SizesType> shape,
258258
size_t* out_indexes,
259259
const size_t out_indexes_len) {
260-
ET_CHECK(t.dim() <= out_indexes_len);
261-
for (auto i = 0; i < t.dim(); ++i) {
262-
auto dim = t.dim() - 1 - i;
263-
auto dim_size = t.size(dim);
260+
ET_CHECK(shape.size() <= out_indexes_len);
261+
for (auto i = 0; i < shape.size(); ++i) {
262+
auto dim = shape.size() - 1 - i;
263+
auto dim_size = shape[dim];
264264
out_indexes[dim] = linear_index % dim_size;
265265
linear_index /= dim_size;
266266
}
267267
}
268268

269+
void delinearize_index(
270+
size_t linear_index,
271+
const Tensor& t,
272+
size_t* out_indexes,
273+
const size_t out_indexes_len) {
274+
delinearize_index(linear_index, t.sizes(), out_indexes, out_indexes_len);
275+
}
276+
269277
size_t linearize_access_indexes(
270278
ArrayRef<size_t> indexes_broadcast_to,
271279
ssize_t broadcast_to_ndim,

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,21 @@ inline void resize_to_broadcast_target_size(
192192
__ET_DEPRECATED void free_broadcast_tensor(
193193
const exec_aten::Tensor& broadcast_tensor);
194194

195+
/**
196+
* Delinearize a flattened index to per-dimension indexes.
197+
*
198+
* @param[in] linear_index The flattened index
199+
* @param[in] shape The tensor shape
200+
* @param[out] out_indexes The per-dimension indexes
201+
* @param[in] out_indexes_len The maximum size of the out_indexes array
202+
* @returns void
203+
*/
204+
void delinearize_index(
205+
size_t linear_index,
206+
exec_aten::ArrayRef<Tensor::SizesType> shape,
207+
size_t* out_indexes,
208+
const size_t out_indexes_len);
209+
195210
/**
196211
* Delinearize a flattened index to per-dimension indexes.
197212
*

kernels/test/op_index_put_test.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -880,10 +880,7 @@ TEST(OpIndexPutOutTest, InvalidIndicesShapesDies) {
880880
op_index_put_out(x, indices, values, /*accumulate=*/false, out), "");
881881
}
882882

883-
TEST(OpIndexPutOutTest, InvalidIndicesShapeDies2) {
884-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
885-
GTEST_SKIP() << "ATen kernel will support non-linear shapes";
886-
}
883+
TEST(OpIndexPutOutTest, NonLinearIndices) {
887884
TensorFactory<ScalarType::Float> tf;
888885
TensorFactory<ScalarType::Long> tfl;
889886

@@ -903,8 +900,13 @@ TEST(OpIndexPutOutTest, InvalidIndicesShapeDies2) {
903900
);
904901
// clang-format on
905902

906-
ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
907-
op_index_put_out(x, indices, values, /*accumulate=*/false, out), "");
903+
Tensor expected =
904+
tf.make({4, 4}, {0, 0, 0, 0, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0});
905+
906+
Tensor ret = op_index_put_out(x, indices, values, /*accumulate=*/false, out);
907+
908+
EXPECT_TENSOR_EQ(ret, out);
909+
EXPECT_TENSOR_EQ(ret, expected);
908910
}
909911

910912
//

0 commit comments

Comments
 (0)