6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ #include < executorch/backends/cadence/fusion_g3/operators/operators.h>
10
+ #include < executorch/backends/cadence/fusion_g3/operators/xt_utils.h>
11
+
9
12
#include < cstring>
10
13
11
14
#include < xa_nnlib_kernels_api.h>
12
15
16
+ #include < executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
13
17
#include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
14
18
#include < executorch/runtime/kernel/kernel_includes.h>
15
19
20
+ using ::executorch::aten::ArrayRef;
16
21
using ::executorch::aten::ScalarType;
17
22
using ::executorch::aten::Tensor;
18
23
using ::executorch::runtime::Error;
@@ -23,7 +28,6 @@ using ::executorch::runtime::KernelRuntimeContext;
23
28
* updated to have support for below data types, these can be removed and
24
29
* operator need to be updated accordingly
25
30
*/
26
- enum datatype { Ushort = 20 , Uint = 23 };
27
31
28
32
namespace cadence {
29
33
namespace impl {
@@ -32,20 +36,22 @@ namespace native {
32
36
33
37
Tensor& cat_out (
34
38
KernelRuntimeContext& ctx,
35
- exec_aten:: ArrayRef<Tensor> tensors,
39
+ ArrayRef<Tensor> tensors,
36
40
int64_t dim,
37
41
Tensor& out) {
38
42
if (dim < 0 ) {
39
43
dim += out.dim ();
40
44
}
41
45
46
+ int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit ;
47
+
48
+ #ifdef OP_ARG_CHECK
42
49
ET_KERNEL_CHECK (
43
50
ctx,
44
51
torch::executor::check_cat_args (tensors, dim, out),
45
52
InvalidArgument,
46
53
out);
47
54
48
- int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit ;
49
55
Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
50
56
size_t expected_out_dim = 0 ;
51
57
torch::executor::get_cat_out_target_size (
@@ -57,14 +63,28 @@ Tensor& cat_out(
57
63
out, {expected_out_size, expected_out_dim}) == Error::Ok,
58
64
InvalidArgument,
59
65
out);
66
+ #endif
67
+ // Special handling when all inputs are 1D-empty tensors for aten
68
+ // consistency In that case, just return an 1D-empty tensor without checking
69
+ // dim
70
+ bool all_1d_empty = true ;
71
+ for (size_t i = 0 ; i < tensors.size (); ++i) {
72
+ if (tensors[i].numel () != 0 || tensors[i].dim () != 1 ) {
73
+ all_1d_empty = false ;
74
+ break ;
75
+ }
76
+ }
77
+ if (all_1d_empty) {
78
+ return out;
79
+ }
60
80
61
81
const signed char * inp_tensors[tensors.size ()];
62
82
const int * inp_tensors_shapes[tensors.size ()];
63
83
64
84
int inp_shapes_size[tensors.size ()];
65
85
66
86
int temp_sizes[tensors.size ()][kTensorDimensionLimit ];
67
- exec_aten:: ArrayRef<Tensor::SizesType> temp_size;
87
+ ArrayRef<Tensor::SizesType> temp_size;
68
88
69
89
for (int i = 0 ; i < tensors.size (); i++) {
70
90
inp_tensors[i] = tensors[i].const_data_ptr <signed char >();
@@ -79,88 +99,32 @@ Tensor& cat_out(
79
99
80
100
signed char * out_data = out.mutable_data_ptr <signed char >();
81
101
82
- const exec_aten:: ArrayRef<Tensor::SizesType> out_size = out.sizes ();
102
+ const ArrayRef<Tensor::SizesType> out_size = out.sizes ();
83
103
int out_shapes[kTensorDimensionLimit ];
84
104
for (int i = 0 ; i < out_size.size (); i++) // output shapes
85
105
{
86
106
out_shapes[i] = out_size[i];
87
107
}
88
108
89
- if (out.scalar_type () == ScalarType::Int) {
90
- xa_nn_cat (
91
- out_data,
92
- out_shapes,
93
- inp_tensors,
94
- inp_tensors_shapes,
95
- inp_shapes_size[0 ],
96
- tensors.size (),
97
- (int )dim,
98
- sizeof (int ));
99
- } else if (out.scalar_type () == ScalarType::Short) {
100
- xa_nn_cat (
101
- out_data,
102
- out_shapes,
103
- inp_tensors,
104
- inp_tensors_shapes,
105
- inp_shapes_size[0 ],
106
- tensors.size (),
107
- (int )dim,
108
- sizeof (short ));
109
- } else if (out.scalar_type () == ScalarType::Char) {
110
- xa_nn_cat (
111
- out_data,
112
- out_shapes,
113
- inp_tensors,
114
- inp_tensors_shapes,
115
- inp_shapes_size[0 ],
116
- tensors.size (),
117
- (int )dim,
118
- sizeof (char ));
119
- } else if (out.scalar_type () == (ScalarType)Uint) {
120
- xa_nn_cat (
121
- out_data,
122
- out_shapes,
123
- inp_tensors,
124
- inp_tensors_shapes,
125
- inp_shapes_size[0 ],
126
- tensors.size (),
127
- (int )dim,
128
- sizeof (int ));
129
- } else if (out.scalar_type () == (ScalarType)Ushort) {
130
- xa_nn_cat (
109
+ if ((out.scalar_type () == ScalarType::Int) ||
110
+ (out.scalar_type () == ScalarType::Short) ||
111
+ (out.scalar_type () == ScalarType::Char) ||
112
+ (out.scalar_type () == ScalarType::UInt32) ||
113
+ (out.scalar_type () == ScalarType::UInt16) ||
114
+ (out.scalar_type () == ScalarType::Byte)) {
115
+ XT_KERNEL_CHECK (
116
+ ctx,
117
+ out,
118
+ xa_nn_cat,
131
119
out_data,
132
120
out_shapes,
133
121
inp_tensors,
134
122
inp_tensors_shapes,
135
123
inp_shapes_size[0 ],
136
124
tensors.size (),
137
125
(int )dim,
138
- sizeof (short ));
139
- } else if (out.scalar_type () == ScalarType::Byte) {
140
- xa_nn_cat (
141
- out_data,
142
- out_shapes,
143
- inp_tensors,
144
- inp_tensors_shapes,
145
- inp_shapes_size[0 ],
146
- tensors.size (),
147
- (int )dim,
148
- sizeof (char ));
149
-
126
+ get_element_size (out.scalar_type ()));
150
127
} else {
151
- // Special handling when all inputs are 1D-empty tensors for aten
152
- // consistency In that case, just return an 1D-empty tensor without checking
153
- // dim
154
- bool all_1d_empty = true ;
155
- for (size_t i = 0 ; i < tensors.size (); ++i) {
156
- if (tensors[i].numel () != 0 || tensors[i].dim () != 1 ) {
157
- all_1d_empty = false ;
158
- break ;
159
- }
160
- }
161
- if (all_1d_empty) {
162
- return out;
163
- }
164
128
const size_t outer = executorch::runtime::getLeadingDims (out, dim);
165
129
const size_t dim_stride = executorch::runtime::getTrailingDims (out, dim);
166
130
const size_t ninputs = tensors.size ();
0 commit comments