@@ -423,8 +423,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
423
423
}
424
424
425
425
// putting them in the kernel cause a significant performance penalty
426
- #define N_DST 4 // each SIMD group works on 4 rows
427
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
426
+ #define N_DST 4 // each SIMD group works on 4 rows
427
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
428
428
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
429
429
// Note: This is a template, but strictly speaking it only applies to
430
430
// quantizations where the block size is 32. It also does not
@@ -435,18 +435,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
435
435
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
436
436
uint3 tgpig, uint tiisg, uint sgitg) {
437
437
const int nb = ne00/QK4_0;
438
+
438
439
const int r0 = tgpig.x ;
439
440
const int r1 = tgpig.y ;
440
441
const int im = tgpig.z ;
442
+
441
443
const int first_row = (r0 * nsg + sgitg) * nr;
444
+
442
445
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
446
+
443
447
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
444
448
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
445
- float yl[16 ]; // src1 vector cache
446
- float sumf[nr]={0 .f };
447
449
448
- const int ix = tiisg/2 ;
449
- const int il = 8 *(tiisg%2 );
450
+ float yl[16 ]; // src1 vector cache
451
+ float sumf[nr] = {0 .f };
452
+
453
+ const int ix = (tiisg/2 );
454
+ const int il = (tiisg%2 )*8 ;
450
455
451
456
device const float * yb = y + ix * QK4_0 + il;
452
457
@@ -457,6 +462,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
457
462
sumy += yb[i] + yb[i+1 ];
458
463
yl[i+0 ] = yb[i+ 0 ];
459
464
yl[i+1 ] = yb[i+ 1 ]/256 .f ;
465
+
460
466
sumy += yb[i+16 ] + yb[i+17 ];
461
467
yl[i+8 ] = yb[i+16 ]/16 .f ;
462
468
yl[i+9 ] = yb[i+17 ]/4096 .f ;
@@ -472,7 +478,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
472
478
for (int row = 0 ; row < nr; ++row) {
473
479
const float tot = simd_sum (sumf[row]);
474
480
if (tiisg == 0 && first_row + row < ne01) {
475
- dst[r1 *ne0 + im *ne0*ne1 + first_row + row] = tot;
481
+ dst[im *ne0*ne1 + r1 *ne0 + first_row + row] = tot;
476
482
}
477
483
}
478
484
}
@@ -490,8 +496,8 @@ kernel void kernel_mul_mv_q4_0_f32(
490
496
constant int64_t & ne1[[buffer(16 )]],
491
497
constant uint & gqa[[buffer(17 )]],
492
498
uint3 tgpig[[threadgroup_position_in_grid]],
493
- uint tiisg[[thread_index_in_simdgroup]],
494
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
499
+ uint tiisg[[thread_index_in_simdgroup]],
500
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
495
501
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
496
502
}
497
503
@@ -669,7 +675,7 @@ kernel void kernel_mul_mv_f16_f32_1row(
669
675
constant int64_t & ne0,
670
676
constant int64_t & ne1,
671
677
uint3 tgpig[[threadgroup_position_in_grid]],
672
- uint tiisg[[thread_index_in_simdgroup]]) {
678
+ uint tiisg[[thread_index_in_simdgroup]]) {
673
679
674
680
const int64_t r0 = tgpig.x ;
675
681
const int64_t r1 = tgpig.y ;
0 commit comments