Skip to content

Commit d989680

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Reduce build size of op_pixel_shuffle & op_pixel_unshuffle (#6020)
Summary: Pull Request resolved: #6020 pixel_shuffle: 60 K -> 1.8 K pixel_unshuffle: 57 K -> 1.7 K ghstack-source-id: 246985128 exported-using-ghexport Reviewed By: malfet, swolchok Differential Revision: D63994872 fbshipit-source-id: 1840abf154c71ec18f66ed47e38dfe1caa317ddd
1 parent 1083643 commit d989680

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

kernels/portable/cpu/op_pixel_shuffle.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ namespace executor {
1414
namespace native {
1515
namespace {
1616

17-
template <typename CTYPE>
1817
void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
19-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
20-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
18+
const char* const in_data =
19+
reinterpret_cast<const char*>(in.const_data_ptr());
20+
char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
21+
const auto elem_size = in.element_size();
2122

2223
const auto leading_dims = getLeadingDims(in, in.dim() - 3);
2324
const auto channels = in.size(in.dim() - 3);
@@ -45,7 +46,11 @@ void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
4546
for (size_t s2 = 0; s2 < S; s2++) {
4647
size_t input_offset = n * stride_n + c * stride_c +
4748
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
48-
out_data[i++] = in_data[input_offset];
49+
std::memcpy(
50+
out_data + i * elem_size,
51+
in_data + input_offset * elem_size,
52+
elem_size);
53+
i++;
4954
}
5055
}
5156
}
@@ -88,13 +93,7 @@ Tensor& pixel_shuffle_out(
8893
InvalidArgument,
8994
out);
9095

91-
constexpr auto name = "pixel_shuffle.out";
92-
93-
const auto in_type = out.scalar_type();
94-
// in and out must be the same dtype
95-
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
96-
pixel_shuffle_impl<CTYPE>(in, upscale_factor, out);
97-
});
96+
pixel_shuffle_impl(in, upscale_factor, out);
9897

9998
return out;
10099
}

kernels/portable/cpu/op_pixel_unshuffle.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ namespace executor {
1414
namespace native {
1515
namespace {
1616

17-
template <typename CTYPE>
1817
void pixel_unshuffle_impl(
1918
const Tensor& in,
2019
int64_t downscale_factor,
2120
Tensor& out) {
22-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
23-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
21+
const char* const in_data =
22+
reinterpret_cast<const char*>(in.const_data_ptr());
23+
char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
24+
const auto elem_size = in.element_size();
2425

2526
const auto leading_dims = getLeadingDims(in, in.dim() - 3);
2627
const auto channels = out.size(in.dim() - 3);
@@ -48,7 +49,11 @@ void pixel_unshuffle_impl(
4849
for (size_t s2 = 0; s2 < S; s2++) {
4950
size_t output_offset = n * stride_n + c * stride_c +
5051
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
51-
out_data[output_offset] = in_data[i++];
52+
std::memcpy(
53+
out_data + output_offset * elem_size,
54+
in_data + i * elem_size,
55+
elem_size);
56+
i++;
5257
}
5358
}
5459
}
@@ -88,13 +93,7 @@ Tensor& pixel_unshuffle_out(
8893
InvalidArgument,
8994
out);
9095

91-
constexpr auto name = "pixel_unshuffle.out";
92-
93-
const auto in_type = out.scalar_type();
94-
// in and out must be the same dtype
95-
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
96-
pixel_unshuffle_impl<CTYPE>(in, downscale_factor, out);
97-
});
96+
pixel_unshuffle_impl(in, downscale_factor, out);
9897

9998
return out;
10099
}

0 commit comments

Comments
 (0)