@@ -534,14 +534,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
534
534
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
535
535
536
536
float sumf = 0 ;
537
- for (int i = tiisg; i < ne00; i += 32 ) {
538
- sumf += (float ) x[i] * (float ) y[i];
537
+ if (ne00 < 128 ) {
538
+ for (int i = tiisg; i < ne00; i += 32 ) {
539
+ sumf += (float ) x[i] * (float ) y[i];
540
+ }
541
+ float all_sum = simd_sum (sumf);
542
+ if (tiisg == 0 ) {
543
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
544
+ }
545
+ } else {
546
+ device const half4 * x4 = (device const half4 *) x;
547
+ device const float4 * y4 = (device const float4 *) y;
548
+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
549
+ for (int k = 0 ; k < 4 ; ++k) sumf += (float )x4[i][k] * y4[i][k];
550
+ }
551
+ float all_sum = simd_sum (sumf);
552
+ if (tiisg == 0 ) {
553
+ for (int i = 4 *(ne00/4 ); i < ne00; ++i) sumf += (float ) x[i] * y[i];
554
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
555
+ }
539
556
}
540
557
541
- float all_sum = simd_sum (sumf);
542
- if (tiisg == 0 ) {
543
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
544
- }
545
558
}
546
559
547
560
#define N_F16_F32 4
@@ -573,22 +586,46 @@ kernel void kernel_mul_mat_f16_f32(
573
586
574
587
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
575
588
576
- for (int row = 0 ; row < N_F16_F32; ++row) {
577
- int r1 = rb + row;
578
- if (r1 >= ne11) {
579
- break ;
580
- }
589
+ if (ne00 < 128 ) {
590
+ for (int row = 0 ; row < N_F16_F32; ++row) {
591
+ int r1 = rb + row;
592
+ if (r1 >= ne11) {
593
+ break ;
594
+ }
581
595
582
- device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
596
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583
597
584
- float sumf = 0 ;
585
- for (int i = tiisg; i < ne00; i += 32 ) {
586
- sumf += (float ) x[i] * (float ) y[i];
598
+ float sumf = 0 ;
599
+ for (int i = tiisg; i < ne00; i += 32 ) {
600
+ sumf += (float ) x[i] * (float ) y[i];
601
+ }
602
+
603
+ float all_sum = simd_sum (sumf);
604
+ if (tiisg == 0 ) {
605
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
606
+ }
587
607
}
608
+ } else {
609
+ device const half4 * x4 = (device const half4 *)x;
610
+ for (int row = 0 ; row < N_F16_F32; ++row) {
611
+ int r1 = rb + row;
612
+ if (r1 >= ne11) {
613
+ break ;
614
+ }
588
615
589
- float all_sum = simd_sum (sumf);
590
- if (tiisg == 0 ) {
591
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
616
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
617
+ device const float4 * y4 = (device const float4 *) y;
618
+
619
+ float sumf = 0 ;
620
+ for (int i = tiisg; i < ne00/4 ; i += 32 ) {
621
+ for (int k = 0 ; k < 4 ; ++k) sumf += (float ) x4[i][k] * y4[i][k];
622
+ }
623
+
624
+ float all_sum = simd_sum (sumf);
625
+ if (tiisg == 0 ) {
626
+ for (int i = 4 *(ne00/4 ); i < ne00; ++i) sumf += (float ) x[i] * y[i];
627
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
628
+ }
592
629
}
593
630
}
594
631
0 commit comments