6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < cstring>
10
-
11
9
#include < executorch/kernels/portable/cpu/util/transpose_util.h>
12
10
#include < executorch/runtime/kernel/kernel_includes.h>
13
- #include < executorch/runtime/platform/assert.h>
14
11
15
12
namespace torch {
16
13
namespace executor {
@@ -20,43 +17,6 @@ using SizesType = exec_aten::SizesType;
20
17
using StridesType = exec_aten::StridesType;
21
18
using Tensor = exec_aten::Tensor;
22
19
23
- namespace {
24
-
25
- /* *
26
- * Verifies preconditions of transpose_copy_int_out
27
- */
28
- void check_preconditions (
29
- const Tensor& a,
30
- int64_t dim0,
31
- int64_t dim1,
32
- Tensor& out) {
33
- auto a_dim = a.dim ();
34
- ET_CHECK_MSG (
35
- a_dim >= 0 && a_dim == out.dim (), " invalid rank of tensor a: %zd" , a_dim);
36
- if (a_dim == 0 ) {
37
- ET_CHECK (dim0 == 0 || dim0 == -1 );
38
- ET_CHECK (dim1 == 0 || dim1 == -1 );
39
- return ;
40
- }
41
- ET_CHECK_MSG (
42
- dim0 >= 0 && dim0 < a_dim,
43
- " dim0: %" PRId64 " out of bounds [0,%zd)" ,
44
- dim0,
45
- a_dim);
46
- ET_CHECK_MSG (
47
- dim1 >= 0 && dim1 < a_dim,
48
- " dim1: %" PRId64 " out of bounds [0,%zd)" ,
49
- dim1,
50
- a_dim);
51
- ET_CHECK_MSG (
52
- a_dim <= kTensorDimensionLimit ,
53
- " input tensor rank %zd greater than %zu" ,
54
- a_dim,
55
- kTensorDimensionLimit );
56
- }
57
-
58
- } // namespace
59
-
60
20
/* *
61
21
* Swaps dimension 'dim0' of 'a' with 'dim1', and copying
62
22
* that mutation into `out` in a manner such that the data is densely packed
@@ -66,37 +26,40 @@ void check_preconditions(
66
26
*/
67
27
Tensor& transpose_copy_int_out (
68
28
RuntimeContext& ctx,
69
- const Tensor& a ,
29
+ const Tensor& in ,
70
30
int64_t dim0,
71
31
int64_t dim1,
72
32
Tensor& out) {
73
33
(void )ctx;
74
34
75
- ET_CHECK_SAME_DTYPE2 (a, out);
35
+ ET_KERNEL_CHECK (
36
+ ctx,
37
+ check_transpose_copy_args (in, dim0, dim1, out),
38
+ InvalidArgument,
39
+ out);
76
40
77
- // fix python negative indexing
78
41
if (dim0 < 0 ) {
79
- dim0 += out. dim ( );
42
+ dim0 += nonzero_dim (out );
80
43
}
81
44
if (dim1 < 0 ) {
82
- dim1 += out. dim ( );
45
+ dim1 += nonzero_dim (out );
83
46
}
84
- check_preconditions (a, dim0, dim1, out);
85
- #define TRANSPOSE_TENSORS (ctype, dtype ) \
86
- case ScalarType::dtype: \
87
- transpose_tensors<ctype>(a, dim0, dim1, out); \
88
- break ;
89
47
90
- switch (a.scalar_type ()) {
91
- ET_FORALL_SCALAR_TYPES (TRANSPOSE_TENSORS)
92
- default :
93
- ET_CHECK_MSG (
94
- false ,
95
- " Unhandled dtype %" PRId8,
96
- static_cast <int8_t >(a.scalar_type ()));
97
- }
48
+ Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
49
+ size_t expected_out_dim = 0 ;
50
+ get_transpose_out_target_size (
51
+ in, dim0, dim1, expected_out_size, &expected_out_dim);
52
+
53
+ // Resize for dynamic shape
54
+ ET_KERNEL_CHECK (
55
+ ctx,
56
+ resize_tensor (out, {expected_out_size, expected_out_dim}) == Error::Ok,
57
+ InvalidArgument,
58
+ out);
98
59
99
- #undef TRANSPOSE_TENSORS
60
+ ET_SWITCH_ALL_TYPES (in.scalar_type (), ctx, __func__, CTYPE, [&] {
61
+ transpose_tensors<CTYPE>(in, dim0, dim1, out);
62
+ });
100
63
101
64
return out;
102
65
}
0 commit comments