Skip to content

Commit b771744

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Update and fix cat
Summary: Clean up implementation of `aten::cat_out`, and allow it to handle input tensor list with different dtypes. Reviewed By: manuelcandales Differential Revision: D47600068 fbshipit-source-id: 57e5a5c6fdc4afc341847219c9d396067cefb81d
1 parent 150b051 commit b771744

File tree

4 files changed

+113
-143
lines changed

4 files changed

+113
-143
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 31 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cstring>
44

5+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
56
#include <executorch/runtime/kernel/kernel_includes.h>
67

78
namespace torch {
@@ -10,159 +11,46 @@ namespace native {
1011

1112
using Tensor = exec_aten::Tensor;
1213

13-
namespace {
14-
15-
// TODO(T128954939): Move this to a common spot so all implementation of
16-
// this operator can share it. (e.g., DSP-specific)
17-
/// Asserts that the parameters are valid.
18-
void check_cat_out_args(
19-
exec_aten::ArrayRef<Tensor> tensors,
20-
int64_t dim,
21-
Tensor& out) {
22-
// Ensure the input tensors list is non-empty
23-
ET_CHECK_MSG(tensors.size() > 0, "Cat expects non-empty tensor list");
24-
25-
// Ensure dim is in range. Use `out` as a proxy for all input tensors, since
26-
// they will all need to have the same number of dimensions.
27-
ET_CHECK_MSG(
28-
dim >= 0 && dim < out.dim(),
29-
"dim %" PRId64 " out of range [0,%zd)",
30-
dim,
31-
out.dim());
32-
33-
size_t cat_dim_size = 0;
34-
for (size_t i = 0; i < tensors.size(); ++i) {
35-
// All input dtypes must match the output dtype.
36-
ET_CHECK_MSG(
37-
tensors[i].scalar_type() == out.scalar_type(),
38-
"tensors[%zu] dtype %hhd != out dtype %hhd",
39-
i,
40-
tensors[i].scalar_type(),
41-
out.scalar_type());
42-
43-
// Empty tensors have no shape constraints.
44-
if (tensors[i].numel() == 0) {
45-
continue;
46-
}
47-
48-
// All input tensors must have the same number of dimensions as the output.
49-
ET_CHECK_MSG(
50-
tensors[i].dim() == out.dim(),
51-
"tensors[%zu].dim() %zd != out.dim() %zd",
52-
i,
53-
tensors[i].dim(),
54-
out.dim());
55-
56-
// "All tensors must either have the same shape (except in the concatenating
57-
// dimension) or be empty."
58-
// https://pytorch.org/docs/stable/generated/torch.cat.html
59-
for (size_t d = 0; d < tensors[i].dim(); ++d) {
60-
if (d != dim) {
61-
ET_CHECK_MSG(
62-
tensors[i].size(d) == out.size(d),
63-
"tensors[%zu].size(%zu) %zd != out.size(%zu) %zd",
64-
i,
65-
d,
66-
tensors[i].size(d),
67-
d,
68-
out.size(d));
69-
}
70-
}
71-
72-
cat_dim_size += tensors[i].size(dim);
73-
}
74-
75-
// The size of the cat dimension of the output should be the sum of the
76-
// input cat dimension sizes.
77-
ET_CHECK_MSG(
78-
out.size(dim) == cat_dim_size,
79-
"out.size(%" PRId64 ") %zd != %zu",
80-
dim,
81-
out.size(dim),
82-
cat_dim_size);
83-
}
84-
85-
void resize_out_tensor(
86-
exec_aten::ArrayRef<Tensor>& tensors,
87-
int64_t dim,
88-
Tensor& out) {
89-
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
90-
91-
// Some elements of expected_output_size may not be set during the loop
92-
// over all the tensors. Set all of them ahead of time here so that none are
93-
// unset by the end of that loop
94-
for (size_t i = 0; i < out.dim(); ++i) {
95-
expected_output_size[i] = out.size(i);
96-
}
97-
98-
size_t cat_dim_size = 0;
99-
for (size_t i = 0; i < tensors.size(); ++i) {
100-
// Empty tensors have no shape constraints.
101-
if (tensors[i].numel() == 0) {
102-
continue;
103-
}
104-
for (size_t d = 0; d < tensors[i].dim(); ++d) {
105-
if (d != dim) {
106-
expected_output_size[d] = tensors[i].size(d);
107-
}
108-
}
109-
cat_dim_size += tensors[i].size(dim);
110-
}
111-
112-
expected_output_size[dim] = cat_dim_size;
113-
114-
ArrayRef<Tensor::SizesType> output_size{
115-
expected_output_size, static_cast<size_t>(out.dim())};
116-
117-
torch::executor::Error err = resize_tensor(out, output_size);
118-
ET_CHECK_MSG(
119-
err == torch::executor::Error::Ok,
120-
"Failed to resize out Tensor in cat_out");
121-
}
122-
} // namespace
123-
124-
/// cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
12514
Tensor& cat_out(
12615
RuntimeContext& context,
12716
exec_aten::ArrayRef<Tensor> tensors,
12817
int64_t dim,
12918
Tensor& out) {
130-
// Support python-style negative indexing. E.g., for the shape {2, 3, 4},
131-
// dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
13219
if (dim < 0) {
13320
dim += out.dim();
13421
}
13522

136-
resize_out_tensor(tensors, dim, out);
137-
138-
// Assert that the args are valid.
139-
check_cat_out_args(tensors, dim, out);
140-
141-
size_t cat_dim = out.size(dim);
142-
143-
size_t leading_dims = getLeadingDims(out, dim);
144-
size_t trailing_dims = getTrailingDims(out, dim);
145-
146-
size_t element_size = out.element_size();
147-
size_t step = cat_dim * trailing_dims * element_size;
148-
149-
char* out_data = out.data_ptr<char>();
150-
for (size_t i = 0, e = tensors.size(); i < e; ++i) {
151-
if (tensors[i].numel() == 0) {
152-
// Ignore empty tensor.
153-
continue;
154-
}
155-
size_t num_bytes = tensors[i].size(dim) * trailing_dims * element_size;
156-
157-
const char* src = tensors[i].data_ptr<char>();
158-
char* dest = out_data;
159-
for (size_t j = 0; j < leading_dims; ++j) {
160-
memcpy(dest, src, num_bytes);
161-
dest += step;
162-
src += num_bytes;
23+
check_cat_args(tensors, dim, out);
24+
25+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
26+
size_t expected_out_dim = 0;
27+
get_cat_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
28+
ET_CHECK(
29+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok);
30+
31+
const size_t outer = getLeadingDims(out, dim);
32+
const size_t dim_stride = getTrailingDims(out, dim);
33+
const size_t ninputs = tensors.size();
34+
35+
const auto out_type = out.scalar_type();
36+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "cat", CTYPE_OUT, [&] {
37+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
38+
for (size_t i = 0; i < outer; ++i) {
39+
for (size_t j = 0; j < ninputs; ++j) {
40+
const auto in_type = tensors[j].scalar_type();
41+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "cat", CTYPE_IN, [&] {
42+
size_t inner = tensors[j].size(dim) * dim_stride;
43+
const CTYPE_IN* const in_ptr =
44+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
45+
46+
for (size_t k = 0; k < inner; ++k) {
47+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
48+
}
49+
out_ptr += inner;
50+
});
51+
}
16352
}
164-
out_data += num_bytes;
165-
}
53+
});
16654

16755
return out;
16856
}

kernels/portable/cpu/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ _ATEN_OPS = (
169169
),
170170
op_target(
171171
name = "op_cat",
172+
deps = [
173+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
174+
],
172175
),
173176
op_target(
174177
name = "op_clamp",

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,74 @@ namespace executor {
1212

1313
using Tensor = exec_aten::Tensor;
1414

15+
void check_cat_args(
16+
exec_aten::ArrayRef<Tensor> tensors,
17+
int64_t dim,
18+
Tensor& out) {
19+
// Ensure the input tensors list is non-empty
20+
ET_CHECK(tensors.size() > 0);
21+
22+
// Find the first non-empty tensor in the list to use as a reference
23+
size_t ref_i = 0;
24+
for (size_t i = 0; i < tensors.size(); ++i) {
25+
if (tensors[i].numel() > 0) {
26+
ref_i = i;
27+
break;
28+
}
29+
}
30+
31+
// "All tensors must either have the same shape (except in the concatenating
32+
// dimension) or be empty."
33+
// https://pytorch.org/docs/stable/generated/torch.cat.html
34+
for (size_t i = 0; i < tensors.size(); ++i) {
35+
// All input dtypes must be castable to the output dtype.
36+
ET_CHECK(canCast(tensors[i].scalar_type(), out.scalar_type()));
37+
38+
// Empty tensors have no shape constraints.
39+
if (tensors[i].numel() == 0) {
40+
continue;
41+
}
42+
43+
// All input tensors must have the same number of dimensions.
44+
ET_CHECK(tensors[i].dim() == tensors[ref_i].dim());
45+
46+
for (size_t d = 0; d < tensors[i].dim(); ++d) {
47+
if (d != dim) {
48+
ET_CHECK(tensors[i].size(d) == tensors[ref_i].size(d));
49+
}
50+
}
51+
}
52+
53+
// Ensure dim is in range.
54+
ET_CHECK(dim >= 0 && dim < tensors[ref_i].dim());
55+
}
56+
57+
void get_cat_out_target_size(
58+
exec_aten::ArrayRef<Tensor> tensors,
59+
int64_t dim,
60+
Tensor::SizesType* out_sizes,
61+
size_t* out_ndim) {
62+
// Find the last non-empty tensor in the list to use as a reference
63+
size_t ref_i = 0;
64+
size_t cat_dim_size = 0;
65+
for (size_t i = 0; i < tensors.size(); ++i) {
66+
if (tensors[i].numel() > 0) {
67+
cat_dim_size += tensors[i].size(dim);
68+
ref_i = i;
69+
}
70+
}
71+
72+
*out_ndim = tensors[ref_i].dim();
73+
74+
for (size_t d = 0; d < *out_ndim; ++d) {
75+
if (d != dim) {
76+
out_sizes[d] = tensors[ref_i].size(d);
77+
} else {
78+
out_sizes[d] = cat_dim_size;
79+
}
80+
}
81+
}
82+
1583
void check_stack_args(
1684
exec_aten::ArrayRef<Tensor> tensors,
1785
int64_t dim,

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
namespace torch {
66
namespace executor {
77

8+
void check_cat_args(
9+
exec_aten::ArrayRef<Tensor> tensors,
10+
int64_t dim,
11+
Tensor& out);
12+
13+
void get_cat_out_target_size(
14+
exec_aten::ArrayRef<Tensor> tensors,
15+
int64_t dim,
16+
Tensor::SizesType* out_sizes,
17+
size_t* out_ndim);
18+
819
void check_stack_args(
920
exec_aten::ArrayRef<Tensor> tensors,
1021
int64_t dim,

0 commit comments

Comments
 (0)