Skip to content

Commit fa5f91d

Browse files
committed
metal : extend FA to support different K and V head sizes
ggml-ci
1 parent b3dbc32 commit fa5f91d

File tree

5 files changed

+626
-558
lines changed

5 files changed

+626
-558
lines changed

ggml/include/ggml.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,11 +1791,11 @@ extern "C" {
17911791

17921792
#define GGML_KQ_MASK_PAD 64
17931793

1794-
// q: [n_embd, n_batch, n_head, 1]
1795-
// k: [n_embd, n_kv, n_head_kv, 1]
1796-
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1797-
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1798-
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
1794+
// q: [n_embd_k, n_batch, n_head, 1]
1795+
// k: [n_embd_k, n_kv, n_head_kv, 1]
1796+
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1797+
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1798+
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
17991799
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
18001800
struct ggml_context * ctx,
18011801
struct ggml_tensor * q,

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,12 @@ typedef struct {
219219
int32_t ne11;
220220
int32_t ne_12_2; // assume K and V are same shape
221221
int32_t ne_12_3;
222-
uint64_t nb_12_1;
223-
uint64_t nb_12_2;
224-
uint64_t nb_12_3;
222+
uint64_t nb11;
223+
uint64_t nb12;
224+
uint64_t nb13;
225+
uint64_t nb21;
226+
uint64_t nb22;
227+
uint64_t nb23;
225228
uint64_t nb31;
226229
int32_t ne1;
227230
int32_t ne2;

0 commit comments

Comments
 (0)