13
13
14
14
layout (std430) buffer ;
15
15
16
- layout (set = 0 , binding = 0 , ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
17
- layout (set = 0 , binding = 1 ) uniform PRECISION sampler3D image_in;
18
- layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D weight_in;
19
- layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
20
- layout (set = 0 , binding = 4 ) uniform PRECISION sampler3D mean_in;
21
- layout (set = 0 , binding = 5 ) uniform PRECISION sampler3D var_in;
16
+ #include "indexing_utils.h"
22
17
23
- layout (set = 0 , binding = 6 ) uniform PRECISION restrict OutLimits {
24
- ivec3 out_limits;
25
- };
18
+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
19
+ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
20
+ ${layout_declare_tensor(B, "r", "weight_in", DTYPE, STORAGE)}
21
+ ${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
22
+ ${layout_declare_tensor(B, "r", "mean_in", DTYPE, STORAGE)}
23
+ ${layout_declare_tensor(B, "r", "var_in", DTYPE, STORAGE)}
26
24
27
- layout (set = 0 , binding = 7 ) uniform PRECISION restrict Params {
28
- float eps;
29
- };
30
-
31
- layout (set = 0 , binding = 8 ) uniform PRECISION restrict Params2 {
32
- int num_texel_per_batch;
33
- };
25
+ ${layout_declare_ubo(B, "ivec3 ", "out_limits")}
26
+ ${layout_declare_ubo(B, "float ", "eps")}
27
+ ${layout_declare_ubo(B, "int ", "num_texel_per_batch")}
34
28
35
29
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
36
30
@@ -40,16 +34,16 @@ void main() {
40
34
return ;
41
35
}
42
36
43
- VEC4_T v = VEC4_T(texelFetch(image_in , pos, 0 ));
37
+ VEC4_T v = VEC4_T(load_texel(t_in , pos));
44
38
45
39
ivec3 param_pos = ivec3 (pos.z % num_texel_per_batch, 0 , 0 );
46
40
47
- VEC4_T weight = VEC4_T(texelFetch (weight_in, param_pos, 0 ));
48
- VEC4_T bias = VEC4_T(texelFetch (bias_in, param_pos, 0 ));
49
- VEC4_T mean = VEC4_T(texelFetch (mean_in, param_pos, 0 ));
50
- VEC4_T var = VEC4_T(texelFetch (var_in, param_pos, 0 ));
41
+ VEC4_T weight = VEC4_T(load_texel (weight_in, param_pos));
42
+ VEC4_T bias = VEC4_T(load_texel (bias_in, param_pos));
43
+ VEC4_T mean = VEC4_T(load_texel (mean_in, param_pos));
44
+ VEC4_T var = VEC4_T(load_texel (var_in, param_pos));
51
45
52
46
v = ((v - mean) / sqrt (var + eps)) * weight + bias;
53
47
54
- imageStore(image_out , pos, v);
48
+ write_texel(t_out , pos, v);
55
49
}
0 commit comments