@@ -6295,8 +6295,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
6295
6295
uint tiitg[[thread_index_in_threadgroup]],
6296
6296
uint sgitg[[simdgroup_index_in_threadgroup]]) {
6297
6297
6298
- threadgroup T * sa = (threadgroup T *)(shared_memory);
6299
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096 );
6298
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
6299
+ threadgroup half * sb = (threadgroup half *)(shared_memory + 4096 );
6300
6300
6301
6301
const uint r0 = tgpig.y ;
6302
6302
const uint r1 = tgpig.x ;
@@ -6310,11 +6310,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
6310
6310
short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
6311
6311
short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
6312
6312
6313
- simdgroup_T8x8 ma[4 ];
6314
- simdgroup_float8x8 mb[2 ];
6315
- simdgroup_float8x8 c_res [8 ];
6313
+ simdgroup_T8x8 ma[4 ];
6314
+ simdgroup_half8x8 mb[2 ];
6315
+ simdgroup_half8x8 mc [8 ];
6316
6316
for (int i = 0 ; i < 8 ; i++){
6317
- c_res [i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
6317
+ mc [i] = make_filled_simdgroup_matrix<half , 8 >(0 .h );
6318
6318
}
6319
6319
6320
6320
short il = (tiitg % THREAD_PER_ROW);
@@ -6345,17 +6345,17 @@ kernel void kernel_mul_mm(device const uchar * src0,
6345
6345
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
6346
6346
}
6347
6347
6348
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6348
+ *(threadgroup half2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = (half2x4)( *((device float2x4 *)y) );
6349
6349
6350
6350
il = (il + 2 < nl) ? il + 2 : il % 2 ;
6351
- x = (il < 2 ) ? x + (2 +nl- 1 )/nl : x;
6351
+ x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
6352
6352
y += BLOCK_SIZE_K;
6353
6353
6354
6354
threadgroup_barrier (mem_flags::mem_threadgroup);
6355
6355
6356
6356
// load matrices from threadgroup memory and conduct outer products
6357
- threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
6358
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
6357
+ threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2 ));
6358
+ threadgroup half * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
6359
6359
6360
6360
#pragma unroll(4)
6361
6361
for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
@@ -6374,7 +6374,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
6374
6374
6375
6375
#pragma unroll(8)
6376
6376
for (int i = 0 ; i < 8 ; i++){
6377
- simdgroup_multiply_accumulate (c_res [i], mb[i/4 ], ma[i%4 ], c_res [i]);
6377
+ simdgroup_multiply_accumulate (mc [i], mb[i/4 ], ma[i%4 ], mc [i]);
6378
6378
}
6379
6379
}
6380
6380
}
@@ -6383,15 +6383,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
6383
6383
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 )) \
6384
6384
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6385
6385
for (int i = 0 ; i < 8 ; i++) {
6386
- simdgroup_store (c_res[i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6386
+ // cast to f32
6387
+ simdgroup_float8x8 mc_f32 (1 .0f );
6388
+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
6389
+ simdgroup_store (mc_f32, C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
6390
+ // simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
6387
6391
}
6388
6392
} else {
6389
6393
// block is smaller than 64x32, we should avoid writing data outside of the matrix
6390
6394
threadgroup_barrier (mem_flags::mem_threadgroup);
6391
6395
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6392
- + 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
6396
+ + 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
6393
6397
for (int i = 0 ; i < 8 ; i++) {
6394
- simdgroup_store (c_res[i], temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6398
+ simdgroup_float8x8 mc_f32 (1 .0f );
6399
+ simdgroup_multiply (mc_f32, mc[i], mc_f32);
6400
+ simdgroup_store (mc_f32, temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
6401
+ // simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6395
6402
}
6396
6403
6397
6404
threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments