2
2
3
3
#include < cstring>
4
4
5
+ #include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
5
6
#include < executorch/runtime/kernel/kernel_includes.h>
6
7
7
8
namespace torch {
@@ -10,127 +11,48 @@ namespace native {
10
11
11
12
using Tensor = exec_aten::Tensor;
12
13
13
- namespace {
14
-
15
- // TODO(gasoonjia): Move this to a common spot so all implementation of
16
- // this operator can share it. (e.g., DSP-specific)
17
- // / Asserts that the parameters are valid.
18
- void check_stack_out_args (
19
- exec_aten::ArrayRef<Tensor> tensors,
20
- int64_t dim,
21
- Tensor& out) {
22
- // Stack expects non-empty tensor list
23
- ET_CHECK_MSG (tensors.size () > 0 , " Stack expects non-empty tensor list" );
24
-
25
- // Ensure dim is in range. Use `out` as a proxy for all input tensors, since
26
- // they will all need to have the same number of dimensions besides the dim
27
- // one.
28
- ET_CHECK_MSG (
29
- dim >= 0 && dim < out.dim (),
30
- " dim %" PRId64 " out of range [0,%zd)" ,
31
- dim,
32
- out.dim ());
33
-
34
- for (size_t i = 0 ; i < tensors.size (); i++) {
35
- // All input dtypes must match the output dtype.
36
- ET_CHECK_MSG (
37
- tensors[i].scalar_type () == out.scalar_type (),
38
- " tensors[%zu] dtype %hhd != out dtype %hhd" ,
39
- i,
40
- tensors[i].scalar_type (),
41
- out.scalar_type ());
42
-
43
- // All input tensors need to be of the same size
44
- // Also, since we create a new axis in output for stacking, the output.dim()
45
- // should be one larger than input.dim()
46
- // https://pytorch.org/docs/stable/generated/torch.stack.html
47
- ET_CHECK_MSG (
48
- tensors[i].dim () == out.dim () - 1 ,
49
- " tensors[%zu].dim() %zd != out.dim() - 1 %zd" ,
50
- i,
51
- tensors[i].dim (),
52
- out.dim () - 1 );
53
-
54
- // The size of each input tensor should be the same. Here we use `out` as
55
- // proxy for comparsion. Also, the size of output tensor should follow these
56
- // rules:
57
- // - For any input tensor, its size(i) == output.size(i) if i < dim, and its
58
- // size(i) == output.size(i+1) if i >= dim
59
- // - For the cat dimension (output[dim]), its size should be the number of
60
- // input tensors
61
- for (size_t d = 0 ; d < tensors[i].dim (); d++) {
62
- if (d < dim) {
63
- ET_CHECK_MSG (
64
- tensors[i].size (d) == out.size (d),
65
- " tensors[%zu].size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
66
- i,
67
- d,
68
- tensors[i].size (d),
69
- d,
70
- out.size (d),
71
- dim);
72
- } else {
73
- ET_CHECK_MSG (
74
- tensors[i].size (d) == out.size (d + 1 ),
75
- " tensors[%zu].size(%zu) %zd != out.size(%zu) %zd | dim = %" PRId64,
76
- i,
77
- d,
78
- tensors[i].size (d),
79
- d + 1 ,
80
- out.size (d + 1 ),
81
- dim);
82
- }
83
- }
84
- }
85
-
86
- // The size of the stack dimension of the output should be the number of
87
- // input tensors
88
- ET_CHECK_MSG (
89
- out.size (dim) == tensors.size (),
90
- " out.size(%" PRId64 " ) %zd != number of input tensors %zu" ,
91
- dim,
92
- out.size (dim),
93
- tensors.size ());
94
- }
95
- } // namespace
96
-
97
- // / stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
98
14
Tensor& stack_out (
99
- RuntimeContext& context ,
15
+ RuntimeContext& ctx ,
100
16
exec_aten::ArrayRef<Tensor> tensors,
101
17
int64_t dim,
102
18
Tensor& out) {
103
- (void )context;
104
- // Support python-style negative indexing. E.g., for the shape {2, 3, 4},
105
- // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
19
+ (void )ctx;
20
+
106
21
if (dim < 0 ) {
107
22
dim += out.dim ();
108
23
}
109
24
110
- // Assert that the args are valid.
111
- check_stack_out_args (tensors, dim, out);
112
-
113
- // If one tensor is empty tensor, all tensors are empty since they share same
114
- // size. Under that, no need do anything. Just return the out.
115
- if (tensors[0 ].numel () == 0 ) {
116
- return out;
117
- }
118
-
119
- size_t leading_dim = getLeadingDims (out, dim);
120
- size_t trailing_dim = getTrailingDims (out, dim);
121
- size_t num_of_tensors = tensors.size ();
122
-
123
- size_t chunk_size = trailing_dim * out.element_size ();
124
-
125
- char * dst_ptr = out.data_ptr <char >();
126
-
127
- for (int i = 0 ; i < leading_dim; i++) {
128
- for (int j = 0 ; j < num_of_tensors; j++) {
129
- char * src_ptr = tensors[j].data_ptr <char >() + chunk_size * i;
130
- memcpy (dst_ptr, src_ptr, chunk_size);
131
- dst_ptr += chunk_size;
25
+ check_stack_args (tensors, dim, out);
26
+
27
+ Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
28
+ size_t expected_out_dim = 0 ;
29
+ get_stack_out_target_size (tensors, dim, expected_out_size, &expected_out_dim);
30
+ ET_CHECK (
31
+ resize_tensor (out, {expected_out_size, expected_out_dim}) == Error::Ok);
32
+
33
+ const size_t outer = getLeadingDims (out, dim);
34
+ const size_t inner = getTrailingDims (out, dim);
35
+ const size_t ninputs = tensors.size ();
36
+
37
+ const auto out_type = out.scalar_type ();
38
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " stack" , CTYPE_OUT, [&] {
39
+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
40
+ for (size_t i = 0 ; i < outer; ++i) {
41
+ for (size_t j = 0 ; j < ninputs; ++j) {
42
+ const auto in_type = tensors[j].scalar_type ();
43
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, " stack" , CTYPE_IN, [&] {
44
+ const CTYPE_IN* const in_ptr =
45
+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
46
+
47
+ for (size_t k = 0 ; k < inner; ++k) {
48
+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
49
+ }
50
+ out_ptr += inner;
51
+ });
52
+ }
132
53
}
133
- }
54
+ });
55
+
134
56
return out;
135
57
}
136
58
0 commit comments