@@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout);
60
60
// First iteration of reduce will have 32 threads sum up 64 elements.
61
61
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
62
62
// Thus thread utilization starts at 100%.
63
- #define SHARED_MEMORY_FACTOR 2
63
+ #define SHARED_MEMORY_FACTOR 1
64
64
65
- #define offset_pos_index(index) ((index) + ((index) >> 2 ))
65
+ #define offset_pos_index(index) ((index) + ((index) >> 3 ))
66
66
67
67
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
68
68
@@ -154,14 +154,13 @@ void reduce_non_packed_dim() {
154
154
if (all (lessThan (in_pos, out_limits))) {
155
155
in_val = load_texel(t_in, in_pos);
156
156
}
157
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
157
+ mean + = in_val;
158
158
}
159
-
160
- reduce_input(width_stride, shared_idx_offset);
161
- mean += shared_input[offset_pos_index(shared_idx_offset)];
162
159
}
163
160
164
- mean /= width;
161
+ shared_input[offset_pos_index(shared_idx)] = mean;
162
+ reduce_input(width_stride, shared_idx_offset);
163
+ mean = shared_input[offset_pos_index(shared_idx_offset)] / width;
165
164
166
165
memoryBarrierShared();
167
166
barrier();
@@ -178,14 +177,13 @@ void reduce_non_packed_dim() {
178
177
}
179
178
180
179
const VEC4_T delta = in_val - mean;
181
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
180
+ var + = delta * delta;
182
181
}
183
-
184
- reduce_input(width_stride, shared_idx_offset);
185
- var += shared_input[offset_pos_index(shared_idx_offset)];
186
182
}
187
183
188
- var /= width;
184
+ shared_input[offset_pos_index(shared_idx)] = var;
185
+ reduce_input(width_stride, shared_idx_offset);
186
+ var = shared_input[offset_pos_index(shared_idx_offset)] / width;
189
187
190
188
VEC4_T rstd = pow (var + epsilon, VEC4_T(- 0.5 ));
191
189
VEC4_T offset = - rstd * mean;
@@ -226,6 +224,7 @@ void reduce_packed_dim() {
226
224
227
225
const int in_pos_x_limit = out_limits[in_axis_map.x];
228
226
227
+ VEC4_T accum = VEC4_T(0 );
229
228
// Loop over the width in stride increments
230
229
for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
231
230
// Read input in shared memory
@@ -244,20 +243,20 @@ void reduce_packed_dim() {
244
243
in_val.z = mix (in_val.z, T(0 ), remain_inv > 1 );
245
244
in_val.w = mix (in_val.w, T(0 ), remain_inv > 0 );
246
245
}
247
-
248
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
246
+ accum += in_val;
249
247
}
250
-
251
- reduce_input(width_stride, shared_idx_offset);
252
- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253
- mean += val.x + val.y + val.z + val.w;
254
248
}
255
249
256
- mean /= width;
250
+ shared_input[offset_pos_index(shared_idx)] = accum;
251
+ reduce_input(width_stride, shared_idx_offset);
252
+ VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253
+ mean = (val.x + val.y + val.z + val.w) / width;
257
254
258
255
memoryBarrierShared();
259
256
barrier();
260
257
258
+ VEC4_T delta2 = VEC4_T(0 );
259
+
261
260
// Loop over the width in stride increments
262
261
for (int width_offset = 0 ; width_offset <= last_packed_width_index; width_offset += width_stride) {
263
262
// Read input in shared memory
@@ -278,16 +277,14 @@ void reduce_packed_dim() {
278
277
}
279
278
280
279
const VEC4_T delta = in_val - mean;
281
- const VEC4_T delta2 = delta * delta;
282
- shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
280
+ delta2 += delta * delta;
283
281
}
284
-
285
- reduce_input(width_stride, shared_idx_offset);
286
- const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287
- var += val.x + val.y + val.z + val.w;
288
282
}
289
283
290
- var /= width;
284
+ shared_input[offset_pos_index(shared_idx)] = delta2;
285
+ reduce_input(width_stride, shared_idx_offset);
286
+ val = shared_input[offset_pos_index(shared_idx_offset)];
287
+ var = (val.x + val.y + val.z + val.w) / width;
291
288
292
289
T rstd = pow (var + epsilon, T(- 0.5 ));
293
290
T offset = - rstd * mean;
0 commit comments