Skip to content

Commit c1032d7

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Fix & cleanup op t_copy (#701)
Summary: Pull Request resolved: #701 ghstack-source-id: 203341581 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D49735860 fbshipit-source-id: b6acfd1fbc89c2fe1475368e48f210b71733ef1f
1 parent eaca8f7 commit c1032d7

File tree

2 files changed

+38
-42
lines changed

2 files changed

+38
-42
lines changed

kernels/portable/cpu/op_t_copy.cpp

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cstring>
10-
119
#include <executorch/kernels/portable/cpu/util/transpose_util.h>
1210
#include <executorch/runtime/kernel/kernel_includes.h>
13-
#include <executorch/runtime/platform/assert.h>
11+
#include <cstring>
1412

1513
namespace torch {
1614
namespace executor {
@@ -20,55 +18,47 @@ using SizesType = exec_aten::SizesType;
2018
using StridesType = exec_aten::StridesType;
2119
using Tensor = exec_aten::Tensor;
2220

23-
namespace {
24-
25-
/**
26-
* Verifies preconditions of t_copy_int_out
27-
*/
28-
void check_preconditions(const Tensor& a, Tensor& out) {
29-
auto a_dim = a.dim();
30-
ET_CHECK_MSG(
31-
a_dim >= 0 && a_dim <= 2,
32-
"Rank of tensor a has to be <=2 but received tensor of rank : %zd.:",
33-
a_dim);
34-
if (a_dim < 2) {
35-
ET_CHECK_SAME_SHAPE_AND_DTYPE2(a, out);
36-
} else {
37-
ET_CHECK_SAME_DTYPE2(a, out);
38-
ET_CHECK_MSG(
39-
(a.sizes()[0] == out.sizes()[1]) && (a.sizes()[1] == out.sizes()[0]),
40-
"Input tensor and output tensor shapes do not support transposing");
41-
ET_CHECK_MSG(out.dim() == 2, "Output tensor must have same dim (2)");
42-
}
43-
}
44-
45-
} // namespace
46-
4721
/**
4822
* Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.
4923
* 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this
5024
* is equivalent to transpose(input, 0, 1).
5125
* t_copy.out(Tensor self, Tensor(a!) out)
5226
*/
53-
Tensor& t_copy_out(RuntimeContext& ctx, const Tensor& a, Tensor& out) {
27+
Tensor& t_copy_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
5428
(void)ctx;
55-
check_preconditions(a, out);
56-
int dim_1 = a.sizes().size() == 2 ? 1 : 0;
57-
#define TRANSPOSE_TENSORS(ctype, dtype) \
58-
case ScalarType::dtype: \
59-
transpose_tensors<ctype>(a, 0, dim_1, out); \
60-
break;
6129

62-
switch (a.scalar_type()) {
63-
ET_FORALL_SCALAR_TYPES(TRANSPOSE_TENSORS)
64-
default:
65-
ET_CHECK_MSG(
66-
false,
67-
"Unhandled dtype %" PRId8,
68-
static_cast<int8_t>(a.scalar_type()));
30+
ET_KERNEL_CHECK(ctx, check_t_copy_args(in, out), InvalidArgument, out);
31+
32+
ScalarType in_type = in.scalar_type();
33+
34+
if (in.dim() < 2) {
35+
// Resize for dynamic shape
36+
ET_KERNEL_CHECK(
37+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
38+
39+
ET_SWITCH_ALL_TYPES(in_type, ctx, __func__, CTYPE, [&]() {
40+
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
41+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
42+
memcpy(out_data, in_data, in.nbytes());
43+
});
44+
45+
return out;
6946
}
7047

71-
#undef TRANSPOSE_TENSORS
48+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
49+
size_t expected_out_dim = 0;
50+
get_transpose_out_target_size(in, 1, 0, expected_out_size, &expected_out_dim);
51+
52+
// Resize for dynamic shape
53+
ET_KERNEL_CHECK(
54+
ctx,
55+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
56+
InvalidArgument,
57+
out);
58+
59+
ET_SWITCH_ALL_TYPES(in_type, ctx, __func__, CTYPE, [&] {
60+
transpose_tensors<CTYPE>(in, 1, 0, out);
61+
});
7262

7363
return out;
7464
}

kernels/portable/cpu/util/transpose_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ void transpose_tensors(
135135
}
136136
}
137137

138+
inline bool check_t_copy_args(const Tensor& in, Tensor& out) {
139+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
140+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_smaller_or_equal_to(in, 2));
141+
return true;
142+
}
143+
138144
inline bool check_transpose_copy_args(
139145
const Tensor& in,
140146
int64_t dim0,

0 commit comments

Comments
 (0)