Skip to content

Commit 5202b43

Browse files
authored
Minor changes to native layer norm shader op to improve perf. (#10585)
Summary: This diff improves perf by changing native layer norm shader to accumulate result in local variable instead of shared memory, and do a shared memory pass later. Differential Revision: D73864950
1 parent 8ffdea1 commit 5202b43

File tree

1 file changed

+23
-26
lines changed

1 file changed

+23
-26
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout);
6060
// First iteration of reduce will have 32 threads sum up 64 elements.
6161
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
6262
// Thus thread utilization starts at 100%.
63-
#define SHARED_MEMORY_FACTOR 2
63+
#define SHARED_MEMORY_FACTOR 1
6464

65-
#define offset_pos_index(index) ((index) + ((index) >> 2))
65+
#define offset_pos_index(index) ((index) + ((index) >> 3))
6666

6767
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
6868

@@ -154,14 +154,13 @@ void reduce_non_packed_dim() {
154154
if (all(lessThan(in_pos, out_limits))) {
155155
in_val = load_texel(t_in, in_pos);
156156
}
157-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
157+
mean += in_val;
158158
}
159-
160-
reduce_input(width_stride, shared_idx_offset);
161-
mean += shared_input[offset_pos_index(shared_idx_offset)];
162159
}
163160

164-
mean /= width;
161+
shared_input[offset_pos_index(shared_idx)] = mean;
162+
reduce_input(width_stride, shared_idx_offset);
163+
mean = shared_input[offset_pos_index(shared_idx_offset)] / width;
165164

166165
memoryBarrierShared();
167166
barrier();
@@ -178,14 +177,13 @@ void reduce_non_packed_dim() {
178177
}
179178

180179
const VEC4_T delta = in_val - mean;
181-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
180+
var += delta * delta;
182181
}
183-
184-
reduce_input(width_stride, shared_idx_offset);
185-
var += shared_input[offset_pos_index(shared_idx_offset)];
186182
}
187183

188-
var /= width;
184+
shared_input[offset_pos_index(shared_idx)] = var;
185+
reduce_input(width_stride, shared_idx_offset);
186+
var = shared_input[offset_pos_index(shared_idx_offset)] / width;
189187

190188
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
191189
VEC4_T offset = -rstd * mean;
@@ -226,6 +224,7 @@ void reduce_packed_dim() {
226224

227225
const int in_pos_x_limit = out_limits[in_axis_map.x];
228226

227+
VEC4_T accum = VEC4_T(0);
229228
// Loop over the width in stride increments
230229
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
231230
// Read input in shared memory
@@ -244,20 +243,20 @@ void reduce_packed_dim() {
244243
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
245244
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
246245
}
247-
248-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
246+
accum += in_val;
249247
}
250-
251-
reduce_input(width_stride, shared_idx_offset);
252-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253-
mean += val.x + val.y + val.z + val.w;
254248
}
255249

256-
mean /= width;
250+
shared_input[offset_pos_index(shared_idx)] = accum;
251+
reduce_input(width_stride, shared_idx_offset);
252+
VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+
mean = (val.x + val.y + val.z + val.w) / width;
257254

258255
memoryBarrierShared();
259256
barrier();
260257

258+
VEC4_T delta2 = VEC4_T(0);
259+
261260
// Loop over the width in stride increments
262261
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
263262
// Read input in shared memory
@@ -278,16 +277,14 @@ void reduce_packed_dim() {
278277
}
279278

280279
const VEC4_T delta = in_val - mean;
281-
const VEC4_T delta2 = delta * delta;
282-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
280+
delta2 += delta * delta;
283281
}
284-
285-
reduce_input(width_stride, shared_idx_offset);
286-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287-
var += val.x + val.y + val.z + val.w;
288282
}
289283

290-
var /= width;
284+
shared_input[offset_pos_index(shared_idx)] = delta2;
285+
reduce_input(width_stride, shared_idx_offset);
286+
val = shared_input[offset_pos_index(shared_idx_offset)];
287+
var = (val.x + val.y + val.z + val.w) / width;
291288

292289
T rstd = pow(var + epsilon, T(-0.5));
293290
T offset = -rstd * mean;

0 commit comments

Comments
 (0)