@@ -477,7 +477,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
477
477
}
478
478
}
479
479
480
- kernel void kernel_mul_mat_q4_0_f32 (
480
+ kernel void kernel_mul_mv_q4_0_f32 (
481
481
device const void * src0,
482
482
device const float * src1,
483
483
device float * dst,
@@ -495,7 +495,7 @@ kernel void kernel_mul_mat_q4_0_f32(
495
495
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
496
496
}
497
497
498
- kernel void kernel_mul_mat_q4_1_f32 (
498
+ kernel void kernel_mul_mv_q4_1_f32 (
499
499
device const void * src0,
500
500
device const float * src1,
501
501
device float * dst,
@@ -515,7 +515,7 @@ kernel void kernel_mul_mat_q4_1_f32(
515
515
516
516
#define NB_Q8_0 8
517
517
518
- kernel void kernel_mul_mat_q8_0_f32 (
518
+ kernel void kernel_mul_mv_q8_0_f32 (
519
519
device const void * src0,
520
520
device const float * src1,
521
521
device float * dst,
@@ -579,7 +579,7 @@ kernel void kernel_mul_mat_q8_0_f32(
579
579
580
580
#define N_F32_F32 4
581
581
582
- kernel void kernel_mul_mat_f32_f32 (
582
+ kernel void kernel_mul_mv_f32_f32 (
583
583
device const char * src0,
584
584
device const char * src1,
585
585
device float * dst,
@@ -650,7 +650,7 @@ kernel void kernel_mul_mat_f32_f32(
650
650
}
651
651
}
652
652
653
- kernel void kernel_mul_mat_f16_f32_1row (
653
+ kernel void kernel_mul_mv_f16_f32_1row (
654
654
device const char * src0,
655
655
device const char * src1,
656
656
device float * dst,
@@ -704,7 +704,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
704
704
705
705
#define N_F16_F32 4
706
706
707
- kernel void kernel_mul_mat_f16_f32 (
707
+ kernel void kernel_mul_mv_f16_f32 (
708
708
device const char * src0,
709
709
device const char * src1,
710
710
device float * dst,
@@ -776,7 +776,7 @@ kernel void kernel_mul_mat_f16_f32(
776
776
}
777
777
778
778
// Assumes row size (ne00) is a multiple of 4
779
- kernel void kernel_mul_mat_f16_f32_l4 (
779
+ kernel void kernel_mul_mv_f16_f32_l4 (
780
780
device const char * src0,
781
781
device const char * src1,
782
782
device float * dst,
@@ -1253,7 +1253,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1253
1253
1254
1254
// ====================================== dot products =========================
1255
1255
1256
- kernel void kernel_mul_mat_q2_K_f32 (
1256
+ kernel void kernel_mul_mv_q2_K_f32 (
1257
1257
device const void * src0,
1258
1258
device const float * src1,
1259
1259
device float * dst,
@@ -1397,7 +1397,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1397
1397
}
1398
1398
1399
1399
#if QK_K == 256
1400
- kernel void kernel_mul_mat_q3_K_f32 (
1400
+ kernel void kernel_mul_mv_q3_K_f32 (
1401
1401
device const void * src0,
1402
1402
device const float * src1,
1403
1403
device float * dst,
@@ -1549,7 +1549,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1549
1549
}
1550
1550
}
1551
1551
#else
1552
- kernel void kernel_mul_mat_q3_K_f32 (
1552
+ kernel void kernel_mul_mv_q3_K_f32 (
1553
1553
device const void * src0,
1554
1554
device const float * src1,
1555
1555
device float * dst,
@@ -1620,7 +1620,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1620
1620
#endif
1621
1621
1622
1622
#if QK_K == 256
1623
- kernel void kernel_mul_mat_q4_K_f32 (
1623
+ kernel void kernel_mul_mv_q4_K_f32 (
1624
1624
device const void * src0,
1625
1625
device const float * src1,
1626
1626
device float * dst,
@@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1726
1726
}
1727
1727
}
1728
1728
#else
1729
- kernel void kernel_mul_mat_q4_K_f32 (
1729
+ kernel void kernel_mul_mv_q4_K_f32 (
1730
1730
device const void * src0,
1731
1731
device const float * src1,
1732
1732
device float * dst,
@@ -1815,7 +1815,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1815
1815
}
1816
1816
#endif
1817
1817
1818
- kernel void kernel_mul_mat_q5_K_f32 (
1818
+ kernel void kernel_mul_mv_q5_K_f32 (
1819
1819
device const void * src0,
1820
1820
device const float * src1,
1821
1821
device float * dst,
@@ -1988,7 +1988,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1988
1988
1989
1989
}
1990
1990
1991
- kernel void kernel_mul_mat_q6_K_f32 (
1991
+ kernel void kernel_mul_mv_q6_K_f32 (
1992
1992
device const void * src0,
1993
1993
device const float * src1,
1994
1994
device float * dst,
@@ -2363,9 +2363,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2363
2363
const uint r0 = tgpig.y ;
2364
2364
const uint r1 = tgpig.x ;
2365
2365
const uint im = tgpig.z ;
2366
+
2366
2367
// if this block is of 64x32 shape or smaller
2367
2368
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2368
2369
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2370
+
2369
2371
// a thread shouldn't load data outside of the matrix
2370
2372
short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
2371
2373
short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
@@ -2393,22 +2395,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2393
2395
half4x4 temp_a;
2394
2396
dequantize_func (x, il, temp_a);
2395
2397
threadgroup_barrier (mem_flags::mem_threadgroup);
2398
+
2396
2399
#pragma unroll(16)
2397
2400
for (int i = 0 ; i < 16 ; i++) {
2398
2401
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
2399
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8 )) \
2400
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
2402
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8 ) * 8 ) \
2403
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
2401
2404
}
2402
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2403
- = *((device float2x4 *)y);
2405
+
2406
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2407
+
2404
2408
il = (il + 2 < nl) ? il + 2 : il % 2 ;
2405
2409
x = (il < 2 ) ? x + (2 +nl-1 )/nl : x;
2406
2410
y += BLOCK_SIZE_K;
2407
2411
2408
2412
threadgroup_barrier (mem_flags::mem_threadgroup);
2413
+
2409
2414
// load matrices from threadgroup memory and conduct outer products
2410
2415
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
2411
2416
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
2417
+
2412
2418
#pragma unroll(4)
2413
2419
for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
2414
2420
#pragma unroll(4)
@@ -2423,6 +2429,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2423
2429
2424
2430
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2425
2431
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2432
+
2426
2433
#pragma unroll(8)
2427
2434
for (int i = 0 ; i < 8 ; i++){
2428
2435
simdgroup_multiply_accumulate (c_res[i], mb[i/4 ], ma[i%4 ], c_res[i]);
@@ -2431,8 +2438,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
2431
2438
}
2432
2439
2433
2440
if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
2434
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg& 1 ) \
2435
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>> 1 )) * ne0 + im*ne1*ne0;
2441
+ device float *C = dst + ( BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 ) ) \
2442
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
2436
2443
for (int i = 0 ; i < 8 ; i++) {
2437
2444
simdgroup_store (c_res[i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
2438
2445
}
0 commit comments