Skip to content

Commit 88edab8

Browse files
Add op: pixel_unshuffle
Differential Revision: D60978345 Pull Request resolved: #4631
1 parent 89a24e0 commit 88edab8

File tree

10 files changed

+346
-50
lines changed

10 files changed

+346
-50
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@
275275

276276
- op: pixel_shuffle.out
277277

278+
- op: pixel_unshuffle.out
279+
278280
- op: pow.Scalar_out
279281

280282
- op: pow.Tensor_Tensor_out

kernels/portable/cpu/op_pixel_shuffle.cpp

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,49 @@
1212
namespace torch {
1313
namespace executor {
1414
namespace native {
15+
namespace {
16+
17+
template <typename CTYPE>
18+
void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
19+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
20+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
21+
22+
const auto leading_dims = getLeadingDims(in, in.dim() - 3);
23+
const auto channels = in.size(in.dim() - 3);
24+
const auto height = in.size(in.dim() - 2);
25+
const auto width = in.size(in.dim() - 1);
26+
27+
const auto sub_channels = channels / (upscale_factor * upscale_factor);
28+
const auto S = upscale_factor;
29+
30+
// input strides
31+
const auto stride_n = channels * height * width;
32+
const auto stride_c = S * S * height * width;
33+
const auto stride_s1 = S * height * width;
34+
const auto stride_s2 = height * width;
35+
const auto stride_h = width;
36+
37+
// input tensor shape of [n, c, s1, s2, h, w]
38+
// output tensor shape of [n, c, h, s1, w, s2]
39+
size_t i = 0;
40+
for (size_t n = 0; n < leading_dims; n++) {
41+
for (size_t c = 0; c < sub_channels; c++) {
42+
for (size_t h = 0; h < height; h++) {
43+
for (size_t s1 = 0; s1 < S; s1++) {
44+
for (size_t w = 0; w < width; w++) {
45+
for (size_t s2 = 0; s2 < S; s2++) {
46+
size_t input_offset = n * stride_n + c * stride_c +
47+
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
48+
out_data[i++] = in_data[input_offset];
49+
}
50+
}
51+
}
52+
}
53+
}
54+
}
55+
}
56+
57+
} // namespace
1558

1659
using SizesType = exec_aten::SizesType;
1760
using Tensor = exec_aten::Tensor;
@@ -29,11 +72,6 @@ Tensor& pixel_shuffle_out(
2972
InvalidArgument,
3073
out);
3174

32-
const Tensor::SizesType leading_dims = getLeadingDims(in, in.dim() - 3);
33-
const Tensor::SizesType channels = in.size(in.dim() - 3);
34-
const Tensor::SizesType height = in.size(in.dim() - 2);
35-
const Tensor::SizesType width = in.size(in.dim() - 1);
36-
3775
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
3876
size_t expected_out_dim = 0;
3977
get_pixel_shuffle_out_target_size(
@@ -46,47 +84,13 @@ Tensor& pixel_shuffle_out(
4684
InvalidArgument,
4785
out);
4886

87+
constexpr auto name = "pixel_shuffle.out";
88+
4989
const auto in_type = out.scalar_type();
5090
// in and out must be the same dtype
51-
ET_SWITCH_ALL_TYPES(
52-
in_type,
53-
ctx,
54-
"pixel_shuffle.out",
55-
CTYPE,
56-
[leading_dims, channels, height, width, upscale_factor, &in, &out] {
57-
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
58-
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
59-
60-
const int64_t sub_channels =
61-
channels / (upscale_factor * upscale_factor);
62-
const int64_t S = upscale_factor;
63-
64-
// input strides
65-
int64_t stride_n = channels * height * width;
66-
int64_t stride_c = S * S * height * width;
67-
int64_t stride_s1 = S * height * width;
68-
int64_t stride_s2 = height * width;
69-
int64_t stride_h = width;
70-
71-
// input tensor shape of [n, c, s1, s2, h, w]
72-
// output tensor shape of [n, c, h, s1, w, s2]
73-
size_t i = 0;
74-
for (size_t n = 0; n < leading_dims; n++) {
75-
for (size_t c = 0; c < sub_channels; c++) {
76-
for (size_t h = 0; h < height; h++) {
77-
for (size_t s1 = 0; s1 < S; s1++) {
78-
for (size_t w = 0; w < width; w++) {
79-
for (size_t s2 = 0; s2 < S; s2++) {
80-
int64_t input_offset = n * stride_n + c * stride_c +
81-
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
82-
out_data[i++] = in_data[input_offset];
83-
}
84-
}
85-
}
86-
}
87-
}
88-
}
89-
});
91+
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
92+
pixel_shuffle_impl<CTYPE>(in, upscale_factor, out);
93+
});
9094

9195
return out;
9296
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
namespace {
16+
17+
template <typename CTYPE>
18+
void pixel_unshuffle_impl(
19+
const Tensor& in,
20+
int64_t downscale_factor,
21+
Tensor& out) {
22+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
23+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
24+
25+
const auto leading_dims = getLeadingDims(in, in.dim() - 3);
26+
const auto channels = out.size(in.dim() - 3);
27+
const auto height = out.size(in.dim() - 2);
28+
const auto width = out.size(in.dim() - 1);
29+
30+
const auto S = downscale_factor;
31+
const auto sub_channels = channels / (S * S);
32+
33+
// output strides
34+
const auto stride_n = channels * height * width;
35+
const auto stride_c = S * S * height * width;
36+
const auto stride_s1 = S * height * width;
37+
const auto stride_s2 = height * width;
38+
const auto stride_h = width;
39+
40+
// input tensor shape of [n, c, h, s1, w, s2]
41+
// output tensor shape of [n, c, s1, s2, h, w]
42+
size_t i = 0;
43+
for (size_t n = 0; n < leading_dims; n++) {
44+
for (size_t c = 0; c < sub_channels; c++) {
45+
for (size_t h = 0; h < height; h++) {
46+
for (size_t s1 = 0; s1 < S; s1++) {
47+
for (size_t w = 0; w < width; w++) {
48+
for (size_t s2 = 0; s2 < S; s2++) {
49+
size_t output_offset = n * stride_n + c * stride_c +
50+
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
51+
out_data[output_offset] = in_data[i++];
52+
}
53+
}
54+
}
55+
}
56+
}
57+
}
58+
}
59+
60+
} // namespace
61+
62+
using SizesType = exec_aten::SizesType;
63+
using Tensor = exec_aten::Tensor;
64+
65+
Tensor& pixel_unshuffle_out(
66+
RuntimeContext& ctx,
67+
const Tensor& in,
68+
int64_t downscale_factor,
69+
Tensor& out) {
70+
(void)ctx;
71+
72+
ET_KERNEL_CHECK(
73+
ctx,
74+
check_pixel_unshuffle_args(in, downscale_factor, out),
75+
InvalidArgument,
76+
out);
77+
78+
// @lint-ignore CLANGTIDY facebook-hte-CArray
79+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
80+
size_t expected_out_dim = 0;
81+
get_pixel_unshuffle_out_target_size(
82+
in, downscale_factor, expected_out_size, &expected_out_dim);
83+
84+
// Make sure the output tensor is the right size.
85+
ET_KERNEL_CHECK(
86+
ctx,
87+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
88+
InvalidArgument,
89+
out);
90+
91+
constexpr auto name = "pixel_unshuffle.out";
92+
93+
const auto in_type = out.scalar_type();
94+
// in and out must be the same dtype
95+
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
96+
pixel_unshuffle_impl<CTYPE>(in, downscale_factor, out);
97+
});
98+
99+
return out;
100+
}
101+
102+
} // namespace native
103+
} // namespace executor
104+
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,19 @@ bool check_pixel_shuffle_args(
325325
return true;
326326
}
327327

328+
bool check_pixel_unshuffle_args(
329+
const Tensor& in,
330+
int64_t downscale_factor,
331+
Tensor& out) {
332+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
333+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(in, 3));
334+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(out, 3));
335+
ET_LOG_AND_RETURN_IF_FALSE(downscale_factor > 0);
336+
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 1) % downscale_factor == 0);
337+
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 2) % downscale_factor == 0);
338+
return true;
339+
}
340+
328341
void get_pixel_shuffle_out_target_size(
329342
const Tensor& in,
330343
int64_t upscale_factor,
@@ -347,6 +360,29 @@ void get_pixel_shuffle_out_target_size(
347360
out_sizes[i] = in.size(i) * casted_upscale_factor;
348361
}
349362

363+
void get_pixel_unshuffle_out_target_size(
364+
const Tensor& in,
365+
int64_t downscale_factor,
366+
exec_aten::SizesType* out_sizes,
367+
size_t* out_ndim) {
368+
*out_ndim = in.dim();
369+
const exec_aten::SizesType casted_factor = downscale_factor;
370+
371+
size_t i = 0;
372+
for (; i < in.dim() - 3; ++i) {
373+
// Copy all leading dimensions in.
374+
out_sizes[i] = in.size(i);
375+
}
376+
// The last 3 dimensions are (channel, height, width). Multiply channel by
377+
// the downscale factor squared and divide the height and width by that
378+
// factor.
379+
out_sizes[i] = in.size(i) * (casted_factor * casted_factor);
380+
i++;
381+
out_sizes[i] = in.size(i) / casted_factor;
382+
i++;
383+
out_sizes[i] = in.size(i) / casted_factor;
384+
}
385+
350386
bool check_select_copy_out_args(
351387
const Tensor& in,
352388
int64_t dim,

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ void get_pixel_shuffle_out_target_size(
113113
exec_aten::SizesType* out_sizes,
114114
size_t* out_ndim);
115115

116+
bool check_pixel_unshuffle_args(
117+
const Tensor& in,
118+
int64_t upscale_factor,
119+
Tensor& out);
120+
121+
void get_pixel_unshuffle_out_target_size(
122+
const Tensor& in,
123+
int64_t upscale_factor,
124+
exec_aten::SizesType* out_sizes,
125+
size_t* out_ndim);
126+
116127
bool check_select_copy_out_args(
117128
const Tensor& in,
118129
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,11 @@
622622
- arg_meta: null
623623
kernel_name: torch::executor::pixel_shuffle_out
624624

625+
- op: pixel_unshuffle.out
626+
kernels:
627+
- arg_meta: null
628+
kernel_name: torch::executor::pixel_unshuffle_out
629+
625630
- op: pow.Scalar_out
626631
kernels:
627632
- arg_meta: null

kernels/test/op_pixel_shuffle_test.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TEST_F(OpPixelShuffleOutTest, AllRealDtypesSupported) {
7474
TEST_F(OpPixelShuffleOutTest, LargerInputRank) {
7575
TensorFactory<ScalarType::Int> tf;
7676

77-
// Pixel shuffle allows a 4D (or higher) input tensor, make sure the extra
77+
// Pixel shuffle allows a 3D (or higher) input tensor, make sure the extra
7878
// dimensions don't cause issues.
7979
Tensor a = tf.ones(/*sizes=*/{1, 4, 1, 4, 2, 2});
8080

@@ -102,11 +102,8 @@ TEST_F(OpPixelShuffleOutTest, InvalidInputChannelsDies) {
102102
TEST_F(OpPixelShuffleOutTest, WrongInputRankDies) {
103103
TensorFactory<ScalarType::Int> tf;
104104

105-
// Pixel shuffle requires a 4D input tensor.
105+
// Pixel shuffle requires a 3D or higher input tensor.
106106
Tensor a = tf.ones(/*sizes=*/{1, 2});
107-
108-
// NOTE: The wrong output rank dies for the portable kernel, but not the aten
109-
// kernel.
110107
Tensor out = tf.zeros(/*sizes=*/{1, 2});
111108

112109
// Using the wrong input shape should exit with an error code.

0 commit comments

Comments
 (0)