Skip to content

Commit 735f16e

Browse files
authored
update batch norm to use layout gen
Differential Revision: D69937208 Pull Request resolved: #8600
1 parent a454be5 commit 735f16e

File tree

2 files changed

+17
-22
lines changed

2 files changed

+17
-22
lines changed

backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,18 @@
1313

1414
layout(std430) buffer;
1515

16-
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
17-
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
18-
layout(set = 0, binding = 2) uniform PRECISION sampler3D weight_in;
19-
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
20-
layout(set = 0, binding = 4) uniform PRECISION sampler3D mean_in;
21-
layout(set = 0, binding = 5) uniform PRECISION sampler3D var_in;
16+
#include "indexing_utils.h"
2217

23-
layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits {
24-
ivec3 out_limits;
25-
};
18+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
19+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
20+
${layout_declare_tensor(B, "r", "weight_in", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
22+
${layout_declare_tensor(B, "r", "mean_in", DTYPE, STORAGE)}
23+
${layout_declare_tensor(B, "r", "var_in", DTYPE, STORAGE)}
2624

27-
layout(set = 0, binding = 7) uniform PRECISION restrict Params {
28-
float eps;
29-
};
30-
31-
layout(set = 0, binding = 8) uniform PRECISION restrict Params2 {
32-
int num_texel_per_batch;
33-
};
25+
${layout_declare_ubo(B, "ivec3", "out_limits")}
26+
${layout_declare_ubo(B, "float", "eps")}
27+
${layout_declare_ubo(B, "int", "num_texel_per_batch")}
3428

3529
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3630

@@ -40,16 +34,16 @@ void main() {
4034
return;
4135
}
4236

43-
VEC4_T v = VEC4_T(texelFetch(image_in, pos, 0));
37+
VEC4_T v = VEC4_T(load_texel(t_in, pos));
4438

4539
ivec3 param_pos = ivec3(pos.z % num_texel_per_batch, 0, 0);
4640

47-
VEC4_T weight = VEC4_T(texelFetch(weight_in, param_pos, 0));
48-
VEC4_T bias = VEC4_T(texelFetch(bias_in, param_pos, 0));
49-
VEC4_T mean = VEC4_T(texelFetch(mean_in, param_pos, 0));
50-
VEC4_T var = VEC4_T(texelFetch(var_in, param_pos, 0));
41+
VEC4_T weight = VEC4_T(load_texel(weight_in, param_pos));
42+
VEC4_T bias = VEC4_T(load_texel(bias_in, param_pos));
43+
VEC4_T mean = VEC4_T(load_texel(mean_in, param_pos));
44+
VEC4_T var = VEC4_T(load_texel(var_in, param_pos));
5145

5246
v = ((v - mean) / sqrt(var + eps)) * weight + bias;
5347

54-
imageStore(image_out, pos, v);
48+
write_texel(t_out, pos, v);
5549
}

backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ batchnorm:
22
parameter_names_with_default_values:
33
DTYPE: float
44
NDIM: 3
5+
STORAGE: texture3d
56
generate_variant_forall:
67
DTYPE:
78
- VALUE: half

0 commit comments

Comments
 (0)