@@ -61,6 +61,11 @@ void resize_conv1d_node(
61
61
vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
62
62
vTensorPtr self = graph->get_tensor (args[1 ].refs [0 ]);
63
63
TensorRefPtr weight_ref = graph->get_tref (extra_args[0 ]);
64
+
65
+ int64_t stride_size = graph->get_int_list (extra_args[1 ])->at (0 );
66
+ int64_t padding_size = graph->get_int_list (extra_args[2 ])->at (0 );
67
+ int64_t dilation_size = graph->get_int_list (extra_args[3 ])->at (0 );
68
+
64
69
const std::vector<int64_t >& weight_sizes = weight_ref->sizes ;
65
70
66
71
const std::vector<int64_t >& in_sizes = self->sizes ();
@@ -71,8 +76,9 @@ void resize_conv1d_node(
71
76
int64_t in_length = in_sizes.at (2 );
72
77
73
78
new_out_sizes.at (0 ) = in_sizes.at (0 );
74
- new_out_sizes.at (1 ) = in_sizes.at (1 );
75
- new_out_sizes.at (2 ) = in_length - kernel_size + 1 ;
79
+ new_out_sizes.at (1 ) = weight_sizes.at (0 );
80
+ new_out_sizes.at (2 ) = calc_out_size (
81
+ in_length, kernel_size, stride_size, padding_size, dilation_size, false );
76
82
77
83
out->virtual_resize (new_out_sizes);
78
84
}
@@ -244,10 +250,6 @@ ValueRef prepack_weights(
244
250
}
245
251
246
252
void check_conv_args (const vTensor& in, const vTensor& out) {
247
- if (in.sizes ().at (0 ) > 1 ) {
248
- VK_THROW (
249
- " aten.convolution.default: input batch size > 1 is not supported yet!" );
250
- }
251
253
VK_CHECK_COND (check_memory_layout_is (in, api::kChannelsPacked ));
252
254
VK_CHECK_COND (check_memory_layout_is (out, api::kChannelsPacked ));
253
255
}
@@ -260,7 +262,7 @@ struct Conv2dParams final {
260
262
Conv2dParams create_conv2d_params (
261
263
ComputeGraph& graph,
262
264
const ValueRef weight,
263
- const KernelParams & p,
265
+ const KernelParams2D & p,
264
266
const bool transposed) {
265
267
const auto & overlay_region = api::utils::make_ivec2 ({
266
268
p.kernel_size .data [0 ] +
@@ -275,7 +277,7 @@ Conv2dParams create_conv2d_params(
275
277
return {overlay_region, in_group_size};
276
278
}
277
279
278
- void check_conv2d_params (const KernelParams & p, const bool transposed) {
280
+ void check_conv2d_params (const KernelParams2D & p, const bool transposed) {
279
281
if (transposed) {
280
282
if (p.dilation .data [0 ] > 1 || p.dilation .data [1 ] > 1 ) {
281
283
VK_THROW (
@@ -342,12 +344,15 @@ void add_conv2d_node(
342
344
343
345
vTensorPtr t_in = graph.get_tensor (arg_in);
344
346
vTensorPtr t_out = graph.get_tensor (out);
347
+ if (t_in->sizes ().at (0 ) > 1 ) {
348
+ VK_THROW (" conv2d: input batch size > 1 is not supported yet!" );
349
+ }
345
350
check_conv_args (*t_in, *t_out);
346
351
347
352
api::utils::uvec3 global_size = t_out->extents ();
348
353
api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
349
354
350
- KernelParams kernel_params = create_kernel_params (
355
+ KernelParams2D kernel_params = create_kernel_params (
351
356
graph,
352
357
weight,
353
358
/* kernel_size_only = */ false ,
@@ -395,8 +400,7 @@ void add_conv1d_node(
395
400
const ValueRef groups,
396
401
const ValueRef out) {
397
402
ValueRef arg_in = prepack_if_tensor_ref (graph, in);
398
- ValueRef arg_weight =
399
- prepack_if_tensor_ref (graph, weight, graph.memory_layout_of (arg_in));
403
+ ValueRef arg_weight = prepack_if_tensor_ref (graph, weight, api::kWidthPacked );
400
404
ValueRef arg_bias = prepack_biases (
401
405
graph,
402
406
bias,
@@ -414,37 +418,33 @@ void add_conv1d_node(
414
418
std::vector<int64_t > in_sizes = t_in->sizes ();
415
419
std::vector<int64_t > weight_sizes = t_weight->sizes ();
416
420
std::vector<int64_t > out_sizes = t_out->sizes ();
417
- IntListPtr stride_sizes = graph.get_int_list (stride);
418
- IntListPtr padding_sizes = graph.get_int_list (padding);
419
- IntListPtr dilation_sizes = graph.get_int_list (dilation);
420
- int64_t weight_out_channels = weight_sizes.at (0 );
421
- int64_t kernel_size = weight_sizes.at (2 );
422
- int64_t in_length = in_sizes.at (2 );
423
421
424
- VK_CHECK_COND (in_sizes.size () == 3 , " input must be a 3-dim tensor" );
425
- VK_CHECK_COND (weight_sizes.size () == 3 , " weight must be a 3-dim tensor" );
426
- VK_CHECK_COND (
427
- stride_sizes->size () == 1 && stride_sizes->at (0 ) == 1 ,
428
- " stride must be 1" );
429
- VK_CHECK_COND (
430
- padding_sizes->size () == 1 && padding_sizes->at (0 ) == 0 ,
431
- " padding must be 0" );
432
- VK_CHECK_COND (
433
- dilation_sizes->size () == 1 && dilation_sizes->at (0 ) == 1 ,
434
- " dilation must be 1" );
435
- VK_CHECK_COND (
436
- groups_val == in_sizes.at (1 ), " groups must be equal to in_channels" );
437
- VK_CHECK_COND (
438
- groups_val == weight_sizes.at (0 ),
439
- " groups must be equal to weight_sizes.at(0)" );
440
- VK_CHECK_COND (weight_sizes.at (1 ) == 1 , " weight_sizes.at(1) must be 1" );
422
+ int32_t in_channels = in_sizes.at (1 );
423
+ int32_t out_channels = weight_sizes.at (0 );
424
+ int32_t kernel_size = weight_sizes.at (2 );
425
+ int32_t in_length = in_sizes.at (2 );
426
+ int32_t stride_size = graph.get_int_list (stride)->at (0 );
427
+ int32_t padding_size = graph.get_int_list (padding)->at (0 );
428
+ int32_t dilation_size = graph.get_int_list (dilation)->at (0 );
429
+ int32_t in_group_size = static_cast <int64_t >(in_channels / groups_val);
430
+ int32_t out_group_size = static_cast <int64_t >(out_channels / groups_val);
431
+ int32_t batch_size = in_sizes.at (0 );
441
432
442
433
check_conv_args (*t_in, *t_out);
443
434
444
- api::utils::uvec3 global_size = {
445
- 1 , static_cast <uint32_t >(weight_out_channels), 1 };
435
+ api::utils::uvec3 global_size = {1 , static_cast <uint32_t >(out_channels), 1 };
446
436
api::utils::uvec3 local_size = {1 , 1 , 1 };
447
437
438
+ KernelParams1D kernel_params = {
439
+ in_length,
440
+ kernel_size,
441
+ stride_size,
442
+ padding_size,
443
+ dilation_size,
444
+ in_group_size,
445
+ out_group_size,
446
+ batch_size};
447
+
448
448
std::string kernel_name (" conv1d" );
449
449
kernel_name.reserve (kShaderNameReserve );
450
450
@@ -460,15 +460,13 @@ void add_conv1d_node(
460
460
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
461
461
// Shader params buffers
462
462
{
463
- graph.create_params_buffer (weight_out_channels),
464
- graph.create_params_buffer (in_length),
465
- graph.create_params_buffer (kernel_size),
463
+ graph.create_params_buffer (kernel_params),
466
464
},
467
465
// Specialization Constants
468
466
{},
469
467
// Resizing Logic
470
468
resize_conv1d_node,
471
- {weight}));
469
+ {weight, stride, padding, dilation }));
472
470
}
473
471
474
472
void conv (ComputeGraph& graph, const std::vector<ValueRef>& args) {
0 commit comments