@@ -1288,50 +1288,45 @@ kernel void kernel_norm(
1288
1288
}
1289
1289
1290
1290
kernel void kernel_rms_norm (
1291
- device const void * src0 ,
1292
- device float * dst ,
1293
- constant int64_t & ne00 ,
1294
- constant uint64_t & nb01 ,
1295
- constant float & eps ,
1296
- threadgroup float * buf [[threadgroup( 0 ) ]],
1297
- uint tgpig[[threadgroup_position_in_grid ]],
1298
- uint tpitg[[thread_position_in_threadgroup ]],
1299
- uint sgitg[[simdgroup_index_in_threadgroup]],
1300
- uint tiisg[[thread_index_in_simdgroup]],
1301
- uint ntg[[threads_per_threadgroup]]) {
1302
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
1291
+ constant ggml_metal_kargs_rms_norm & args ,
1292
+ device const char * src0 ,
1293
+ device char * dst ,
1294
+ threadgroup float * shmem_f32 [[threadgroup( 0 )]] ,
1295
+ uint tgpig[[threadgroup_position_in_grid]] ,
1296
+ ushort tpitg[[thread_position_in_threadgroup ]],
1297
+ ushort sgitg[[simdgroup_index_in_threadgroup ]],
1298
+ ushort tiisg[[thread_index_in_simdgroup ]],
1299
+ ushort ntg[[threads_per_threadgroup]]) {
1300
+ if (sgitg == 0 ) {
1301
+ shmem_f32[tiisg] = 0 . 0f ;
1302
+ }
1303
1303
1304
- float4 sumf = 0 ;
1305
- float all_sum = 0 ;
1304
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01 );
1305
+
1306
+ float sumf = 0 .0f ;
1306
1307
1307
1308
// parallel sum
1308
- for (int i00 = tpitg; i00 < ne00/ 4 ; i00 += ntg) {
1309
- sumf += x[i00] * x[i00];
1309
+ for (int i00 = tpitg; i00 < args. ne00_4 ; i00 += ntg) {
1310
+ sumf += dot ( x[i00], x[i00]) ;
1310
1311
}
1311
- all_sum = sumf[0 ] + sumf[1 ] + sumf[2 ] + sumf[3 ];
1312
- all_sum = simd_sum (all_sum);
1313
- if (ntg > N_SIMDWIDTH) {
1314
- if (sgitg == 0 ) {
1315
- buf[tiisg] = 0 .0f ;
1316
- }
1312
+ sumf = simd_sum (sumf);
1317
1313
1318
- threadgroup_barrier (mem_flags::mem_threadgroup);
1314
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1319
1315
1320
- if (tiisg == 0 ) {
1321
- buf [sgitg] = all_sum ;
1322
- }
1316
+ if (tiisg == 0 ) {
1317
+ shmem_f32 [sgitg] = sumf ;
1318
+ }
1323
1319
1324
- threadgroup_barrier (mem_flags::mem_threadgroup);
1320
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1325
1321
1326
- all_sum = buf[tiisg];
1327
- all_sum = simd_sum (all_sum);
1328
- }
1322
+ sumf = shmem_f32[tiisg];
1323
+ sumf = simd_sum (sumf);
1329
1324
1330
- const float mean = all_sum/ ne00;
1331
- const float scale = 1 .0f /sqrt (mean + eps);
1325
+ const float mean = sumf/args. ne00 ;
1326
+ const float scale = 1 .0f /sqrt (mean + args. eps );
1332
1327
1333
- device float4 * y = (device float4 *) ( dst + tgpig*ne00) ;
1334
- for (int i00 = tpitg; i00 < ne00/ 4 ; i00 += ntg) {
1328
+ device float4 * y = (device float4 *) dst + tgpig*args. ne00_4 ;
1329
+ for (int i00 = tpitg; i00 < args. ne00_4 ; i00 += ntg) {
1335
1330
y[i00] = x[i00] * scale;
1336
1331
}
1337
1332
}
0 commit comments