Skip to content

Commit 2b39b3b

Browse files
[ET][Portable] Implement output broadcasting for split_with_sizes_copy
Differential Revision: [D49979520](https://our.internmc.facebook.com/intern/diff/D49979520/) ghstack-source-id: 203341589 Pull Request resolved: #712
1 parent 57ac8d1 commit 2b39b3b

File tree

4 files changed

+137
-36
lines changed

4 files changed

+137
-36
lines changed

kernels/portable/cpu/op_split_with_sizes_copy.cpp

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <cstdint>
1010
#include <cstring>
1111

12+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1213
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
1415

@@ -38,16 +39,25 @@ void split_with_sizes_copy_out(
3839
InvalidArgument,
3940
out);
4041

41-
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
42-
size_t expected_out_dim = 0;
42+
// If out is empty, then nothing needs to be done after checking the args.
43+
// Valid args implies that in.size(dim) == 0 and split_sizes is also empty.
44+
if (out.size() == 0) {
45+
return;
46+
}
47+
48+
// Check that all chunks broadcast to their respective out tensor
49+
Tensor::SizesType target_out_sizes[kTensorDimensionLimit];
50+
size_t target_out_ndim = in.dim();
51+
for (size_t d = 0; d < in.dim(); ++d) {
52+
target_out_sizes[d] = static_cast<Tensor::SizesType>(in.size(d));
53+
}
54+
4355
for (size_t i = 0; i < split_sizes.size(); i++) {
44-
expected_out_size[expected_out_dim++] = split_sizes[i];
45-
get_split_with_sizes_copy_out_target_size(
46-
in, split_sizes[i], dim, expected_out_size, &expected_out_dim);
56+
target_out_sizes[dim] = static_cast<Tensor::SizesType>(split_sizes[i]);
4757
ET_KERNEL_CHECK(
4858
ctx,
49-
resize_tensor(out[i], {expected_out_size, expected_out_dim}) ==
50-
Error::Ok,
59+
tensor_is_broadcastable_to(
60+
{target_out_sizes, target_out_ndim}, out[i].sizes()),
5161
InvalidArgument,
5262
out);
5363
}
@@ -62,21 +72,65 @@ void split_with_sizes_copy_out(
6272
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
6373
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
6474
const CTYPE_IN* in_data = in.const_data_ptr<CTYPE_IN>();
65-
for (size_t i = 0, e = out.size(); i < e; ++i) {
66-
size_t out_step = out[i].size(dim) * trailing_dims;
67-
if (out_step == 0) {
75+
76+
// Iterate through list of out tensors
77+
for (size_t i = 0; i < out.size(); ++i) {
78+
const Tensor& out_tensor = out[i];
79+
80+
// If out tensor is empty, no action is required
81+
if (out_tensor.numel() == 0) {
6882
continue;
6983
}
70-
const CTYPE_IN* src = in_data;
71-
CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
72-
for (size_t j = 0; j < leading_dims; ++j) {
73-
for (size_t k = 0; k < out_step; ++k) {
74-
dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
84+
85+
size_t chunk_step = split_sizes[i] * trailing_dims;
86+
87+
// Update target out shape
88+
target_out_sizes[dim] = static_cast<Tensor::SizesType>(split_sizes[i]);
89+
ArrayRef<Tensor::SizesType> target_shape(
90+
{target_out_sizes, target_out_ndim});
91+
92+
// Check if output involves broadcasting
93+
const bool is_broadcasted = !out_tensor.sizes().equals(target_shape);
94+
95+
CTYPE_OUT* out_data = out_tensor.mutable_data_ptr<CTYPE_OUT>();
96+
97+
// Simpler logic if there's no broadcasting
98+
if (!is_broadcasted) {
99+
const CTYPE_IN* src = in_data;
100+
for (size_t j = 0; j < leading_dims; ++j) {
101+
for (size_t k = 0; k < chunk_step; ++k) {
102+
out_data[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
103+
}
104+
src += step;
105+
out_data += chunk_step;
106+
}
107+
} else { // Otherwise, we need to do a copy with broadcasting
108+
// Compute target strides
109+
Tensor::StridesType target_out_strides[kTensorDimensionLimit];
110+
target_out_strides[in.dim() - 1] = 1;
111+
for (int d = in.dim() - 2; d >= 0; --d) {
112+
target_out_strides[d] = target_out_strides[d + 1] *
113+
static_cast<Tensor::StridesType>(target_out_sizes[d + 1]);
114+
}
115+
ArrayRef<Tensor::StridesType> target_strides(
116+
{target_out_strides, target_out_ndim});
117+
118+
// For each element in the out tensor, find its corresponding index
119+
// in the input tensor and copy it over
120+
for (size_t ix = 0; ix < out_tensor.numel(); ++ix) {
121+
size_t out_coord[kTensorDimensionLimit];
122+
delinearize_index(ix, out_tensor, out_coord, kTensorDimensionLimit);
123+
124+
size_t in_linear_index = linearize_access_indexes(
125+
out_coord, out_tensor.dim(), target_shape, target_strides);
126+
127+
out_data[ix] =
128+
convert<CTYPE_OUT, CTYPE_IN>(in_data[in_linear_index]);
75129
}
76-
src += step;
77-
dest += out_step;
78130
}
79-
in_data += out_step;
131+
132+
// Move input data pointer
133+
in_data += chunk_step;
80134
}
81135
});
82136
});

kernels/portable/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ _ATEN_OPS = (
701701
op_target(
702702
name = "op_split_with_sizes_copy",
703703
deps = [
704+
"//executorch/kernels/portable/cpu/util:broadcast_util",
704705
"//executorch/kernels/portable/cpu/util:copy_ops_util",
705706
],
706707
),

kernels/portable/cpu/util/broadcast_util.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,9 @@ Tensor make_tensor(
7777
} // namespace
7878

7979
bool tensor_is_broadcastable_to(
80-
const Tensor& broadcast_from,
81-
const Tensor& broadcast_to) {
80+
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
81+
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape) {
8282
bool feasible_bcast = true;
83-
auto broadcast_to_shape = broadcast_to.sizes();
84-
auto broadcast_from_shape = broadcast_from.sizes();
8583

8684
if (broadcast_to_shape.size() < broadcast_from_shape.size()) {
8785
return false;
@@ -103,6 +101,13 @@ bool tensor_is_broadcastable_to(
103101
return feasible_bcast;
104102
}
105103

104+
bool tensor_is_broadcastable_to(
105+
const Tensor& broadcast_from,
106+
const Tensor& broadcast_to) {
107+
return tensor_is_broadcastable_to(
108+
broadcast_from.sizes(), broadcast_to.sizes());
109+
}
110+
106111
bool tensors_are_broadcastable_between(
107112
const exec_aten::ArrayRef<Tensor::SizesType> a_shape,
108113
const exec_aten::ArrayRef<Tensor::SizesType> b_shape) {
@@ -264,27 +269,39 @@ void delinearize_index(
264269
size_t linearize_access_indexes(
265270
ArrayRef<size_t> indexes_broadcast_to,
266271
ssize_t broadcast_to_ndim,
267-
const Tensor& broadcast_from) {
268-
size_t num_skip_dims = broadcast_to_ndim - broadcast_from.dim();
272+
exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
273+
exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides) {
274+
size_t num_skip_dims = broadcast_to_ndim - broadcast_from_shape.size();
269275
ArrayRef<size_t> indexes_broadcast_from = indexes_broadcast_to.slice(
270276
num_skip_dims, broadcast_to_ndim - num_skip_dims);
271277

272-
ET_CHECK(indexes_broadcast_from.size() == broadcast_from.dim());
278+
ET_CHECK(indexes_broadcast_from.size() == broadcast_from_shape.size());
273279

274280
size_t linear_index = 0;
275281
for (size_t i = 0; i < indexes_broadcast_from.size(); ++i) {
276282
// If this dimension is broadcasted, add zero to the linear address.
277-
if (indexes_broadcast_from[i] >= broadcast_from.size(i)) {
283+
if (indexes_broadcast_from[i] >= broadcast_from_shape[i]) {
278284
ET_CHECK_MSG(
279-
broadcast_from.size(i) == 1,
280-
"Expected dim size == 1 if broadcasted, but actual dim size is %zd",
281-
broadcast_from.size(i));
285+
broadcast_from_shape[i] == 1,
286+
"Expected dim size == 1 if broadcasted, but actual dim size is %zu",
287+
static_cast<size_t>(broadcast_from_shape[i]));
282288
continue;
283289
}
284-
linear_index += indexes_broadcast_from[i] * broadcast_from.strides()[i];
290+
linear_index += indexes_broadcast_from[i] * broadcast_from_strides[i];
285291
}
286292
return linear_index;
287293
}
288294

295+
size_t linearize_access_indexes(
296+
ArrayRef<size_t> indexes_broadcast_to,
297+
ssize_t broadcast_to_ndim,
298+
const Tensor& broadcast_from) {
299+
return linearize_access_indexes(
300+
indexes_broadcast_to,
301+
broadcast_to_ndim,
302+
broadcast_from.sizes(),
303+
broadcast_from.strides());
304+
}
305+
289306
} // namespace executor
290307
} // namespace torch

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,25 @@
1414
namespace torch {
1515
namespace executor {
1616

17+
/**
18+
* Check whether or not the broadcast_from_shape can be broadcasted onto the
19+
* broadcast_to_shape.
20+
*
21+
* @param[in] broadcast_from_shape The tensor shape which we want to broadcast.
22+
* @param[in] broadcast_to_shape The tensor shape which we want to broadcast to.
23+
* @returns A bool to indicate whether or not the shape can be broadcasted.
24+
*
25+
*/
26+
bool tensor_is_broadcastable_to(
27+
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
28+
const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape);
29+
1730
/**
1831
* Check whether or not the broadcast_from tensor should and can be broadcasted
1932
* onto the broadcast_to tensor. broadcast_tensor should only be called if this
2033
* returns true.
2134
*
22-
* @param[in] broadcast_from The tensor to which we want to broadcast from.
35+
* @param[in] broadcast_from The tensor which we want to broadcast from.
2336
* @param[in] broadcast_to The tensor to which we want to broadcast to.
2437
* @returns A bool to indicate whether or not the tensor can be broadcasted.
2538
*
@@ -29,11 +42,11 @@ bool tensor_is_broadcastable_to(
2942
const Tensor& broadcast_to);
3043

3144
/**
32-
* Returns true if the two tensors can both be broadcasted to a common shape.
45+
* Returns true if the two tensor shapes can both be broadcasted to a common
46+
* shape.
3347
*
3448
* @param[in] a_shape The sizes of the first tensor going to be test.
3549
* @param[in] b_shape The sizes of the second tensor going to be test.
36-
*
3750
* @returns true if the tensors are broadcastable, false otherwise.
3851
*/
3952
bool tensors_are_broadcastable_between(
@@ -45,7 +58,6 @@ bool tensors_are_broadcastable_between(
4558
*
4659
* @param[in] a The first tensor going to be test.
4760
* @param[in] b The second tensor going to be test.
48-
*
4961
* @returns true if the tensors are broadcastable, false otherwise.
5062
*/
5163
bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b);
@@ -195,12 +207,29 @@ void delinearize_index(
195207
size_t* out_indexes,
196208
const size_t out_indexes_len);
197209

210+
/**
211+
* Return the linear index for broatcast_from tensor, given the indexes and
212+
* number of dimensions of broadcast_to tensor, and the shape and strides
213+
* of broadcast_from tensor.
214+
*
215+
* @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
216+
* @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
217+
* @param[in] broadcast_from_shape The shape of the broadcasted tensor.
218+
* @param[in] broadcast_from_strides The strides of the broadcasted tensor.
219+
* @returns The flattend index for broadcast_from tensor.
220+
*/
221+
size_t linearize_access_indexes(
222+
ArrayRef<size_t> indexes_broadcast_to,
223+
ssize_t broadcast_to_ndim,
224+
exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
225+
exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides);
226+
198227
/**
199228
* Return the linear index for broatcast_from tensor, given the indexes of
200229
* broadcast_to tensor and itself.
201230
*
202-
* @param[in] indexes The tensor access indexes of broadcast_to tensor
203-
* @param[in] broadcast_to_ndim The number of dims of the broadcasted shape.
231+
* @param[in] indexes_broadcast_to The access indexes of broadcast_to tensor.
232+
* @param[in] broadcast_to_ndim The number of dims of broadcast_to tensor.
204233
* @param[in] broadcast_from The tensor to be broadcasted.
205234
* @returns The flattend index for broadcast_from tensor.
206235
*/

0 commit comments

Comments
 (0)