Skip to content

Commit e40b85a

Browse files
committed
metal : use F16 math in mul_mat kernels
ggml-ci
1 parent 841f27a commit e40b85a

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

ggml/src/ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,7 @@ static void ggml_metal_encode_node(
19651965
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
19661966
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
19671967
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
1968-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
1968+
[encoder setThreadgroupMemoryLength:(4096 + 2048) atIndex:0];
19691969
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
19701970
} else {
19711971
int nth0 = 32;

ggml/src/ggml-metal.metal

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6295,8 +6295,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
62956295
uint tiitg[[thread_index_in_threadgroup]],
62966296
uint sgitg[[simdgroup_index_in_threadgroup]]) {
62976297

6298-
threadgroup T * sa = (threadgroup T *)(shared_memory);
6299-
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
6298+
threadgroup T * sa = (threadgroup T *)(shared_memory);
6299+
threadgroup half * sb = (threadgroup half *)(shared_memory + 4096);
63006300

63016301
const uint r0 = tgpig.y;
63026302
const uint r1 = tgpig.x;
@@ -6310,11 +6310,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
63106310
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
63116311
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
63126312

6313-
simdgroup_T8x8 ma[4];
6314-
simdgroup_float8x8 mb[2];
6315-
simdgroup_float8x8 c_res[8];
6313+
simdgroup_T8x8 ma[4];
6314+
simdgroup_half8x8 mb[2];
6315+
simdgroup_half8x8 mc[8];
63166316
for (int i = 0; i < 8; i++){
6317-
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
6317+
mc[i] = make_filled_simdgroup_matrix<half, 8>(0.h);
63186318
}
63196319

63206320
short il = (tiitg % THREAD_PER_ROW);
@@ -6345,17 +6345,17 @@ kernel void kernel_mul_mm(device const uchar * src0,
63456345
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
63466346
}
63476347

6348-
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6348+
*(threadgroup half2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = (half2x4)(*((device float2x4 *)y));
63496349

63506350
il = (il + 2 < nl) ? il + 2 : il % 2;
6351-
x = (il < 2) ? x + (2+nl-1)/nl : x;
6351+
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
63526352
y += BLOCK_SIZE_K;
63536353

63546354
threadgroup_barrier(mem_flags::mem_threadgroup);
63556355

63566356
// load matrices from threadgroup memory and conduct outer products
6357-
threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
6358-
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
6357+
threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
6358+
threadgroup half * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
63596359

63606360
#pragma unroll(4)
63616361
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
@@ -6374,7 +6374,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
63746374

63756375
#pragma unroll(8)
63766376
for (int i = 0; i < 8; i++){
6377-
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
6377+
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
63786378
}
63796379
}
63806380
}
@@ -6383,15 +6383,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
63836383
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
63846384
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
63856385
for (int i = 0; i < 8; i++) {
6386-
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
6386+
// cast to f32
6387+
simdgroup_float8x8 mc_f32(1.0f);
6388+
simdgroup_multiply(mc_f32, mc[i], mc_f32);
6389+
simdgroup_store(mc_f32, C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
6390+
//simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
63876391
}
63886392
} else {
63896393
// block is smaller than 64x32, we should avoid writing data outside of the matrix
63906394
threadgroup_barrier(mem_flags::mem_threadgroup);
63916395
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6392-
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
6396+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
63936397
for (int i = 0; i < 8; i++) {
6394-
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6398+
simdgroup_float8x8 mc_f32(1.0f);
6399+
simdgroup_multiply(mc_f32, mc[i], mc_f32);
6400+
simdgroup_store(mc_f32, temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6401+
//simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
63956402
}
63966403

63976404
threadgroup_barrier(mem_flags::mem_threadgroup);

0 commit comments

Comments
 (0)