Skip to content

[ET][Portable] Implement output broadcasting for split_with_sizes_copy #712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 72 additions & 18 deletions kernels/portable/cpu/op_split_with_sizes_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cstdint>
#include <cstring>

#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

Expand Down Expand Up @@ -38,16 +39,25 @@ void split_with_sizes_copy_out(
InvalidArgument,
out);

Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
// If out is empty, then nothing needs to be done after checking the args.
// Valid args implies that in.size(dim) == 0 and split_sizes is also empty.
if (out.size() == 0) {
return;
}

// Check that all chunks broadcast to their respective out tensor
Tensor::SizesType target_out_sizes[kTensorDimensionLimit];
size_t target_out_ndim = in.dim();
for (size_t d = 0; d < in.dim(); ++d) {
target_out_sizes[d] = static_cast<Tensor::SizesType>(in.size(d));
}

for (size_t i = 0; i < split_sizes.size(); i++) {
expected_out_size[expected_out_dim++] = split_sizes[i];
get_split_with_sizes_copy_out_target_size(
in, split_sizes[i], dim, expected_out_size, &expected_out_dim);
target_out_sizes[dim] = static_cast<Tensor::SizesType>(split_sizes[i]);
ET_KERNEL_CHECK(
ctx,
resize_tensor(out[i], {expected_out_size, expected_out_dim}) ==
Error::Ok,
tensor_is_broadcastable_to(
{target_out_sizes, target_out_ndim}, out[i].sizes()),
InvalidArgument,
out);
}
Expand All @@ -62,21 +72,65 @@ void split_with_sizes_copy_out(
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
const CTYPE_IN* in_data = in.const_data_ptr<CTYPE_IN>();
for (size_t i = 0, e = out.size(); i < e; ++i) {
size_t out_step = out[i].size(dim) * trailing_dims;
if (out_step == 0) {

// Iterate through list of out tensors
for (size_t i = 0; i < out.size(); ++i) {
const Tensor& out_tensor = out[i];

// If out tensor is empty, no action is required
if (out_tensor.numel() == 0) {
continue;
}
const CTYPE_IN* src = in_data;
CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
for (size_t j = 0; j < leading_dims; ++j) {
for (size_t k = 0; k < out_step; ++k) {
dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);

size_t chunk_step = split_sizes[i] * trailing_dims;

// Update target out shape
target_out_sizes[dim] = static_cast<Tensor::SizesType>(split_sizes[i]);
ArrayRef<Tensor::SizesType> target_shape(
{target_out_sizes, target_out_ndim});

// Check if output involves broadcasting
const bool is_broadcasted = !out_tensor.sizes().equals(target_shape);

CTYPE_OUT* out_data = out_tensor.mutable_data_ptr<CTYPE_OUT>();

// Simpler logic if there's no broadcasting
if (!is_broadcasted) {
const CTYPE_IN* src = in_data;
for (size_t j = 0; j < leading_dims; ++j) {
for (size_t k = 0; k < chunk_step; ++k) {
out_data[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
}
src += step;
out_data += chunk_step;
}
} else { // Otherwise, we need to do a copy with broadcasting
// Compute target strides
Tensor::StridesType target_out_strides[kTensorDimensionLimit];
target_out_strides[in.dim() - 1] = 1;
for (int d = in.dim() - 2; d >= 0; --d) {
target_out_strides[d] = target_out_strides[d + 1] *
static_cast<Tensor::StridesType>(target_out_sizes[d + 1]);
}
ArrayRef<Tensor::StridesType> target_strides(
{target_out_strides, target_out_ndim});

// For each element in the out tensor, find its corresponding index
// in the input tensor and copy it over
for (size_t ix = 0; ix < out_tensor.numel(); ++ix) {
size_t out_coord[kTensorDimensionLimit];
delinearize_index(ix, out_tensor, out_coord, kTensorDimensionLimit);

size_t in_linear_index = linearize_access_indexes(
out_coord, out_tensor.dim(), target_shape, target_strides);

out_data[ix] =
convert<CTYPE_OUT, CTYPE_IN>(in_data[in_linear_index]);
}
src += step;
dest += out_step;
}
in_data += out_step;

// Move input data pointer
in_data += chunk_step;
}
});
});
Expand Down
1 change: 1 addition & 0 deletions kernels/portable/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ _ATEN_OPS = (
op_target(
name = "op_split_with_sizes_copy",
deps = [
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:copy_ops_util",
],
),
Expand Down
41 changes: 29 additions & 12 deletions kernels/portable/cpu/util/broadcast_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ Tensor make_tensor(
} // namespace

bool tensor_is_broadcastable_to(
const Tensor& broadcast_from,
const Tensor& broadcast_to) {
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape) {
bool feasible_bcast = true;
auto broadcast_to_shape = broadcast_to.sizes();
auto broadcast_from_shape = broadcast_from.sizes();

if (broadcast_to_shape.size() < broadcast_from_shape.size()) {
return false;
Expand All @@ -103,6 +101,13 @@ bool tensor_is_broadcastable_to(
return feasible_bcast;
}

bool tensor_is_broadcastable_to(
const Tensor& broadcast_from,
const Tensor& broadcast_to) {
return tensor_is_broadcastable_to(
broadcast_from.sizes(), broadcast_to.sizes());
}

bool tensors_are_broadcastable_between(
const exec_aten::ArrayRef<Tensor::SizesType> a_shape,
const exec_aten::ArrayRef<Tensor::SizesType> b_shape) {
Expand Down Expand Up @@ -264,27 +269,39 @@ void delinearize_index(
size_t linearize_access_indexes(
ArrayRef<size_t> indexes_broadcast_to,
ssize_t broadcast_to_ndim,
const Tensor& broadcast_from) {
size_t num_skip_dims = broadcast_to_ndim - broadcast_from.dim();
exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides) {
size_t num_skip_dims = broadcast_to_ndim - broadcast_from_shape.size();
ArrayRef<size_t> indexes_broadcast_from = indexes_broadcast_to.slice(
num_skip_dims, broadcast_to_ndim - num_skip_dims);

ET_CHECK(indexes_broadcast_from.size() == broadcast_from.dim());
ET_CHECK(indexes_broadcast_from.size() == broadcast_from_shape.size());

size_t linear_index = 0;
for (size_t i = 0; i < indexes_broadcast_from.size(); ++i) {
// If this dimension is broadcasted, add zero to the linear address.
if (indexes_broadcast_from[i] >= broadcast_from.size(i)) {
if (indexes_broadcast_from[i] >= broadcast_from_shape[i]) {
ET_CHECK_MSG(
broadcast_from.size(i) == 1,
"Expected dim size == 1 if broadcasted, but actual dim size is %zd",
broadcast_from.size(i));
broadcast_from_shape[i] == 1,
"Expected dim size == 1 if broadcasted, but actual dim size is %zu",
static_cast<size_t>(broadcast_from_shape[i]));
continue;
}
linear_index += indexes_broadcast_from[i] * broadcast_from.strides()[i];
linear_index += indexes_broadcast_from[i] * broadcast_from_strides[i];
}
return linear_index;
}

size_t linearize_access_indexes(
ArrayRef<size_t> indexes_broadcast_to,
ssize_t broadcast_to_ndim,
const Tensor& broadcast_from) {
return linearize_access_indexes(
indexes_broadcast_to,
broadcast_to_ndim,
broadcast_from.sizes(),
broadcast_from.strides());
}

} // namespace executor
} // namespace torch
41 changes: 35 additions & 6 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,25 @@
namespace torch {
namespace executor {

/**
* Check whether or not the broadcast_from_shape can be broadcasted onto the
* broadcast_to_shape.
*
* @param[in] broadcast_from_shape The tensor shape which we want to broadcast.
* @param[in] broadcast_to_shape The tensor shape which we want to broadcast to.
* @returns A bool to indicate whether or not the shape can be broadcasted.
*
*/
bool tensor_is_broadcastable_to(
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape);

/**
* Check whether or not the broadcast_from tensor should and can be broadcasted
* onto the broadcast_to tensor. broadcast_tensor should only be called if this
* returns true.
*
* @param[in] broadcast_from The tensor to which we want to broadcast from.
* @param[in] broadcast_from The tensor which we want to broadcast from.
* @param[in] broadcast_to The tensor to which we want to broadcast to.
* @returns A bool to indicate whether or not the tensor can be broadcasted.
*
Expand All @@ -29,11 +42,11 @@ bool tensor_is_broadcastable_to(
const Tensor& broadcast_to);

/**
* Returns true if the two tensors can both be broadcasted to a common shape.
* Returns true if the two tensor shapes can both be broadcasted to a common
* shape.
*
* @param[in] a_shape The sizes of the first tensor going to be test.
* @param[in] b_shape The sizes of the second tensor going to be test.
*
* @returns true if the tensors are broadcastable, false otherwise.
*/
bool tensors_are_broadcastable_between(
Expand All @@ -45,7 +58,6 @@ bool tensors_are_broadcastable_between(
*
* @param[in] a The first tensor going to be test.
* @param[in] b The second tensor going to be test.
*
* @returns true if the tensors are broadcastable, false otherwise.
*/
bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b);
Expand Down Expand Up @@ -195,12 +207,29 @@ void delinearize_index(
size_t* out_indexes,
const size_t out_indexes_len);

/**
* Return the linear index for broatcast_from tensor, given the indexes and
* number of dimensions of broadcast_to tensor, and the shape and strides
* of broadcast_from tensor.
*
* @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
* @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
* @param[in] broadcast_from_shape The shape of the broadcasted tensor.
* @param[in] broadcast_from_strides The strides of the broadcasted tensor.
* @returns The flattend index for broadcast_from tensor.
*/
size_t linearize_access_indexes(
ArrayRef<size_t> indexes_broadcast_to,
ssize_t broadcast_to_ndim,
exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides);

/**
* Return the linear index for broatcast_from tensor, given the indexes of
* broadcast_to tensor and itself.
*
* @param[in] indexes The tensor access indexes of broadcast_to tensor
* @param[in] broadcast_to_ndim The number of dims of the broadcasted shape.
* @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
* @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
* @param[in] broadcast_from The tensor to be broadcasted.
* @returns The flattend index for broadcast_from tensor.
*/
Expand Down