|
17 | 17 |
|
18 | 18 | #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
|
19 | 19 |
|
20 |
| -#include <iostream> |
21 |
| - |
22 | 20 | namespace vkcompute {
|
23 | 21 |
|
24 | 22 | void resize_conv2d_node(
|
@@ -56,6 +54,29 @@ void resize_conv2d_node(
|
56 | 54 | out->virtual_resize(new_out_sizes);
|
57 | 55 | }
|
58 | 56 |
|
| 57 | +void resize_conv1d_node( |
| 58 | + ComputeGraph* graph, |
| 59 | + const std::vector<ArgGroup>& args, |
| 60 | + const std::vector<ValueRef>& extra_args) { |
| 61 | + vTensorPtr out = graph->get_tensor(args[0].refs[0]); |
| 62 | + vTensorPtr self = graph->get_tensor(args[1].refs[0]); |
| 63 | + TensorRefPtr weight_ref = graph->get_tref(extra_args[0]); |
| 64 | + const std::vector<int64_t>& weight_sizes = weight_ref->sizes; |
| 65 | + |
| 66 | + const std::vector<int64_t>& in_sizes = self->sizes(); |
| 67 | + size_t ndim = in_sizes.size(); |
| 68 | + std::vector<int64_t> new_out_sizes(ndim); |
| 69 | + |
| 70 | + int64_t kernel_size = weight_sizes.at(2); |
| 71 | + int64_t in_length = in_sizes.at(2); |
| 72 | + |
| 73 | + new_out_sizes.at(0) = in_sizes.at(0); |
| 74 | + new_out_sizes.at(1) = in_sizes.at(1); |
| 75 | + new_out_sizes.at(2) = in_length - kernel_size + 1; |
| 76 | + |
| 77 | + out->virtual_resize(new_out_sizes); |
| 78 | +} |
| 79 | + |
59 | 80 | ValueRef prepack_biases(
|
60 | 81 | ComputeGraph& graph,
|
61 | 82 | const ValueRef vref,
|
@@ -219,7 +240,7 @@ ValueRef prepack_weights(
|
219 | 240 | return v;
|
220 | 241 | }
|
221 | 242 |
|
222 |
| -void check_conv2d_args(const vTensor& in, const vTensor& out) { |
| 243 | +void check_conv_args(const vTensor& in, const vTensor& out) { |
223 | 244 | if (in.sizes().at(0) > 1) {
|
224 | 245 | VK_THROW(
|
225 | 246 | "aten.convolution.default: input batch size > 1 is not supported yet!");
|
@@ -312,7 +333,7 @@ void add_conv2d_node(
|
312 | 333 |
|
313 | 334 | vTensorPtr t_in = graph.get_tensor(arg_in);
|
314 | 335 | vTensorPtr t_out = graph.get_tensor(out);
|
315 |
| - check_conv2d_args(*t_in, *t_out); |
| 336 | + check_conv_args(*t_in, *t_out); |
316 | 337 |
|
317 | 338 | api::utils::uvec3 global_size = t_out->extents();
|
318 | 339 | api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
|
@@ -352,23 +373,121 @@ void add_conv2d_node(
|
352 | 373 | {weight, stride, padding, dilation, transposed, output_padding}));
|
353 | 374 | }
|
354 | 375 |
|
355 |
| -void conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) { |
356 |
| - return add_conv2d_node( |
| 376 | +void add_conv1d_node( |
| 377 | + ComputeGraph& graph, |
| 378 | + const ValueRef in, |
| 379 | + const ValueRef weight, |
| 380 | + const ValueRef bias, |
| 381 | + const ValueRef stride, |
| 382 | + const ValueRef padding, |
| 383 | + const ValueRef dilation, |
| 384 | + const ValueRef groups, |
| 385 | + const ValueRef out) { |
| 386 | + if (graph.val_is_none(bias)) { |
| 387 | + VK_THROW("conv1d: Null bias is not supported yet!"); |
| 388 | + } |
| 389 | + |
| 390 | + ValueRef arg_in = prepack_if_tensor_ref(graph, in); |
| 391 | + ValueRef arg_weight = |
| 392 | + prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in)); |
| 393 | + ValueRef arg_bias = |
| 394 | + prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in)); |
| 395 | + |
| 396 | + vTensorPtr t_in = graph.get_tensor(arg_in); |
| 397 | + vTensorPtr t_weight = graph.get_tensor(arg_weight); |
| 398 | + vTensorPtr t_bias = graph.get_tensor(arg_bias); |
| 399 | + vTensorPtr t_out = graph.get_tensor(out); |
| 400 | + const int64_t groups_val = graph.get_int(groups); |
| 401 | + |
| 402 | + std::vector<int64_t> in_sizes = t_in->sizes(); |
| 403 | + std::vector<int64_t> weight_sizes = t_weight->sizes(); |
| 404 | + std::vector<int64_t> out_sizes = t_out->sizes(); |
| 405 | + IntListPtr stride_sizes = graph.get_int_list(stride); |
| 406 | + IntListPtr padding_sizes = graph.get_int_list(padding); |
| 407 | + IntListPtr dilation_sizes = graph.get_int_list(dilation); |
| 408 | + int64_t weight_out_channels = weight_sizes.at(0); |
| 409 | + int64_t kernel_size = weight_sizes.at(2); |
| 410 | + int64_t in_length = in_sizes.at(2); |
| 411 | + |
| 412 | + VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor"); |
| 413 | + VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor"); |
| 414 | + VK_CHECK_COND( |
| 415 | + stride_sizes->size() == 1 && stride_sizes->at(0) == 1, |
| 416 | + "stride must be 1"); |
| 417 | + VK_CHECK_COND( |
| 418 | + padding_sizes->size() == 1 && padding_sizes->at(0) == 0, |
| 419 | + "padding must be 0"); |
| 420 | + VK_CHECK_COND( |
| 421 | + dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1, |
| 422 | + "dilation must be 1"); |
| 423 | + VK_CHECK_COND( |
| 424 | + groups_val == in_sizes.at(1), "groups must be equal to in_channels"); |
| 425 | + VK_CHECK_COND( |
| 426 | + groups_val == weight_sizes.at(0), |
| 427 | + "groups must be equal to weight_sizes.at(0)"); |
| 428 | + VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1"); |
| 429 | + |
| 430 | + check_conv_args(*t_in, *t_out); |
| 431 | + |
| 432 | + api::utils::uvec3 global_size = { |
| 433 | + 1, static_cast<uint32_t>(weight_out_channels), 1}; |
| 434 | + api::utils::uvec3 local_size = {1, 1, 1}; |
| 435 | + |
| 436 | + std::string kernel_name("conv1d"); |
| 437 | + kernel_name.reserve(kShaderNameReserve); |
| 438 | + |
| 439 | + add_dtype_suffix(kernel_name, *t_out); |
| 440 | + |
| 441 | + graph.execute_nodes().emplace_back(new ExecuteNode( |
357 | 442 | graph,
|
358 |
| - args[0], |
359 |
| - args[1], |
360 |
| - args[2], |
361 |
| - args[3], |
362 |
| - args[4], |
363 |
| - args[5], |
364 |
| - args[6], |
365 |
| - args[7], |
366 |
| - args[8], |
367 |
| - args[9]); |
| 443 | + VK_KERNEL_FROM_STR(kernel_name), |
| 444 | + global_size, |
| 445 | + local_size, |
| 446 | + // Inputs and Outputs |
| 447 | + {{out, api::MemoryAccessType::WRITE}, |
| 448 | + {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, |
| 449 | + // Shader params buffers |
| 450 | + { |
| 451 | + graph.create_params_buffer(weight_out_channels), |
| 452 | + graph.create_params_buffer(in_length), |
| 453 | + graph.create_params_buffer(kernel_size), |
| 454 | + }, |
| 455 | + // Resizing |
| 456 | + resize_conv1d_node, |
| 457 | + {weight})); |
| 458 | +} |
| 459 | + |
| 460 | +void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) { |
| 461 | + int64_t in_ndim = graph.get_tensor(args[0])->sizes().size(); |
| 462 | + if (in_ndim == 4) { |
| 463 | + return add_conv2d_node( |
| 464 | + graph, |
| 465 | + args[0], |
| 466 | + args[1], |
| 467 | + args[2], |
| 468 | + args[3], |
| 469 | + args[4], |
| 470 | + args[5], |
| 471 | + args[6], |
| 472 | + args[7], |
| 473 | + args[8], |
| 474 | + args[9]); |
| 475 | + } else { |
| 476 | + return add_conv1d_node( |
| 477 | + graph, |
| 478 | + args[0], |
| 479 | + args[1], |
| 480 | + args[2], |
| 481 | + args[3], |
| 482 | + args[4], |
| 483 | + args[5], |
| 484 | + args[8], |
| 485 | + args[9]); |
| 486 | + } |
368 | 487 | }
|
369 | 488 |
|
370 | 489 | REGISTER_OPERATORS {
|
371 |
| - VK_REGISTER_OP(aten.convolution.default, conv2d); |
| 490 | + VK_REGISTER_OP(aten.convolution.default, conv); |
372 | 491 | }
|
373 | 492 |
|
374 | 493 | } // namespace vkcompute
|
0 commit comments