@@ -111,7 +111,7 @@ kernel void kernel_soft_max(
111
111
uint3 tgpig[[threadgroup_position_in_grid]],
112
112
uint3 tpitg[[thread_position_in_threadgroup]],
113
113
uint3 ntg[[threads_per_threadgroup]]) {
114
- const int64_t i03 = tgpig[2 ];
114
+ const int64_t i03 = tgpig[2 ];
115
115
const int64_t i02 = tgpig[1 ];
116
116
const int64_t i01 = tgpig[0 ];
117
117
@@ -220,27 +220,26 @@ kernel void kernel_norm(
220
220
}
221
221
threadgroup_barrier (mem_flags::mem_threadgroup);
222
222
}
223
- // // broadcast
224
- // if (tpitg == 0) {
225
- // sum[0] /= ne00;
226
- // }
227
- // threadgroup_barrier(mem_flags::mem_threadgroup);
223
+ // broadcast
224
+ if (tpitg == 0 ) {
225
+ sum[0 ] /= ne00;
226
+ }
227
+ threadgroup_barrier (mem_flags::mem_threadgroup);
228
228
const float mean = sum[0 ];
229
229
230
- // recenter and VARIANCE
230
+ // recenter
231
231
device float * y = dst + tgpig*ne00;
232
- sum[tpitg] = 0 .0f ;
233
232
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
234
233
y[i00] = x[i00] - mean;
234
+ }
235
+
236
+ // VARIANCE
237
+ // parallel sum
238
+ sum[tpitg] = 0 .0f ;
239
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
235
240
sum[tpitg] += y[i00] * y[i00];
236
241
}
237
242
238
- // // VARIANCE
239
- // // parallel sum
240
- // sum[tpitg] = 0.0f;
241
- // for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
242
- // sum[tpitg] += y[i00] * y[i00];
243
- // }
244
243
// reduce
245
244
threadgroup_barrier (mem_flags::mem_threadgroup);
246
245
for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
@@ -249,11 +248,11 @@ kernel void kernel_norm(
249
248
}
250
249
threadgroup_barrier (mem_flags::mem_threadgroup);
251
250
}
252
- // // broadcast
253
- // if (tpitg == 0) {
254
- // sum[0] /= ne00;
255
- // }
256
- // threadgroup_barrier(mem_flags::mem_threadgroup);
251
+ // broadcast
252
+ if (tpitg == 0 ) {
253
+ sum[0 ] /= ne00;
254
+ }
255
+ threadgroup_barrier (mem_flags::mem_threadgroup);
257
256
const float variance = sum[0 ];
258
257
259
258
const float scale = 1 .0f /sqrt (variance + eps);
@@ -262,7 +261,6 @@ kernel void kernel_norm(
262
261
}
263
262
}
264
263
265
-
266
264
kernel void kernel_rms_norm (
267
265
device const void * src0,
268
266
device float * dst,
@@ -630,7 +628,6 @@ kernel void kernel_mul_mat_f16_f32(
630
628
}
631
629
}
632
630
}
633
-
634
631
}
635
632
636
633
kernel void kernel_alibi_f32 (
0 commit comments