Skip to content

Commit dd0d9ed

Browse files
committed
metal : clean-up
1 parent 13b87f2 commit dd0d9ed

File tree

2 files changed

+25
-59
lines changed

2 files changed

+25
-59
lines changed

ggml/src/ggml-metal.m

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3139,7 +3139,6 @@ static void ggml_metal_encode_node(
31393139
if (!use_vec_kernel) {
31403140
// half8x8 kernel
31413141
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
3142-
const int64_t nkpsg = 8; // keys per simdgroup
31433142
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
31443143

31453144
GGML_ASSERT(nqptg <= 32);
@@ -3149,7 +3148,9 @@ static void ggml_metal_encode_node(
31493148
int64_t nsgmax = 2;
31503149

31513150
while (true) {
3152-
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 4*16*nkpsg*nsgmax)*(sizeof(float)/2);
3151+
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
3152+
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3153+
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
31533154
if (smem > device.maxThreadgroupMemoryLength) {
31543155
break;
31553156
}
@@ -3160,12 +3161,12 @@ static void ggml_metal_encode_node(
31603161
// simdgroups per threadgroup (a.k.a. warps)
31613162
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
31623163

3163-
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 4*16*nkpsg*nsg)*(sizeof(float)/2);
3164+
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
31643165

31653166
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
31663167
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
31673168

3168-
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 128) atIndex:0];
3169+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
31693170

31703171
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
31713172
} else {

ggml/src/ggml-metal.metal

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2723,45 +2723,9 @@ kernel void kernel_leaky_relu_f32(
27232723
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
27242724
}
27252725

2726-
typedef void (flash_attn_ext_t)(
2727-
device const char * q,
2728-
device const char * k,
2729-
device const char * v,
2730-
device const char * mask,
2731-
device float * dst,
2732-
constant int64_t & ne01,
2733-
constant int64_t & ne02,
2734-
constant int64_t & ne03,
2735-
constant uint64_t & nb01,
2736-
constant uint64_t & nb02,
2737-
constant uint64_t & nb03,
2738-
constant int64_t & ne11,
2739-
constant int64_t & ne12,
2740-
constant int64_t & ne13,
2741-
constant uint64_t & nb11,
2742-
constant uint64_t & nb12,
2743-
constant uint64_t & nb13,
2744-
constant uint64_t & nb21,
2745-
constant uint64_t & nb22,
2746-
constant uint64_t & nb23,
2747-
constant uint64_t & nb31,
2748-
constant int64_t & ne1,
2749-
constant int64_t & ne2,
2750-
constant float & scale,
2751-
constant float & max_bias,
2752-
constant float & m0,
2753-
constant float & m1,
2754-
constant uint32_t & n_head_log2,
2755-
constant float & logit_softcap,
2756-
threadgroup half * shared,
2757-
uint3 tgpig[[threadgroup_position_in_grid]],
2758-
uint3 tpitg[[thread_position_in_threadgroup]],
2759-
uint3 ntg[[threads_per_threadgroup]],
2760-
ushort tiisg[[thread_index_in_simdgroup]],
2761-
ushort sgitg[[simdgroup_index_in_threadgroup]]);
2762-
27632726
// ref: https://arxiv.org/pdf/2307.08691.pdf
2764-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short K = 8, short C = 32> // head size, queries per threadgroup, cache items per threadgroup
2727+
// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
2728+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
27652729
kernel void kernel_flash_attn_ext(
27662730
device const char * q,
27672731
device const char * k,
@@ -2818,8 +2782,8 @@ kernel void kernel_flash_attn_ext(
28182782
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
28192783
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
28202784

2821-
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*K) + Q*T); // scratch buffer to load K and V in shared memory
2822-
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*K) + Q*T); // same as above but in half4x4
2785+
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
2786+
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
28232787

28242788
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
28252789
simdgroup_half8x8 lo[D8];
@@ -3179,6 +3143,8 @@ kernel void kernel_flash_attn_ext(
31793143
}
31803144
}
31813145

3146+
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
3147+
31823148
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
31833149
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
31843150
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
@@ -3223,7 +3189,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_
32233189

32243190
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
32253191
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
3226-
kernel void flash_attn_ext_vec(
3192+
kernel void kernel_flash_attn_ext_vec(
32273193
device const char * q,
32283194
device const char * k,
32293195
device const char * v,
@@ -3548,22 +3514,21 @@ kernel void flash_attn_ext_vec(
35483514
}
35493515
}
35503516

3551-
//template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>;
3552-
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>;
3517+
typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
35533518

3554-
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
3555-
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
3556-
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
3557-
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
3558-
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
3559-
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
3519+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
3520+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
3521+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
3522+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
3523+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
3524+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
35603525

3561-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
3562-
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
3563-
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
3564-
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
3565-
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
3566-
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_t flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
3526+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
3527+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
3528+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
3529+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
3530+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
3531+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
35673532

35683533
template<typename T0, typename T1>
35693534
kernel void kernel_cpy(

0 commit comments

Comments
 (0)