Skip to content

Commit 4663bd3

Browse files
authored
metal : use constexpr in FA kernels + fix typedef (#12659)
* metal : use constexpr in FA kernels ggml-ci * cont ggml-ci * cont : fix typedef ggml-ci
1 parent b3de7ca commit 4663bd3

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
31283128
const int iq2 = tgpig[1];
31293129
const int iq1 = tgpig[0]*Q;
31303130

3131-
const short DK4 = DK/4;
3132-
const short DK8 = DK/8;
3133-
const short DK16 = DK/16;
3134-
const short DV4 = DV/4;
3135-
const short DV8 = DV/8;
3136-
const short DV16 = DV/16;
3137-
const short NW = N_SIMDWIDTH;
3138-
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3131+
constexpr short DK4 = DK/4;
3132+
constexpr short DK8 = DK/8;
3133+
constexpr short DK16 = DK/16;
3134+
constexpr short DV4 = DV/4;
3135+
constexpr short DV8 = DV/8;
3136+
constexpr short DV16 = DV/16;
3137+
3138+
constexpr short NW = N_SIMDWIDTH;
3139+
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
31393140

31403141
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
31413142
const short T = DK + 2*TS; // shared memory size per query in (half)
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
36413642
const int iq2 = tgpig[1];
36423643
const int iq1 = tgpig[0];
36433644

3644-
const short DK4 = DK/4;
3645-
const short DV4 = DV/4;
3646-
const short NW = N_SIMDWIDTH;
3647-
const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3648-
const short SH = 2*C; // shared memory per simdgroup
3645+
constexpr short DK4 = DK/4;
3646+
constexpr short DV4 = DV/4;
3647+
constexpr short NW = N_SIMDWIDTH;
3648+
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649+
constexpr short SH = 2*C; // shared memory per simdgroup
36493650

36503651
const short T = DK + nsg*SH; // shared memory size per query in (half)
36513652

@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
39563957
half, half4, \
39573958
half4
39583959

3959-
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
3960+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
39603961

39613962
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
39623963
#if defined(GGML_METAL_USE_BF16)

0 commit comments

Comments
 (0)