Skip to content

Commit 08d5ccb

Browse files
ggerganovarthw
authored andcommitted
metal : fix build and some more comments (ggml-org#10229)
1 parent 5cabf58 commit 08d5ccb

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

ggml/src/ggml-metal.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,6 +3041,8 @@ static void ggml_metal_encode_node(
30413041

30423042
bool use_vec_kernel = false;
30433043

3044+
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3045+
// for now avoiding mainly to keep the number of templates/kernels a bit lower
30443046
if (ne01 >= 4 || (ne00%128 != 0)) {
30453047
switch (src1->type) {
30463048
case GGML_TYPE_F16:

ggml/src/ggml-metal.metal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
33563356
const short D4 = D/4;
33573357
const short D16 = D/16;
33583358
const short NW = N_SIMDWIDTH;
3359-
const short NL = NW/4;
3360-
const short SH = 2*C; // shared memory per simdgroup
3359+
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
3360+
const short SH = 2*C; // shared memory per simdgroup
33613361

33623362
const short T = D + nsg*SH; // shared memory size per query in (half)
33633363

@@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
34483448

34493449
// Q*K^T
34503450
{
3451-
// each simdgroup processes 1 query and 4 keys
3451+
// each simdgroup processes 1 query and 4 (NW/NL) keys
34523452
for (short cc = 0; cc < C/4; ++cc) {
34533453
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
34543454

@@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
36463646
half, half4, half4x4, \
36473647
half4x4
36483648

3649-
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
3649+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
36503650

36513651
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
36523652
#if defined(GGML_METAL_USE_BF16)

0 commit comments

Comments
 (0)