@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
528
528
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
529
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
530
531
- sum[tpitg.x ] = 0 .0f ;
531
+ uint ith = tpitg.x ;
532
+ uint nth = tptg.x ;
532
533
533
- for (int i = tpitg.x ; i < ne00; i += tptg.x ) {
534
- sum[tpitg.x ] += (float ) x[i] * (float ) y[i];
534
+ sum[ith] = 0 .0f ;
535
+
536
+ for (int i = ith; i < ne00; i += nth) {
537
+ sum[ith] += (float ) x[i] * (float ) y[i];
535
538
}
536
539
537
540
// accumulate the sum from all threads in the threadgroup
538
541
threadgroup_barrier (mem_flags::mem_threadgroup);
539
- for (uint i = tptg.x /2 ; i > 0 ; i /= 2 ) {
540
- if (tpitg.x < i) {
541
- sum[tpitg.x ] += sum[tpitg.x + i];
542
- }
543
- threadgroup_barrier (mem_flags::mem_threadgroup);
542
+ if (ith%4 == 0 ) {
543
+ for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
544
544
}
545
-
546
- if (tpitg.x == 0 ) {
545
+ threadgroup_barrier (mem_flags::mem_threadgroup);
546
+ if (ith%16 == 0 ) {
547
+ for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
548
+ }
549
+ threadgroup_barrier (mem_flags::mem_threadgroup);
550
+ if (ith == 0 ) {
551
+ for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
547
552
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0 ];
548
553
}
554
+
555
+ // Original implementation. Left behind commented out for now
556
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
557
+ // for (uint i = tptg.x/2; i > 0; i /= 2) {
558
+ // if (tpitg.x < i) {
559
+ // sum[tpitg.x] += sum[tpitg.x + i];
560
+ // }
561
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
562
+ // }
563
+ //
564
+ // if (tpitg.x == 0) {
565
+ // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
+ // }
549
567
}
550
568
551
569
kernel void kernel_alibi_f32 (
0 commit comments