@@ -21,78 +21,112 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21
21
layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D kernel_in;
22
22
layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
23
23
24
- layout (set = 0 , binding = 4 ) uniform PRECISION restrict Out_channels {
25
- int data;
26
- }
27
- out_channels;
24
+ layout (set = 0 , binding = 4 ) uniform PRECISION restrict In_length {
25
+ int in_length;
26
+ };
28
27
29
- layout (set = 0 , binding = 5 ) uniform PRECISION restrict In_length {
30
- int data;
31
- }
32
- in_length;
28
+ layout (set = 0 , binding = 5 ) uniform PRECISION restrict Kernel_size {
29
+ int kernel_size;
30
+ };
33
31
34
- layout (set = 0 , binding = 6 ) uniform PRECISION restrict Kernel_size {
35
- int data;
36
- }
37
- kernel_size;
32
+ layout (set = 0 , binding = 6 ) uniform PRECISION restrict Stride {
33
+ int stride;
34
+ };
38
35
39
- layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
36
+ layout (set = 0 , binding = 7 ) uniform PRECISION restrict Padding {
37
+ int padding;
38
+ };
40
39
41
- /*
42
- * This implementation optimize for simplicity (and partially performance) for a
43
- * (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44
- * kernel of the L dimension.
45
- */
46
- void main() {
47
- const ivec3 pos = ivec3 (gl_GlobalInvocationID);
40
+ layout (set = 0 , binding = 8 ) uniform PRECISION restrict Dilation {
41
+ int dilation;
42
+ };
48
43
49
- // The global workgroup should have taken care of it. We only perform one
50
- // work item for each 1d tensor on lengths
51
- if (pos.x >= 1 ) {
52
- return ;
53
- }
44
+ layout (set = 0 , binding = 9 ) uniform PRECISION restrict In_group_size {
45
+ int in_group_size;
46
+ };
54
47
55
- int c = pos.y;
56
- if (c >= out_channels.data) {
57
- return ;
58
- }
48
+ layout (set = 0 , binding = 10 ) uniform PRECISION restrict Out_group_size {
49
+ int out_group_size;
50
+ };
59
51
60
- // Assume n = 1, do not handle n > 1 case for now.
61
- int n = pos.z ;
62
- if (n >= 1 ) {
63
- return ;
64
- }
52
+ layout (set = 0 , binding = 11 ) uniform PRECISION restrict Batch_size {
53
+ int batch_size ;
54
+ };
55
+
56
+ layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
65
57
66
- vec4 bias = texelFetch(bias_in, ivec3 (c, 0 , 0 ), 0 );
67
-
68
- for (int i = 0 ; i < in_length.data - kernel_size.data + 1 ; ++ i) {
69
- vec4 v = vec4 (0 );
70
- for (int k = 0 ; k < kernel_size.data; ++ k) {
71
- const ivec3 in_pos = ivec3 (i+ k, c, 0 );
72
- const vec4 input_value = texelFetch(image_in, in_pos, 0 );
73
-
74
- // Note that we are reading weight in the inner loop, this could be
75
- // improved by moving it before the outer loop. Since the weight vector is
76
- // contant for the entire call.
77
-
78
- // weight in input-space: (c, 0, k);
79
- // notice that c is 4-packed. We need to mod 4 to get the actual weight.
80
- const ivec3 w_pos = ivec3 (k, 0 , c / 4 );
81
- const vec4 weight = texelFetch(kernel_in, w_pos, 0 );
82
-
83
- float w = weight.x;
84
- if (c % 4 == 1 ) {
85
- w = weight.y;
86
- } else if (c % 4 == 2 ) {
87
- w = weight.z;
88
- } else if (c % 4 == 3 ) {
89
- w = weight.w;
58
+ // Let us define
59
+ //
60
+ // input = (N, in_C, in_L),
61
+ // output = (N, out_C, out_L),
62
+ // groups = G,
63
+ // kernel = K,
64
+ //
65
+ // which results in shapes
66
+ //
67
+ // weight = (out_C, in_C / G, K),
68
+ // bias = (out_C,).
69
+ //
70
+ // This implementation performs out_C shader invocations, where each invocation
71
+ // calculates the rolling kernel of the length dimension for each batch, i.e.,
72
+ // computes out_L * N results.
73
+ //
74
+ // Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
75
+ // shader invocations, where each invocation computes 1 result. But that
76
+ // performs worse.
77
+ void main() {
78
+ const ivec3 pos = ivec3 (gl_GlobalInvocationID);
79
+
80
+ // "out_c" is the output's channel index where we write our result.
81
+ // Across shader invocations, this is the only value that varies.
82
+ int out_c = pos.y;
83
+ vec4 bias = texelFetch(bias_in, ivec3 (out_c, 0 , 0 ), 0 );
84
+
85
+ // "in_c" tracks the input's channel start index.
86
+ // We iterate over the input group that corresponds to the output group.
87
+ int c_start = (out_c / out_group_size) * in_group_size;
88
+ int c_end = c_start + in_group_size;
89
+
90
+ // "in_l" tracks the input's length start index for our input-kernel overlay
91
+ // region.
92
+ int l_start = - padding;
93
+ int l_end = in_length + padding - dilation * (kernel_size - 1 );
94
+
95
+ // Since the input/output tensors are channel-packed, which is along the
96
+ // batch dimension, we can batch-read/write four elements at a time.
97
+ for (int n = 0 ; n < batch_size; n += 4 ) {
98
+ // "out_l" tracks the output's length index where we write our result.
99
+ int out_l = 0 ;
100
+
101
+ for (int in_l = l_start; in_l < l_end; in_l += stride, ++ out_l) {
102
+ vec4 sum = vec4 (0 );
103
+
104
+ for (int in_c = c_start; in_c < c_end; ++ in_c) {
105
+ // "k" tracks the kernel's index for our input-kernel computation.
106
+ // It reads out-of-bound zeros, but trying to avoid them complicates
107
+ // for-loop conditions, which results in worse performance.
108
+ for (int k = 0 ; k < kernel_size; k += 4 ) {
109
+ // Since the weight tensor is width-packed, which is along the length
110
+ // dimension, we can batch-read four elements at a time.
111
+ const ivec3 w_pos = ivec3 (k / 4 , in_c % in_group_size, out_c);
112
+ const vec4 weight = texelFetch(kernel_in, w_pos, 0 );
113
+
114
+ const ivec3 in_pos_0 = ivec3 (in_l + k * dilation, in_c, n / 4 );
115
+ sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0 ), sum);
116
+
117
+ const ivec3 in_pos_1 = ivec3 (in_l + (k+ 1 ) * dilation, in_c, n / 4 );
118
+ sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0 ), sum);
119
+
120
+ const ivec3 in_pos_2 = ivec3 (in_l + (k+ 2 ) * dilation, in_c, n / 4 );
121
+ sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0 ), sum);
122
+
123
+ const ivec3 in_pos_3 = ivec3 (in_l + (k+ 3 ) * dilation, in_c, n / 4 );
124
+ sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0 ), sum);
125
+ }
90
126
}
91
127
92
- v += w * input_value.x;
128
+ ivec3 out_pos = ivec3 (out_l, out_c, n / 4 );
129
+ imageStore(image_out, out_pos, sum + bias.x);
93
130
}
94
-
95
- ivec3 out_pos = ivec3 (i, c, 0 );
96
- imageStore(image_out, out_pos, vec4 (v.x + bias.x, 0 , 0 , 0 ));
97
131
}
98
132
}
0 commit comments