Skip to content

Commit d5fdbd4

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
update conv1d to new layout specifier gen, axis mapping, and use non-singlethreaded local workgroup (#5504)
Summary: Pull Request resolved: #5504 Using new load_texel_lpos for simpler updating Reviewed By: SS-JIA Differential Revision: D62990822 fbshipit-source-id: 9163b807d9095ebdb089f08aa6ea20fbbb563d02
1 parent 0eee42a commit d5fdbd4

File tree

3 files changed

+39
-46
lines changed

3 files changed

+39
-46
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,22 @@
1818

1919
layout(std430) buffer;
2020

21-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
22-
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
23-
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
24-
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
25-
26-
layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits {
27-
ivec3 out_limits;
28-
};
29-
30-
layout(set = 0, binding = 5) uniform PRECISION restrict InSizes {
31-
ivec4 in_sizes;
32-
};
33-
34-
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
35-
int kernel_size;
36-
int stride;
37-
int padding;
38-
int dilation;
39-
int in_group_size;
40-
int out_group_size;
41-
};
42-
43-
layout(set = 0, binding = 7) uniform PRECISION restrict OutputParams {
44-
float out_min;
45-
float out_max;
46-
};
21+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
22+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
23+
${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)}
24+
${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
25+
26+
${layout_declare_ubo(B, "ivec3", "out_limits")}
27+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
28+
29+
${layout_declare_ubo(B, "ivec4", "out_axis_map")}
30+
${layout_declare_ubo(B, "ivec4", "in_axis_map")}
31+
${layout_declare_ubo(B, "ivec4", "kernel_axis_map")}
32+
${layout_declare_ubo(B, "ivec4", "bias_axis_map")}
33+
34+
${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")}
35+
36+
${layout_declare_ubo(B, "float", "out_min", "float", "out_max")}
4737

4838
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4939

@@ -67,9 +57,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
6757
// shader invocations, where each invocation computes 1 result. But that
6858
// performs worse.
6959
void main() {
70-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
60+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
7161

72-
if (any(greaterThanEqual(pos, out_limits))) {
62+
if (any(greaterThanEqual(lpos, out_limits))) {
7363
return;
7464
}
7565

@@ -78,8 +68,8 @@ void main() {
7868

7969
// "out_c" is the output's channel index where we write our result.
8070
// Across shader invocations, this is the only value that varies.
81-
int out_c = pos.y;
82-
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);
71+
int out_c = lpos.y;
72+
VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
8373

8474
// "in_c" tracks the input's channel start index.
8575
// We iterate over the input group that corresponds to the output group.
@@ -98,7 +88,7 @@ void main() {
9888
int out_l = 0;
9989

10090
for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
101-
vec4 sum = vec4(0);
91+
VEC4_T sum = VEC4_T(0);
10292

10393
for (int in_c = c_start; in_c < c_end; ++in_c) {
10494
// "k" tracks the kernel's index for our input-kernel computation.
@@ -107,25 +97,25 @@ void main() {
10797
for (int k = 0; k < kernel_size; k += 4) {
10898
// Since the weight tensor is width-packed, which is along the length
10999
// dimension, we can batch-read four elements at a time.
110-
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
111-
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
100+
const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c);
101+
const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
112102

113-
const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
114-
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);
103+
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
104+
sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);
115105

116-
const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
117-
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);
106+
in_pos[in_axis_map.x] += dilation;
107+
sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);
118108

119-
const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
120-
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);
109+
in_pos[in_axis_map.x] += dilation;
110+
sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);
121111

122-
const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
123-
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
112+
in_pos[in_axis_map.x] += dilation;
113+
sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
124114
}
125115
}
126116

127-
ivec3 out_pos = ivec3(out_l, out_c, n / 4);
128-
imageStore(image_out, out_pos, op(sum + bias.x, out_min, out_max));
117+
const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
118+
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
129119
}
130120
}
131121
}

backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
conv1d:
88
parameter_names_with_default_values:
99
OPERATOR: X
10-
NDIM: 3
1110
DTYPE: float
12-
PACKING: C_packed
11+
STORAGE: texture3d
1312
generate_variant_forall:
1413
DTYPE:
1514
- VALUE: half

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ void add_conv1d_node(
444444
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
445445

446446
utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
447-
utils::uvec3 local_size = {1, 1, 1};
447+
utils::uvec3 local_size = {1, 64, 1};
448448

449449
Kernel1dParams kernel_params = {
450450
kernel_size,
@@ -476,6 +476,10 @@ void add_conv1d_node(
476476
{
477477
t_out->logical_limits_ubo(),
478478
t_in->sizes_ubo(),
479+
t_out->axis_map_ubo(),
480+
t_in->axis_map_ubo(),
481+
t_weight->axis_map_ubo(),
482+
t_bias->axis_map_ubo(),
479483
graph.create_params_buffer(kernel_params),
480484
graph.create_params_buffer(out_params),
481485
},

0 commit comments

Comments
 (0)