Skip to content

Add op: pixel_unshuffle #4631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@

- op: pixel_shuffle.out

- op: pixel_unshuffle.out

- op: pow.Scalar_out

- op: pow.Tensor_Tensor_out
Expand Down
92 changes: 48 additions & 44 deletions kernels/portable/cpu/op_pixel_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,49 @@
namespace torch {
namespace executor {
namespace native {
namespace {

template <typename CTYPE>
void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

const auto leading_dims = getLeadingDims(in, in.dim() - 3);
const auto channels = in.size(in.dim() - 3);
const auto height = in.size(in.dim() - 2);
const auto width = in.size(in.dim() - 1);

const auto sub_channels = channels / (upscale_factor * upscale_factor);
const auto S = upscale_factor;

// input strides
const auto stride_n = channels * height * width;
const auto stride_c = S * S * height * width;
const auto stride_s1 = S * height * width;
const auto stride_s2 = height * width;
const auto stride_h = width;

// input tensor shape of [n, c, s1, s2, h, w]
// output tensor shape of [n, c, h, s1, w, s2]
size_t i = 0;
for (size_t n = 0; n < leading_dims; n++) {
for (size_t c = 0; c < sub_channels; c++) {
for (size_t h = 0; h < height; h++) {
for (size_t s1 = 0; s1 < S; s1++) {
for (size_t w = 0; w < width; w++) {
for (size_t s2 = 0; s2 < S; s2++) {
size_t input_offset = n * stride_n + c * stride_c +
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
out_data[i++] = in_data[input_offset];
}
}
}
}
}
}
}

} // namespace

using SizesType = exec_aten::SizesType;
using Tensor = exec_aten::Tensor;
Expand All @@ -29,11 +72,6 @@ Tensor& pixel_shuffle_out(
InvalidArgument,
out);

const Tensor::SizesType leading_dims = getLeadingDims(in, in.dim() - 3);
const Tensor::SizesType channels = in.size(in.dim() - 3);
const Tensor::SizesType height = in.size(in.dim() - 2);
const Tensor::SizesType width = in.size(in.dim() - 1);

Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
get_pixel_shuffle_out_target_size(
Expand All @@ -46,47 +84,13 @@ Tensor& pixel_shuffle_out(
InvalidArgument,
out);

constexpr auto name = "pixel_shuffle.out";

const auto in_type = out.scalar_type();
// in and out must be the same dtype
ET_SWITCH_ALL_TYPES(
in_type,
ctx,
"pixel_shuffle.out",
CTYPE,
[leading_dims, channels, height, width, upscale_factor, &in, &out] {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

const int64_t sub_channels =
channels / (upscale_factor * upscale_factor);
const int64_t S = upscale_factor;

// input strides
int64_t stride_n = channels * height * width;
int64_t stride_c = S * S * height * width;
int64_t stride_s1 = S * height * width;
int64_t stride_s2 = height * width;
int64_t stride_h = width;

// input tensor shape of [n, c, s1, s2, h, w]
// output tensor shape of [n, c, h, s1, w, s2]
size_t i = 0;
for (size_t n = 0; n < leading_dims; n++) {
for (size_t c = 0; c < sub_channels; c++) {
for (size_t h = 0; h < height; h++) {
for (size_t s1 = 0; s1 < S; s1++) {
for (size_t w = 0; w < width; w++) {
for (size_t s2 = 0; s2 < S; s2++) {
int64_t input_offset = n * stride_n + c * stride_c +
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
out_data[i++] = in_data[input_offset];
}
}
}
}
}
}
});
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
pixel_shuffle_impl<CTYPE>(in, upscale_factor, out);
});

return out;
}
Expand Down
104 changes: 104 additions & 0 deletions kernels/portable/cpu/op_pixel_unshuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {
namespace {

template <typename CTYPE>
void pixel_unshuffle_impl(
const Tensor& in,
int64_t downscale_factor,
Tensor& out) {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();

const auto leading_dims = getLeadingDims(in, in.dim() - 3);
const auto channels = out.size(in.dim() - 3);
const auto height = out.size(in.dim() - 2);
const auto width = out.size(in.dim() - 1);

const auto S = downscale_factor;
const auto sub_channels = channels / (S * S);

// output strides
const auto stride_n = channels * height * width;
const auto stride_c = S * S * height * width;
const auto stride_s1 = S * height * width;
const auto stride_s2 = height * width;
const auto stride_h = width;

// input tensor shape of [n, c, h, s1, w, s2]
// output tensor shape of [n, c, s1, s2, h, w]
size_t i = 0;
for (size_t n = 0; n < leading_dims; n++) {
for (size_t c = 0; c < sub_channels; c++) {
for (size_t h = 0; h < height; h++) {
for (size_t s1 = 0; s1 < S; s1++) {
for (size_t w = 0; w < width; w++) {
for (size_t s2 = 0; s2 < S; s2++) {
size_t output_offset = n * stride_n + c * stride_c +
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
out_data[output_offset] = in_data[i++];
}
}
}
}
}
}
}

} // namespace

using SizesType = exec_aten::SizesType;
using Tensor = exec_aten::Tensor;

Tensor& pixel_unshuffle_out(
RuntimeContext& ctx,
const Tensor& in,
int64_t downscale_factor,
Tensor& out) {
(void)ctx;

ET_KERNEL_CHECK(
ctx,
check_pixel_unshuffle_args(in, downscale_factor, out),
InvalidArgument,
out);

// @lint-ignore CLANGTIDY facebook-hte-CArray
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
get_pixel_unshuffle_out_target_size(
in, downscale_factor, expected_out_size, &expected_out_dim);

// Make sure the output tensor is the right size.
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
InvalidArgument,
out);

constexpr auto name = "pixel_unshuffle.out";

const auto in_type = out.scalar_type();
// in and out must be the same dtype
ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() {
pixel_unshuffle_impl<CTYPE>(in, downscale_factor, out);
});

return out;
}

} // namespace native
} // namespace executor
} // namespace torch
36 changes: 36 additions & 0 deletions kernels/portable/cpu/util/copy_ops_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,19 @@ bool check_pixel_shuffle_args(
return true;
}

bool check_pixel_unshuffle_args(
const Tensor& in,
int64_t downscale_factor,
Tensor& out) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(in, 3));
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(out, 3));
ET_LOG_AND_RETURN_IF_FALSE(downscale_factor > 0);
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 1) % downscale_factor == 0);
ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 2) % downscale_factor == 0);
return true;
}

void get_pixel_shuffle_out_target_size(
const Tensor& in,
int64_t upscale_factor,
Expand All @@ -347,6 +360,29 @@ void get_pixel_shuffle_out_target_size(
out_sizes[i] = in.size(i) * casted_upscale_factor;
}

void get_pixel_unshuffle_out_target_size(
const Tensor& in,
int64_t downscale_factor,
exec_aten::SizesType* out_sizes,
size_t* out_ndim) {
*out_ndim = in.dim();
const exec_aten::SizesType casted_factor = downscale_factor;

size_t i = 0;
for (; i < in.dim() - 3; ++i) {
// Copy all leading dimensions in.
out_sizes[i] = in.size(i);
}
// The last 3 dimensions are (channel, height, width). Multiply channel by
// the downscale factor squared and divide the height and width by that
// factor.
out_sizes[i] = in.size(i) * (casted_factor * casted_factor);
i++;
out_sizes[i] = in.size(i) / casted_factor;
i++;
out_sizes[i] = in.size(i) / casted_factor;
}

bool check_select_copy_out_args(
const Tensor& in,
int64_t dim,
Expand Down
11 changes: 11 additions & 0 deletions kernels/portable/cpu/util/copy_ops_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ void get_pixel_shuffle_out_target_size(
exec_aten::SizesType* out_sizes,
size_t* out_ndim);

bool check_pixel_unshuffle_args(
const Tensor& in,
int64_t upscale_factor,
Tensor& out);

void get_pixel_unshuffle_out_target_size(
const Tensor& in,
int64_t upscale_factor,
exec_aten::SizesType* out_sizes,
size_t* out_ndim);

bool check_select_copy_out_args(
const Tensor& in,
int64_t dim,
Expand Down
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@
- arg_meta: null
kernel_name: torch::executor::pixel_shuffle_out

- op: pixel_unshuffle.out
kernels:
- arg_meta: null
kernel_name: torch::executor::pixel_unshuffle_out

- op: pow.Scalar_out
kernels:
- arg_meta: null
Expand Down
7 changes: 2 additions & 5 deletions kernels/test/op_pixel_shuffle_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TEST_F(OpPixelShuffleOutTest, AllRealDtypesSupported) {
TEST_F(OpPixelShuffleOutTest, LargerInputRank) {
TensorFactory<ScalarType::Int> tf;

// Pixel shuffle allows a 4D (or higher) input tensor, make sure the extra
// Pixel shuffle allows a 3D (or higher) input tensor, make sure the extra
// dimensions don't cause issues.
Tensor a = tf.ones(/*sizes=*/{1, 4, 1, 4, 2, 2});

Expand Down Expand Up @@ -102,11 +102,8 @@ TEST_F(OpPixelShuffleOutTest, InvalidInputChannelsDies) {
TEST_F(OpPixelShuffleOutTest, WrongInputRankDies) {
TensorFactory<ScalarType::Int> tf;

// Pixel shuffle requires a 4D input tensor.
// Pixel shuffle requires a 3D or higher input tensor.
Tensor a = tf.ones(/*sizes=*/{1, 2});

// NOTE: The wrong output rank dies for the portable kernel, but not the aten
// kernel.
Tensor out = tf.zeros(/*sizes=*/{1, 2});

// Using the wrong input shape should exit with an error code.
Expand Down
Loading
Loading