Skip to content

Commit 95c80f0

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add op: enable transposed convolution (#4197)
Summary: Pull Request resolved: #4197 Differential Revision: D59589884
1 parent 561c035 commit 95c80f0

File tree

4 files changed

+372
-72
lines changed

4 files changed

+372
-72
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 134 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ void conv2d_impl(
5050
StridesArrayRef out_strides,
5151
const size_t batch,
5252
const size_t group,
53-
const size_t out_c) {
53+
const size_t out_c,
54+
bool transposed) {
5455
size_t in_C = in_sizes[1];
56+
size_t out_C = out_sizes[1];
5557

5658
size_t out_H = out_sizes[2];
5759
size_t in_H = in_sizes[2];
@@ -64,13 +66,15 @@ void conv2d_impl(
6466
size_t in_C_per_group = in_C / groups;
6567
size_t in_c_start = group * in_C_per_group;
6668

69+
size_t out_C_per_group = out_C / groups;
70+
size_t out_c_start = group * out_C_per_group;
71+
6772
exec_aten::SizesType in_coord[kTensorDimensionLimit];
6873
in_coord[0] = batch;
6974
exec_aten::SizesType out_coord[kTensorDimensionLimit];
7075
out_coord[0] = batch;
7176
out_coord[1] = out_c;
7277
exec_aten::SizesType w_coord[kTensorDimensionLimit];
73-
w_coord[0] = out_c;
7478

7579
const int64_t stride_y = val_at(stride, 0);
7680
const int64_t padding_y = val_at(padding, 0, /*default_value=*/0);
@@ -79,53 +83,115 @@ void conv2d_impl(
7983
const int64_t padding_x = val_at(padding, 1, /*default_value=*/0);
8084
const int64_t dilation_x = val_at(dilation, 1);
8185

82-
// Compute 2D output region
83-
for (size_t out_y = 0; out_y < out_H; ++out_y) {
84-
out_coord[2] = out_y;
85-
for (size_t out_x = 0; out_x < out_W; ++out_x) {
86-
out_coord[3] = out_x;
87-
88-
CTYPE accum = 0.0f;
89-
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
90-
++in_c) {
91-
in_coord[1] = in_c;
92-
w_coord[1] = in_c - in_c_start;
93-
94-
for (size_t w_y = 0; w_y < w_H; ++w_y) {
95-
w_coord[2] = w_y;
96-
97-
size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
98-
in_coord[2] = in_y;
99-
// Only proceed if input y coordinate is within bounds
100-
if (in_y >= 0 && in_y < in_H) {
101-
for (size_t w_x = 0; w_x < w_W; ++w_x) {
102-
w_coord[3] = w_x;
103-
104-
size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
105-
in_coord[3] = in_x;
106-
107-
// Only proceed if input coordinate is within bounds
108-
if (in_x >= 0 && in_x < in_W) {
109-
size_t in_idx =
110-
calculate_linear_index(in_coord, in_strides.data(), 4);
111-
CTYPE in_val = in_ptr[in_idx];
112-
113-
size_t w_idx =
114-
calculate_linear_index(w_coord, w_strides.data(), 4);
115-
CTYPE w_val = w_ptr[w_idx];
116-
117-
accum += in_val * w_val;
86+
if (!transposed) {
87+
w_coord[0] = out_c;
88+
// Compute 2D output region
89+
for (size_t out_y = 0; out_y < out_H; ++out_y) {
90+
out_coord[2] = out_y;
91+
for (size_t out_x = 0; out_x < out_W; ++out_x) {
92+
out_coord[3] = out_x;
93+
94+
CTYPE accum = 0.0f;
95+
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
96+
++in_c) {
97+
in_coord[1] = in_c;
98+
w_coord[1] = in_c - in_c_start;
99+
100+
for (size_t w_y = 0; w_y < w_H; ++w_y) {
101+
w_coord[2] = w_y;
102+
103+
size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
104+
in_coord[2] = in_y;
105+
// Only proceed if input y coordinate is within bounds
106+
if (in_y >= 0 && in_y < in_H) {
107+
for (size_t w_x = 0; w_x < w_W; ++w_x) {
108+
w_coord[3] = w_x;
109+
110+
size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
111+
in_coord[3] = in_x;
112+
113+
// Only proceed if input x coordinate is within bounds
114+
if (in_x >= 0 && in_x < in_W) {
115+
size_t in_idx =
116+
calculate_linear_index(in_coord, in_strides.data(), 4);
117+
CTYPE in_val = in_ptr[in_idx];
118+
119+
size_t w_idx =
120+
calculate_linear_index(w_coord, w_strides.data(), 4);
121+
CTYPE w_val = w_ptr[w_idx];
122+
123+
accum += in_val * w_val;
124+
}
118125
}
119126
}
120127
}
121128
}
129+
130+
if (bias_ptr != nullptr) {
131+
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
132+
}
133+
size_t out_idx =
134+
calculate_linear_index(out_coord, out_strides.data(), 4);
135+
out_ptr[out_idx] = accum;
122136
}
137+
}
138+
} else { // transposed convolution
139+
w_coord[1] = out_c - out_c_start;
140+
141+
for (size_t in_y = 0; in_y < in_H; ++in_y) {
142+
in_coord[2] = in_y;
143+
144+
for (size_t in_x = 0; in_x < in_W; ++in_x) {
145+
in_coord[3] = in_x;
146+
147+
for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
148+
++in_c) {
149+
in_coord[1] = in_c;
150+
151+
size_t in_idx =
152+
calculate_linear_index(in_coord, in_strides.data(), 4);
153+
CTYPE in_val = in_ptr[in_idx];
154+
155+
w_coord[0] = in_c;
156+
for (size_t w_y = 0; w_y < w_H; ++w_y) {
157+
w_coord[2] = w_y;
158+
size_t out_y = stride_y * in_y + dilation_y * w_y - padding_y;
159+
out_coord[2] = out_y;
160+
161+
// Only proceed if output y coordinate is within bounds
162+
if (out_y >= 0 && out_y < out_H) {
163+
for (size_t w_x = 0; w_x < w_W; ++w_x) {
164+
w_coord[3] = w_x;
165+
size_t out_x = stride_x * in_x + dilation_x * w_x - padding_x;
166+
out_coord[3] = out_x;
167+
168+
// Only proceed if output x coordinate is within bounds
169+
if (out_x >= 0 && out_x < out_W) {
170+
size_t w_idx =
171+
calculate_linear_index(w_coord, w_strides.data(), 4);
172+
CTYPE w_val = w_ptr[w_idx];
173+
174+
size_t out_idx =
175+
calculate_linear_index(out_coord, out_strides.data(), 4);
176+
177+
out_ptr[out_idx] += in_val * w_val;
178+
}
179+
}
180+
}
181+
}
182+
}
183+
}
184+
}
123185

124-
if (bias_ptr != nullptr) {
125-
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
186+
if (bias_ptr != nullptr) {
187+
out_coord[2] = 0;
188+
out_coord[3] = 0;
189+
size_t out_c_start_idx =
190+
calculate_linear_index(out_coord, out_strides.data(), 4);
191+
size_t out_c_end_idx = out_c_start_idx + out_H * out_W;
192+
for (size_t out_ix = out_c_start_idx; out_ix < out_c_end_idx; out_ix++) {
193+
out_ptr[out_ix] += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
126194
}
127-
size_t out_idx = calculate_linear_index(out_coord, out_strides.data(), 4);
128-
out_ptr[out_idx] = accum;
129195
}
130196
}
131197
}
@@ -138,14 +204,9 @@ void convolution_wrapper(
138204
IntArrayRef stride,
139205
IntArrayRef padding,
140206
IntArrayRef dilation,
207+
bool transposed,
141208
int64_t groups,
142209
Tensor& out) {
143-
size_t out_N = in.size(0);
144-
size_t out_C = weight.size(0);
145-
146-
// Compute the number of in and out channels in each group
147-
size_t out_C_per_group = out_C / groups;
148-
149210
SizesArrayRef in_sizes = in.sizes();
150211
SizesArrayRef weight_sizes = weight.sizes();
151212
SizesArrayRef out_sizes = out.sizes();
@@ -233,6 +294,9 @@ void convolution_wrapper(
233294
const CTYPE_BIAS* const bias_ptr =
234295
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
235296

297+
size_t out_N = out.size(0);
298+
size_t out_C_per_group = out.size(1) / groups;
299+
236300
for (size_t batch = 0; batch < out_N; ++batch) {
237301
for (size_t group = 0; group < groups; ++group) {
238302
// Align channel offset based on the group
@@ -257,7 +321,8 @@ void convolution_wrapper(
257321
{out_strides, 4},
258322
batch,
259323
group,
260-
out_c);
324+
out_c,
325+
transposed);
261326
}
262327
}
263328
}
@@ -273,8 +338,8 @@ Tensor& convolution_out(
273338
IntArrayRef stride,
274339
IntArrayRef padding,
275340
IntArrayRef dilation,
276-
__ET_UNUSED bool transposed,
277-
__ET_UNUSED IntArrayRef output_padding,
341+
bool transposed,
342+
IntArrayRef output_padding,
278343
int64_t groups,
279344
Tensor& out) {
280345
(void)ctx;
@@ -298,7 +363,16 @@ Tensor& convolution_out(
298363
size_t output_ndim = 0;
299364
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
300365
get_convolution_out_target_size(
301-
in, weight, stride, padding, dilation, output_sizes, &output_ndim);
366+
in,
367+
weight,
368+
stride,
369+
padding,
370+
dilation,
371+
transposed,
372+
output_padding,
373+
groups,
374+
output_sizes,
375+
&output_ndim);
302376

303377
ET_KERNEL_CHECK(
304378
ctx,
@@ -321,12 +395,14 @@ Tensor& convolution_out(
321395
if (bias.has_value()) {
322396
bias_type = bias.value().scalar_type();
323397
}
324-
ET_SWITCH_REAL_TYPES(in_type, ctx, "convolution.out", CTYPE, [&]() {
325-
ET_SWITCH_REAL_TYPES_AND(
326-
Bool, bias_type, ctx, "convolution.out", CTYPE_BIAS, [&]() {
327-
convolution_wrapper<CTYPE, CTYPE_BIAS>(
328-
in, weight, bias, stride, padding, dilation, groups, out);
329-
});
398+
399+
constexpr auto name = "convolution.out";
400+
401+
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
402+
ET_SWITCH_REALHB_TYPES(bias_type, ctx, name, CTYPE_BIAS, [&]() {
403+
convolution_wrapper<CTYPE, CTYPE_BIAS>(
404+
in, weight, bias, stride, padding, dilation, transposed, groups, out);
405+
});
330406
});
331407

332408
return out;

0 commit comments

Comments
 (0)