Skip to content

Commit 1f8210b

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
conv1d general case
Summary: We port jorgep31415's work of conv1d for lite interpreter into ET. The current implementation supports general batch_size, weight_size, stride, padding, dilation and groups. Differential Revision: D56380147
1 parent 9d2af4c commit 1f8210b

File tree

4 files changed

+157
-101
lines changed

4 files changed

+157
-101
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 96 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,78 +21,112 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
2121
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
2222
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
2323

24-
layout(set = 0, binding = 4) uniform PRECISION restrict Out_channels {
25-
int data;
26-
}
27-
out_channels;
24+
layout(set = 0, binding = 4) uniform PRECISION restrict In_length {
25+
int in_length;
26+
};
2827

29-
layout(set = 0, binding = 5) uniform PRECISION restrict In_length {
30-
int data;
31-
}
32-
in_length;
28+
layout(set = 0, binding = 5) uniform PRECISION restrict Kernel_size {
29+
int kernel_size;
30+
};
3331

34-
layout(set = 0, binding = 6) uniform PRECISION restrict Kernel_size {
35-
int data;
36-
}
37-
kernel_size;
32+
layout(set = 0, binding = 6) uniform PRECISION restrict Stride {
33+
int stride;
34+
};
3835

39-
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
36+
layout(set = 0, binding = 7) uniform PRECISION restrict Padding {
37+
int padding;
38+
};
4039

41-
/*
42-
* This implementation optimize for simplicity (and partially performance) for a
43-
* (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44-
* kernel of the L dimension.
45-
*/
46-
void main() {
47-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
40+
layout(set = 0, binding = 8) uniform PRECISION restrict Dilation {
41+
int dilation;
42+
};
4843

49-
// The global workgroup should have taken care of it. We only perform one
50-
// work item for each 1d tensor on lengths
51-
if (pos.x >= 1) {
52-
return;
53-
}
44+
layout(set = 0, binding = 9) uniform PRECISION restrict In_group_size {
45+
int in_group_size;
46+
};
5447

55-
int c = pos.y;
56-
if (c >= out_channels.data) {
57-
return;
58-
}
48+
layout(set = 0, binding = 10) uniform PRECISION restrict Out_group_size {
49+
int out_group_size;
50+
};
5951

60-
// Assume n = 1, do not handle n > 1 case for now.
61-
int n = pos.z;
62-
if (n >= 1) {
63-
return;
64-
}
52+
layout(set = 0, binding = 11) uniform PRECISION restrict Batch_size {
53+
int batch_size;
54+
};
55+
56+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
6557

66-
vec4 bias = texelFetch(bias_in, ivec3(c, 0, 0), 0);
67-
68-
for (int i = 0; i < in_length.data - kernel_size.data + 1; ++i) {
69-
vec4 v = vec4(0);
70-
for (int k = 0; k < kernel_size.data; ++k) {
71-
const ivec3 in_pos = ivec3(i+k, c, 0);
72-
const vec4 input_value = texelFetch(image_in, in_pos, 0);
73-
74-
// Note that we are reading weight in the inner loop, this could be
75-
// improved by moving it before the outer loop. Since the weight vector is
76-
// contant for the entire call.
77-
78-
// weight in input-space: (c, 0, k);
79-
// notice that c is 4-packed. We need to mod 4 to get the actual weight.
80-
const ivec3 w_pos = ivec3(k, 0, c / 4);
81-
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
82-
83-
float w = weight.x;
84-
if (c % 4 == 1) {
85-
w = weight.y;
86-
} else if (c % 4 == 2) {
87-
w = weight.z;
88-
} else if (c % 4 == 3) {
89-
w = weight.w;
58+
// Let us define
59+
//
60+
// input = (N, in_C, in_L),
61+
// output = (N, out_C, out_L),
62+
// groups = G,
63+
// kernel = K,
64+
//
65+
// which results in shapes
66+
//
67+
// weight = (out_C, in_C / G, K),
68+
// bias = (out_C,).
69+
//
70+
// This implementation performs out_C shader invocations, where each invocation
71+
// calculates the rolling kernel of the length dimension for each batch, i.e.,
72+
// computes out_L * N results.
73+
//
74+
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
75+
// shader invocations, where each invocation computes 1 result. But that
76+
// performs worse.
77+
void main() {
78+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
79+
80+
// "out_c" is the output's channel index where we write our result.
81+
// Across shader invocations, this is the only value that varies.
82+
int out_c = pos.y;
83+
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);
84+
85+
// "in_c" tracks the input's channel start index.
86+
// We iterate over the input group that corresponds to the output group.
87+
int c_start = (out_c / out_group_size) * in_group_size;
88+
int c_end = c_start + in_group_size;
89+
90+
// "in_l" tracks the input's length start index for our input-kernel overlay
91+
// region.
92+
int l_start = -padding;
93+
int l_end = in_length + padding - dilation * (kernel_size - 1);
94+
95+
// Since the input/output tensors are channel-packed, which is along the
96+
// batch dimension, we can batch-read/write four elements at a time.
97+
for (int n = 0; n < batch_size; n += 4) {
98+
// "out_l" tracks the output's length index where we write our result.
99+
int out_l = 0;
100+
101+
for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
102+
vec4 sum = vec4(0);
103+
104+
for (int in_c = c_start; in_c < c_end; ++in_c) {
105+
// "k" tracks the kernel's index for our input-kernel computation.
106+
// It reads out-of-bound zeros, but trying to avoid them complicates
107+
// for-loop conditions, which results in worse performance.
108+
for (int k = 0; k < kernel_size; k += 4) {
109+
// Since the weight tensor is width-packed, which is along the length
110+
// dimension, we can batch-read four elements at a time.
111+
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
112+
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
113+
114+
const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
115+
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);
116+
117+
const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
118+
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);
119+
120+
const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
121+
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);
122+
123+
const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
124+
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
125+
}
90126
}
91127

92-
v += w * input_value.x;
128+
ivec3 out_pos = ivec3(out_l, out_c, n / 4);
129+
imageStore(image_out, out_pos, sum + bias.x);
93130
}
94-
95-
ivec3 out_pos = ivec3(i, c, 0);
96-
imageStore(image_out, out_pos, vec4(v.x + bias.x, 0, 0, 0));
97131
}
98132
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ void resize_conv1d_node(
6161
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
6262
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
6363
TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
64+
65+
int64_t stride_size = graph->get_int_list(extra_args[1])->at(0);
66+
int64_t padding_size = graph->get_int_list(extra_args[2])->at(0);
67+
int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0);
68+
6469
const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
6570

6671
const std::vector<int64_t>& in_sizes = self->sizes();
@@ -71,8 +76,11 @@ void resize_conv1d_node(
7176
int64_t in_length = in_sizes.at(2);
7277

7378
new_out_sizes.at(0) = in_sizes.at(0);
74-
new_out_sizes.at(1) = in_sizes.at(1);
75-
new_out_sizes.at(2) = in_length - kernel_size + 1;
79+
new_out_sizes.at(1) = weight_sizes.at(0);
80+
new_out_sizes.at(2) =
81+
(in_length + 2 * padding_size - dilation_size * (kernel_size - 1) - 1) /
82+
stride_size +
83+
1;
7684

7785
out->virtual_resize(new_out_sizes);
7886
}
@@ -244,10 +252,6 @@ ValueRef prepack_weights(
244252
}
245253

246254
void check_conv_args(const vTensor& in, const vTensor& out) {
247-
if (in.sizes().at(0) > 1) {
248-
VK_THROW(
249-
"aten.convolution.default: input batch size > 1 is not supported yet!");
250-
}
251255
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
252256
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
253257
}
@@ -342,6 +346,9 @@ void add_conv2d_node(
342346

343347
vTensorPtr t_in = graph.get_tensor(arg_in);
344348
vTensorPtr t_out = graph.get_tensor(out);
349+
if (t_in->sizes().at(0) > 1) {
350+
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
351+
}
345352
check_conv_args(*t_in, *t_out);
346353

347354
api::utils::uvec3 global_size = t_out->extents();
@@ -395,8 +402,7 @@ void add_conv1d_node(
395402
const ValueRef groups,
396403
const ValueRef out) {
397404
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
398-
ValueRef arg_weight =
399-
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
405+
ValueRef arg_weight = prepack_if_tensor_ref(graph, weight, api::kWidthPacked);
400406
ValueRef arg_bias = prepack_biases(
401407
graph,
402408
bias,
@@ -414,35 +420,21 @@ void add_conv1d_node(
414420
std::vector<int64_t> in_sizes = t_in->sizes();
415421
std::vector<int64_t> weight_sizes = t_weight->sizes();
416422
std::vector<int64_t> out_sizes = t_out->sizes();
417-
IntListPtr stride_sizes = graph.get_int_list(stride);
418-
IntListPtr padding_sizes = graph.get_int_list(padding);
419-
IntListPtr dilation_sizes = graph.get_int_list(dilation);
420-
int64_t weight_out_channels = weight_sizes.at(0);
423+
424+
int64_t in_channels = in_sizes.at(1);
425+
int64_t out_channels = weight_sizes.at(0);
421426
int64_t kernel_size = weight_sizes.at(2);
422427
int64_t in_length = in_sizes.at(2);
423-
424-
VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor");
425-
VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor");
426-
VK_CHECK_COND(
427-
stride_sizes->size() == 1 && stride_sizes->at(0) == 1,
428-
"stride must be 1");
429-
VK_CHECK_COND(
430-
padding_sizes->size() == 1 && padding_sizes->at(0) == 0,
431-
"padding must be 0");
432-
VK_CHECK_COND(
433-
dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1,
434-
"dilation must be 1");
435-
VK_CHECK_COND(
436-
groups_val == in_sizes.at(1), "groups must be equal to in_channels");
437-
VK_CHECK_COND(
438-
groups_val == weight_sizes.at(0),
439-
"groups must be equal to weight_sizes.at(0)");
440-
VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1");
428+
int64_t stride_size = graph.get_int_list(stride)->at(0);
429+
int64_t padding_size = graph.get_int_list(padding)->at(0);
430+
int64_t dilation_size = graph.get_int_list(dilation)->at(0);
431+
int64_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
432+
int64_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
433+
int64_t batch_size = in_sizes.at(0);
441434

442435
check_conv_args(*t_in, *t_out);
443436

444-
api::utils::uvec3 global_size = {
445-
1, static_cast<uint32_t>(weight_out_channels), 1};
437+
api::utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
446438
api::utils::uvec3 local_size = {1, 1, 1};
447439

448440
std::string kernel_name("conv1d");
@@ -460,15 +452,20 @@ void add_conv1d_node(
460452
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
461453
// Shader params buffers
462454
{
463-
graph.create_params_buffer(weight_out_channels),
464455
graph.create_params_buffer(in_length),
465456
graph.create_params_buffer(kernel_size),
457+
graph.create_params_buffer(stride_size),
458+
graph.create_params_buffer(padding_size),
459+
graph.create_params_buffer(dilation_size),
460+
graph.create_params_buffer(in_group_size),
461+
graph.create_params_buffer(out_group_size),
462+
graph.create_params_buffer(batch_size),
466463
},
467464
// Specialization Constants
468465
{},
469466
// Resizing Logic
470467
resize_conv1d_node,
471-
{weight}));
468+
{weight, stride, padding, dilation}));
472469
}
473470

474471
void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/test/op_tests/cases.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ def get_conv_inputs():
135135
[0],
136136
6,
137137
),
138+
(
139+
(2, 20, 30),
140+
(10, 4, 6),
141+
(10,),
142+
[5],
143+
[5],
144+
[3],
145+
False,
146+
[0],
147+
5,
148+
),
138149
(
139150
(1, 9, 11),
140151
(9, 1, 3),
@@ -146,6 +157,17 @@ def get_conv_inputs():
146157
[0],
147158
9,
148159
),
160+
(
161+
(5, 15, 30),
162+
(20, 3, 3),
163+
None,
164+
[3],
165+
[5],
166+
[7],
167+
False,
168+
[0],
169+
5,
170+
),
149171
]
150172
)
151173
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -653,18 +653,21 @@ class Conv1dModule(torch.nn.Module):
653653
def __init__(self):
654654
super().__init__()
655655
self.conv = torch.nn.Conv1d(
656-
in_channels=6,
657-
out_channels=6,
658-
kernel_size=3,
659-
groups=6,
656+
in_channels=20,
657+
out_channels=10,
658+
kernel_size=6,
659+
stride=5,
660+
padding=5,
661+
dilation=3,
662+
groups=5,
660663
bias=True,
661664
)
662665

663666
def forward(self, x):
664667
return self.conv(x)
665668

666669
conv1d_module = Conv1dModule()
667-
sample_inputs = (torch.randn(size=(1, 6, 7), dtype=torch.float32),)
670+
sample_inputs = (torch.randn(size=(3, 20, 30), dtype=torch.float32),)
668671

669672
self.lower_module_and_test_output(
670673
conv1d_module,

0 commit comments

Comments
 (0)