|
14 | 14 |
|
15 | 15 | #define TILE_SIZE ${TILE_SIZE}
|
16 | 16 |
|
| 17 | +#define BATCH_SIZE_X ${BATCH_SIZE_X} |
| 18 | + |
17 | 19 | #define BATCH_SIZE_Y ${BATCH_SIZE_Y}
|
18 | 20 |
|
19 | 21 | #define op(X, A, B) ${OPERATOR}
|
@@ -41,70 +43,79 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
41 | 43 | * output at a single output location.
|
42 | 44 | */
|
43 | 45 | void main() {
|
44 |
| - // y divided up by batch size is used to determine 3d position |
| 46 | + // x and y are divided by batch size to determine 3d position |
45 | 47 | // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
|
46 |
| - const int out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1) / BATCH_SIZE_Y; |
| 48 | + const ivec2 out_limits_xy_scaled = (out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y); |
47 | 49 |
|
48 |
| - u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits_y_scaled); |
| 50 | + ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled.x, out_limits_xy_scaled.y); |
49 | 51 |
|
50 |
| - // scale pos.y by batch size, because that's the top pixel to be processed |
51 |
| - pos.y *= uint16_t(BATCH_SIZE_Y); |
| 52 | + // scale pos.xy by batch sizes, because that's the top pixel to be processed |
| 53 | + pos.x *= BATCH_SIZE_X; |
| 54 | + pos.y *= BATCH_SIZE_Y; |
52 | 55 |
|
53 | 56 | // do not process if top pixel does not fit within the output range
|
54 |
| - if (any(greaterThanEqual(u16vec3(pos.x, pos.y, pos.z), out_limits))) { |
| 57 | + if (any(greaterThanEqual(pos, out_limits))) { |
55 | 58 | return;
|
56 | 59 | }
|
57 | 60 |
|
58 | 61 | // Compute the index of the top-left element of the overlay region. Negative
|
59 | 62 | // indices indicate that the top-left element is in a region added by padding.
|
60 |
| - const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding); |
| 63 | + const ivec2 ipos = pos.xy * stride - padding; |
61 | 64 |
|
62 | 65 | // Compute the start and end of the input indices to load. Padding is assumed
|
63 | 66 | // to be constant 0 padding, so any reads from the padding region is skipped.
|
64 |
| - const u16vec2 start = ipos; |
65 |
| - const u16vec2 end = ipos + u16vec2(overlay_region.xy); |
| 67 | + const ivec2 start = ipos; |
| 68 | + const ivec2 end = ipos + overlay_region.xy; |
66 | 69 |
|
67 | 70 | // sum outputs
|
68 |
| - VEC4_T sum[BATCH_SIZE_Y]; |
| 71 | + VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X]; |
69 | 72 |
|
70 |
| - sum[0] = texelFetch(t_bias, u16vec2(pos.z, 0), 0); |
71 |
| - for (int i = 1; i < BATCH_SIZE_Y; i++) { |
72 |
| - sum[i] = sum[0]; |
| 73 | + sum[0][0] = texelFetch(t_bias, ivec2(pos.z, 0), 0); |
| 74 | + for (int y = 0; y < BATCH_SIZE_Y; y++) { |
| 75 | + for (int x = 0; x < BATCH_SIZE_X; x++) { |
| 76 | + sum[y][x] = sum[0][0]; |
| 77 | + } |
73 | 78 | }
|
74 | 79 |
|
75 | 80 | // array to store input texels
|
76 |
| - VEC4_T in_texels[TILE_SIZE]; |
| 81 | + VEC4_T in_texels[TILE_SIZE + BATCH_SIZE_X - 1]; |
77 | 82 |
|
78 | 83 | // array to store kernel data of previous y
|
79 | 84 | VEC4_T prev_kernel_line[TILE_SIZE];
|
80 | 85 |
|
81 |
| - uint16_t kx = uint16_t(0); |
82 |
| - for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1); y += uint16_t(dilation.y), i++) { |
83 |
| - for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) { |
84 |
| - in_texels[int(j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0); |
| 86 | + int kx = 0; |
| 87 | + for (int y = start.y, i = 0; i < TILE_SIZE + BATCH_SIZE_Y - 1; y += dilation.y, i++) { |
| 88 | + for (int x = start.x, j = 0; j < TILE_SIZE + BATCH_SIZE_X - 1; x += dilation.x, j++) { |
| 89 | + in_texels[j] = texelFetch(t_in, ivec3(x, y, pos.z), 0); |
85 | 90 | }
|
86 | 91 |
|
87 | 92 | // from 2nd iteration onwards accumulate dot product in 2nd sum
|
88 | 93 | // based on kernel line data fetched in previous iteration and input texel from this iteration
|
89 |
| - if (i > uint16_t(0)) { |
90 |
| - for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++) { |
91 |
| - sum[1] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[1]); |
| 94 | + if (i > 0) { |
| 95 | + for (int j = 0; j < TILE_SIZE; j++) { |
| 96 | + for (int s = 0; s < BATCH_SIZE_X; s++) { |
| 97 | + sum[1][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[1][s]); |
| 98 | + } |
92 | 99 | }
|
93 | 100 | }
|
94 | 101 |
|
95 | 102 | // accumulate dot product in 1st sum only until tile size
|
96 |
| - if (i < uint16_t(TILE_SIZE)) { |
97 |
| - for (uint16_t j = uint16_t(0); j < uint16_t(TILE_SIZE); j++, kx++) { |
98 |
| - prev_kernel_line[int(j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0); |
99 |
| - sum[0] = fma(in_texels[int(j)], prev_kernel_line[int(j)], sum[0]); |
| 103 | + if (i < TILE_SIZE) { |
| 104 | + for (int j = 0; j < TILE_SIZE; j++, kx++) { |
| 105 | + prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0); |
| 106 | + for (int s = 0; s < BATCH_SIZE_X; s++) { |
| 107 | + sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]); |
| 108 | + } |
100 | 109 | }
|
101 | 110 | }
|
102 | 111 | }
|
103 | 112 |
|
104 |
| - for (int i = 0; i < BATCH_SIZE_Y; i++) { |
105 |
| - if (any(greaterThanEqual(u16vec3(pos.x, pos.y + i, pos.z), out_limits))) { |
106 |
| - continue; |
| 113 | + for (int y = 0; y < BATCH_SIZE_Y; y++) { |
| 114 | + for (int x = 0; x < BATCH_SIZE_X; x++) { |
| 115 | + if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) { |
| 116 | + continue; |
| 117 | + } |
| 118 | + imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max)); |
107 | 119 | }
|
108 |
| - imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max)); |
109 | 120 | }
|
110 | 121 | }
|
0 commit comments