18
18
19
19
layout (std430) buffer ;
20
20
21
- layout (set = 0 , binding = 0 , ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
22
- layout (set = 0 , binding = 1 ) uniform PRECISION sampler3D image_in;
23
- layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D kernel_in;
24
- layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
25
-
26
- layout (set = 0 , binding = 4 ) uniform PRECISION restrict OutLimits {
27
- ivec3 out_limits;
28
- };
29
-
30
- layout (set = 0 , binding = 5 ) uniform PRECISION restrict InSizes {
31
- ivec4 in_sizes;
32
- };
33
-
34
- layout (set = 0 , binding = 6 ) uniform PRECISION restrict Params {
35
- int kernel_size;
36
- int stride;
37
- int padding;
38
- int dilation;
39
- int in_group_size;
40
- int out_group_size;
41
- };
42
-
43
- layout (set = 0 , binding = 7 ) uniform PRECISION restrict OutputParams {
44
- float out_min;
45
- float out_max;
46
- };
21
+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
22
+ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
23
+ ${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)}
24
+ ${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
25
+
26
+ ${layout_declare_ubo(B, "ivec3 ", "out_limits")}
27
+ ${layout_declare_ubo(B, "ivec4 ", "in_sizes")}
28
+
29
+ ${layout_declare_ubo(B, "ivec4 ", "out_axis_map")}
30
+ ${layout_declare_ubo(B, "ivec4 ", "in_axis_map")}
31
+ ${layout_declare_ubo(B, "ivec4 ", "kernel_axis_map")}
32
+ ${layout_declare_ubo(B, "ivec4 ", "bias_axis_map")}
33
+
34
+ ${layout_declare_ubo(B,"int ", "kernel_size", "int ", "stride", "int ", "padding", "int ", "dilation", "int ", "in_group_size", "int ", "out_group_size")}
35
+
36
+ ${layout_declare_ubo(B, "float ", "out_min", "float ", "out_max")}
47
37
48
38
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
49
39
@@ -67,9 +57,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
67
57
// shader invocations, where each invocation computes 1 result. But that
68
58
// performs worse.
69
59
void main() {
70
- const ivec3 pos = ivec3 (gl_GlobalInvocationID);
60
+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
71
61
72
- if (any (greaterThanEqual (pos , out_limits))) {
62
+ if (any (greaterThanEqual (lpos , out_limits))) {
73
63
return ;
74
64
}
75
65
@@ -78,8 +68,8 @@ void main() {
78
68
79
69
// "out_c" is the output's channel index where we write our result.
80
70
// Across shader invocations, this is the only value that varies.
81
- int out_c = pos .y;
82
- vec4 bias = texelFetch (bias_in, ivec3 (out_c, 0 , 0 ), 0 );
71
+ int out_c = lpos .y;
72
+ VEC4_T bias = load_texel_lpos (bias_in, ivec3 (out_c, 0 , 0 ), bias_axis_map );
83
73
84
74
// "in_c" tracks the input's channel start index.
85
75
// We iterate over the input group that corresponds to the output group.
@@ -98,7 +88,7 @@ void main() {
98
88
int out_l = 0 ;
99
89
100
90
for (int in_l = l_start; in_l < l_end; in_l += stride, ++ out_l) {
101
- vec4 sum = vec4 (0 );
91
+ VEC4_T sum = VEC4_T (0 );
102
92
103
93
for (int in_c = c_start; in_c < c_end; ++ in_c) {
104
94
// "k" tracks the kernel's index for our input-kernel computation.
@@ -107,25 +97,25 @@ void main() {
107
97
for (int k = 0 ; k < kernel_size; k += 4 ) {
108
98
// Since the weight tensor is width-packed, which is along the length
109
99
// dimension, we can batch-read four elements at a time.
110
- const ivec3 w_pos = ivec3 (k / 4 , in_c % in_group_size, out_c);
111
- const vec4 weight = texelFetch (kernel_in, w_pos, 0 );
100
+ const ivec3 w_lpos = ivec3 (k / 4 , in_c % in_group_size, out_c);
101
+ const VEC4_T weight = load_texel_lpos (kernel_in, w_lpos, kernel_axis_map );
112
102
113
- const ivec3 in_pos_0 = ivec3 (in_l + k * dilation, in_c, n / 4 );
114
- sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0 ), sum);
103
+ ivec3 in_pos = lpos_to_pos( ivec3 (in_l + k * dilation, in_c, n / 4 ), in_axis_map );
104
+ sum = fma(weight.xxxx, load_texel(t_in, in_pos ), sum);
115
105
116
- const ivec3 in_pos_1 = ivec3 (in_l + (k + 1 ) * dilation, in_c, n / 4 ) ;
117
- sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0 ), sum);
106
+ in_pos[in_axis_map.x] += dilation;
107
+ sum = fma(weight.yyyy, load_texel(t_in, in_pos ), sum);
118
108
119
- const ivec3 in_pos_2 = ivec3 (in_l + (k + 2 ) * dilation, in_c, n / 4 ) ;
120
- sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0 ), sum);
109
+ in_pos[in_axis_map.x] += dilation;
110
+ sum = fma(weight.zzzz, load_texel(t_in, in_pos ), sum);
121
111
122
- const ivec3 in_pos_3 = ivec3 (in_l + (k + 3 ) * dilation, in_c, n / 4 ) ;
123
- sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0 ), sum);
112
+ in_pos[in_axis_map.x] += dilation;
113
+ sum = fma(weight.wwww, load_texel(t_in, in_pos ), sum);
124
114
}
125
115
}
126
116
127
- ivec3 out_pos = ivec3 (out_l, out_c, n / 4 );
128
- imageStore(image_out, out_pos , op(sum + bias.x, out_min, out_max));
117
+ const ivec3 out_lpos = ivec3 (out_l, out_c, n / 4 );
118
+ write_texel_lpos(t_out, out_lpos , op(sum + bias.x, out_min, out_max), out_axis_map );
129
119
}
130
120
}
131
121
}
0 commit comments