Skip to content

[ET][Portable] Rewrite index_put #724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 110 additions & 113 deletions kernels/portable/cpu/op_index_put.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cinttypes>
#include <cstdint>
#include <cstring>

#include <executorch/kernels/portable/cpu/util/index_util.h>
#include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand All @@ -19,133 +18,131 @@ namespace native {

using Tensor = exec_aten::Tensor;

namespace {

template <typename CTYPE>
void index_put_out_impl_mask(
const Tensor& in,
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
const Tensor& values,
const bool accum,
Tensor& out) {
// Data pointers
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

const CTYPE* val_data = values.const_data_ptr<CTYPE>();

// To start, copy the in into the output
memcpy(out_data, in_data, in.nbytes());

const Tensor& mask = indices[0].value();
const bool* const mask_ptr = mask.const_data_ptr<bool>();
size_t count = 0;
for (int i = 0; i < mask.numel(); ++i) {
if (mask_ptr[i]) {
if (accum) {
out_data[i] += val_data[count];
} else {
out_data[i] = val_data[count];
}
if (values.numel() > 1) {
count++;
}
}
}
}

template <typename CTYPE>
void index_put_out_impl_list(
const Tensor& in,
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
const Tensor& values,
const bool accum,
Tensor& out) {
// Data pointers
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

const CTYPE* val = values.const_data_ptr<CTYPE>();

// To start, copy the in into the output
memcpy(out_data, in_data, in.nbytes());

size_t num_idx_queries = get_indices_broadcast_len(indices);
for (size_t idx = 0; idx < num_idx_queries; idx++) {
const CTYPE* src = in_data;
CTYPE* dst = out_data;

// For each index query, align the src and dst pointers to the position
// described by the query.
size_t offset = get_index_query_pos_offset(idx, in, indices);
src += offset;
dst += offset;

// Calculate the region of data to copy for this query.
// For example, a 2x4x3x5 tensor indexing at [1, 1, :, :] should copy 15
// elements.
size_t copy_len = getTrailingDims(in, indices.size() - 1);

// If values only contains 1 element, it needs to be broadcasted.
if (values.numel() == 1) {
CTYPE value = *val;

for (size_t i = 0; i < copy_len; ++i) {
if (accum) {
dst[i] += value;
} else {
dst[i] = value;
}
}
}
// General case.
else {
if (accum) {
for (size_t i = 0; i < copy_len; ++i) {
dst[i] = src[i] + val[i];
}
val += copy_len;
} else {
size_t copy_size = copy_len * sizeof(CTYPE);
memcpy(dst, val, copy_size);
val += copy_len;
}
}
}
}

} // namespace

Tensor& index_put_out(
RuntimeContext& ctx,
const Tensor& in,
exec_aten::ArrayRef<exec_aten::optional<Tensor>> indices,
const Tensor& values,
const bool accumulate,
Tensor& out) {
(void)ctx;

ET_KERNEL_CHECK(
ctx,
check_index_put_args(in, indices, values, out),
InvalidArgument,
out);
ctx, check_index_args(in, indices, out), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dtype(in, values), InvalidArgument, out);

if (indices.empty() || in.numel() == 0) {
ScalarType in_type = in.scalar_type();
size_t block_count = count_index_blocks(indices);

// If indices list is empty or all indices are null, then the operation is
// performed over then entire input tensor. So, this is equivalent to
// out = values when accumulate is false. Otherwise, the operation is
// out = in + values where accumulate is true.
if (block_count == 0) {
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
memcpy(
out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());

// Check that values tensors can be broadcasted to out
ET_KERNEL_CHECK(
ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
[accumulate](const CTYPE val_in, const CTYPE val) {
return accumulate ? val_in + val : val;
},
in,
values,
out);
});
return out;
}

// The index output shape depends on whether all the non-null indices are
// adjacent or not.
bool adjacent = (block_count == 1);

// Compute the expected index output shape.
Tensor::SizesType x_sizes[kTensorDimensionLimit];
size_t x_dim = 0;
ET_KERNEL_CHECK(
ctx,
get_index_out_target_size(in, indices, adjacent, x_sizes, &x_dim),
InvalidArgument,
out);

// Check that values tensors can be broadcasted to indexing result
ET_KERNEL_CHECK(
ctx,
tensor_is_broadcastable_to(values.sizes(), {x_sizes, x_dim}),
InvalidArgument,
out);

ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

ScalarType dtype = in.scalar_type();
ET_SWITCH_REAL_TYPES_AND(Bool, dtype, ctx, "index_put", CTYPE, [&]() {
if (is_index_mask(in, indices)) {
index_put_out_impl_mask<CTYPE>(in, indices, values, accumulate, out);
} else {
index_put_out_impl_list<CTYPE>(in, indices, values, accumulate, out);
// No further action if the input is empty
if (in.numel() == 0) {
return out;
}

// To start, copy the input data into the out tensor
memcpy(out.mutable_data_ptr<char>(), in.const_data_ptr<char>(), in.nbytes());

// In what follows, `x = in[indices]`. This tensor is implicit, and it would
// be much easier to be able to allocate memory, and then call index.Tensor
// to compute `x`. But since we can't do that, we have to keep track of its
// shape, number of dimensions, number of elements, and use it to translate
// coordinates from `x` to `in`.

// Compute the dim_map and ix_map needed for `x -> in` coordinate translation
int32_t dim_map[kTensorDimensionLimit];
int32_t ix_map[kTensorDimensionLimit];
size_t start = 0;

if (adjacent) {
start = get_num_leading_null_indices(indices);
}
size_t bc_ndim = get_indices_broadcast_ndim(indices);
compute_dim_map(in, indices, dim_map, block_count == 1);
compute_index_map(in, indices, ix_map);

// Compute the number of elements in the indexed space
size_t x_numel = 1;
for (size_t i = 0; i < x_dim; i++) {
x_numel *= x_sizes[i];
}

ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
const CTYPE* const values_data = values.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

for (auto x_ix = 0; x_ix < x_numel; x_ix++) {
size_t in_ix = 0;

size_t x_coord[kTensorDimensionLimit];
delinearize_index(x_ix, {x_sizes, x_dim}, x_coord, kTensorDimensionLimit);

size_t in_coord[kTensorDimensionLimit];

ET_KERNEL_CHECK(
ctx,
get_in_coord(
in, indices, start, bc_ndim, dim_map, ix_map, x_coord, in_coord),
InvalidArgument,
out);

in_ix = coordinateToIndex(in, in_coord);

// Braodcast values
size_t val_ix = linearize_access_indexes(x_coord, x_dim, values);
if (accumulate) {
out_data[in_ix] += values_data[val_ix];
} else {
out_data[in_ix] = values_data[val_ix];
}
}
});

Expand Down
3 changes: 2 additions & 1 deletion kernels/portable/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ _ATEN_OPS = (
op_target(
name = "op_index_put",
deps = [
"//executorch/kernels/portable/cpu/util:index_util",
"//executorch/kernels/portable/cpu/util:advanced_index_util",
"//executorch/kernels/portable/cpu/util:broadcast_util",
],
),
op_target(
Expand Down
18 changes: 13 additions & 5 deletions kernels/portable/cpu/util/broadcast_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,26 @@ void get_broadcast_target_size(

void delinearize_index(
size_t linear_index,
const Tensor& t,
exec_aten::ArrayRef<Tensor::SizesType> shape,
size_t* out_indexes,
const size_t out_indexes_len) {
ET_CHECK(t.dim() <= out_indexes_len);
for (auto i = 0; i < t.dim(); ++i) {
auto dim = t.dim() - 1 - i;
auto dim_size = t.size(dim);
ET_CHECK(shape.size() <= out_indexes_len);
for (auto i = 0; i < shape.size(); ++i) {
auto dim = shape.size() - 1 - i;
auto dim_size = shape[dim];
out_indexes[dim] = linear_index % dim_size;
linear_index /= dim_size;
}
}

void delinearize_index(
size_t linear_index,
const Tensor& t,
size_t* out_indexes,
const size_t out_indexes_len) {
delinearize_index(linear_index, t.sizes(), out_indexes, out_indexes_len);
}

size_t linearize_access_indexes(
ArrayRef<size_t> indexes_broadcast_to,
ssize_t broadcast_to_ndim,
Expand Down
15 changes: 15 additions & 0 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,21 @@ inline void resize_to_broadcast_target_size(
__ET_DEPRECATED void free_broadcast_tensor(
const exec_aten::Tensor& broadcast_tensor);

/**
* Delinearize a flattened index to per-dimension indexes.
*
* @param[in] linear_index The flattened index
* @param[in] shape The tensor shape
* @param[out] out_indexes The per-dimension indexes
* @param[in] out_indexes_len The maximum size of the out_indexes array
* @returns void
*/
void delinearize_index(
size_t linear_index,
exec_aten::ArrayRef<Tensor::SizesType> shape,
size_t* out_indexes,
const size_t out_indexes_len);

/**
* Delinearize a flattened index to per-dimension indexes.
*
Expand Down
14 changes: 8 additions & 6 deletions kernels/test/op_index_put_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,10 +880,7 @@ TEST(OpIndexPutOutTest, InvalidIndicesShapesDies) {
op_index_put_out(x, indices, values, /*accumulate=*/false, out), "");
}

TEST(OpIndexPutOutTest, InvalidIndicesShapeDies2) {
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "ATen kernel will support non-linear shapes";
}
TEST(OpIndexPutOutTest, NonLinearIndices) {
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Long> tfl;

Expand All @@ -903,8 +900,13 @@ TEST(OpIndexPutOutTest, InvalidIndicesShapeDies2) {
);
// clang-format on

ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
op_index_put_out(x, indices, values, /*accumulate=*/false, out), "");
Tensor expected =
tf.make({4, 4}, {0, 0, 0, 0, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0});

Tensor ret = op_index_put_out(x, indices, values, /*accumulate=*/false, out);

EXPECT_TENSOR_EQ(ret, out);
EXPECT_TENSOR_EQ(ret, expected);
}

//
Expand Down