Skip to content

Commit 1315a92

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Add aten::pixel_shuffle.out portable variant (#351)
Summary: Pull Request resolved: #351 The executorch portable runtime was missing an implementation of `aten::pixel_shuffle.out`, which is used by the PiCA decoder. This is basically just a reshape operator, and I adapted the implementation from aten's `PixelShuffleKernel.cpp` after not using some helper functions that don't exist. I don't know much about how to make a more optimized implementation, or if there's a way to use parallelism automatically like the normal aten kernels do. Reviewed By: manuelcandales Differential Revision: D49173297 fbshipit-source-id: 625de6ec519b40929d4db9aae9c61fc1c49da691
1 parent 1200d59 commit 1315a92

File tree

8 files changed

+296
-0
lines changed

8 files changed

+296
-0
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using SizesType = exec_aten::SizesType;
17+
using Tensor = exec_aten::Tensor;
18+
19+
Tensor& pixel_shuffle_out(
20+
RuntimeContext& ctx,
21+
const Tensor& in,
22+
int64_t upscale_factor,
23+
Tensor& out) {
24+
(void)ctx;
25+
26+
ET_KERNEL_CHECK(
27+
ctx,
28+
check_pixel_shuffle_args(in, upscale_factor, out),
29+
InvalidArgument,
30+
out);
31+
32+
const Tensor::SizesType leading_dims = getLeadingDims(in, in.dim() - 3);
33+
const Tensor::SizesType channels = in.size(in.dim() - 3);
34+
const Tensor::SizesType height = in.size(in.dim() - 2);
35+
const Tensor::SizesType width = in.size(in.dim() - 1);
36+
37+
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
38+
size_t expected_out_dim = 0;
39+
get_pixel_shuffle_out_target_size(
40+
in, upscale_factor, expected_out_size, &expected_out_dim);
41+
42+
// Make sure the output tensor is the right size.
43+
ET_KERNEL_CHECK(
44+
ctx,
45+
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
46+
InvalidArgument,
47+
out);
48+
49+
const auto in_type = out.scalar_type();
50+
// in and out must be the same dtype
51+
ET_SWITCH_ALL_TYPES(
52+
in_type,
53+
ctx,
54+
__func__,
55+
CTYPE,
56+
[leading_dims, channels, height, width, upscale_factor, &in, &out] {
57+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
58+
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
59+
60+
const int64_t sub_channels =
61+
channels / (upscale_factor * upscale_factor);
62+
const int64_t S = upscale_factor;
63+
64+
// input strides
65+
int64_t stride_n = channels * height * width;
66+
int64_t stride_c = S * S * height * width;
67+
int64_t stride_s1 = S * height * width;
68+
int64_t stride_s2 = height * width;
69+
int64_t stride_h = width;
70+
71+
// input tensor shape of [n, c, s1, s2, h, w]
72+
// output tensor shape of [n, c, h, s1, w, s2]
73+
size_t i = 0;
74+
for (size_t n = 0; n < leading_dims; n++) {
75+
for (size_t c = 0; c < sub_channels; c++) {
76+
for (size_t h = 0; h < height; h++) {
77+
for (size_t s1 = 0; s1 < S; s1++) {
78+
for (size_t w = 0; w < width; w++) {
79+
for (size_t s2 = 0; s2 < S; s2++) {
80+
int64_t input_offset = n * stride_n + c * stride_c +
81+
s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
82+
out_data[i++] = in_data[input_offset];
83+
}
84+
}
85+
}
86+
}
87+
}
88+
}
89+
});
90+
91+
return out;
92+
}
93+
94+
} // namespace native
95+
} // namespace executor
96+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,12 @@ _ATEN_OPS = (
578578
"//executorch/kernels/portable/cpu/util:copy_ops_util",
579579
],
580580
),
581+
op_target(
582+
name = "op_pixel_shuffle",
583+
deps = [
584+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
585+
],
586+
),
581587
op_target(
582588
name = "op_reciprocal",
583589
deps = [

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,41 @@ void get_permute_copy_out_target_size(
128128
}
129129
}
130130

131+
bool check_pixel_shuffle_args(
132+
const Tensor& in,
133+
int64_t upscale_factor,
134+
Tensor& out) {
135+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
136+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(in, 3));
137+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(out, 3));
138+
ET_LOG_AND_RETURN_IF_FALSE(upscale_factor > 0);
139+
ET_LOG_AND_RETURN_IF_FALSE(
140+
in.size(in.dim() - 3) % (upscale_factor * upscale_factor) == 0);
141+
return true;
142+
}
143+
144+
void get_pixel_shuffle_out_target_size(
145+
const Tensor& in,
146+
int64_t upscale_factor,
147+
Tensor::SizesType* out_sizes,
148+
size_t* out_ndim) {
149+
*out_ndim = in.dim();
150+
const Tensor::SizesType casted_upscale_factor = upscale_factor;
151+
152+
size_t i = 0;
153+
for (; i < in.dim() - 3; ++i) {
154+
// Copy all leading dimensions in.
155+
out_sizes[i] = in.size(i);
156+
}
157+
// The last 3 dimensions are (channel, height, width). Divide by the upscale
158+
// factor squared and multiply the height and width by that factor.
159+
out_sizes[i] = in.size(i) / (casted_upscale_factor * casted_upscale_factor);
160+
i++;
161+
out_sizes[i] = in.size(i) * casted_upscale_factor;
162+
i++;
163+
out_sizes[i] = in.size(i) * casted_upscale_factor;
164+
}
165+
131166
bool check_stack_args(
132167
exec_aten::ArrayRef<Tensor> tensors,
133168
int64_t dim,

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ void get_permute_copy_out_target_size(
3232
Tensor::SizesType* out_sizes,
3333
size_t* out_ndim);
3434

35+
bool check_pixel_shuffle_args(
36+
const Tensor& in,
37+
int64_t upscale_factor,
38+
Tensor& out);
39+
40+
void get_pixel_shuffle_out_target_size(
41+
const Tensor& in,
42+
int64_t upscale_factor,
43+
Tensor::SizesType* out_sizes,
44+
size_t* out_ndim);
45+
3546
bool check_stack_args(
3647
exec_aten::ArrayRef<Tensor> tensors,
3748
int64_t dim,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,11 @@
522522
- arg_meta: null
523523
kernel_name: torch::executor::permute_copy_out
524524

525+
- op: pixel_shuffle.out
526+
kernels:
527+
- arg_meta: null
528+
kernel_name: torch::executor::pixel_shuffle_out
529+
525530
- op: pow.Tensor_Scalar_out
526531
kernels:
527532
- arg_meta: null
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using namespace ::testing;
19+
using exec_aten::Scalar;
20+
using exec_aten::ScalarType;
21+
using exec_aten::Tensor;
22+
using torch::executor::testing::SupportedFeatures;
23+
using torch::executor::testing::TensorFactory;
24+
25+
Tensor&
26+
op_pixel_shuffle_out(const Tensor& self, int64_t upscale_factor, Tensor& out) {
27+
exec_aten::RuntimeContext context{};
28+
return torch::executor::aten::pixel_shuffle_outf(
29+
context, self, upscale_factor, out);
30+
}
31+
32+
//
33+
// Correctness Tests
34+
//
35+
36+
template <ScalarType DTYPE_IN>
37+
void test_pixel_shuffle() {
38+
TensorFactory<DTYPE_IN> tf_in;
39+
40+
const std::vector<int32_t> sizes = {1, 4, 2, 2};
41+
const std::vector<int32_t> out_sizes = {1, 1, 4, 4};
42+
43+
// Destination for the pixel_shuffle.
44+
Tensor out = tf_in.zeros(out_sizes);
45+
46+
op_pixel_shuffle_out(
47+
tf_in.make(sizes, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}),
48+
2,
49+
out);
50+
EXPECT_TENSOR_EQ(
51+
out,
52+
// Pixel shuffle distributes channels amongst the spatial dimensions.
53+
tf_in.make(
54+
out_sizes, {0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}));
55+
}
56+
57+
/**
58+
* Uses the function templates above to test all input dtypes.
59+
*/
60+
TEST(OpPixelShuffleOutKernelTest, AllRealDtypesSupported) {
61+
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
62+
test_pixel_shuffle<ScalarType::dtype>();
63+
64+
ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY)
65+
66+
#undef ENUMERATE_TEST_ENTRY
67+
}
68+
69+
TEST(OpPixelShuffleOutKernelTest, LargerInputRank) {
70+
TensorFactory<ScalarType::Int> tf;
71+
72+
// Pixel shuffle allows a 4D (or higher) input tensor, make sure the extra
73+
// dimensions don't cause issues.
74+
Tensor a = tf.ones(/*sizes=*/{1, 4, 1, 4, 2, 2});
75+
76+
const std::vector<int32_t> out_sizes = {1, 4, 1, 1, 4, 4};
77+
Tensor out = tf.zeros(out_sizes);
78+
79+
op_pixel_shuffle_out(a, 2, out);
80+
EXPECT_TENSOR_EQ(out, tf.ones(out_sizes));
81+
}
82+
83+
// Mismatched shape tests.
84+
TEST(OpPixelShuffleOutKernelTest, InvalidInputChannelsDies) {
85+
TensorFactory<ScalarType::Int> tf;
86+
87+
// Input tensors with invalid shapes. 7 is not divisible by upsample_factor
88+
// ** 2.
89+
Tensor a = tf.ones(/*sizes=*/{1, 7, 4, 4});
90+
91+
Tensor out = tf.zeros(/*sizes=*/{1, 1, 8, 8});
92+
93+
// Using the wrong input shape should exit with an error code.
94+
ET_EXPECT_KERNEL_FAILURE(op_pixel_shuffle_out(a, 2, out));
95+
}
96+
97+
TEST(OpPixelShuffleOutKernelTest, WrongInputRankDies) {
98+
TensorFactory<ScalarType::Int> tf;
99+
100+
// Pixel shuffle requires a 4D input tensor.
101+
Tensor a = tf.ones(/*sizes=*/{1, 2});
102+
103+
// NOTE: The wrong output rank dies for the portable kernel, but not the aten
104+
// kernel.
105+
Tensor out = tf.zeros(/*sizes=*/{1, 2});
106+
107+
// Using the wrong input shape should exit with an error code.
108+
ET_EXPECT_KERNEL_FAILURE(op_pixel_shuffle_out(a, 2, out));
109+
}
110+
111+
TEST(OpPixelShuffleOutKernelTest, DifferentDtypeDies) {
112+
TensorFactory<ScalarType::Int> tf;
113+
TensorFactory<ScalarType::Float> tf_float;
114+
115+
Tensor a = tf.ones(/*sizes=*/{1, 18, 4, 4});
116+
117+
// Pixel shuffle requires two tensors with the same dtype.
118+
Tensor out = tf_float.zeros(/*sizes=*/{1, 2, 12, 12});
119+
120+
// Using the wrong output shape should exit with an error code.
121+
ET_EXPECT_KERNEL_FAILURE(op_pixel_shuffle_out(a, 3, out));
122+
}
123+
124+
TEST(OpPixelShuffleOutKernelTest, NegativeUpscaleFactorDies) {
125+
TensorFactory<ScalarType::Int> tf;
126+
Tensor a = tf.ones(/*sizes=*/{1, 18, 4, 4});
127+
Tensor out = tf.zeros(/*sizes=*/{1, 2, 12, 12});
128+
// Using a negative upscale factor should exit with an error code.
129+
ET_EXPECT_KERNEL_FAILURE(op_pixel_shuffle_out(a, -3, out));
130+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def define_common_targets():
240240
_common_op_test("op_nonzero_test", ["aten", "portable"])
241241
_common_op_test("op_ones_test", ["aten", "portable"])
242242
_common_op_test("op_permute_copy_test", ["aten", "portable"])
243+
_common_op_test("op_pixel_shuffle_test", ["aten", "portable"])
243244
_common_op_test("op_reciprocal_test", ["aten", "portable"])
244245
_common_op_test("op_relu_test", ["aten", "portable"])
245246
_common_op_test("op_remainder_test", ["aten", "portable"])

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,18 @@ inline bool tensor_is_rank(exec_aten::Tensor t, size_t rank) {
502502
return true;
503503
}
504504

505+
inline bool tensor_has_rank_greater_or_equal_to(
506+
exec_aten::Tensor t,
507+
size_t rank) {
508+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
509+
t.dim() >= rank,
510+
"Expected tensor.dim() to be >= %zu, but got %zu",
511+
static_cast<size_t>(rank),
512+
static_cast<size_t>(t.dim()));
513+
514+
return true;
515+
}
516+
505517
inline bool tensor_has_dim(exec_aten::Tensor t, int64_t d) {
506518
ET_LOG_MSG_AND_RETURN_IF_FALSE(
507519
d > 0 ? d < t.dim() : t.dim() + d >= 0,

0 commit comments

Comments
 (0)