@@ -15,6 +15,36 @@ namespace executor {
15
15
16
16
using Tensor = exec_aten::Tensor;
17
17
18
+ namespace {
19
+
20
+ bool param_array_is_valid (
21
+ const char * name,
22
+ IntArrayRef array,
23
+ int64_t min_val,
24
+ size_t length,
25
+ bool allow_empty) {
26
+ auto size = array.size ();
27
+ if (allow_empty) {
28
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
29
+ size == 0 || size == 1 || size == length,
30
+ " Expected %s to have size 0, 1 or %zu but got %zd" ,
31
+ name,
32
+ length,
33
+ size);
34
+ } else {
35
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
36
+ size == 1 || size == length,
37
+ " Expected %s to have size 1 or %zu but got %zd" ,
38
+ name,
39
+ length,
40
+ size);
41
+ }
42
+ ET_LOG_AND_RETURN_IF_FALSE (int_array_all_ge (array, min_val));
43
+ return true ;
44
+ }
45
+
46
+ } // namespace
47
+
18
48
int64_t val_at (IntArrayRef array, size_t i, int64_t default_val) {
19
49
if (array.size () == 1 ) {
20
50
return array[0 ];
@@ -41,38 +71,29 @@ bool int_array_all_ge(IntArrayRef array, int64_t val) {
41
71
}
42
72
43
73
bool kernel_size_is_valid (IntArrayRef kernel_size, size_t kernel_ndim) {
44
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
45
- kernel_size.size () == kernel_ndim,
46
- " Expected kernel_size to have size %zu but got %zd" ,
74
+ return param_array_is_valid (
75
+ " kernel_size" ,
76
+ kernel_size,
77
+ /* min_val=*/ 1 ,
47
78
kernel_ndim,
48
- kernel_size.size ());
49
- ET_LOG_AND_RETURN_IF_FALSE (int_array_all_ge (kernel_size, 1 ));
50
- return true ;
79
+ /* allow_empty=*/ false );
51
80
}
52
81
53
- bool stride_is_valid (IntArrayRef stride, size_t kernel_ndim) {
54
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
55
- stride.size () > 0 && stride.size () <= kernel_ndim,
56
- " Expected stride to have size between 1 and %zu inclusive "
57
- " but got %zd" ,
58
- kernel_ndim,
59
- stride.size ());
60
- ET_LOG_AND_RETURN_IF_FALSE (int_array_all_ge (stride, 1 ));
61
- return true ;
82
+ bool stride_is_valid (IntArrayRef stride, size_t kernel_ndim, bool allow_empty) {
83
+ return param_array_is_valid (
84
+ " stride" , stride, /* min_val=*/ 1 , kernel_ndim, allow_empty);
62
85
}
63
86
64
87
bool padding_is_valid (
65
88
IntArrayRef padding,
66
89
IntArrayRef kernel_size,
67
90
size_t kernel_ndim,
68
91
bool enforce_half_kernel) {
69
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
70
- padding.size () > 0 && padding.size () <= kernel_ndim,
71
- " Expected padding to have size between 1 and %zu inclusive "
72
- " but got %zd" ,
73
- kernel_ndim,
74
- padding.size ());
75
- ET_LOG_AND_RETURN_IF_FALSE (int_array_all_ge (padding, 0 ));
92
+ bool valid = param_array_is_valid (
93
+ " padding" , padding, /* min_val=*/ 0 , kernel_ndim, /* allow_empty=*/ false );
94
+ if (!valid) {
95
+ return false ;
96
+ }
76
97
77
98
if (enforce_half_kernel) {
78
99
// Padding must be at most half of kernel size.
@@ -94,20 +115,21 @@ bool padding_is_valid(
94
115
}
95
116
96
117
bool dilation_is_valid (IntArrayRef dilation, size_t kernel_ndim) {
97
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
98
- dilation.size () > 0 && dilation.size () <= kernel_ndim,
99
- " Expected dilation to have size between 1 and %zu inclusive "
100
- " but got %zd" ,
101
- kernel_ndim,
102
- dilation.size ());
103
- ET_LOG_AND_RETURN_IF_FALSE (int_array_all_ge (dilation, 1 ));
104
- return true ;
118
+ return param_array_is_valid (
119
+ " dilation" , dilation, /* min_val=*/ 1 , kernel_ndim, /* allow_empty=*/ false );
105
120
}
106
121
107
122
bool output_size_is_valid (
108
- exec_aten::ArrayRef<exec_aten::SizesType> output_size) {
123
+ exec_aten::ArrayRef<exec_aten::SizesType> output_size,
124
+ size_t kernel_ndim) {
109
125
bool valid = true ;
110
- for (size_t i = 0 ; i < output_size.size (); i++) {
126
+ size_t out_dim = output_size.size ();
127
+ for (size_t i = 0 ; i < out_dim - kernel_ndim; i++) {
128
+ if (output_size[i] < 0 ) {
129
+ valid = false ;
130
+ }
131
+ }
132
+ for (size_t i = out_dim - kernel_ndim; i < out_dim; i++) {
111
133
if (output_size[i] <= 0 ) {
112
134
valid = false ;
113
135
}
@@ -158,37 +180,44 @@ void get_unsqueezed_dim_order(
158
180
return ;
159
181
}
160
182
183
+ int64_t _kernel_output_size_helper (
184
+ size_t inputSize,
185
+ int64_t kernelSize,
186
+ int64_t pad,
187
+ int64_t stride,
188
+ int64_t dilation,
189
+ bool ceil_mode) {
190
+ int64_t numerator = inputSize + 2 * pad - dilation * (kernelSize - 1 ) - 1 +
191
+ (ceil_mode ? stride - 1 : 0 );
192
+ int64_t outputSize = numerator / stride + 1 ;
193
+ if (ceil_mode) {
194
+ // ensure that the last pooling starts inside the image
195
+ // needed to avoid problems in ceil mode
196
+ if ((outputSize - 1 ) * stride >= inputSize + pad) {
197
+ --outputSize;
198
+ }
199
+ }
200
+ return outputSize;
201
+ }
202
+
161
203
void calculate_kernel_output_sizes (
162
204
const Tensor& in,
205
+ size_t kernel_ndim,
163
206
IntArrayRef kernel_size,
164
207
IntArrayRef stride,
165
208
IntArrayRef padding,
166
209
IntArrayRef dilation,
167
210
exec_aten::SizesType* out_sizes,
168
211
bool ceil_mode) {
169
- size_t dim_offset = in.dim () - kernel_size.size ();
170
- for (size_t d = 0 ; d < kernel_size.size (); ++d) {
171
- int64_t dilation_val = 1 ;
172
- if (dilation.size () > 1 ) {
173
- dilation_val = val_at (dilation, d);
174
- }
175
- int64_t padding_val = val_at (padding, d, /* default=*/ 0 );
176
- int64_t stride_val = val_at (stride, d);
177
-
178
- int64_t kernel_len = dilation_val * (val_at (kernel_size, d) - 1 ) + 1 ;
179
- if (ceil_mode) {
180
- out_sizes[d + dim_offset] =
181
- std::ceil (
182
- static_cast <float >(
183
- in.size (d + dim_offset) + (2 * padding_val) - kernel_len) /
184
- static_cast <float >(stride_val)) +
185
- 1 ;
186
- } else {
187
- out_sizes[d + dim_offset] =
188
- (in.size (d + dim_offset) + (2 * padding_val) - kernel_len) /
189
- stride_val +
190
- 1 ;
191
- }
212
+ for (size_t i = 0 ; i < kernel_ndim; ++i) {
213
+ auto dim = in.dim () - (kernel_ndim - i);
214
+ int64_t k = val_at (kernel_size, i);
215
+ int64_t s = val_at (stride, i, /* default_value=*/ k);
216
+ int64_t d = val_at (dilation, i, /* default_value=*/ 1 );
217
+ int64_t p = val_at (padding, i, /* default_value=*/ 0 );
218
+
219
+ out_sizes[dim] =
220
+ _kernel_output_size_helper (in.size (dim), k, p, s, d, ceil_mode);
192
221
}
193
222
}
194
223
@@ -206,16 +235,22 @@ bool check_avg_pool2d_args(
206
235
ET_LOG_AND_RETURN_IF_FALSE (tensor_is_default_or_channels_last_dim_order (in));
207
236
ET_LOG_AND_RETURN_IF_FALSE (tensor_is_default_or_channels_last_dim_order (out));
208
237
209
- ET_LOG_AND_RETURN_IF_FALSE (kernel_size_is_valid (kernel_size, 2 ));
210
- if (stride.size () > 0 ) {
211
- ET_LOG_AND_RETURN_IF_FALSE (stride_is_valid (kernel_size, 2 ));
212
- }
213
- ET_LOG_AND_RETURN_IF_FALSE (padding_is_valid (padding, kernel_size, 2 , true ));
238
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
239
+ (in.dim () == 3 && in.size (0 ) > 0 && in.size (1 ) > 0 && in.size (2 ) > 0 ) ||
240
+ (in.dim () == 4 && in.size (1 ) > 0 && in.size (2 ) > 0 && in.size (3 ) > 0 ),
241
+ " Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input" );
242
+
243
+ ET_LOG_AND_RETURN_IF_FALSE (
244
+ kernel_size_is_valid (kernel_size, /* kernel_ndim=*/ 2 ));
245
+ ET_LOG_AND_RETURN_IF_FALSE (
246
+ stride_is_valid (kernel_size, /* kernel_ndim=*/ 2 , /* allow_empty=*/ true ));
247
+ ET_LOG_AND_RETURN_IF_FALSE (padding_is_valid (
248
+ padding, kernel_size, /* kernel_ndim=*/ 2 , /* enforce_half_kernel=*/ true ));
214
249
215
250
if (divisor_override.has_value ()) {
216
251
ET_LOG_MSG_AND_RETURN_IF_FALSE (
217
- divisor_override.value () > 0 ,
218
- " divisor_override must be > 0 , but found %" PRId64,
252
+ divisor_override.value () != 0 ,
253
+ " divisor_override must be non-zero , but found %" PRId64,
219
254
divisor_override.value ());
220
255
}
221
256
@@ -241,7 +276,7 @@ void get_avg_pool2d_out_target_size(
241
276
}
242
277
243
278
calculate_kernel_output_sizes (
244
- in, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
279
+ in, 2 , kernel_size, stride, padding, {}, out_sizes, ceil_mode);
245
280
}
246
281
247
282
bool check_convolution_args (
@@ -284,12 +319,11 @@ bool check_convolution_args(
284
319
kernel_size[0 ] = weight.size (2 );
285
320
kernel_size[1 ] = weight.size (3 );
286
321
}
287
- ET_LOG_AND_RETURN_IF_FALSE (stride_is_valid (stride, kernel_ndim));
322
+ ET_LOG_AND_RETURN_IF_FALSE (
323
+ stride_is_valid (stride, kernel_ndim, /* allow_empty=*/ false ));
288
324
ET_LOG_AND_RETURN_IF_FALSE (
289
325
padding_is_valid (padding, {kernel_size, kernel_ndim}, kernel_ndim));
290
- if (dilation.size () > 0 ) {
291
- ET_LOG_AND_RETURN_IF_FALSE (dilation_is_valid (dilation, kernel_ndim));
292
- }
326
+ ET_LOG_AND_RETURN_IF_FALSE (dilation_is_valid (dilation, kernel_ndim));
293
327
294
328
ET_LOG_MSG_AND_RETURN_IF_FALSE (
295
329
in.size (1 ) % groups == 0 ,
@@ -314,7 +348,7 @@ void get_convolution_out_target_size(
314
348
*out_ndim = in.dim ();
315
349
316
350
out_sizes[0 ] = in.size (0 );
317
- out_sizes[1 ] = weight.size (0 );
351
+ out_sizes[1 ] = in. size ( 1 ) == 0 ? 0 : weight.size (0 );
318
352
319
353
int64_t kernel_size[2 ];
320
354
size_t kernel_ndim = 2 ;
@@ -326,7 +360,14 @@ void get_convolution_out_target_size(
326
360
kernel_size[1 ] = weight.size (3 );
327
361
}
328
362
calculate_kernel_output_sizes (
329
- in, {kernel_size, kernel_ndim}, stride, padding, dilation, out_sizes);
363
+ in,
364
+ kernel_ndim,
365
+ {kernel_size, kernel_ndim},
366
+ stride,
367
+ padding,
368
+ dilation,
369
+ out_sizes,
370
+ false );
330
371
}
331
372
332
373
bool check_max_pool2d_with_indices_args (
@@ -347,14 +388,18 @@ bool check_max_pool2d_with_indices_args(
347
388
ET_LOG_AND_RETURN_IF_FALSE (tensor_is_default_or_channels_last_dim_order (in));
348
389
ET_LOG_AND_RETURN_IF_FALSE (tensor_is_default_or_channels_last_dim_order (out));
349
390
350
- ET_LOG_AND_RETURN_IF_FALSE (kernel_size_is_valid (kernel_size, 2 ));
351
- if (stride.size () > 0 ) {
352
- ET_LOG_AND_RETURN_IF_FALSE (stride_is_valid (kernel_size, 2 ));
353
- }
354
- ET_LOG_AND_RETURN_IF_FALSE (padding_is_valid (padding, kernel_size, 2 , true ));
355
- if (dilation.size () > 0 ) {
356
- ET_LOG_AND_RETURN_IF_FALSE (dilation_is_valid (dilation, 2 ));
357
- }
391
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
392
+ (in.dim () == 3 && in.size (0 ) > 0 && in.size (1 ) > 0 && in.size (2 ) > 0 ) ||
393
+ (in.dim () == 4 && in.size (1 ) > 0 && in.size (2 ) > 0 && in.size (3 ) > 0 ),
394
+ " Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input" );
395
+
396
+ ET_LOG_AND_RETURN_IF_FALSE (
397
+ kernel_size_is_valid (kernel_size, /* kernel_ndim=*/ 2 ));
398
+ ET_LOG_AND_RETURN_IF_FALSE (
399
+ stride_is_valid (kernel_size, /* kernel_ndim=*/ 2 , /* allow_empty=*/ true ));
400
+ ET_LOG_AND_RETURN_IF_FALSE (padding_is_valid (
401
+ padding, kernel_size, /* kernel_ndim=*/ 2 , /* enforce_half_kernel=*/ true ));
402
+ ET_LOG_AND_RETURN_IF_FALSE (dilation_is_valid (kernel_size, /* kernel_ndim=*/ 2 ));
358
403
359
404
return true ;
360
405
}
@@ -379,7 +424,7 @@ void get_max_pool2d_with_indices_out_target_size(
379
424
}
380
425
381
426
calculate_kernel_output_sizes (
382
- in, kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
427
+ in, 2 , kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
383
428
}
384
429
385
430
} // namespace executor
0 commit comments