@@ -6313,7 +6313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
6313
6313
simdgroup_T8x8 ma[4 ];
6314
6314
simdgroup_half8x8 mb[2 ];
6315
6315
simdgroup_half8x8 mc[8 ];
6316
- for (int i = 0 ; i < 8 ; i++){
6316
+ for (short i = 0 ; i < 8 ; i++){
6317
6317
mc[i] = make_filled_simdgroup_matrix<half, 8 >(0 .h );
6318
6318
}
6319
6319
@@ -6339,7 +6339,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
6339
6339
threadgroup_barrier (mem_flags::mem_threadgroup);
6340
6340
6341
6341
#pragma unroll(16)
6342
- for (int i = 0 ; i < 16 ; i++) {
6342
+ for (short i = 0 ; i < 16 ; i++) {
6343
6343
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8 ) \
6344
6344
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8 ) * 8 ) \
6345
6345
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7 ) * 8 ) = temp_a[i/4 ][i%4 ];
@@ -6358,22 +6358,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
6358
6358
threadgroup half * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2 ));
6359
6359
6360
6360
#pragma unroll(4)
6361
- for (int ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6361
+ for (short ik = 0 ; ik < BLOCK_SIZE_K / 8 ; ik++) {
6362
6362
#pragma unroll(4)
6363
- for (int i = 0 ; i < 4 ; i++) {
6363
+ for (short i = 0 ; i < 4 ; i++) {
6364
6364
simdgroup_load (ma[i],lsma + SG_MAT_SIZE * i);
6365
6365
}
6366
6366
simdgroup_barrier (mem_flags::mem_none);
6367
6367
#pragma unroll(2)
6368
- for (int i = 0 ; i < 2 ; i++) {
6368
+ for (short i = 0 ; i < 2 ; i++) {
6369
6369
simdgroup_load (mb[i],lsmb + SG_MAT_SIZE * i);
6370
6370
}
6371
6371
6372
6372
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6373
6373
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6374
6374
6375
6375
#pragma unroll(8)
6376
- for (int i = 0 ; i < 8 ; i++){
6376
+ for (short i = 0 ; i < 8 ; i++){
6377
6377
simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
6378
6378
}
6379
6379
}
@@ -6382,7 +6382,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
6382
6382
if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
6383
6383
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1 )) \
6384
6384
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1 )) * ne0 + im*ne1*ne0;
6385
- for (int i = 0 ; i < 8 ; i++) {
6385
+ for (short i = 0 ; i < 8 ; i++) {
6386
6386
// cast to f32
6387
6387
simdgroup_float8x8 mc_f32 (1 .0f );
6388
6388
simdgroup_multiply (mc_f32, mc[i], mc_f32);
@@ -6394,7 +6394,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
6394
6394
threadgroup_barrier (mem_flags::mem_threadgroup);
6395
6395
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6396
6396
+ 32 * (sgitg&1 ) + (16 * (sgitg>>1 )) * BLOCK_SIZE_M;
6397
- for (int i = 0 ; i < 8 ; i++) {
6397
+ for (short i = 0 ; i < 8 ; i++) {
6398
6398
simdgroup_float8x8 mc_f32 (1 .0f );
6399
6399
simdgroup_multiply (mc_f32, mc[i], mc_f32);
6400
6400
simdgroup_store (mc_f32, temp_str + 8 * (i%4 ) + 8 * BLOCK_SIZE_M * (i/4 ), BLOCK_SIZE_M);
0 commit comments