Skip to content

Commit 3bd4519

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
conv1d general case (#3223)
Summary: We port jorgep31415's work of conv1d for lite interpreter into ET. The current implementation supports general batch_size, weight_size, stride, padding, dilation and groups. Differential Revision: D56380147
1 parent 6c30eea commit 3bd4519

File tree

7 files changed

+168
-113
lines changed

7 files changed

+168
-113
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -21,78 +21,91 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
2121
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
2222
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
2323

24-
layout(set = 0, binding = 4) uniform PRECISION restrict Out_channels {
25-
int data;
26-
}
27-
out_channels;
28-
29-
layout(set = 0, binding = 5) uniform PRECISION restrict In_length {
30-
int data;
31-
}
32-
in_length;
33-
34-
layout(set = 0, binding = 6) uniform PRECISION restrict Kernel_size {
35-
int data;
36-
}
37-
kernel_size;
24+
layout(set = 0, binding = 4) uniform PRECISION restrict Params {
25+
int in_length;
26+
int kernel_size;
27+
int stride;
28+
int padding;
29+
int dilation;
30+
int in_group_size;
31+
int out_group_size;
32+
int batch_size;
33+
};
3834

3935
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4036

41-
/*
42-
* This implementation optimize for simplicity (and partially performance) for a
43-
* (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44-
* kernel of the L dimension.
45-
*/
37+
// Let us define
38+
//
39+
// input = (N, in_C, in_L),
40+
// output = (N, out_C, out_L),
41+
// groups = G,
42+
// kernel = K,
43+
//
44+
// which results in shapes
45+
//
46+
// weight = (out_C, in_C / G, K),
47+
// bias = (out_C,).
48+
//
49+
// This implementation performs out_C shader invocations, where each invocation
50+
// calculates the rolling kernel of the length dimension for each batch, i.e.,
51+
// computes out_L * N results.
52+
//
53+
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
54+
// shader invocations, where each invocation computes 1 result. But that
55+
// performs worse.
4656
void main() {
4757
const ivec3 pos = ivec3(gl_GlobalInvocationID);
4858

49-
// The global workgroup should have taken care of it. We only perform one
50-
// work item for each 1d tensor on lengths
51-
if (pos.x >= 1) {
52-
return;
53-
}
54-
55-
int c = pos.y;
56-
if (c >= out_channels.data) {
57-
return;
58-
}
59-
60-
// Assume n = 1, do not handle n > 1 case for now.
61-
int n = pos.z;
62-
if (n >= 1) {
63-
return;
64-
}
65-
66-
vec4 bias = texelFetch(bias_in, ivec3(c, 0, 0), 0);
67-
68-
for (int i = 0; i < in_length.data - kernel_size.data + 1; ++i) {
69-
vec4 v = vec4(0);
70-
for (int k = 0; k < kernel_size.data; ++k) {
71-
const ivec3 in_pos = ivec3(i+k, c, 0);
72-
const vec4 input_value = texelFetch(image_in, in_pos, 0);
73-
74-
// Note that we are reading weight in the inner loop, this could be
75-
// improved by moving it before the outer loop. Since the weight vector is
76-
// contant for the entire call.
77-
78-
// weight in input-space: (c, 0, k);
79-
// notice that c is 4-packed. We need to mod 4 to get the actual weight.
80-
const ivec3 w_pos = ivec3(k, 0, c / 4);
81-
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
82-
83-
float w = weight.x;
84-
if (c % 4 == 1) {
85-
w = weight.y;
86-
} else if (c % 4 == 2) {
87-
w = weight.z;
88-
} else if (c % 4 == 3) {
89-
w = weight.w;
59+
// "out_c" is the output's channel index where we write our result.
60+
// Across shader invocations, this is the only value that varies.
61+
int out_c = pos.y;
62+
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);
63+
64+
// "in_c" tracks the input's channel start index.
65+
// We iterate over the input group that corresponds to the output group.
66+
int c_start = (out_c / out_group_size) * in_group_size;
67+
int c_end = c_start + in_group_size;
68+
69+
// "in_l" tracks the input's length start index for our input-kernel overlay
70+
// region.
71+
int l_start = -padding;
72+
int l_end = in_length + padding - dilation * (kernel_size - 1);
73+
74+
// Since the input/output tensors are channel-packed, which is along the
75+
// batch dimension, we can batch-read/write four elements at a time.
76+
for (int n = 0; n < batch_size; n += 4) {
77+
// "out_l" tracks the output's length index where we write our result.
78+
int out_l = 0;
79+
80+
for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
81+
vec4 sum = vec4(0);
82+
83+
for (int in_c = c_start; in_c < c_end; ++in_c) {
84+
// "k" tracks the kernel's index for our input-kernel computation.
85+
// It reads out-of-bound zeros, but trying to avoid them complicates
86+
// for-loop conditions, which results in worse performance.
87+
for (int k = 0; k < kernel_size; k += 4) {
88+
// Since the weight tensor is width-packed, which is along the length
89+
// dimension, we can batch-read four elements at a time.
90+
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
91+
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
92+
93+
const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
94+
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);
95+
96+
const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
97+
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);
98+
99+
const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
100+
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);
101+
102+
const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
103+
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
104+
}
90105
}
91106

92-
v += w * input_value.x;
107+
ivec3 out_pos = ivec3(out_l, out_c, n / 4);
108+
imageStore(image_out, out_pos, sum + bias.x);
93109
}
94-
95-
ivec3 out_pos = ivec3(i, c, 0);
96-
imageStore(image_out, out_pos, vec4(v.x + bias.x, 0, 0, 0));
97110
}
98111
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ void resize_conv1d_node(
6161
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
6262
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
6363
TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
64+
65+
int64_t stride_size = graph->get_int_list(extra_args[1])->at(0);
66+
int64_t padding_size = graph->get_int_list(extra_args[2])->at(0);
67+
int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0);
68+
6469
const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
6570

6671
const std::vector<int64_t>& in_sizes = self->sizes();
@@ -71,8 +76,9 @@ void resize_conv1d_node(
7176
int64_t in_length = in_sizes.at(2);
7277

7378
new_out_sizes.at(0) = in_sizes.at(0);
74-
new_out_sizes.at(1) = in_sizes.at(1);
75-
new_out_sizes.at(2) = in_length - kernel_size + 1;
79+
new_out_sizes.at(1) = weight_sizes.at(0);
80+
new_out_sizes.at(2) = calc_out_size(
81+
in_length, kernel_size, stride_size, padding_size, dilation_size, false);
7682

7783
out->virtual_resize(new_out_sizes);
7884
}
@@ -244,10 +250,6 @@ ValueRef prepack_weights(
244250
}
245251

246252
void check_conv_args(const vTensor& in, const vTensor& out) {
247-
if (in.sizes().at(0) > 1) {
248-
VK_THROW(
249-
"aten.convolution.default: input batch size > 1 is not supported yet!");
250-
}
251253
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
252254
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
253255
}
@@ -260,7 +262,7 @@ struct Conv2dParams final {
260262
Conv2dParams create_conv2d_params(
261263
ComputeGraph& graph,
262264
const ValueRef weight,
263-
const KernelParams& p,
265+
const KernelParams2D& p,
264266
const bool transposed) {
265267
const auto& overlay_region = api::utils::make_ivec2({
266268
p.kernel_size.data[0] +
@@ -275,7 +277,7 @@ Conv2dParams create_conv2d_params(
275277
return {overlay_region, in_group_size};
276278
}
277279

278-
void check_conv2d_params(const KernelParams& p, const bool transposed) {
280+
void check_conv2d_params(const KernelParams2D& p, const bool transposed) {
279281
if (transposed) {
280282
if (p.dilation.data[0] > 1 || p.dilation.data[1] > 1) {
281283
VK_THROW(
@@ -342,12 +344,15 @@ void add_conv2d_node(
342344

343345
vTensorPtr t_in = graph.get_tensor(arg_in);
344346
vTensorPtr t_out = graph.get_tensor(out);
347+
if (t_in->sizes().at(0) > 1) {
348+
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
349+
}
345350
check_conv_args(*t_in, *t_out);
346351

347352
api::utils::uvec3 global_size = t_out->extents();
348353
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
349354

350-
KernelParams kernel_params = create_kernel_params(
355+
KernelParams2D kernel_params = create_kernel_params(
351356
graph,
352357
weight,
353358
/*kernel_size_only = */ false,
@@ -395,8 +400,7 @@ void add_conv1d_node(
395400
const ValueRef groups,
396401
const ValueRef out) {
397402
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
398-
ValueRef arg_weight =
399-
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
403+
ValueRef arg_weight = prepack_if_tensor_ref(graph, weight, api::kWidthPacked);
400404
ValueRef arg_bias = prepack_biases(
401405
graph,
402406
bias,
@@ -414,37 +418,33 @@ void add_conv1d_node(
414418
std::vector<int64_t> in_sizes = t_in->sizes();
415419
std::vector<int64_t> weight_sizes = t_weight->sizes();
416420
std::vector<int64_t> out_sizes = t_out->sizes();
417-
IntListPtr stride_sizes = graph.get_int_list(stride);
418-
IntListPtr padding_sizes = graph.get_int_list(padding);
419-
IntListPtr dilation_sizes = graph.get_int_list(dilation);
420-
int64_t weight_out_channels = weight_sizes.at(0);
421-
int64_t kernel_size = weight_sizes.at(2);
422-
int64_t in_length = in_sizes.at(2);
423421

424-
VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor");
425-
VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor");
426-
VK_CHECK_COND(
427-
stride_sizes->size() == 1 && stride_sizes->at(0) == 1,
428-
"stride must be 1");
429-
VK_CHECK_COND(
430-
padding_sizes->size() == 1 && padding_sizes->at(0) == 0,
431-
"padding must be 0");
432-
VK_CHECK_COND(
433-
dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1,
434-
"dilation must be 1");
435-
VK_CHECK_COND(
436-
groups_val == in_sizes.at(1), "groups must be equal to in_channels");
437-
VK_CHECK_COND(
438-
groups_val == weight_sizes.at(0),
439-
"groups must be equal to weight_sizes.at(0)");
440-
VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1");
422+
int32_t in_channels = in_sizes.at(1);
423+
int32_t out_channels = weight_sizes.at(0);
424+
int32_t kernel_size = weight_sizes.at(2);
425+
int32_t in_length = in_sizes.at(2);
426+
int32_t stride_size = graph.get_int_list(stride)->at(0);
427+
int32_t padding_size = graph.get_int_list(padding)->at(0);
428+
int32_t dilation_size = graph.get_int_list(dilation)->at(0);
429+
int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
430+
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
431+
int32_t batch_size = in_sizes.at(0);
441432

442433
check_conv_args(*t_in, *t_out);
443434

444-
api::utils::uvec3 global_size = {
445-
1, static_cast<uint32_t>(weight_out_channels), 1};
435+
api::utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
446436
api::utils::uvec3 local_size = {1, 1, 1};
447437

438+
KernelParams1D kernel_params = {
439+
in_length,
440+
kernel_size,
441+
stride_size,
442+
padding_size,
443+
dilation_size,
444+
in_group_size,
445+
out_group_size,
446+
batch_size};
447+
448448
std::string kernel_name("conv1d");
449449
kernel_name.reserve(kShaderNameReserve);
450450

@@ -460,15 +460,13 @@ void add_conv1d_node(
460460
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
461461
// Shader params buffers
462462
{
463-
graph.create_params_buffer(weight_out_channels),
464-
graph.create_params_buffer(in_length),
465-
graph.create_params_buffer(kernel_size),
463+
graph.create_params_buffer(kernel_params),
466464
},
467465
// Specialization Constants
468466
{},
469467
// Resizing Logic
470468
resize_conv1d_node,
471-
{weight}));
469+
{weight, stride, padding, dilation}));
472470
}
473471

474472
void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void add_max_pool2d_node(
7676
std::string kernel_name("max_pool2d");
7777
add_dtype_suffix(kernel_name, *t_out);
7878

79-
KernelParams kernel_params = create_kernel_params(
79+
KernelParams2D kernel_params = create_kernel_params(
8080
graph,
8181
kernel_size,
8282
/*kernel_size_only = */ true,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ api::utils::ivec2 make_ivec2_kernel_size(
2626
}
2727
}
2828

29-
KernelParams create_kernel_params(
29+
KernelParams2D create_kernel_params(
3030
ComputeGraph& graph,
3131
const ValueRef weight,
3232
const bool kernel_size_only,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,40 @@
1616

1717
namespace vkcompute {
1818

19-
struct KernelParams final {
19+
struct KernelParams2D final {
2020
api::utils::ivec2 kernel_size;
2121
api::utils::ivec2 stride;
2222
api::utils::ivec2 padding;
2323
api::utils::ivec2 dilation;
2424
};
2525

26-
KernelParams create_kernel_params(
26+
struct KernelParams1D final {
27+
int in_length;
28+
int kernel_size;
29+
int stride;
30+
int padding;
31+
int dilation;
32+
int in_group_size;
33+
int out_group_size;
34+
int batch_size;
35+
};
36+
37+
KernelParams2D create_kernel_params(
2738
ComputeGraph& graph,
2839
const ValueRef weight,
2940
const bool kernel_size_only,
3041
const ValueRef stride,
3142
const ValueRef padding,
3243
const ValueRef dilation);
3344

45+
int64_t calc_out_size(
46+
const int64_t in_size,
47+
const int64_t kernel_size,
48+
const int64_t stride,
49+
const int64_t padding,
50+
const int64_t dilation,
51+
const bool ceil_mode);
52+
3453
std::vector<int64_t> calc_out_sizes_hw(
3554
ComputeGraph& graph,
3655
const std::vector<int64_t>& in_sizes,

0 commit comments

Comments
 (0)