Skip to content

Commit 1faa1bb

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add op: enable transposed convolution (#4197)
Summary: Pull Request resolved: #4197 Reviewed By: tarun292 Differential Revision: D59589884 fbshipit-source-id: d0b7248b227081e5fefbf4d12cfa2a686d3c646d
1 parent 90d7d07 commit 1faa1bb

File tree

4 files changed

+378
-72
lines changed

4 files changed

+378
-72
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 140 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ void conv2d_impl(
5050
StridesArrayRef out_strides,
5151
const size_t batch,
5252
const size_t group,
53-
const size_t out_c) {
53+
const size_t out_c,
54+
bool transposed) {
5455
size_t in_C = in_sizes[1];
56+
size_t out_C = out_sizes[1];
5557

5658
size_t out_H = out_sizes[2];
5759
size_t in_H = in_sizes[2];
@@ -64,13 +66,15 @@ void conv2d_impl(
6466
size_t in_C_per_group = in_C / groups;
6567
size_t in_c_start = group * in_C_per_group;
6668

69+
size_t out_C_per_group = out_C / groups;
70+
size_t out_c_start = group * out_C_per_group;
71+
6772
exec_aten::SizesType in_coord[kTensorDimensionLimit];
6873
in_coord[0] = batch;
6974
exec_aten::SizesType out_coord[kTensorDimensionLimit];
7075
out_coord[0] = batch;
7176
out_coord[1] = out_c;
7277
exec_aten::SizesType w_coord[kTensorDimensionLimit];
73-
w_coord[0] = out_c;
7478

7579
const int64_t stride_y = val_at(stride, 0);
7680
const int64_t padding_y = val_at(padding, 0, /*default_value=*/0);
@@ -79,53 +83,115 @@ void conv2d_impl(
7983
const int64_t padding_x = val_at(padding, 1, /*default_value=*/0);
8084
const int64_t dilation_x = val_at(dilation, 1);
8185

82-
// Compute 2D output region
83-
for (size_t out_y = 0; out_y < out_H; ++out_y) {
84-
out_coord[2] = out_y;
85-
for (size_t out_x = 0; out_x < out_W; ++out_x) {
86-
out_coord[3] = out_x;
87-
88-
CTYPE accum = 0.0f;
89-
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
90-
++in_c) {
91-
in_coord[1] = in_c;
92-
w_coord[1] = in_c - in_c_start;
93-
94-
for (size_t w_y = 0; w_y < w_H; ++w_y) {
95-
w_coord[2] = w_y;
96-
97-
size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
98-
in_coord[2] = in_y;
99-
// Only proceed if input y coordinate is within bounds
100-
if (in_y >= 0 && in_y < in_H) {
101-
for (size_t w_x = 0; w_x < w_W; ++w_x) {
102-
w_coord[3] = w_x;
103-
104-
size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
105-
in_coord[3] = in_x;
106-
107-
// Only proceed if input coordinate is within bounds
108-
if (in_x >= 0 && in_x < in_W) {
109-
size_t in_idx =
110-
calculate_linear_index(in_coord, in_strides.data(), 4);
111-
CTYPE in_val = in_ptr[in_idx];
112-
113-
size_t w_idx =
114-
calculate_linear_index(w_coord, w_strides.data(), 4);
115-
CTYPE w_val = w_ptr[w_idx];
116-
117-
accum += in_val * w_val;
86+
if (!transposed) {
87+
w_coord[0] = out_c;
88+
// Compute 2D output region
89+
for (size_t out_y = 0; out_y < out_H; ++out_y) {
90+
out_coord[2] = out_y;
91+
for (size_t out_x = 0; out_x < out_W; ++out_x) {
92+
out_coord[3] = out_x;
93+
94+
CTYPE accum = 0.0f;
95+
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
96+
++in_c) {
97+
in_coord[1] = in_c;
98+
w_coord[1] = in_c - in_c_start;
99+
100+
for (size_t w_y = 0; w_y < w_H; ++w_y) {
101+
w_coord[2] = w_y;
102+
103+
size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
104+
in_coord[2] = in_y;
105+
// Only proceed if input y coordinate is within bounds
106+
if (in_y >= 0 && in_y < in_H) {
107+
for (size_t w_x = 0; w_x < w_W; ++w_x) {
108+
w_coord[3] = w_x;
109+
110+
size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
111+
in_coord[3] = in_x;
112+
113+
// Only proceed if input x coordinate is within bounds
114+
if (in_x >= 0 && in_x < in_W) {
115+
size_t in_idx =
116+
calculate_linear_index(in_coord, in_strides.data(), 4);
117+
CTYPE in_val = in_ptr[in_idx];
118+
119+
size_t w_idx =
120+
calculate_linear_index(w_coord, w_strides.data(), 4);
121+
CTYPE w_val = w_ptr[w_idx];
122+
123+
accum += in_val * w_val;
124+
}
118125
}
119126
}
120127
}
121128
}
129+
130+
if (bias_ptr != nullptr) {
131+
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
132+
}
133+
size_t out_idx =
134+
calculate_linear_index(out_coord, out_strides.data(), 4);
135+
out_ptr[out_idx] = accum;
122136
}
137+
}
138+
} else { // transposed convolution
139+
if (bias_ptr != nullptr) {
140+
out_coord[2] = 0;
141+
out_coord[3] = 0;
142+
size_t out_c_start_idx =
143+
calculate_linear_index(out_coord, out_strides.data(), 4);
144+
size_t out_c_end_idx = out_c_start_idx + out_H * out_W;
145+
for (size_t out_ix = out_c_start_idx; out_ix < out_c_end_idx; out_ix++) {
146+
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
147+
}
148+
}
149+
150+
w_coord[1] = out_c - out_c_start;
151+
152+
for (size_t in_y = 0; in_y < in_H; ++in_y) {
153+
in_coord[2] = in_y;
154+
155+
for (size_t in_x = 0; in_x < in_W; ++in_x) {
156+
in_coord[3] = in_x;
157+
158+
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
159+
++in_c) {
160+
in_coord[1] = in_c;
161+
162+
size_t in_idx =
163+
calculate_linear_index(in_coord, in_strides.data(), 4);
164+
CTYPE in_val = in_ptr[in_idx];
165+
166+
w_coord[0] = in_c;
167+
for (size_t w_y = 0; w_y < w_H; ++w_y) {
168+
w_coord[2] = w_y;
169+
size_t out_y = stride_y * in_y + dilation_y * w_y - padding_y;
170+
out_coord[2] = out_y;
171+
172+
// Only proceed if output y coordinate is within bounds
173+
if (out_y >= 0 && out_y < out_H) {
174+
for (size_t w_x = 0; w_x < w_W; ++w_x) {
175+
w_coord[3] = w_x;
176+
size_t out_x = stride_x * in_x + dilation_x * w_x - padding_x;
177+
out_coord[3] = out_x;
123178

124-
if (bias_ptr != nullptr) {
125-
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
179+
// Only proceed if output x coordinate is within bounds
180+
if (out_x >= 0 && out_x < out_W) {
181+
size_t w_idx =
182+
calculate_linear_index(w_coord, w_strides.data(), 4);
183+
CTYPE w_val = w_ptr[w_idx];
184+
185+
size_t out_idx =
186+
calculate_linear_index(out_coord, out_strides.data(), 4);
187+
188+
out_ptr[out_idx] += in_val * w_val;
189+
}
190+
}
191+
}
192+
}
193+
}
126194
}
127-
size_t out_idx = calculate_linear_index(out_coord, out_strides.data(), 4);
128-
out_ptr[out_idx] = accum;
129195
}
130196
}
131197
}
@@ -138,14 +204,9 @@ void convolution_wrapper(
138204
IntArrayRef stride,
139205
IntArrayRef padding,
140206
IntArrayRef dilation,
207+
bool transposed,
141208
int64_t groups,
142209
Tensor& out) {
143-
size_t out_N = in.size(0);
144-
size_t out_C = weight.size(0);
145-
146-
// Compute the number of in and out channels in each group
147-
size_t out_C_per_group = out_C / groups;
148-
149210
SizesArrayRef in_sizes = in.sizes();
150211
SizesArrayRef weight_sizes = weight.sizes();
151212
SizesArrayRef out_sizes = out.sizes();
@@ -233,6 +294,15 @@ void convolution_wrapper(
233294
const CTYPE_BIAS* const bias_ptr =
234295
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
235296

297+
size_t out_N = out.size(0);
298+
size_t out_C_per_group = out.size(1) / groups;
299+
300+
if (transposed && bias_ptr == nullptr) {
301+
// If bias is not present, we need to initialize the output to 0
302+
// before we can accumulate into it.
303+
memset(out_ptr, 0, out.nbytes());
304+
}
305+
236306
for (size_t batch = 0; batch < out_N; ++batch) {
237307
for (size_t group = 0; group < groups; ++group) {
238308
// Align channel offset based on the group
@@ -257,7 +327,8 @@ void convolution_wrapper(
257327
{out_strides, 4},
258328
batch,
259329
group,
260-
out_c);
330+
out_c,
331+
transposed);
261332
}
262333
}
263334
}
@@ -273,8 +344,8 @@ Tensor& convolution_out(
273344
IntArrayRef stride,
274345
IntArrayRef padding,
275346
IntArrayRef dilation,
276-
__ET_UNUSED bool transposed,
277-
__ET_UNUSED IntArrayRef output_padding,
347+
bool transposed,
348+
IntArrayRef output_padding,
278349
int64_t groups,
279350
Tensor& out) {
280351
(void)ctx;
@@ -298,7 +369,16 @@ Tensor& convolution_out(
298369
size_t output_ndim = 0;
299370
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
300371
get_convolution_out_target_size(
301-
in, weight, stride, padding, dilation, output_sizes, &output_ndim);
372+
in,
373+
weight,
374+
stride,
375+
padding,
376+
dilation,
377+
transposed,
378+
output_padding,
379+
groups,
380+
output_sizes,
381+
&output_ndim);
302382

303383
ET_KERNEL_CHECK(
304384
ctx,
@@ -321,12 +401,14 @@ Tensor& convolution_out(
321401
if (bias.has_value()) {
322402
bias_type = bias.value().scalar_type();
323403
}
324-
ET_SWITCH_REAL_TYPES(in_type, ctx, "convolution.out", CTYPE, [&]() {
325-
ET_SWITCH_REAL_TYPES_AND(
326-
Bool, bias_type, ctx, "convolution.out", CTYPE_BIAS, [&]() {
327-
convolution_wrapper<CTYPE, CTYPE_BIAS>(
328-
in, weight, bias, stride, padding, dilation, groups, out);
329-
});
404+
405+
constexpr auto name = "convolution.out";
406+
407+
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
408+
ET_SWITCH_REALHB_TYPES(bias_type, ctx, name, CTYPE_BIAS, [&]() {
409+
convolution_wrapper<CTYPE, CTYPE_BIAS>(
410+
in, weight, bias, stride, padding, dilation, transposed, groups, out);
411+
});
330412
});
331413

332414
return out;

0 commit comments

Comments
 (0)