@@ -5439,8 +5439,8 @@ kernel void kernel_mul_mm(
5439
5439
ushort tiitg[[thread_index_in_threadgroup]],
5440
5440
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5441
5441
5442
- threadgroup T * sa = (threadgroup T *)(shmem);
5443
- threadgroup float * sb = (threadgroup float *)(shmem + 4096 );
5442
+ threadgroup T * sa = (threadgroup T *)(shmem);
5443
+ threadgroup half * sb = (threadgroup half *)(shmem + 4096 );
5444
5444
5445
5445
const int r0 = tgpig.y ;
5446
5446
const int r1 = tgpig.x ;
@@ -5454,12 +5454,12 @@ kernel void kernel_mul_mm(
5454
5454
const short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
5455
5455
const short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
5456
5456
5457
- simdgroup_T8x8 ma[4 ];
5458
- simdgroup_float8x8 mb[2 ];
5459
- simdgroup_float8x8 mc[8 ];
5457
+ simdgroup_T8x8 ma[4 ];
5458
+ simdgroup_half8x8 mb[2 ];
5459
+ simdgroup_half8x8 mc[8 ];
5460
5460
5461
5461
for (short i = 0 ; i < 8 ; i++){
5462
- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
5462
+ mc[i] = make_filled_simdgroup_matrix<half , 8 >(0 .h );
5463
5463
}
5464
5464
5465
5465
short il = (tiitg % THREAD_PER_ROW);
@@ -5493,7 +5493,7 @@ kernel void kernel_mul_mm(
5493
5493
+ (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
5494
5494
}
5495
5495
5496
- *(threadgroup float2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y );
5496
+ *(threadgroup half2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)( *((device float2x4 *)y) );
5497
5497
5498
5498
il = (il + 2 < nl) ? il + 2 : il % 2 ;
5499
5499
x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -5502,8 +5502,8 @@ kernel void kernel_mul_mm(
5502
5502
threadgroup_barrier (mem_flags::mem_threadgroup);
5503
5503
5504
5504
// load matrices from threadgroup memory and conduct outer products
5505
- threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
5506
- threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
5505
+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
5506
+ threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
5507
5507
5508
5508
#pragma unroll(4)
5509
5509
for (short ik = 0 ; ik < BLOCK_SIZE_K/8 ; ik++) {
@@ -5535,15 +5535,22 @@ kernel void kernel_mul_mm(
5535
5535
(BLOCK_SIZE_N * r1 + 16 *(sgitg >> 1 )) * args.ne0 + im*args.ne1 *args.ne0 ;
5536
5536
5537
5537
for (short i = 0 ; i < 8 ; i++) {
5538
- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 );
5538
+ // cast to f32
5539
+ simdgroup_float8x8 mc_f32 (1 .0f );
5540
+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
5541
+ simdgroup_store (mc_f32, C + 8 * (i%4 ) + 8 * args.ne0 * (i/4 ), args.ne0 );
5542
+ // simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
5539
5543
}
5540
5544
} else {
5541
5545
// block is smaller than 64x32, we should avoid writing data outside of the matrix
5542
5546
threadgroup_barrier (mem_flags::mem_threadgroup);
5543
5547
threadgroup float * temp_str = ((threadgroup float *) shmem) \
5544
5548
+ 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
5545
5549
for (short i = 0 ; i < 8 ; i++) {
5546
- simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
5550
+ simdgroup_float8x8 mc_f32 (1 .0f );
5551
+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
5552
+ simdgroup_store (mc_f32, temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
5553
+ // simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
5547
5554
}
5548
5555
5549
5556
threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments