6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < cstring>
10
-
11
9
#include < executorch/kernels/portable/cpu/util/transpose_util.h>
12
10
#include < executorch/runtime/kernel/kernel_includes.h>
13
- #include < executorch/runtime/platform/assert.h >
11
+ #include < cstring >
14
12
15
13
namespace torch {
16
14
namespace executor {
@@ -20,55 +18,47 @@ using SizesType = exec_aten::SizesType;
20
18
using StridesType = exec_aten::StridesType;
21
19
using Tensor = exec_aten::Tensor;
22
20
23
- namespace {
24
-
25
- /* *
26
- * Verifies preconditions of t_copy_int_out
27
- */
28
- void check_preconditions (const Tensor& a, Tensor& out) {
29
- auto a_dim = a.dim ();
30
- ET_CHECK_MSG (
31
- a_dim >= 0 && a_dim <= 2 ,
32
- " Rank of tensor a has to be <=2 but received tensor of rank : %zd.:" ,
33
- a_dim);
34
- if (a_dim < 2 ) {
35
- ET_CHECK_SAME_SHAPE_AND_DTYPE2 (a, out);
36
- } else {
37
- ET_CHECK_SAME_DTYPE2 (a, out);
38
- ET_CHECK_MSG (
39
- (a.sizes ()[0 ] == out.sizes ()[1 ]) && (a.sizes ()[1 ] == out.sizes ()[0 ]),
40
- " Input tensor and output tensor shapes do not support transposing" );
41
- ET_CHECK_MSG (out.dim () == 2 , " Output tensor must have same dim (2)" );
42
- }
43
- }
44
-
45
- } // namespace
46
-
47
21
/* *
48
22
* Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.
49
23
* 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this
50
24
* is equivalent to transpose(input, 0, 1).
51
25
* t_copy.out(Tensor self, Tensor(a!) out)
52
26
*/
53
- Tensor& t_copy_out (RuntimeContext& ctx, const Tensor& a , Tensor& out) {
27
+ Tensor& t_copy_out (RuntimeContext& ctx, const Tensor& in , Tensor& out) {
54
28
(void )ctx;
55
- check_preconditions (a, out);
56
- int dim_1 = a.sizes ().size () == 2 ? 1 : 0 ;
57
- #define TRANSPOSE_TENSORS (ctype, dtype ) \
58
- case ScalarType::dtype: \
59
- transpose_tensors<ctype>(a, 0 , dim_1, out); \
60
- break ;
61
29
62
- switch (a.scalar_type ()) {
63
- ET_FORALL_SCALAR_TYPES (TRANSPOSE_TENSORS)
64
- default :
65
- ET_CHECK_MSG (
66
- false ,
67
- " Unhandled dtype %" PRId8,
68
- static_cast <int8_t >(a.scalar_type ()));
30
+ ET_KERNEL_CHECK (ctx, check_t_copy_args (in, out), InvalidArgument, out);
31
+
32
+ ScalarType in_type = in.scalar_type ();
33
+
34
+ if (in.dim () < 2 ) {
35
+ // Resize for dynamic shape
36
+ ET_KERNEL_CHECK (
37
+ ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
38
+
39
+ ET_SWITCH_ALL_TYPES (in_type, ctx, __func__, CTYPE, [&]() {
40
+ const CTYPE* in_data = in.const_data_ptr <CTYPE>();
41
+ CTYPE* out_data = out.mutable_data_ptr <CTYPE>();
42
+ memcpy (out_data, in_data, in.nbytes ());
43
+ });
44
+
45
+ return out;
69
46
}
70
47
71
- #undef TRANSPOSE_TENSORS
48
+ Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
49
+ size_t expected_out_dim = 0 ;
50
+ get_transpose_out_target_size (in, 1 , 0 , expected_out_size, &expected_out_dim);
51
+
52
+ // Resize for dynamic shape
53
+ ET_KERNEL_CHECK (
54
+ ctx,
55
+ resize_tensor (out, {expected_out_size, expected_out_dim}) == Error::Ok,
56
+ InvalidArgument,
57
+ out);
58
+
59
+ ET_SWITCH_ALL_TYPES (in_type, ctx, __func__, CTYPE, [&] {
60
+ transpose_tensors<CTYPE>(in, 1 , 0 , out);
61
+ });
72
62
73
63
return out;
74
64
}
0 commit comments