Skip to content

Commit c4dff1e

Browse files
committed
metal : reduce registers
1 parent e51778d commit c4dff1e

File tree

2 files changed

+10
-22
lines changed

2 files changed

+10
-22
lines changed

ggml-metal.m

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,6 @@
179179
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
180180
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
181181
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
182-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
183-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80,
184-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
185-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112,
186182
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
187183
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
188184
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
@@ -625,10 +621,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
625621
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
626622
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
627623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
628-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, true);
629-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, true);
630-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, true);
631-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, true);
632624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
633625
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
634626
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
@@ -2521,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
25212513

25222514
id<MTLComputePipelineState> pipeline = nil;
25232515

2524-
if (ne01 > 1) {
2516+
if (ne01 > 1 || (ne00%128 != 0)) {
25252517
switch (ne00) {
25262518
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
25272519
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
@@ -2538,10 +2530,6 @@ static enum ggml_status ggml_metal_graph_compute(
25382530
}
25392531
} else {
25402532
switch (ne00) {
2541-
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break;
2542-
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break;
2543-
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break;
2544-
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break;
25452533
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
25462534
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
25472535
default:

ggml-metal.metal

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2516,7 +2516,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25162516
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + Q*T); // scratch buffer for the results
25172517

25182518
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2519-
half4 lo[Q][D4];
2519+
half4 lo[Q][D4/NW];
25202520

25212521
// load heads from Q to shared memory
25222522
for (short j = sgitg; j < Q; j += nsg) {
@@ -2534,7 +2534,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
25342534
// zero out lo
25352535
for (short j = 0; j < Q; ++j) {
25362536
for (short i = tiisg; i < D4; i += NW) {
2537-
lo[j][i] = 0.0h;
2537+
lo[j][i/NW] = 0.0h;
25382538
}
25392539
}
25402540

@@ -2711,7 +2711,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
27112711

27122712
for (short i = tiisg; i < D4; i += NW) {
27132713
//simdgroup_multiply(lo[j][i], mm, lo[j][i]);
2714-
lo[j][i] = lo[j][i]*mm;
2714+
lo[j][i/NW] = lo[j][i/NW]*mm;
27152715
}
27162716
}
27172717

@@ -2722,7 +2722,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
27222722

27232723
for (short i = tiisg; i < D4; i += NW) {
27242724
for (short j = 0; j < Q; ++j) {
2725-
lo[j][i] += pv4[i]*ss[j*T + cc];
2725+
lo[j][i/NW] += pv4[i]*ss[j*T + cc];
27262726
}
27272727
}
27282728
}
@@ -2743,7 +2743,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
27432743
// store results to shared memory
27442744
for (short j = 0; j < Q; ++j) {
27452745
for (short i = tiisg; i < D4; i += NW) {
2746-
sr4[i] = lo[j][i];
2746+
sr4[i] = lo[j][i/NW];
27472747
}
27482748
}
27492749

@@ -2805,10 +2805,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
28052805
}
28062806
}
28072807

2808-
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<64, 1, 32>;
2809-
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<80, 1, 32>;
2810-
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<96, 1, 32>;
2811-
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<112, 1, 32>;
2808+
template [[host_name("kernel_flash_attn_ext_vec_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 2, 32>;
2809+
template [[host_name("kernel_flash_attn_ext_vec_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 3, 32>;
2810+
template [[host_name("kernel_flash_attn_ext_vec_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 4, 32>;
2811+
template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 5, 32>;
28122812
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 1, 32>;
28132813
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 1, 32>;
28142814

0 commit comments

Comments
 (0)