Skip to content

Commit c0b24e0

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
enable channels last transposed convolution
Differential Revision: D59622072
1 parent 074a81e commit c0b24e0

File tree

2 files changed

+84
-16
lines changed

2 files changed

+84
-16
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,6 @@ void conv2d_impl(
136136
}
137137
}
138138
} else { // transposed convolution
139-
if (bias_ptr != nullptr) {
140-
out_coord[2] = 0;
141-
out_coord[3] = 0;
142-
size_t out_c_start_idx =
143-
calculate_linear_index(out_coord, out_strides.data(), 4);
144-
size_t out_c_end_idx = out_c_start_idx + out_H * out_W;
145-
for (size_t out_ix = out_c_start_idx; out_ix < out_c_end_idx; out_ix++) {
146-
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
147-
}
148-
}
149-
150139
w_coord[1] = out_c - out_c_start;
151140

152141
for (size_t in_y = 0; in_y < in_H; ++in_y) {
@@ -295,12 +284,22 @@ void convolution_wrapper(
295284
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
296285

297286
size_t out_N = out.size(0);
298-
size_t out_C_per_group = out.size(1) / groups;
287+
size_t out_C = out.size(1);
288+
size_t out_C_per_group = out_C / groups;
299289

300-
if (transposed && bias_ptr == nullptr) {
301-
// If bias is not present, we need to initialize the output to 0
302-
// before we can accumulate into it.
303-
memset(out_ptr, 0, out.nbytes());
290+
if (transposed) {
291+
// For transposed convolution, we need to initialized the output before we
292+
// can accumulate into it.
293+
if (bias_ptr == nullptr) {
294+
// If bias is not present, we need to initialize the output to 0
295+
memset(out_ptr, 0, out.nbytes());
296+
} else {
297+
// If bias is present, we initialize the output to the bias value
298+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
299+
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(
300+
bias_ptr[(out_ix / out_strides[1]) % out_C]);
301+
}
302+
}
304303
}
305304

306305
for (size_t batch = 0; batch < out_N; ++batch) {

kernels/test/op_convolution_test.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,75 @@ TEST_F(OpConvCorrectnessTest, TransposedNonDefaultParams) {
587587
EXPECT_TENSOR_CLOSE(out, expected);
588588
}
589589

590+
TEST_F(OpConvCorrectnessTest, TransposedNonDefaultParamsChannelsLast) {
591+
TensorFactory<ScalarType::Float> tf;
592+
593+
Tensor input = tf.full_channels_last({2, 6, 4, 5}, 2.0);
594+
Tensor weight = tf.full_channels_last({6, 1, 2, 2}, 0.5);
595+
Tensor bias = tf.make({3}, {1, 2, 3});
596+
Tensor out = tf.full_channels_last({2, 3, 3, 6}, 0.7);
597+
Tensor expected = tf.make(
598+
{2, 3, 3, 6},
599+
{1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 2, 2, 2, 2,
600+
2, 2, 2, 4, 4, 2, 4, 4, 2, 4, 4, 2, 4, 4, 3, 3, 3, 3, 3, 3, 3, 5,
601+
5, 3, 5, 5, 3, 5, 5, 3, 5, 5, 1, 1, 1, 1, 1, 1, 1, 3, 3, 1, 3, 3,
602+
1, 3, 3, 1, 3, 3, 2, 2, 2, 2, 2, 2, 2, 4, 4, 2, 4, 4, 2, 4, 4, 2,
603+
4, 4, 3, 3, 3, 3, 3, 3, 3, 5, 5, 3, 5, 5, 3, 5, 5, 3, 5, 5});
604+
605+
const std::vector<int32_t> sizes(
606+
expected.sizes().begin(), expected.sizes().end());
607+
608+
int32_t N = sizes[0];
609+
int32_t C = sizes[1];
610+
int32_t H = sizes[2];
611+
int32_t W = sizes[3];
612+
613+
std::vector<float> contiguous_data(
614+
expected.const_data_ptr<float>(),
615+
expected.const_data_ptr<float>() + expected.numel());
616+
std::vector<float> channels_last_data(
617+
N * C * H * W); // Create a new blob with the same total size to contain
618+
// channels_last data
619+
for (int32_t n = 0; n < N; ++n) {
620+
for (int32_t c = 0; c < C; ++c) {
621+
for (int32_t h = 0; h < H; ++h) {
622+
for (int32_t w = 0; w < W; ++w) {
623+
// Calculate the index in the original blob
624+
int32_t old_index = ((n * C + c) * H + h) * W + w;
625+
// Calculate the index in the new blob
626+
int32_t new_index = ((n * H + h) * W + w) * C + c;
627+
// Copy the data
628+
channels_last_data[new_index] = contiguous_data[old_index];
629+
}
630+
}
631+
}
632+
}
633+
634+
Tensor expected_channels_last =
635+
tf.make_channels_last(sizes, channels_last_data);
636+
637+
int64_t stride[1] = {3};
638+
int64_t padding[1] = {7};
639+
int64_t dilation[1] = {5};
640+
bool transposed = true;
641+
int64_t output_padding[1] = {2};
642+
int64_t groups = 3;
643+
644+
op_convolution_out(
645+
input,
646+
weight,
647+
exec_aten::optional<Tensor>(bias),
648+
exec_aten::ArrayRef<int64_t>{stride, 1},
649+
exec_aten::ArrayRef<int64_t>{padding, 1},
650+
exec_aten::ArrayRef<int64_t>{dilation, 1},
651+
transposed,
652+
exec_aten::ArrayRef<int64_t>{output_padding, 1},
653+
groups,
654+
out);
655+
656+
EXPECT_TENSOR_CLOSE(out, expected_channels_last);
657+
}
658+
590659
TEST_F(OpConvCorrectnessTest, InvalidOutputPadding) {
591660
TensorFactory<ScalarType::Float> tf;
592661

0 commit comments

Comments
 (0)