Skip to content

Commit 8ad92dc

Browse files
committed
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
1 parent 2ddc9bb commit 8ad92dc

File tree

7 files changed

+79
-62
lines changed

7 files changed

+79
-62
lines changed

ggml-cuda.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5917,7 +5917,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
59175917
}
59185918

59195919
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
5920-
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5920+
static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
59215921
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
59225922
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
59235923
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
@@ -5952,12 +5952,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
59525952
if (need_check && col_data + 0 >= ncols_data) {
59535953
val.x = -INFINITY;
59545954
} else {
5955-
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
5955+
val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f);
59565956
}
59575957
if (need_check && col_data + WARP_SIZE >= ncols_data) {
59585958
val.y = -INFINITY;
59595959
} else {
5960-
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
5960+
val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f);
59615961
}
59625962
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
59635963
vals[col_smem] = val;
@@ -6047,7 +6047,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
60476047
}
60486048

60496049
template <bool vals_smem, int ncols_template, int block_size_template>
6050-
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
6050+
static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
60516051
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
60526052

60536053
const int tid = threadIdx.x;
@@ -6077,7 +6077,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
60776077
const int ix = rowx*ncols + col;
60786078
const int iy = rowy*ncols + col;
60796079

6080-
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
6080+
const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f);
60816081
vals[col] = val;
60826082
max_val = max(max_val, val);
60836083
}
@@ -7585,7 +7585,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
75857585
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
75867586
}
75877587

7588-
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
7588+
static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
75897589
int nth = WARP_SIZE;
75907590
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
75917591
const dim3 block_dims(nth, 1, 1);
@@ -7628,7 +7628,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
76287628
}
76297629
}
76307630

7631-
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
7631+
static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
76327632
int nth = WARP_SIZE;
76337633
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
76347634
const dim3 block_dims(nth, 1, 1);
@@ -9060,7 +9060,7 @@ static void ggml_cuda_op_soft_max(
90609060
GGML_ASSERT(src0->type == GGML_TYPE_F32);
90619061
GGML_ASSERT( dst->type == GGML_TYPE_F32);
90629062

9063-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
9063+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
90649064

90659065
const int64_t ne00 = src0->ne[0];
90669066
const int64_t nrows_x = ggml_nrows(src0);
@@ -9080,9 +9080,9 @@ static void ggml_cuda_op_soft_max(
90809080
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
90819081

90829082
if (use_f16_soft_max) {
9083-
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
9083+
soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
90849084
} else {
9085-
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
9085+
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
90869086
}
90879087

90889088
(void) dst;

ggml-metal.m

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,8 @@ static bool ggml_metal_graph_compute(
11871187
} break;
11881188
case GGML_OP_SOFT_MAX:
11891189
{
1190+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
1191+
11901192
int nth = 32; // SIMD width
11911193

11921194
id<MTLComputePipelineState> pipeline = nil;
@@ -2213,6 +2215,10 @@ static bool ggml_metal_graph_compute(
22132215

22142216
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
22152217

2218+
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
2219+
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2220+
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2221+
22162222
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
22172223
const int64_t ne31 = src3 ? src3->ne[1] : 0;
22182224
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);

ggml-metal.metal

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,9 @@ kernel void kernel_sum_rows(
349349
}
350350

351351
kernel void kernel_soft_max(
352-
device const float * src0,
353-
device const float * src1,
354-
device float * dst,
352+
device const char * src0,
353+
device const char * src1,
354+
device char * dst,
355355
constant int64_t & ne00,
356356
constant int64_t & ne01,
357357
constant int64_t & ne02,
@@ -366,9 +366,9 @@ kernel void kernel_soft_max(
366366
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
367367
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
368368

369-
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370-
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371-
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
369+
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
370+
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
371+
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
372372

373373
// parallel max
374374
float lmax = -INFINITY;
@@ -435,14 +435,14 @@ kernel void kernel_soft_max(
435435
}
436436

437437
kernel void kernel_soft_max_4(
438-
device const float * src0,
439-
device const float * src1,
440-
device float * dst,
438+
device const char * src0,
439+
device const char * src1,
440+
device char * dst,
441441
constant int64_t & ne00,
442442
constant int64_t & ne01,
443443
constant int64_t & ne02,
444444
constant float & scale,
445-
threadgroup float * buf [[threadgroup(0)]],
445+
threadgroup float * buf [[threadgroup(0)]],
446446
uint tgpig[[threadgroup_position_in_grid]],
447447
uint tpitg[[thread_position_in_threadgroup]],
448448
uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -452,15 +452,15 @@ kernel void kernel_soft_max_4(
452452
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
453453
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
454454

455-
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456-
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
457-
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
455+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
456+
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
457+
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
458458

459459
// parallel max
460460
float4 lmax4 = -INFINITY;
461461

462462
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
463+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f));
464464
}
465465

466466
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -486,7 +486,7 @@ kernel void kernel_soft_max_4(
486486
// parallel sum
487487
float4 lsum4 = 0.0f;
488488
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
489-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
489+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val);
490490
lsum4 += exp_psrc4;
491491
pdst4[i00] = exp_psrc4;
492492
}
@@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16(
21442144
}
21452145
}
21462146

2147-
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
2148-
21492147
// pointer to the mask
2150-
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
2148+
device const half * mp = (device const half *) (mask + iq1*nb31);
21512149

21522150
// prepare diagonal scale matrix
2153-
simdgroup_float8x8 mscale(scale);
2151+
simdgroup_half8x8 mscale(scale);
21542152

21552153
// loop over the KV cache
21562154
// each simdgroup handles blocks of Q rows and C columns
@@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16(
21762174

21772175
// mqk = mqk*scale + mask
21782176
for (int64_t j = 0; j < Q8; ++j) {
2179-
simdgroup_float8x8 mm;
2180-
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false);
2177+
simdgroup_half8x8 mm;
2178+
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
21812179
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
21822180

21832181
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);

ggml.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5085,6 +5085,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
50855085
bool inplace) {
50865086
GGML_ASSERT(ggml_is_contiguous(a));
50875087
if (mask) {
5088+
GGML_ASSERT(mask->type == GGML_TYPE_F16);
50885089
GGML_ASSERT(ggml_is_contiguous(mask));
50895090
GGML_ASSERT(mask->ne[2] == 1);
50905091
GGML_ASSERT(mask->ne[3] == 1);
@@ -5854,6 +5855,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
58545855
GGML_ASSERT(ggml_is_contiguous(mask));
58555856
GGML_ASSERT(mask->ne[2] == 1);
58565857
GGML_ASSERT(mask->ne[3] == 1);
5858+
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
5859+
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
58575860
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
58585861
}
58595862

@@ -11552,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32(
1155211555
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
1155311556

1155411557
// broadcast the mask across rows
11555-
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
11558+
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
1155611559

1155711560
ggml_vec_cpy_f32 (nc, wp, sp);
1155811561
ggml_vec_scale_f32(nc, wp, scale);
1155911562
if (mp) {
11560-
ggml_vec_acc_f32(nc, wp, mp);
11563+
for (int i = 0; i < nc; ++i) {
11564+
wp[i] += GGML_FP16_TO_FP32(mp[i]);
11565+
}
1156111566
}
1156211567

1156311568
#ifndef NDEBUG
@@ -13760,7 +13765,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1376013765

1376113766
memset(V16, 0, D*sizeof(ggml_fp16_t));
1376213767

13763-
const float * mp = mask ? (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]) : NULL;
13768+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
1376413769

1376513770
// k indices
1376613771
const int ik3 = iq3 / rk3;
@@ -13774,7 +13779,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1377413779
// loop over n_kv and n_head_kv
1377513780
// ref: https://arxiv.org/pdf/2112.05682.pdf
1377613781
for (int64_t ic = 0; ic < nek1; ++ic) {
13777-
const float mv = mp ? mp[ic] : 0.0f;
13782+
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
1377813783
if (mv == -INFINITY) {
1377913784
continue;
1378013785
}

ggml.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,11 +1646,13 @@ extern "C" {
16461646
struct ggml_tensor * v,
16471647
bool masked);
16481648

1649-
// q: [n_embd, n_batch, n_head, 1]
1650-
// k: [n_embd, n_kv, n_head_kv, 1]
1651-
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1652-
// mask: [n_kv, n_batch, 1, 1]
1653-
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
1649+
#define GGML_KQ_MASK_PAD 32
1650+
1651+
// q: [n_embd, n_batch, n_head, 1]
1652+
// k: [n_embd, n_kv, n_head_kv, 1]
1653+
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1654+
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1655+
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
16541656
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
16551657
struct ggml_context * ctx,
16561658
struct ggml_tensor * q,

0 commit comments

Comments
 (0)