@@ -1236,53 +1236,68 @@ kernel void kernel_ssm_scan_f32(
1236
1236
}
1237
1237
1238
1238
kernel void kernel_norm (
1239
- device const void * src0,
1240
- device float * dst,
1241
- constant int64_t & ne00,
1242
- constant uint64_t & nb01,
1243
- constant float & eps,
1244
- threadgroup float * sum [[threadgroup(0 )]],
1245
- uint tgpig[[threadgroup_position_in_grid]],
1246
- uint tpitg[[thread_position_in_threadgroup]],
1247
- uint ntg[[threads_per_threadgroup]]) {
1248
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
1249
- // MEAN
1250
- // parallel sum
1251
- sum[tpitg] = 0 .0f ;
1252
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1253
- sum[tpitg] += x[i00];
1239
+ constant ggml_metal_kargs_norm & args,
1240
+ device const char * src0,
1241
+ device char * dst,
1242
+ threadgroup float * shmem_f32 [[threadgroup(0 )]],
1243
+ uint tgpig[[threadgroup_position_in_grid]],
1244
+ ushort tpitg[[thread_position_in_threadgroup]],
1245
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1246
+ ushort tiisg[[thread_index_in_simdgroup]],
1247
+ ushort ntg[[threads_per_threadgroup]]) {
1248
+ if (sgitg == 0 ) {
1249
+ shmem_f32[tiisg] = 0 .0f ;
1254
1250
}
1255
- // reduce
1251
+
1252
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01 );
1253
+
1254
+ float4 sumf4 (0 .0f );
1255
+
1256
+ float sumf = 0 .0f ;
1257
+
1258
+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1259
+ sumf4 += x[i00];
1260
+ }
1261
+ sumf = sumf4[0 ] + sumf4[1 ] + sumf4[2 ] + sumf4[3 ];
1262
+ sumf = simd_sum (sumf);
1263
+
1256
1264
threadgroup_barrier (mem_flags::mem_threadgroup);
1257
- for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
1258
- if (tpitg < i) {
1259
- sum[tpitg] += sum[tpitg + i];
1260
- }
1261
- threadgroup_barrier (mem_flags::mem_threadgroup);
1265
+
1266
+ if (tiisg == 0 ) {
1267
+ shmem_f32[sgitg] = sumf;
1262
1268
}
1263
- const float mean = sum[0 ] / ne00;
1264
1269
1265
- // recenter and VARIANCE
1266
1270
threadgroup_barrier (mem_flags::mem_threadgroup);
1267
- device float * y = dst + tgpig*ne00;
1268
- sum[tpitg] = 0 .0f ;
1269
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1271
+
1272
+ sumf = shmem_f32[tiisg];
1273
+ sumf = simd_sum (sumf);
1274
+
1275
+ const float mean = sumf/args.ne00 ;
1276
+
1277
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4 ;
1278
+
1279
+ sumf = 0 .0f ;
1280
+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1270
1281
y[i00] = x[i00] - mean;
1271
- sum[tpitg] += y[i00] * y[i00];
1282
+ sumf += dot ( y[i00], y[i00]) ;
1272
1283
}
1284
+ sumf = simd_sum (sumf);
1273
1285
1274
- // reduce
1275
1286
threadgroup_barrier (mem_flags::mem_threadgroup);
1276
- for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
1277
- if (tpitg < i) {
1278
- sum[tpitg] += sum[tpitg + i];
1279
- }
1280
- threadgroup_barrier (mem_flags::mem_threadgroup);
1287
+
1288
+ if (tiisg == 0 ) {
1289
+ shmem_f32[sgitg] = sumf;
1281
1290
}
1282
- const float variance = sum[0 ] / ne00;
1283
1291
1284
- const float scale = 1 .0f /sqrt (variance + eps);
1285
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1292
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1293
+
1294
+ sumf = shmem_f32[tiisg];
1295
+ sumf = simd_sum (sumf);
1296
+
1297
+ const float variance = sumf/args.ne00 ;
1298
+
1299
+ const float scale = 1 .0f /sqrt (variance + args.eps );
1300
+ for (int i00 = tpitg; i00 < args.ne00_4 ; i00 += ntg) {
1286
1301
y[i00] = y[i00] * scale;
1287
1302
}
1288
1303
}
0 commit comments