Skip to content

Commit eaca8f7

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op transpose_copy (#700)
Summary: Pull Request resolved: #700 ghstack-source-id: 203341579 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49735856 fbshipit-source-id: 47a8c43728d2c3dd511495f9c306807fc4851180
1 parent d3df91a commit eaca8f7

File tree

2 files changed

+49
-59
lines changed

2 files changed

+49
-59
lines changed

kernels/portable/cpu/op_transpose_copy.cpp

Lines changed: 22 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cstring>
10-
119
#include <executorch/kernels/portable/cpu/util/transpose_util.h>
1210
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
1411

1512
namespace torch {
1613
namespace executor {
@@ -20,43 +17,6 @@ using SizesType = exec_aten::SizesType;
2017
using StridesType = exec_aten::StridesType;
2118
using Tensor = exec_aten::Tensor;
2219

23-
namespace {
24-
25-
/**
26-
* Verifies preconditions of transpose_copy_int_out
27-
*/
28-
void check_preconditions(
29-
const Tensor& a,
30-
int64_t dim0,
31-
int64_t dim1,
32-
Tensor& out) {
33-
auto a_dim = a.dim();
34-
ET_CHECK_MSG(
35-
a_dim >= 0 && a_dim == out.dim(), "invalid rank of tensor a: %zd", a_dim);
36-
if (a_dim == 0) {
37-
ET_CHECK(dim0 == 0 || dim0 == -1);
38-
ET_CHECK(dim1 == 0 || dim1 == -1);
39-
return;
40-
}
41-
ET_CHECK_MSG(
42-
dim0 >= 0 && dim0 < a_dim,
43-
"dim0: %" PRId64 " out of bounds [0,%zd)",
44-
dim0,
45-
a_dim);
46-
ET_CHECK_MSG(
47-
dim1 >= 0 && dim1 < a_dim,
48-
"dim1: %" PRId64 " out of bounds [0,%zd)",
49-
dim1,
50-
a_dim);
51-
ET_CHECK_MSG(
52-
a_dim <= kTensorDimensionLimit,
53-
"input tensor rank %zd greater than %zu",
54-
a_dim,
55-
kTensorDimensionLimit);
56-
}
57-
58-
} // namespace
59-
6020
/**
6121
* Swaps dimension 'dim0' of 'a' with 'dim1', and copying
6222
* that mutation into `out` in a manner such that the data is densely packed
@@ -66,37 +26,40 @@ void check_preconditions(
6626
*/
6727
Tensor& transpose_copy_int_out(
6828
RuntimeContext& ctx,
69-
const Tensor& a,
29+
const Tensor& in,
7030
int64_t dim0,
7131
int64_t dim1,
7232
Tensor& out) {
7333
(void)ctx;
7434

75-
ET_CHECK_SAME_DTYPE2(a, out);
35+
ET_KERNEL_CHECK(
36+
ctx,
37+
check_transpose_copy_args(in, dim0, dim1, out),
38+
InvalidArgument,
39+
out);
7640

77-
// fix python negative indexing
7841
if (dim0 < 0) {
79-
dim0 += out.dim();
42+
dim0 += nonzero_dim(out);
8043
}
8144
if (dim1 < 0) {
82-
dim1 += out.dim();
45+
dim1 += nonzero_dim(out);
8346
}
84-
check_preconditions(a, dim0, dim1, out);
85-
#define TRANSPOSE_TENSORS(ctype, dtype) \
86-
case ScalarType::dtype: \
87-
transpose_tensors<ctype>(a, dim0, dim1, out); \
88-
break;
8947

90-
switch (a.scalar_type()) {
91-
ET_FORALL_SCALAR_TYPES(TRANSPOSE_TENSORS)
92-
default:
93-
ET_CHECK_MSG(
94-
false,
95-
"Unhandled dtype %" PRId8,
96-
static_cast<int8_t>(a.scalar_type()));
97-
}
48+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
49+
size_t expected_out_dim = 0;
50+
get_transpose_out_target_size(
51+
in, dim0, dim1, expected_out_size, &expected_out_dim);
52+
53+
// Resize for dynamic shape
54+
ET_KERNEL_CHECK(
55+
ctx,
56+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
57+
InvalidArgument,
58+
out);
9859

99-
#undef TRANSPOSE_TENSORS
60+
ET_SWITCH_ALL_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
61+
transpose_tensors<CTYPE>(in, dim0, dim1, out);
62+
});
10063

10164
return out;
10265
}

kernels/portable/cpu/util/transpose_util.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,5 +135,32 @@ void transpose_tensors(
135135
}
136136
}
137137

138+
inline bool check_transpose_copy_args(
139+
const Tensor& in,
140+
int64_t dim0,
141+
int64_t dim1,
142+
Tensor& out) {
143+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
144+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim0));
145+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim1));
146+
return true;
147+
}
148+
149+
inline void get_transpose_out_target_size(
150+
const Tensor& in,
151+
SizesType dim0,
152+
SizesType dim1,
153+
SizesType* out_sizes,
154+
size_t* out_ndim) {
155+
*out_ndim = in.dim();
156+
157+
size_t i = 0;
158+
for (; i < in.dim() - 1; ++i) {
159+
out_sizes[i] = in.size(i);
160+
}
161+
out_sizes[dim0] = in.size(dim1);
162+
out_sizes[dim1] = in.size(dim0);
163+
}
164+
138165
} // namespace executor
139166
} // namespace torch

0 commit comments

Comments
 (0)