File tree Expand file tree Collapse file tree 3 files changed +193
-151
lines changed Expand file tree Collapse file tree 3 files changed +193
-151
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,17 @@ void main() {
47
47
48
48
// Compute the start and end of the input indices to load. Padding is assumed
49
49
// to be constant 0 padding, so reads from the padding region are skipped.
50
- const ivec2 start = max (ivec2 (0 ), ipos);
50
+ ivec2 start = ipos;
51
+ if (start.x < 0 ) {
52
+ // number of "steps" to get to >= zero is div_up(-start, dilation)
53
+ int num_steps = ((- ipos.x) + dilation.x - 1 ) / dilation.x;
54
+ start.x = ipos.x + num_steps * dilation.x;
55
+ }
56
+ if (start.y < 0 ) {
57
+ // number of "steps" to get to >= zero is div_up(-start, dilation)
58
+ int num_steps = ((- ipos.y) + dilation.y - 1 ) / dilation.y;
59
+ start.y = ipos.y + num_steps * dilation.y;
60
+ }
51
61
const ivec2 end = min (ipos + overlay_region.xy, ivec2 (in_sizes.xy));
52
62
// Compute the start of the kernel based on how far we are skipping ahead when
53
63
// reading the input. Note that these are "canonical" indices.
Original file line number Diff line number Diff line change @@ -262,11 +262,6 @@ void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
262
262
" aten.convolution.default: transposed = true, dilation > 1 is not supported yet!" );
263
263
}
264
264
}
265
- if ((p.padding [0 ] > 0 && p.kernel_size [0 ] > 1 && p.dilation [0 ] > 1 ) ||
266
- (p.padding [1 ] > 0 && p.kernel_size [1 ] > 1 && p.dilation [1 ] > 1 )) {
267
- VK_THROW (
268
- " aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!" );
269
- }
270
265
}
271
266
272
267
Conv2dMethod get_conv2d_method (
Original file line number Diff line number Diff line change @@ -226,153 +226,190 @@ def get_max_pool2d_inputs():
226
226
227
227
@register_test_suite ("aten.convolution.default" )
228
228
def get_conv_inputs ():
229
- test_suite = VkTestSuite (
229
+ Test = namedtuple (
230
+ "ConvTest" ,
230
231
[
231
- (
232
- (1 , 6 , 40 , 50 ),
233
- (8 , 6 , 3 , 3 ),
234
- (8 ,),
235
- [1 , 2 ],
236
- [2 , 3 ],
237
- [1 , 1 ],
238
- False ,
239
- [0 , 0 ],
240
- 1 ,
241
- ),
242
- (
243
- (1 , 6 , 40 , 50 ),
244
- (6 , 8 , 3 , 3 ),
245
- (8 ,),
246
- [1 , 2 ],
247
- [2 , 3 ],
248
- [1 , 1 ],
249
- True ,
250
- [0 , 1 ],
251
- 1 ,
252
- ),
253
- (
254
- (1 , 8 , 72 , 96 ),
255
- (8 , 1 , 3 , 3 ),
256
- (8 ,),
257
- [1 , 1 ],
258
- [1 , 1 ],
259
- [1 , 1 ],
260
- False ,
261
- [0 , 0 ],
262
- 8 ,
263
- ),
264
- (
265
- (1 , 8 , 72 , 96 ),
266
- (8 , 8 , 1 , 1 ),
267
- (8 ,),
268
- [1 , 1 ],
269
- [1 , 1 ],
270
- [1 , 1 ],
271
- False ,
272
- [0 , 0 ],
273
- 1 ,
274
- ),
275
- (
276
- (1 , 6 , 40 , 50 ),
277
- (8 , 6 , 3 , 3 ),
278
- None ,
279
- [1 , 2 ],
280
- [2 , 3 ],
281
- [1 , 1 ],
282
- False ,
283
- [0 , 0 ],
284
- 1 ,
285
- ),
286
- (
287
- (1 , 6 , 7 ),
288
- (6 , 1 , 3 ),
289
- (6 ,),
290
- [1 ],
291
- [0 ],
292
- [1 ],
293
- False ,
294
- [0 ],
295
- 6 ,
296
- ),
297
- (
298
- (2 , 20 , 30 ),
299
- (10 , 4 , 6 ),
300
- (10 ,),
301
- [5 ],
302
- [5 ],
303
- [3 ],
304
- False ,
305
- [0 ],
306
- 5 ,
307
- ),
308
- (
309
- (1 , 9 , 11 ),
310
- (9 , 1 , 3 ),
311
- None ,
312
- [1 ],
313
- [0 ],
314
- [1 ],
315
- False ,
316
- [0 ],
317
- 9 ,
318
- ),
319
- (
320
- (5 , 15 , 30 ),
321
- (20 , 3 , 3 ),
322
- None ,
323
- [3 ],
324
- [5 ],
325
- [7 ],
326
- False ,
327
- [0 ],
328
- 5 ,
329
- ),
330
- (
331
- (1 , 16 , 672 , 512 ),
332
- (64 , 16 , 1 , 1 ),
333
- (64 ,),
334
- [1 , 1 ],
335
- [0 , 0 ],
336
- [1 , 1 ],
337
- False ,
338
- [0 , 0 ],
339
- 1 ,
340
- ),
341
- (
342
- (1 , 4 , 234 , 234 ),
343
- (4 , 1 , 3 , 3 ),
344
- (4 ,),
345
- [2 , 1 ],
346
- [1 , 1 ],
347
- [1 , 1 ],
348
- False ,
349
- [0 , 0 ],
350
- 4 ,
351
- ),
352
- (
353
- (1 , 4 , 234 , 234 ),
354
- (4 , 1 , 3 , 3 ),
355
- (4 ,),
356
- [1 , 2 ],
357
- [1 , 1 ],
358
- [1 , 1 ],
359
- False ,
360
- [0 , 0 ],
361
- 4 ,
362
- ),
363
- (
364
- (1 , 4 , 234 , 234 ),
365
- (4 , 1 , 3 , 3 ),
366
- (4 ,),
367
- [2 , 2 ],
368
- [1 , 1 ],
369
- [1 , 1 ],
370
- False ,
371
- [0 , 0 ],
372
- 4 ,
373
- ),
374
- ]
232
+ "self" ,
233
+ "weight" ,
234
+ "bias" ,
235
+ "stride" ,
236
+ "padding" ,
237
+ "dilation" ,
238
+ "transposed" ,
239
+ "output_padding" ,
240
+ "groups" ,
241
+ ],
242
+ )
243
+ Test .__new__ .__defaults__ = (
244
+ None ,
245
+ None ,
246
+ None ,
247
+ [1 , 1 ],
248
+ [0 , 0 ],
249
+ [1 , 1 ],
250
+ False ,
251
+ [9 , 0 ],
252
+ 1 ,
375
253
)
254
+ test_cases = []
255
+ test_cases = [
256
+ Test (
257
+ self = (1 , 6 , 40 , 50 ),
258
+ weight = (8 , 6 , 3 , 3 ),
259
+ bias = (8 ,),
260
+ stride = [1 , 2 ],
261
+ padding = [2 , 3 ],
262
+ dilation = [1 , 1 ],
263
+ transposed = False ,
264
+ output_padding = [0 , 0 ],
265
+ groups = 1 ,
266
+ ),
267
+ Test (
268
+ self = (1 , 6 , 40 , 50 ),
269
+ weight = (6 , 8 , 3 , 3 ),
270
+ bias = (8 ,),
271
+ stride = [1 , 2 ],
272
+ padding = [2 , 3 ],
273
+ dilation = [1 , 1 ],
274
+ transposed = True ,
275
+ output_padding = [0 , 1 ],
276
+ groups = 1 ,
277
+ ),
278
+ Test (
279
+ self = (1 , 8 , 72 , 96 ),
280
+ weight = (8 , 1 , 3 , 3 ),
281
+ bias = (8 ,),
282
+ stride = [1 , 1 ],
283
+ padding = [1 , 1 ],
284
+ dilation = [1 , 1 ],
285
+ transposed = False ,
286
+ output_padding = [0 , 0 ],
287
+ groups = 8 ,
288
+ ),
289
+ Test (
290
+ self = (1 , 8 , 72 , 96 ),
291
+ weight = (8 , 8 , 1 , 1 ),
292
+ bias = (8 ,),
293
+ stride = [1 , 1 ],
294
+ padding = [1 , 1 ],
295
+ dilation = [1 , 1 ],
296
+ transposed = False ,
297
+ output_padding = [0 , 0 ],
298
+ groups = 1 ,
299
+ ),
300
+ Test (
301
+ self = (1 , 6 , 40 , 50 ),
302
+ weight = (8 , 6 , 3 , 3 ),
303
+ bias = None ,
304
+ stride = [1 , 2 ],
305
+ padding = [2 , 3 ],
306
+ dilation = [1 , 1 ],
307
+ transposed = False ,
308
+ output_padding = [0 , 0 ],
309
+ groups = 1 ,
310
+ ),
311
+ Test (
312
+ self = (1 , 6 , 7 ),
313
+ weight = (6 , 1 , 3 ),
314
+ bias = (6 ,),
315
+ stride = [1 ],
316
+ padding = [0 ],
317
+ dilation = [1 ],
318
+ transposed = False ,
319
+ output_padding = [0 ],
320
+ groups = 6 ,
321
+ ),
322
+ Test (
323
+ self = (2 , 20 , 30 ),
324
+ weight = (10 , 4 , 6 ),
325
+ bias = (10 ,),
326
+ stride = [5 ],
327
+ padding = [5 ],
328
+ dilation = [3 ],
329
+ transposed = False ,
330
+ output_padding = [0 ],
331
+ groups = 5 ,
332
+ ),
333
+ Test (
334
+ self = (1 , 9 , 11 ),
335
+ weight = (9 , 1 , 3 ),
336
+ bias = None ,
337
+ stride = [1 ],
338
+ padding = [0 ],
339
+ dilation = [1 ],
340
+ transposed = False ,
341
+ output_padding = [0 ],
342
+ groups = 9 ,
343
+ ),
344
+ Test (
345
+ self = (5 , 15 , 30 ),
346
+ weight = (20 , 3 , 3 ),
347
+ bias = None ,
348
+ stride = [3 ],
349
+ padding = [5 ],
350
+ dilation = [7 ],
351
+ transposed = False ,
352
+ output_padding = [0 ],
353
+ groups = 5 ,
354
+ ),
355
+ Test (
356
+ self = (1 , 16 , 672 , 512 ),
357
+ weight = (64 , 16 , 1 , 1 ),
358
+ bias = (64 ,),
359
+ stride = [1 , 1 ],
360
+ padding = [0 , 0 ],
361
+ dilation = [1 , 1 ],
362
+ transposed = False ,
363
+ output_padding = [0 , 0 ],
364
+ groups = 1 ,
365
+ ),
366
+ Test (
367
+ self = (1 , 4 , 234 , 234 ),
368
+ weight = (4 , 1 , 3 , 3 ),
369
+ bias = (4 ,),
370
+ stride = [2 , 1 ],
371
+ padding = [1 , 1 ],
372
+ dilation = [1 , 1 ],
373
+ transposed = False ,
374
+ output_padding = [0 , 0 ],
375
+ groups = 4 ,
376
+ ),
377
+ Test (
378
+ self = (1 , 4 , 234 , 234 ),
379
+ weight = (4 , 1 , 3 , 3 ),
380
+ bias = (4 ,),
381
+ stride = [1 , 2 ],
382
+ padding = [1 , 1 ],
383
+ dilation = [1 , 1 ],
384
+ transposed = False ,
385
+ output_padding = [0 , 0 ],
386
+ groups = 4 ,
387
+ ),
388
+ Test (
389
+ self = (1 , 4 , 234 , 234 ),
390
+ weight = (4 , 1 , 3 , 3 ),
391
+ bias = (4 ,),
392
+ stride = [2 , 2 ],
393
+ padding = [1 , 1 ],
394
+ dilation = [1 , 1 ],
395
+ transposed = False ,
396
+ output_padding = [0 , 0 ],
397
+ groups = 4 ,
398
+ ),
399
+ Test (
400
+ self = (1 , 8 , 90 , 77 ),
401
+ weight = (1 , 8 , 3 , 3 ),
402
+ bias = (1 ,),
403
+ stride = [1 , 1 ],
404
+ padding = [2 , 2 ],
405
+ dilation = [2 , 2 ],
406
+ transposed = False ,
407
+ output_padding = [0 , 0 ],
408
+ groups = 1 ,
409
+ ),
410
+ ]
411
+
412
+ test_suite = VkTestSuite (test_cases )
376
413
return test_suite
377
414
378
415
You can’t perform that action at this time.
0 commit comments