Skip to content

Commit ddded1a

Browse files
[ET][Portable] Fix kernel ops (convolution, avg_pool2d, max_pool2d_with_indices)
Differential Revision: [D49967523](https://our.internmc.facebook.com/intern/diff/D49967523/) ghstack-source-id: 203341587 Pull Request resolved: #710
1 parent acd93d7 commit ddded1a

File tree

5 files changed

+171
-103
lines changed

5 files changed

+171
-103
lines changed

kernels/portable/cpu/op_avg_pool2d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Tensor& avg_pool2d_out(
5151

5252
ET_KERNEL_CHECK(
5353
ctx,
54-
output_size_is_valid({output_sizes, output_ndim}),
54+
output_size_is_valid({output_sizes, output_ndim}, 2),
5555
InvalidArgument,
5656
out);
5757

kernels/portable/cpu/op_convolution.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ Tensor& convolution_out(
301301

302302
ET_KERNEL_CHECK(
303303
ctx,
304-
output_size_is_valid({output_sizes, output_ndim}),
304+
output_size_is_valid({output_sizes, output_ndim}, in.dim() - 2),
305305
InvalidArgument,
306306
out);
307307

@@ -311,6 +311,10 @@ Tensor& convolution_out(
311311
InvalidArgument,
312312
out);
313313

314+
if (out.numel() == 0) {
315+
return out;
316+
}
317+
314318
ScalarType in_type = in.scalar_type();
315319
ScalarType bias_type = in_type;
316320
if (bias.has_value()) {

kernels/portable/cpu/op_max_pool2d_with_indices.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out(
5353

5454
ET_KERNEL_CHECK(
5555
ctx,
56-
output_size_is_valid({output_sizes, output_ndim}),
56+
output_size_is_valid({output_sizes, output_ndim}, 2),
5757
InvalidArgument,
5858
ret_val);
5959

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 123 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,36 @@ namespace executor {
1515

1616
using Tensor = exec_aten::Tensor;
1717

18+
namespace {
19+
20+
bool param_array_is_valid(
21+
const char* name,
22+
IntArrayRef array,
23+
int64_t min_val,
24+
size_t length,
25+
bool allow_empty) {
26+
auto size = array.size();
27+
if (allow_empty) {
28+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
29+
size == 0 || size == 1 || size == length,
30+
"Expected %s to have size 0, 1 or %zu but got %zd",
31+
name,
32+
length,
33+
size);
34+
} else {
35+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
36+
size == 1 || size == length,
37+
"Expected %s to have size 1 or %zu but got %zd",
38+
name,
39+
length,
40+
size);
41+
}
42+
ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(array, min_val));
43+
return true;
44+
}
45+
46+
} // namespace
47+
1848
int64_t val_at(IntArrayRef array, size_t i, int64_t default_val) {
1949
if (array.size() == 1) {
2050
return array[0];
@@ -41,38 +71,29 @@ bool int_array_all_ge(IntArrayRef array, int64_t val) {
4171
}
4272

4373
bool kernel_size_is_valid(IntArrayRef kernel_size, size_t kernel_ndim) {
44-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
45-
kernel_size.size() == kernel_ndim,
46-
"Expected kernel_size to have size %zu but got %zd",
74+
return param_array_is_valid(
75+
"kernel_size",
76+
kernel_size,
77+
/*min_val=*/1,
4778
kernel_ndim,
48-
kernel_size.size());
49-
ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(kernel_size, 1));
50-
return true;
79+
/*allow_empty=*/false);
5180
}
5281

53-
bool stride_is_valid(IntArrayRef stride, size_t kernel_ndim) {
54-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
55-
stride.size() > 0 && stride.size() <= kernel_ndim,
56-
"Expected stride to have size between 1 and %zu inclusive "
57-
"but got %zd",
58-
kernel_ndim,
59-
stride.size());
60-
ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(stride, 1));
61-
return true;
82+
bool stride_is_valid(IntArrayRef stride, size_t kernel_ndim, bool allow_empty) {
83+
return param_array_is_valid(
84+
"stride", stride, /*min_val=*/1, kernel_ndim, allow_empty);
6285
}
6386

6487
bool padding_is_valid(
6588
IntArrayRef padding,
6689
IntArrayRef kernel_size,
6790
size_t kernel_ndim,
6891
bool enforce_half_kernel) {
69-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
70-
padding.size() > 0 && padding.size() <= kernel_ndim,
71-
"Expected padding to have size between 1 and %zu inclusive "
72-
"but got %zd",
73-
kernel_ndim,
74-
padding.size());
75-
ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(padding, 0));
92+
bool valid = param_array_is_valid(
93+
"padding", padding, /*min_val=*/0, kernel_ndim, /*allow_empty=*/false);
94+
if (!valid) {
95+
return false;
96+
}
7697

7798
if (enforce_half_kernel) {
7899
// Padding must be at most half of kernel size.
@@ -94,20 +115,21 @@ bool padding_is_valid(
94115
}
95116

96117
bool dilation_is_valid(IntArrayRef dilation, size_t kernel_ndim) {
97-
ET_LOG_MSG_AND_RETURN_IF_FALSE(
98-
dilation.size() > 0 && dilation.size() <= kernel_ndim,
99-
"Expected dilation to have size between 1 and %zu inclusive "
100-
"but got %zd",
101-
kernel_ndim,
102-
dilation.size());
103-
ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(dilation, 1));
104-
return true;
118+
return param_array_is_valid(
119+
"dilation", dilation, /*min_val=*/1, kernel_ndim, /*allow_empty=*/false);
105120
}
106121

107122
bool output_size_is_valid(
108-
exec_aten::ArrayRef<exec_aten::SizesType> output_size) {
123+
exec_aten::ArrayRef<exec_aten::SizesType> output_size,
124+
size_t kernel_ndim) {
109125
bool valid = true;
110-
for (size_t i = 0; i < output_size.size(); i++) {
126+
size_t out_dim = output_size.size();
127+
for (size_t i = 0; i < out_dim - kernel_ndim; i++) {
128+
if (output_size[i] < 0) {
129+
valid = false;
130+
}
131+
}
132+
for (size_t i = out_dim - kernel_ndim; i < out_dim; i++) {
111133
if (output_size[i] <= 0) {
112134
valid = false;
113135
}
@@ -158,37 +180,44 @@ void get_unsqueezed_dim_order(
158180
return;
159181
}
160182

183+
int64_t _kernel_output_size_helper(
184+
size_t inputSize,
185+
int64_t kernelSize,
186+
int64_t pad,
187+
int64_t stride,
188+
int64_t dilation,
189+
bool ceil_mode) {
190+
int64_t numerator = inputSize + 2 * pad - dilation * (kernelSize - 1) - 1 +
191+
(ceil_mode ? stride - 1 : 0);
192+
int64_t outputSize = numerator / stride + 1;
193+
if (ceil_mode) {
194+
// ensure that the last pooling starts inside the image
195+
// needed to avoid problems in ceil mode
196+
if ((outputSize - 1) * stride >= inputSize + pad) {
197+
--outputSize;
198+
}
199+
}
200+
return outputSize;
201+
}
202+
161203
void calculate_kernel_output_sizes(
162204
const Tensor& in,
205+
size_t kernel_ndim,
163206
IntArrayRef kernel_size,
164207
IntArrayRef stride,
165208
IntArrayRef padding,
166209
IntArrayRef dilation,
167210
exec_aten::SizesType* out_sizes,
168211
bool ceil_mode) {
169-
size_t dim_offset = in.dim() - kernel_size.size();
170-
for (size_t d = 0; d < kernel_size.size(); ++d) {
171-
int64_t dilation_val = 1;
172-
if (dilation.size() > 1) {
173-
dilation_val = val_at(dilation, d);
174-
}
175-
int64_t padding_val = val_at(padding, d, /*default=*/0);
176-
int64_t stride_val = val_at(stride, d);
177-
178-
int64_t kernel_len = dilation_val * (val_at(kernel_size, d) - 1) + 1;
179-
if (ceil_mode) {
180-
out_sizes[d + dim_offset] =
181-
std::ceil(
182-
static_cast<float>(
183-
in.size(d + dim_offset) + (2 * padding_val) - kernel_len) /
184-
static_cast<float>(stride_val)) +
185-
1;
186-
} else {
187-
out_sizes[d + dim_offset] =
188-
(in.size(d + dim_offset) + (2 * padding_val) - kernel_len) /
189-
stride_val +
190-
1;
191-
}
212+
for (size_t i = 0; i < kernel_ndim; ++i) {
213+
auto dim = in.dim() - (kernel_ndim - i);
214+
int64_t k = val_at(kernel_size, i);
215+
int64_t s = val_at(stride, i, /*default_value=*/k);
216+
int64_t d = val_at(dilation, i, /*default_value=*/1);
217+
int64_t p = val_at(padding, i, /*default_value=*/0);
218+
219+
out_sizes[dim] =
220+
_kernel_output_size_helper(in.size(dim), k, p, s, d, ceil_mode);
192221
}
193222
}
194223

@@ -206,16 +235,22 @@ bool check_avg_pool2d_args(
206235
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
207236
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
208237

209-
ET_LOG_AND_RETURN_IF_FALSE(kernel_size_is_valid(kernel_size, 2));
210-
if (stride.size() > 0) {
211-
ET_LOG_AND_RETURN_IF_FALSE(stride_is_valid(kernel_size, 2));
212-
}
213-
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(padding, kernel_size, 2, true));
238+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
239+
(in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
240+
(in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
241+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
242+
243+
ET_LOG_AND_RETURN_IF_FALSE(
244+
kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
245+
ET_LOG_AND_RETURN_IF_FALSE(
246+
stride_is_valid(kernel_size, /*kernel_ndim=*/2, /*allow_empty=*/true));
247+
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(
248+
padding, kernel_size, /*kernel_ndim=*/2, /*enforce_half_kernel=*/true));
214249

215250
if (divisor_override.has_value()) {
216251
ET_LOG_MSG_AND_RETURN_IF_FALSE(
217-
divisor_override.value() > 0,
218-
"divisor_override must be > 0, but found %" PRId64,
252+
divisor_override.value() != 0,
253+
"divisor_override must be non-zero, but found %" PRId64,
219254
divisor_override.value());
220255
}
221256

@@ -241,7 +276,7 @@ void get_avg_pool2d_out_target_size(
241276
}
242277

243278
calculate_kernel_output_sizes(
244-
in, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
279+
in, 2, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
245280
}
246281

247282
bool check_convolution_args(
@@ -284,12 +319,11 @@ bool check_convolution_args(
284319
kernel_size[0] = weight.size(2);
285320
kernel_size[1] = weight.size(3);
286321
}
287-
ET_LOG_AND_RETURN_IF_FALSE(stride_is_valid(stride, kernel_ndim));
322+
ET_LOG_AND_RETURN_IF_FALSE(
323+
stride_is_valid(stride, kernel_ndim, /*allow_empty=*/false));
288324
ET_LOG_AND_RETURN_IF_FALSE(
289325
padding_is_valid(padding, {kernel_size, kernel_ndim}, kernel_ndim));
290-
if (dilation.size() > 0) {
291-
ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(dilation, kernel_ndim));
292-
}
326+
ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(dilation, kernel_ndim));
293327

294328
ET_LOG_MSG_AND_RETURN_IF_FALSE(
295329
in.size(1) % groups == 0,
@@ -314,7 +348,7 @@ void get_convolution_out_target_size(
314348
*out_ndim = in.dim();
315349

316350
out_sizes[0] = in.size(0);
317-
out_sizes[1] = weight.size(0);
351+
out_sizes[1] = in.size(1) == 0 ? 0 : weight.size(0);
318352

319353
int64_t kernel_size[2];
320354
size_t kernel_ndim = 2;
@@ -326,7 +360,14 @@ void get_convolution_out_target_size(
326360
kernel_size[1] = weight.size(3);
327361
}
328362
calculate_kernel_output_sizes(
329-
in, {kernel_size, kernel_ndim}, stride, padding, dilation, out_sizes);
363+
in,
364+
kernel_ndim,
365+
{kernel_size, kernel_ndim},
366+
stride,
367+
padding,
368+
dilation,
369+
out_sizes,
370+
false);
330371
}
331372

332373
bool check_max_pool2d_with_indices_args(
@@ -347,14 +388,18 @@ bool check_max_pool2d_with_indices_args(
347388
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
348389
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
349390

350-
ET_LOG_AND_RETURN_IF_FALSE(kernel_size_is_valid(kernel_size, 2));
351-
if (stride.size() > 0) {
352-
ET_LOG_AND_RETURN_IF_FALSE(stride_is_valid(kernel_size, 2));
353-
}
354-
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(padding, kernel_size, 2, true));
355-
if (dilation.size() > 0) {
356-
ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(dilation, 2));
357-
}
391+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
392+
(in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
393+
(in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
394+
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
395+
396+
ET_LOG_AND_RETURN_IF_FALSE(
397+
kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
398+
ET_LOG_AND_RETURN_IF_FALSE(
399+
stride_is_valid(kernel_size, /*kernel_ndim=*/2, /*allow_empty=*/true));
400+
ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(
401+
padding, kernel_size, /*kernel_ndim=*/2, /*enforce_half_kernel=*/true));
402+
ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(kernel_size, /*kernel_ndim=*/2));
358403

359404
return true;
360405
}
@@ -379,7 +424,7 @@ void get_max_pool2d_with_indices_out_target_size(
379424
}
380425

381426
calculate_kernel_output_sizes(
382-
in, kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
427+
in, 2, kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
383428
}
384429

385430
} // namespace executor

0 commit comments

Comments
 (0)