Skip to content

metal : use F16 math in mul_mat kernels #10220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2021,7 +2021,8 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];

[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
//[encoder setThreadgroupMemoryLength:4096 + 2048 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
Expand Down
29 changes: 18 additions & 11 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5439,8 +5439,8 @@ kernel void kernel_mul_mm(
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {

threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
threadgroup T * sa = (threadgroup T *)(shmem);
threadgroup half * sb = (threadgroup half *)(shmem + 4096);

const int r0 = tgpig.y;
const int r1 = tgpig.x;
Expand All @@ -5454,12 +5454,12 @@ kernel void kernel_mul_mm(
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;

simdgroup_T8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 mc[8];
simdgroup_T8x8 ma[4];
simdgroup_half8x8 mb[2];
simdgroup_half8x8 mc[8];

for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
mc[i] = make_filled_simdgroup_matrix<half, 8>(0.h);
}

short il = (tiitg % THREAD_PER_ROW);
Expand Down Expand Up @@ -5493,7 +5493,7 @@ kernel void kernel_mul_mm(
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}

*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *)y));

il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
Expand All @@ -5502,8 +5502,8 @@ kernel void kernel_mul_mm(
threadgroup_barrier(mem_flags::mem_threadgroup);

// load matrices from threadgroup memory and conduct outer products
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));

#pragma unroll(4)
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
Expand Down Expand Up @@ -5535,15 +5535,22 @@ kernel void kernel_mul_mm(
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;

for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
// cast to f32
simdgroup_float8x8 mc_f32(1.0f);
simdgroup_multiply(mc_f32, mc[i], mc_f32);
simdgroup_store(mc_f32, C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
//simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
simdgroup_float8x8 mc_f32(1.0f);
simdgroup_multiply(mc_f32, mc[i], mc_f32);
simdgroup_store(mc_f32, temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
//simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}

threadgroup_barrier(mem_flags::mem_threadgroup);
Expand Down
Loading