Skip to content

Commit 1d467d0

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
conv1d, special case
Summary: We follow D50914117 to implement a specific case of conv1d for our needs. Specifically, we require - the input tensor to have a single batch - groups == in_channels == out_channels - weight_sizes.at(1) == 1 - stride == 1 - padding == 0 - dilation == 1 We assume `bias==True`. The `bias==False` case in handled in the next diff. General cases and optimizations will be enabled later. Reviewed By: jorgep31415 Differential Revision: D56220143 fbshipit-source-id: a18de3a463875b9617cb7930febf7622fe866536
1 parent 70baafe commit 1d467d0

File tree

5 files changed

+288
-19
lines changed

5 files changed

+288
-19
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#include "indexing_utils.h"
16+
17+
layout(std430) buffer;
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
22+
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
23+
24+
layout(set = 0, binding = 4) uniform PRECISION restrict Out_channels {
25+
int data;
26+
}
27+
out_channels;
28+
29+
layout(set = 0, binding = 5) uniform PRECISION restrict In_length {
30+
int data;
31+
}
32+
in_length;
33+
34+
layout(set = 0, binding = 6) uniform PRECISION restrict Kernel_size {
35+
int data;
36+
}
37+
kernel_size;
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
/*
42+
* This implementation optimize for simplicity (and partially performance) for a
43+
* (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44+
* kernel of the L dimension.
45+
*/
46+
void main() {
47+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
48+
49+
// The global workgroup should have taken care of it. We only perform one
50+
// work item for each 1d tensor on lengths
51+
if (pos.x >= 1) {
52+
return;
53+
}
54+
55+
int c = pos.y;
56+
if (c >= out_channels.data) {
57+
return;
58+
}
59+
60+
// Assume n = 1, do not handle n > 1 case for now.
61+
int n = pos.z;
62+
if (n >= 1) {
63+
return;
64+
}
65+
66+
vec4 bias = texelFetch(bias_in, ivec3(c, 0, 0), 0);
67+
68+
for (int i = 0; i < in_length.data - kernel_size.data + 1; ++i) {
69+
vec4 v = vec4(0);
70+
for (int k = 0; k < kernel_size.data; ++k) {
71+
const ivec3 in_pos = ivec3(i+k, c, 0);
72+
const vec4 input_value = texelFetch(image_in, in_pos, 0);
73+
74+
// Note that we are reading weight in the inner loop, this could be
75+
// improved by moving it before the outer loop. Since the weight vector is
76+
// contant for the entire call.
77+
78+
// weight in input-space: (c, 0, k);
79+
// notice that c is 4-packed. We need to mod 4 to get the actual weight.
80+
const ivec3 w_pos = ivec3(k, 0, c / 4);
81+
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
82+
83+
float w = weight.x;
84+
if (c % 4 == 1) {
85+
w = weight.y;
86+
} else if (c % 4 == 2) {
87+
w = weight.z;
88+
} else if (c % 4 == 3) {
89+
w = weight.w;
90+
}
91+
92+
v += w * input_value.x;
93+
}
94+
95+
ivec3 out_pos = ivec3(i, c, 0);
96+
imageStore(image_out, out_pos, vec4(v.x + bias.x, 0, 0, 0));
97+
}
98+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv1d:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: C_packed
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
shader_variants:
17+
- NAME: conv1d

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp renamed to backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 136 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1919

20-
#include <iostream>
21-
2220
namespace vkcompute {
2321

2422
void resize_conv2d_node(
@@ -56,6 +54,29 @@ void resize_conv2d_node(
5654
out->virtual_resize(new_out_sizes);
5755
}
5856

57+
void resize_conv1d_node(
58+
ComputeGraph* graph,
59+
const std::vector<ArgGroup>& args,
60+
const std::vector<ValueRef>& extra_args) {
61+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
62+
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
63+
TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
64+
const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
65+
66+
const std::vector<int64_t>& in_sizes = self->sizes();
67+
size_t ndim = in_sizes.size();
68+
std::vector<int64_t> new_out_sizes(ndim);
69+
70+
int64_t kernel_size = weight_sizes.at(2);
71+
int64_t in_length = in_sizes.at(2);
72+
73+
new_out_sizes.at(0) = in_sizes.at(0);
74+
new_out_sizes.at(1) = in_sizes.at(1);
75+
new_out_sizes.at(2) = in_length - kernel_size + 1;
76+
77+
out->virtual_resize(new_out_sizes);
78+
}
79+
5980
ValueRef prepack_biases(
6081
ComputeGraph& graph,
6182
const ValueRef vref,
@@ -219,7 +240,7 @@ ValueRef prepack_weights(
219240
return v;
220241
}
221242

222-
void check_conv2d_args(const vTensor& in, const vTensor& out) {
243+
void check_conv_args(const vTensor& in, const vTensor& out) {
223244
if (in.sizes().at(0) > 1) {
224245
VK_THROW(
225246
"aten.convolution.default: input batch size > 1 is not supported yet!");
@@ -312,7 +333,7 @@ void add_conv2d_node(
312333

313334
vTensorPtr t_in = graph.get_tensor(arg_in);
314335
vTensorPtr t_out = graph.get_tensor(out);
315-
check_conv2d_args(*t_in, *t_out);
336+
check_conv_args(*t_in, *t_out);
316337

317338
api::utils::uvec3 global_size = t_out->extents();
318339
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
@@ -352,23 +373,121 @@ void add_conv2d_node(
352373
{weight, stride, padding, dilation, transposed, output_padding}));
353374
}
354375

355-
void conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
356-
return add_conv2d_node(
376+
void add_conv1d_node(
377+
ComputeGraph& graph,
378+
const ValueRef in,
379+
const ValueRef weight,
380+
const ValueRef bias,
381+
const ValueRef stride,
382+
const ValueRef padding,
383+
const ValueRef dilation,
384+
const ValueRef groups,
385+
const ValueRef out) {
386+
if (graph.val_is_none(bias)) {
387+
VK_THROW("conv1d: Null bias is not supported yet!");
388+
}
389+
390+
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
391+
ValueRef arg_weight =
392+
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
393+
ValueRef arg_bias =
394+
prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in));
395+
396+
vTensorPtr t_in = graph.get_tensor(arg_in);
397+
vTensorPtr t_weight = graph.get_tensor(arg_weight);
398+
vTensorPtr t_bias = graph.get_tensor(arg_bias);
399+
vTensorPtr t_out = graph.get_tensor(out);
400+
const int64_t groups_val = graph.get_int(groups);
401+
402+
std::vector<int64_t> in_sizes = t_in->sizes();
403+
std::vector<int64_t> weight_sizes = t_weight->sizes();
404+
std::vector<int64_t> out_sizes = t_out->sizes();
405+
IntListPtr stride_sizes = graph.get_int_list(stride);
406+
IntListPtr padding_sizes = graph.get_int_list(padding);
407+
IntListPtr dilation_sizes = graph.get_int_list(dilation);
408+
int64_t weight_out_channels = weight_sizes.at(0);
409+
int64_t kernel_size = weight_sizes.at(2);
410+
int64_t in_length = in_sizes.at(2);
411+
412+
VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor");
413+
VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor");
414+
VK_CHECK_COND(
415+
stride_sizes->size() == 1 && stride_sizes->at(0) == 1,
416+
"stride must be 1");
417+
VK_CHECK_COND(
418+
padding_sizes->size() == 1 && padding_sizes->at(0) == 0,
419+
"padding must be 0");
420+
VK_CHECK_COND(
421+
dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1,
422+
"dilation must be 1");
423+
VK_CHECK_COND(
424+
groups_val == in_sizes.at(1), "groups must be equal to in_channels");
425+
VK_CHECK_COND(
426+
groups_val == weight_sizes.at(0),
427+
"groups must be equal to weight_sizes.at(0)");
428+
VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1");
429+
430+
check_conv_args(*t_in, *t_out);
431+
432+
api::utils::uvec3 global_size = {
433+
1, static_cast<uint32_t>(weight_out_channels), 1};
434+
api::utils::uvec3 local_size = {1, 1, 1};
435+
436+
std::string kernel_name("conv1d");
437+
kernel_name.reserve(kShaderNameReserve);
438+
439+
add_dtype_suffix(kernel_name, *t_out);
440+
441+
graph.execute_nodes().emplace_back(new ExecuteNode(
357442
graph,
358-
args[0],
359-
args[1],
360-
args[2],
361-
args[3],
362-
args[4],
363-
args[5],
364-
args[6],
365-
args[7],
366-
args[8],
367-
args[9]);
443+
VK_KERNEL_FROM_STR(kernel_name),
444+
global_size,
445+
local_size,
446+
// Inputs and Outputs
447+
{{out, api::MemoryAccessType::WRITE},
448+
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
449+
// Shader params buffers
450+
{
451+
graph.create_params_buffer(weight_out_channels),
452+
graph.create_params_buffer(in_length),
453+
graph.create_params_buffer(kernel_size),
454+
},
455+
// Resizing
456+
resize_conv1d_node,
457+
{weight}));
458+
}
459+
460+
void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
461+
int64_t in_ndim = graph.get_tensor(args[0])->sizes().size();
462+
if (in_ndim == 4) {
463+
return add_conv2d_node(
464+
graph,
465+
args[0],
466+
args[1],
467+
args[2],
468+
args[3],
469+
args[4],
470+
args[5],
471+
args[6],
472+
args[7],
473+
args[8],
474+
args[9]);
475+
} else {
476+
return add_conv1d_node(
477+
graph,
478+
args[0],
479+
args[1],
480+
args[2],
481+
args[3],
482+
args[4],
483+
args[5],
484+
args[8],
485+
args[9]);
486+
}
368487
}
369488

370489
REGISTER_OPERATORS {
371-
VK_REGISTER_OP(aten.convolution.default, conv2d);
490+
VK_REGISTER_OP(aten.convolution.default, conv);
372491
}
373492

374493
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_pool2d_inputs():
6666
return test_suite
6767

6868

69-
def get_conv2d_inputs():
69+
def get_conv_inputs():
7070
test_suite = VkTestSuite(
7171
[
7272
(
@@ -124,6 +124,17 @@ def get_conv2d_inputs():
124124
[0, 0],
125125
1,
126126
),
127+
(
128+
(1, 6, 7),
129+
(6, 1, 3),
130+
(6,),
131+
[1],
132+
[0],
133+
[1],
134+
False,
135+
[0],
136+
6,
137+
),
127138
]
128139
)
129140
return test_suite
@@ -297,7 +308,7 @@ def get_slice_inputs():
297308
"aten.mul.Tensor": get_binary_elementwise_inputs(),
298309
"aten.mm.default": get_mm_inputs(),
299310
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
300-
"aten.convolution.default": get_conv2d_inputs(),
311+
"aten.convolution.default": get_conv_inputs(),
301312
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
302313
"aten.full.default": get_full_inputs(),
303314
"aten.select.int": get_select_int_inputs(),

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,30 @@ def forward(self, x):
648648
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
649649
)
650650

651+
def test_vulkan_backend_conv1d(self):
652+
class Conv1dModule(torch.nn.Module):
653+
def __init__(self):
654+
super().__init__()
655+
self.conv = torch.nn.Conv1d(
656+
in_channels=6,
657+
out_channels=6,
658+
kernel_size=3,
659+
groups=6,
660+
bias=True,
661+
)
662+
663+
def forward(self, x):
664+
return self.conv(x)
665+
666+
conv1d_module = Conv1dModule()
667+
sample_inputs = (torch.randn(size=(1, 6, 7), dtype=torch.float32),)
668+
669+
self.lower_module_and_test_output(
670+
conv1d_module,
671+
sample_inputs,
672+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
673+
)
674+
651675
def test_vulkan_backend_native_layer_norm(self):
652676
class NativeLayerNormModule(torch.nn.Module):
653677
def __init__(self):

0 commit comments

Comments
 (0)