@@ -50,8 +50,10 @@ void conv2d_impl(
50
50
StridesArrayRef out_strides,
51
51
const size_t batch,
52
52
const size_t group,
53
- const size_t out_c) {
53
+ const size_t out_c,
54
+ bool transposed) {
54
55
size_t in_C = in_sizes[1 ];
56
+ size_t out_C = out_sizes[1 ];
55
57
56
58
size_t out_H = out_sizes[2 ];
57
59
size_t in_H = in_sizes[2 ];
@@ -64,13 +66,15 @@ void conv2d_impl(
64
66
size_t in_C_per_group = in_C / groups;
65
67
size_t in_c_start = group * in_C_per_group;
66
68
69
+ size_t out_C_per_group = out_C / groups;
70
+ size_t out_c_start = group * out_C_per_group;
71
+
67
72
exec_aten::SizesType in_coord[kTensorDimensionLimit ];
68
73
in_coord[0 ] = batch;
69
74
exec_aten::SizesType out_coord[kTensorDimensionLimit ];
70
75
out_coord[0 ] = batch;
71
76
out_coord[1 ] = out_c;
72
77
exec_aten::SizesType w_coord[kTensorDimensionLimit ];
73
- w_coord[0 ] = out_c;
74
78
75
79
const int64_t stride_y = val_at (stride, 0 );
76
80
const int64_t padding_y = val_at (padding, 0 , /* default_value=*/ 0 );
@@ -79,53 +83,115 @@ void conv2d_impl(
79
83
const int64_t padding_x = val_at (padding, 1 , /* default_value=*/ 0 );
80
84
const int64_t dilation_x = val_at (dilation, 1 );
81
85
82
- // Compute 2D output region
83
- for (size_t out_y = 0 ; out_y < out_H; ++out_y) {
84
- out_coord[2 ] = out_y;
85
- for (size_t out_x = 0 ; out_x < out_W; ++out_x) {
86
- out_coord[3 ] = out_x;
87
-
88
- CTYPE accum = 0 .0f ;
89
- for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
90
- ++in_c) {
91
- in_coord[1 ] = in_c;
92
- w_coord[1 ] = in_c - in_c_start;
93
-
94
- for (size_t w_y = 0 ; w_y < w_H; ++w_y) {
95
- w_coord[2 ] = w_y;
96
-
97
- size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
98
- in_coord[2 ] = in_y;
99
- // Only proceed if input y coordinate is within bounds
100
- if (in_y >= 0 && in_y < in_H) {
101
- for (size_t w_x = 0 ; w_x < w_W; ++w_x) {
102
- w_coord[3 ] = w_x;
103
-
104
- size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
105
- in_coord[3 ] = in_x;
106
-
107
- // Only proceed if input coordinate is within bounds
108
- if (in_x >= 0 && in_x < in_W) {
109
- size_t in_idx =
110
- calculate_linear_index (in_coord, in_strides.data (), 4 );
111
- CTYPE in_val = in_ptr[in_idx];
112
-
113
- size_t w_idx =
114
- calculate_linear_index (w_coord, w_strides.data (), 4 );
115
- CTYPE w_val = w_ptr[w_idx];
116
-
117
- accum += in_val * w_val;
86
+ if (!transposed) {
87
+ w_coord[0 ] = out_c;
88
+ // Compute 2D output region
89
+ for (size_t out_y = 0 ; out_y < out_H; ++out_y) {
90
+ out_coord[2 ] = out_y;
91
+ for (size_t out_x = 0 ; out_x < out_W; ++out_x) {
92
+ out_coord[3 ] = out_x;
93
+
94
+ CTYPE accum = 0 .0f ;
95
+ for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
96
+ ++in_c) {
97
+ in_coord[1 ] = in_c;
98
+ w_coord[1 ] = in_c - in_c_start;
99
+
100
+ for (size_t w_y = 0 ; w_y < w_H; ++w_y) {
101
+ w_coord[2 ] = w_y;
102
+
103
+ size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
104
+ in_coord[2 ] = in_y;
105
+ // Only proceed if input y coordinate is within bounds
106
+ if (in_y >= 0 && in_y < in_H) {
107
+ for (size_t w_x = 0 ; w_x < w_W; ++w_x) {
108
+ w_coord[3 ] = w_x;
109
+
110
+ size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
111
+ in_coord[3 ] = in_x;
112
+
113
+ // Only proceed if input x coordinate is within bounds
114
+ if (in_x >= 0 && in_x < in_W) {
115
+ size_t in_idx =
116
+ calculate_linear_index (in_coord, in_strides.data (), 4 );
117
+ CTYPE in_val = in_ptr[in_idx];
118
+
119
+ size_t w_idx =
120
+ calculate_linear_index (w_coord, w_strides.data (), 4 );
121
+ CTYPE w_val = w_ptr[w_idx];
122
+
123
+ accum += in_val * w_val;
124
+ }
118
125
}
119
126
}
120
127
}
121
128
}
129
+
130
+ if (bias_ptr != nullptr ) {
131
+ accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
132
+ }
133
+ size_t out_idx =
134
+ calculate_linear_index (out_coord, out_strides.data (), 4 );
135
+ out_ptr[out_idx] = accum;
122
136
}
137
+ }
138
+ } else { // transposed convolution
139
+ w_coord[1 ] = out_c - out_c_start;
140
+
141
+ for (size_t in_y = 0 ; in_y < in_H; ++in_y) {
142
+ in_coord[2 ] = in_y;
143
+
144
+ for (size_t in_x = 0 ; in_x < in_W; ++in_x) {
145
+ in_coord[3 ] = in_x;
146
+
147
+ for (size_t in_c = in_c_start; in_c < in_c_start + in_C_per_group;
148
+ ++in_c) {
149
+ in_coord[1 ] = in_c;
150
+
151
+ size_t in_idx =
152
+ calculate_linear_index (in_coord, in_strides.data (), 4 );
153
+ CTYPE in_val = in_ptr[in_idx];
154
+
155
+ w_coord[0 ] = in_c;
156
+ for (size_t w_y = 0 ; w_y < w_H; ++w_y) {
157
+ w_coord[2 ] = w_y;
158
+ size_t out_y = stride_y * in_y + dilation_y * w_y - padding_y;
159
+ out_coord[2 ] = out_y;
160
+
161
+ // Only proceed if output y coordinate is within bounds
162
+ if (out_y >= 0 && out_y < out_H) {
163
+ for (size_t w_x = 0 ; w_x < w_W; ++w_x) {
164
+ w_coord[3 ] = w_x;
165
+ size_t out_x = stride_x * in_x + dilation_x * w_x - padding_x;
166
+ out_coord[3 ] = out_x;
167
+
168
+ // Only proceed if output x coordinate is within bounds
169
+ if (out_x >= 0 && out_x < out_W) {
170
+ size_t w_idx =
171
+ calculate_linear_index (w_coord, w_strides.data (), 4 );
172
+ CTYPE w_val = w_ptr[w_idx];
173
+
174
+ size_t out_idx =
175
+ calculate_linear_index (out_coord, out_strides.data (), 4 );
176
+
177
+ out_ptr[out_idx] += in_val * w_val;
178
+ }
179
+ }
180
+ }
181
+ }
182
+ }
183
+ }
184
+ }
123
185
124
- if (bias_ptr != nullptr ) {
125
- accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
186
+ if (bias_ptr != nullptr ) {
187
+ out_coord[2 ] = 0 ;
188
+ out_coord[3 ] = 0 ;
189
+ size_t out_c_start_idx =
190
+ calculate_linear_index (out_coord, out_strides.data (), 4 );
191
+ size_t out_c_end_idx = out_c_start_idx + out_H * out_W;
192
+ for (size_t out_ix = out_c_start_idx; out_ix < out_c_end_idx; out_ix++) {
193
+ out_ptr[out_ix] += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
126
194
}
127
- size_t out_idx = calculate_linear_index (out_coord, out_strides.data (), 4 );
128
- out_ptr[out_idx] = accum;
129
195
}
130
196
}
131
197
}
@@ -138,14 +204,9 @@ void convolution_wrapper(
138
204
IntArrayRef stride,
139
205
IntArrayRef padding,
140
206
IntArrayRef dilation,
207
+ bool transposed,
141
208
int64_t groups,
142
209
Tensor& out) {
143
- size_t out_N = in.size (0 );
144
- size_t out_C = weight.size (0 );
145
-
146
- // Compute the number of in and out channels in each group
147
- size_t out_C_per_group = out_C / groups;
148
-
149
210
SizesArrayRef in_sizes = in.sizes ();
150
211
SizesArrayRef weight_sizes = weight.sizes ();
151
212
SizesArrayRef out_sizes = out.sizes ();
@@ -233,6 +294,9 @@ void convolution_wrapper(
233
294
const CTYPE_BIAS* const bias_ptr =
234
295
bias.has_value () ? bias.value ().const_data_ptr <CTYPE_BIAS>() : nullptr ;
235
296
297
+ size_t out_N = out.size (0 );
298
+ size_t out_C_per_group = out.size (1 ) / groups;
299
+
236
300
for (size_t batch = 0 ; batch < out_N; ++batch) {
237
301
for (size_t group = 0 ; group < groups; ++group) {
238
302
// Align channel offset based on the group
@@ -257,7 +321,8 @@ void convolution_wrapper(
257
321
{out_strides, 4 },
258
322
batch,
259
323
group,
260
- out_c);
324
+ out_c,
325
+ transposed);
261
326
}
262
327
}
263
328
}
@@ -273,8 +338,8 @@ Tensor& convolution_out(
273
338
IntArrayRef stride,
274
339
IntArrayRef padding,
275
340
IntArrayRef dilation,
276
- __ET_UNUSED bool transposed,
277
- __ET_UNUSED IntArrayRef output_padding,
341
+ bool transposed,
342
+ IntArrayRef output_padding,
278
343
int64_t groups,
279
344
Tensor& out) {
280
345
(void )ctx;
@@ -298,7 +363,16 @@ Tensor& convolution_out(
298
363
size_t output_ndim = 0 ;
299
364
exec_aten::SizesType output_sizes[kTensorDimensionLimit ];
300
365
get_convolution_out_target_size (
301
- in, weight, stride, padding, dilation, output_sizes, &output_ndim);
366
+ in,
367
+ weight,
368
+ stride,
369
+ padding,
370
+ dilation,
371
+ transposed,
372
+ output_padding,
373
+ groups,
374
+ output_sizes,
375
+ &output_ndim);
302
376
303
377
ET_KERNEL_CHECK (
304
378
ctx,
@@ -321,12 +395,14 @@ Tensor& convolution_out(
321
395
if (bias.has_value ()) {
322
396
bias_type = bias.value ().scalar_type ();
323
397
}
324
- ET_SWITCH_REAL_TYPES (in_type, ctx, " convolution.out" , CTYPE, [&]() {
325
- ET_SWITCH_REAL_TYPES_AND (
326
- Bool, bias_type, ctx, " convolution.out" , CTYPE_BIAS, [&]() {
327
- convolution_wrapper<CTYPE, CTYPE_BIAS>(
328
- in, weight, bias, stride, padding, dilation, groups, out);
329
- });
398
+
399
+ constexpr auto name = " convolution.out" ;
400
+
401
+ ET_SWITCH_REALH_TYPES (in_type, ctx, name, CTYPE, [&]() {
402
+ ET_SWITCH_REALHB_TYPES (bias_type, ctx, name, CTYPE_BIAS, [&]() {
403
+ convolution_wrapper<CTYPE, CTYPE_BIAS>(
404
+ in, weight, bias, stride, padding, dilation, transposed, groups, out);
405
+ });
330
406
});
331
407
332
408
return out;
0 commit comments