20
20
21
21
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
22
22
23
+ #define LOCAL_WG_SIZE 64
24
+
23
25
#define op(X, A, B) ${OPERATOR}
24
26
25
27
#include "indexing_utils.h"
@@ -38,6 +40,11 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
38
40
39
41
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
40
42
43
+ // For performance improvement, reduce register usage by caching positions in shared memory.
44
+ // Offset index by 1 every 16 points to avoid bank access conflict.
45
+ #define offset_pos_index(index) (index + ((index) >> 4 ))
46
+ shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)];
47
+
41
48
/*
42
49
* Computes a depthwise convolution. Each shader invocation calculates the
43
50
* output at a single output location.
@@ -63,6 +70,8 @@ void main() {
63
70
return ;
64
71
}
65
72
73
+ pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos;
74
+
66
75
// Compute the index of the top-left element of the overlay region. Negative
67
76
// indices indicate that the top-left element is in a region added by padding.
68
77
const ivec2 ipos = pos.xy * stride - padding;
@@ -109,18 +118,19 @@ void main() {
109
118
for (int j = 0 ; j < TILE_SIZE; j++ , kx++ ) {
110
119
prev_kernel_line[j] = texelFetch(t_kernel, ivec2 (kx, pos.z), 0 );
111
120
for (int s = 0 ; s < BATCH_SIZE_X; s++ ) {
112
- sum[0 ][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0 ][s]);
121
+ sum[0 ][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0 ][s]);
113
122
}
114
123
}
115
124
}
116
125
}
117
126
127
+ const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)];
118
128
for (int y = 0 ; y < BATCH_SIZE_Y; y++ ) {
119
129
for (int x = 0 ; x < BATCH_SIZE_X; x++ ) {
120
- if (any (greaterThanEqual (ivec3 (pos .x + x, pos .y + y, pos .z), out_limits))) {
130
+ if (any (greaterThanEqual (ivec3 (out_pos .x + x, out_pos .y + y, out_pos .z), out_limits))) {
121
131
continue ;
122
132
}
123
- imageStore(t_out, ivec3 (pos .x + x, pos .y + y, pos .z), op(sum[y][x], out_min, out_max));
133
+ imageStore(t_out, ivec3 (out_pos .x + x, out_pos .y + y, out_pos .z), op(sum[y][x], out_min, out_max));
124
134
}
125
135
}
126
136
}
0 commit comments