@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
256
256
(device float *) ((device char *) dst + i*nb1), ne00);
257
257
}
258
258
259
+ kernel void kernel_norm (
260
+ device const void * src0,
261
+ device float * dst,
262
+ constant int64_t & ne00,
263
+ constant uint64_t & nb01,
264
+ constant float & eps,
265
+ threadgroup float * sum [[threadgroup(0 )]],
266
+ uint tgpig[[threadgroup_position_in_grid]],
267
+ uint tpitg[[thread_position_in_threadgroup]],
268
+ uint ntg[[threads_per_threadgroup]]) {
269
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
270
+ // MEAN
271
+ // parallel sum
272
+ sum[tpitg] = 0 .0f ;
273
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
274
+ sum[tpitg] += x[i00];
275
+ }
276
+ // reduce
277
+ threadgroup_barrier (mem_flags::mem_threadgroup);
278
+ for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
279
+ if (tpitg < i) {
280
+ sum[tpitg] += sum[tpitg + i];
281
+ }
282
+ threadgroup_barrier (mem_flags::mem_threadgroup);
283
+ }
284
+ // broadcast
285
+ if (tpitg == 0 ) {
286
+ sum[0 ] /= ne00;
287
+ }
288
+ threadgroup_barrier (mem_flags::mem_threadgroup);
289
+ const float mean = sum[0 ];
290
+
291
+ // recenter
292
+ device float * y = dst + tgpig*ne00;
293
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
294
+ y[i00] = x[i00] - mean;
295
+ }
296
+
297
+ // VARIANCE
298
+ // parallel sum
299
+ sum[tpitg] = 0 .0f ;
300
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
301
+ sum[tpitg] += y[i00] * y[i00];
302
+ }
303
+ // reduce
304
+ threadgroup_barrier (mem_flags::mem_threadgroup);
305
+ for (uint i = ntg/2 ; i > 0 ; i /= 2 ) {
306
+ if (tpitg < i) {
307
+ sum[tpitg] += sum[tpitg + i];
308
+ }
309
+ threadgroup_barrier (mem_flags::mem_threadgroup);
310
+ }
311
+ // broadcast
312
+ if (tpitg == 0 ) {
313
+ sum[0 ] /= ne00;
314
+ }
315
+ threadgroup_barrier (mem_flags::mem_threadgroup);
316
+ const float variance = sum[0 ];
317
+
318
+ const float scale = 1 .0f /sqrt (variance + eps);
319
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
320
+ y[i00] = y[i00] * scale;
321
+ }
322
+ }
323
+
324
+
259
325
kernel void kernel_rms_norm (
260
326
device const void * src0,
261
327
device float * dst,
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
485
551
}
486
552
}
487
553
554
+ kernel void kernel_alibi_f32 (
555
+ device const float * src0,
556
+ device float * dst,
557
+ constant int64_t & ne00,
558
+ constant int64_t & ne01,
559
+ constant int64_t & ne02,
560
+ constant int64_t & ne03,
561
+ constant uint64_t & nb00,
562
+ constant uint64_t & nb01,
563
+ constant uint64_t & nb02,
564
+ constant uint64_t & nb03,
565
+ constant int64_t & ne0,
566
+ constant int64_t & ne1,
567
+ constant int64_t & ne2,
568
+ constant int64_t & ne3,
569
+ constant uint64_t & nb0,
570
+ constant uint64_t & nb1,
571
+ constant uint64_t & nb2,
572
+ constant uint64_t & nb3,
573
+ constant float & m0,
574
+ uint3 tgpig[[threadgroup_position_in_grid]],
575
+ uint3 tpitg[[thread_position_in_threadgroup]],
576
+ uint3 ntg[[threads_per_threadgroup]]) {
577
+ const int64_t i03 = tgpig[2 ];
578
+ const int64_t i02 = tgpig[1 ];
579
+ const int64_t i01 = tgpig[0 ];
580
+
581
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
582
+
583
+ const int64_t i3 = n / (ne2*ne1*ne0);
584
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
585
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
586
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
587
+
588
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
589
+ float m_k = pow (m0, i2 + 1 );
590
+ for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
591
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
592
+ dst_data[i00] = src[0 ] + m_k * (i00 - ne00 + 1 );
593
+ }
594
+ }
595
+
488
596
kernel void kernel_rope (
489
597
device const void * src0,
490
598
device float * dst,
@@ -540,6 +648,47 @@ kernel void kernel_rope(
540
648
}
541
649
}
542
650
651
+ kernel void kernel_cpy_f16_f16 (
652
+ device const half * src0,
653
+ device half * dst,
654
+ constant int64_t & ne00,
655
+ constant int64_t & ne01,
656
+ constant int64_t & ne02,
657
+ constant int64_t & ne03,
658
+ constant uint64_t & nb00,
659
+ constant uint64_t & nb01,
660
+ constant uint64_t & nb02,
661
+ constant uint64_t & nb03,
662
+ constant int64_t & ne0,
663
+ constant int64_t & ne1,
664
+ constant int64_t & ne2,
665
+ constant int64_t & ne3,
666
+ constant uint64_t & nb0,
667
+ constant uint64_t & nb1,
668
+ constant uint64_t & nb2,
669
+ constant uint64_t & nb3,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint3 tpitg[[thread_position_in_threadgroup]],
672
+ uint3 ntg[[threads_per_threadgroup]]) {
673
+ const int64_t i03 = tgpig[2 ];
674
+ const int64_t i02 = tgpig[1 ];
675
+ const int64_t i01 = tgpig[0 ];
676
+
677
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
678
+
679
+ const int64_t i3 = n / (ne2*ne1*ne0);
680
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
681
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
682
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
683
+
684
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
685
+
686
+ for (int64_t i00 = tpitg.x ; i00 < ne00; i00 += ntg.x ) {
687
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
688
+ dst_data[i00] = src[0 ];
689
+ }
690
+ }
691
+
543
692
kernel void kernel_cpy_f32_f16 (
544
693
device const float * src0,
545
694
device half * dst,
0 commit comments