2
2
3
3
#include < cstring>
4
4
5
+ #include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
5
6
#include < executorch/runtime/kernel/kernel_includes.h>
6
7
7
8
namespace torch {
@@ -10,159 +11,46 @@ namespace native {
10
11
11
12
using Tensor = exec_aten::Tensor;
12
13
13
- namespace {
14
-
15
- // TODO(T128954939): Move this to a common spot so all implementation of
16
- // this operator can share it. (e.g., DSP-specific)
17
- // / Asserts that the parameters are valid.
18
- void check_cat_out_args (
19
- exec_aten::ArrayRef<Tensor> tensors,
20
- int64_t dim,
21
- Tensor& out) {
22
- // Ensure the input tensors list is non-empty
23
- ET_CHECK_MSG (tensors.size () > 0 , " Cat expects non-empty tensor list" );
24
-
25
- // Ensure dim is in range. Use `out` as a proxy for all input tensors, since
26
- // they will all need to have the same number of dimensions.
27
- ET_CHECK_MSG (
28
- dim >= 0 && dim < out.dim (),
29
- " dim %" PRId64 " out of range [0,%zd)" ,
30
- dim,
31
- out.dim ());
32
-
33
- size_t cat_dim_size = 0 ;
34
- for (size_t i = 0 ; i < tensors.size (); ++i) {
35
- // All input dtypes must match the output dtype.
36
- ET_CHECK_MSG (
37
- tensors[i].scalar_type () == out.scalar_type (),
38
- " tensors[%zu] dtype %hhd != out dtype %hhd" ,
39
- i,
40
- tensors[i].scalar_type (),
41
- out.scalar_type ());
42
-
43
- // Empty tensors have no shape constraints.
44
- if (tensors[i].numel () == 0 ) {
45
- continue ;
46
- }
47
-
48
- // All input tensors must have the same number of dimensions as the output.
49
- ET_CHECK_MSG (
50
- tensors[i].dim () == out.dim (),
51
- " tensors[%zu].dim() %zd != out.dim() %zd" ,
52
- i,
53
- tensors[i].dim (),
54
- out.dim ());
55
-
56
- // "All tensors must either have the same shape (except in the concatenating
57
- // dimension) or be empty."
58
- // https://pytorch.org/docs/stable/generated/torch.cat.html
59
- for (size_t d = 0 ; d < tensors[i].dim (); ++d) {
60
- if (d != dim) {
61
- ET_CHECK_MSG (
62
- tensors[i].size (d) == out.size (d),
63
- " tensors[%zu].size(%zu) %zd != out.size(%zu) %zd" ,
64
- i,
65
- d,
66
- tensors[i].size (d),
67
- d,
68
- out.size (d));
69
- }
70
- }
71
-
72
- cat_dim_size += tensors[i].size (dim);
73
- }
74
-
75
- // The size of the cat dimension of the output should be the sum of the
76
- // input cat dimension sizes.
77
- ET_CHECK_MSG (
78
- out.size (dim) == cat_dim_size,
79
- " out.size(%" PRId64 " ) %zd != %zu" ,
80
- dim,
81
- out.size (dim),
82
- cat_dim_size);
83
- }
84
-
85
- void resize_out_tensor (
86
- exec_aten::ArrayRef<Tensor>& tensors,
87
- int64_t dim,
88
- Tensor& out) {
89
- Tensor::SizesType expected_output_size[kTensorDimensionLimit ];
90
-
91
- // Some elements of expected_output_size may not be set during the loop
92
- // over all the tensors. Set all of them ahead of time here so that none are
93
- // unset by the end of that loop
94
- for (size_t i = 0 ; i < out.dim (); ++i) {
95
- expected_output_size[i] = out.size (i);
96
- }
97
-
98
- size_t cat_dim_size = 0 ;
99
- for (size_t i = 0 ; i < tensors.size (); ++i) {
100
- // Empty tensors have no shape constraints.
101
- if (tensors[i].numel () == 0 ) {
102
- continue ;
103
- }
104
- for (size_t d = 0 ; d < tensors[i].dim (); ++d) {
105
- if (d != dim) {
106
- expected_output_size[d] = tensors[i].size (d);
107
- }
108
- }
109
- cat_dim_size += tensors[i].size (dim);
110
- }
111
-
112
- expected_output_size[dim] = cat_dim_size;
113
-
114
- ArrayRef<Tensor::SizesType> output_size{
115
- expected_output_size, static_cast <size_t >(out.dim ())};
116
-
117
- torch::executor::Error err = resize_tensor (out, output_size);
118
- ET_CHECK_MSG (
119
- err == torch::executor::Error::Ok,
120
- " Failed to resize out Tensor in cat_out" );
121
- }
122
- } // namespace
123
-
124
- // / cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
125
14
Tensor& cat_out (
126
15
RuntimeContext& context,
127
16
exec_aten::ArrayRef<Tensor> tensors,
128
17
int64_t dim,
129
18
Tensor& out) {
130
- // Support python-style negative indexing. E.g., for the shape {2, 3, 4},
131
- // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
132
19
if (dim < 0 ) {
133
20
dim += out.dim ();
134
21
}
135
22
136
- resize_out_tensor (tensors, dim, out);
137
-
138
- // Assert that the args are valid.
139
- check_cat_out_args (tensors, dim, out);
140
-
141
- size_t cat_dim = out.size (dim);
142
-
143
- size_t leading_dims = getLeadingDims (out, dim);
144
- size_t trailing_dims = getTrailingDims (out, dim);
145
-
146
- size_t element_size = out.element_size ();
147
- size_t step = cat_dim * trailing_dims * element_size;
148
-
149
- char * out_data = out.data_ptr <char >();
150
- for (size_t i = 0 , e = tensors.size (); i < e; ++i) {
151
- if (tensors[i].numel () == 0 ) {
152
- // Ignore empty tensor.
153
- continue ;
154
- }
155
- size_t num_bytes = tensors[i].size (dim) * trailing_dims * element_size;
156
-
157
- const char * src = tensors[i].data_ptr <char >();
158
- char * dest = out_data;
159
- for (size_t j = 0 ; j < leading_dims; ++j) {
160
- memcpy (dest, src, num_bytes);
161
- dest += step;
162
- src += num_bytes;
23
+ check_cat_args (tensors, dim, out);
24
+
25
+ Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
26
+ size_t expected_out_dim = 0 ;
27
+ get_cat_out_target_size (tensors, dim, expected_out_size, &expected_out_dim);
28
+ ET_CHECK (
29
+ resize_tensor (out, {expected_out_size, expected_out_dim}) == Error::Ok);
30
+
31
+ const size_t outer = getLeadingDims (out, dim);
32
+ const size_t dim_stride = getTrailingDims (out, dim);
33
+ const size_t ninputs = tensors.size ();
34
+
35
+ const auto out_type = out.scalar_type ();
36
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " cat" , CTYPE_OUT, [&] {
37
+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
38
+ for (size_t i = 0 ; i < outer; ++i) {
39
+ for (size_t j = 0 ; j < ninputs; ++j) {
40
+ const auto in_type = tensors[j].scalar_type ();
41
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, " cat" , CTYPE_IN, [&] {
42
+ size_t inner = tensors[j].size (dim) * dim_stride;
43
+ const CTYPE_IN* const in_ptr =
44
+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
45
+
46
+ for (size_t k = 0 ; k < inner; ++k) {
47
+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
48
+ }
49
+ out_ptr += inner;
50
+ });
51
+ }
163
52
}
164
- out_data += num_bytes;
165
- }
53
+ });
166
54
167
55
return out;
168
56
}
0 commit comments