@@ -6310,8 +6310,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
6310
6310
uint tiitg[[thread_index_in_threadgroup]],
6311
6311
uint sgitg[[simdgroup_index_in_threadgroup]]) {
6312
6312
6313
- threadgroup T * sa = (threadgroup T *)(shared_memory);
6314
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096 );
6313
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
6314
+ threadgroup half * sb = (threadgroup half *)(shared_memory + 4096 );
6315
6315
6316
6316
const uint r0 = tgpig.y ;
6317
6317
const uint r1 = tgpig.x ;
@@ -6325,12 +6325,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
6325
6325
short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
6326
6326
short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
6327
6327
6328
- simdgroup_T8x8 ma[4 ];
6329
- simdgroup_float8x8 mb[2 ];
6330
- simdgroup_float8x8 mc[8 ];
6328
+ simdgroup_T8x8 ma[4 ];
6329
+ simdgroup_half8x8 mb[2 ];
6330
+ simdgroup_half8x8 mc[8 ];
6331
6331
6332
6332
for (short i = 0 ; i < 8 ; i++){
6333
- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6333
+ mc[i] = make_filled_simdgroup_matrix<half , 8 >(0 .h );
6334
6334
}
6335
6335
6336
6336
short il = (tiitg % THREAD_PER_ROW);
@@ -6361,17 +6361,17 @@ kernel void kernel_mul_mm(device const uchar * src0,
6361
6361
+ (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
6362
6362
}
6363
6363
6364
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8 *32 + 8 *(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y );
6364
+ *(threadgroup half2x4 *)(sb + (tiitg% THREAD_PER_COL)*8 *32 + 8 *(tiitg/THREAD_PER_COL)) = (half2x4)( *((device float2x4 *)y) );
6365
6365
6366
6366
il = (il + 2 < nl) ? il + 2 : il % 2 ;
6367
- x = (il < 2 ) ? x + (2 +nl- 1 )/nl : x;
6367
+ x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
6368
6368
y += BLOCK_SIZE_K;
6369
6369
6370
6370
threadgroup_barrier (mem_flags::mem_threadgroup);
6371
6371
6372
6372
// load matrices from threadgroup memory and conduct outer products
6373
- threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
6374
- threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
6373
+ threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2 ));
6374
+ threadgroup half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2 ));
6375
6375
6376
6376
#pragma unroll(4)
6377
6377
for (short ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
@@ -6399,15 +6399,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
6399
6399
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 )) \
6400
6400
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6401
6401
for (short i = 0 ; i < 8 ; i++) {
6402
- simdgroup_store (mc[i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6402
+ // cast to f32
6403
+ simdgroup_float8x8 mc_f32 (1 .0f );
6404
+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
6405
+ simdgroup_store (mc_f32, C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6406
+ // simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
6403
6407
}
6404
6408
} else {
6405
6409
// block is smaller than 64x32, we should avoid writing data outside of the matrix
6406
6410
threadgroup_barrier (mem_flags::mem_threadgroup);
6407
6411
threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
6408
- + 32 * (sgitg&1 ) + (16 * (sgitg>> 1 ))*BLOCK_SIZE_M;
6412
+ + 32 * (sgitg&1 ) + (16 * (sgitg >> 1 ))*BLOCK_SIZE_M;
6409
6413
for (short i = 0 ; i < 8 ; i++) {
6410
- simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M*(i/4 ), BLOCK_SIZE_M);
6414
+ simdgroup_float8x8 mc_f32 (1 .0f );
6415
+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
6416
+ simdgroup_store (mc_f32, temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6417
+ // simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6411
6418
}
6412
6419
6413
6420
threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments