Skip to content

Commit c1c6d84

Browse files
ggml: implement quantized KV cache for FA
1 parent 059031b commit c1c6d84

File tree

1 file changed

+130
-41
lines changed

1 file changed

+130
-41
lines changed

ggml.c

Lines changed: 130 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15883,8 +15883,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1588315883
GGML_ASSERT(ne2 == N);
1588415884

1588515885
GGML_ASSERT(nbq0 == sizeof(float));
15886-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
1588815886

1588915887
GGML_ASSERT(neq0 == D);
1589015888
GGML_ASSERT(nek0 == D);
@@ -15945,17 +15943,47 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1594515943
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
1594615944
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
1594715945

15948-
const uint32_t h = iq2; // head
15946+
const uint32_t h = iq2; // head index
1594915947
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
1595015948

15951-
float S = 0.0f;
15952-
float M = -INFINITY;
15949+
float S = 0.0f; // sum
15950+
float M = -INFINITY; // maximum KQ value
15951+
15952+
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15953+
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15954+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15955+
ggml_fp16_t * Q16 = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) Q buffer
1595315956

15954-
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955-
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956-
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15957+
ggml_to_float_t v_to_float = NULL;
1595715958

15958-
memset(V16, 0, D*sizeof(ggml_fp16_t));
15959+
switch (v->type) {
15960+
case GGML_TYPE_F16: {
15961+
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15962+
} break;
15963+
case GGML_TYPE_Q8_0: {
15964+
v_to_float = (ggml_to_float_t) dequantize_row_q8_0;
15965+
memset(VKQ32, 0, D*sizeof(float));
15966+
} break;
15967+
case GGML_TYPE_Q5_1: {
15968+
v_to_float = (ggml_to_float_t) dequantize_row_q5_1;
15969+
memset(VKQ32, 0, D*sizeof(float));
15970+
} break;
15971+
case GGML_TYPE_Q5_0: {
15972+
v_to_float = (ggml_to_float_t) dequantize_row_q5_0;
15973+
memset(VKQ32, 0, D*sizeof(float));
15974+
} break;
15975+
case GGML_TYPE_Q4_1: {
15976+
v_to_float = (ggml_to_float_t) dequantize_row_q4_1;
15977+
memset(VKQ32, 0, D*sizeof(float));
15978+
} break;
15979+
case GGML_TYPE_Q4_0: {
15980+
v_to_float = (ggml_to_float_t) dequantize_row_q4_0;
15981+
memset(VKQ32, 0, D*sizeof(float));
15982+
} break;
15983+
default: {
15984+
GGML_ASSERT(false);
15985+
} break;
15986+
}
1595915987

1596015988
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
1596115989

@@ -15967,6 +15995,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1596715995
const int iv3 = iq3 / rv3;
1596815996
const int iv2 = iq2 / rv2;
1596915997

15998+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15999+
switch (k->type) {
16000+
case GGML_TYPE_F16: {
16001+
// convert Q to F16 in V32
16002+
for (int64_t d = 0; d < D; ++d) {
16003+
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
16004+
}
16005+
} break;
16006+
case GGML_TYPE_Q8_0:
16007+
case GGML_TYPE_Q5_0:
16008+
case GGML_TYPE_Q4_0: {
16009+
// convert Q to q8_0 in V32
16010+
quantize_row_q8_0(pq, Q16, D);
16011+
} break;
16012+
case GGML_TYPE_Q5_1:
16013+
case GGML_TYPE_Q4_1: {
16014+
// convert Q to q8_0 in V32
16015+
quantize_row_q8_1(pq, Q16, D);
16016+
} break;
16017+
default: {
16018+
GGML_ASSERT(false && "Unsupported k type.");
16019+
} break;
16020+
}
16021+
1597016022
// online softmax / attention
1597116023
// loop over n_kv and n_head_kv
1597216024
// ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +16028,89 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1597616028
continue;
1597716029
}
1597816030

15979-
float s;
16031+
float s; // KQ value
1598016032

15981-
// convert Q to F16 in V32
15982-
{
15983-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15984-
15985-
for (int64_t d = 0; d < D; ++d) {
15986-
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987-
}
16033+
char * k_data = (char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
16034+
switch (k->type) {
16035+
case GGML_TYPE_F16: {
16036+
ggml_vec_dot_f16(D, &s, 0, (ggml_fp16_t *) k_data, 0, Q16, 0, 1);
16037+
} break;
16038+
case GGML_TYPE_Q8_0: {
16039+
ggml_vec_dot_q8_0_q8_0(D, &s, 0, k_data, 0, Q16, 0, 1);
16040+
} break;
16041+
case GGML_TYPE_Q5_1: {
16042+
ggml_vec_dot_q5_1_q8_1(D, &s, 0, k_data, 0, Q16, 0, 1);
16043+
} break;
16044+
case GGML_TYPE_Q5_0: {
16045+
ggml_vec_dot_q5_0_q8_0(D, &s, 0, k_data, 0, Q16, 0, 1);
16046+
} break;
16047+
case GGML_TYPE_Q4_1: {
16048+
ggml_vec_dot_q4_1_q8_1(D, &s, 0, k_data, 0, Q16, 0, 1);
16049+
} break;
16050+
case GGML_TYPE_Q4_0: {
16051+
ggml_vec_dot_q4_0_q8_0(D, &s, 0, k_data, 0, Q16, 0, 1);
16052+
} break;
16053+
default: {
16054+
GGML_ASSERT(false && "Unsupported k type.");
16055+
} break;
1598816056
}
1598916057

15990-
ggml_vec_dot_f16(D,
15991-
&s, 0,
15992-
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993-
Q16, 0, 1);
15994-
15995-
s = s*scale + mv;
16058+
s = s*scale + mv; // scale KQ value and apply mask
1599616059

1599716060
const float Mold = M;
1599816061

15999-
float ms = 1.0f;
16000-
float vs = 1.0f;
16062+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
16063+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
16064+
16065+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
1600116066

16002-
if (s > M) {
16003-
M = s;
16004-
ms = expf(Mold - M);
16067+
if (v->type== GGML_TYPE_F16) {
16068+
if (s > M) {
16069+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
16070+
M = s;
16071+
ms = expf(Mold - M);
1600516072

16006-
// V = V*expf(Mold - M)
16007-
ggml_vec_scale_f16(D, V16, ms);
16073+
// V = V*expf(Mold - M)
16074+
ggml_vec_scale_f16(D, VKQ16, ms);
16075+
} else {
16076+
// no new maximum, ms == 1.0f, vs != 1.0f
16077+
vs = expf(s - M);
16078+
}
16079+
16080+
// V += v*expf(s - M)
16081+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
1600816082
} else {
16009-
vs = expf(s - M);
16010-
}
16083+
if (s > M) {
16084+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
16085+
M = s;
16086+
ms = expf(Mold - M);
1601116087

16012-
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
16088+
// V = V*expf(Mold - M)
16089+
ggml_vec_scale_f32(D, VKQ32, ms);
16090+
} else {
16091+
// no new maximum, ms == 1.0f, vs != 1.0f
16092+
vs = expf(s - M);
16093+
}
1601316094

16014-
// V += v*expf(s - M)
16015-
ggml_vec_mad_f16(D, V16, v16, vs);
16095+
v_to_float(v_data, V32, D);
1601616096

16017-
S = S*ms + vs;
16097+
// V += v*expf(s - M)
16098+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
16099+
}
16100+
16101+
S = S*ms + vs; // scale and increment sum with partial sum
1601816102
}
1601916103

16020-
// V /= S
16021-
for (int64_t d = 0; d < D; ++d) {
16022-
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
16104+
if (v->type == GGML_TYPE_F16) {
16105+
for (int64_t d = 0; d < D; ++d) {
16106+
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
16107+
}
1602316108
}
1602416109

16110+
// V /= S
16111+
const float S_inv = 1.0f/S;
16112+
ggml_vec_scale_f32(D, VKQ32, S_inv);
16113+
1602516114
// dst indices
1602616115
const int i1 = iq1;
1602716116
const int i2 = iq2;
@@ -16031,7 +16120,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1603116120
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
1603216121

1603316122
// permute(0, 2, 1, 3)
16034-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
16123+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
1603516124
}
1603616125
}
1603716126

@@ -19972,7 +20061,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1997220061
{
1997320062
const int64_t ne00 = node->src[0]->ne[0]; // D
1997420063

19975-
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
20064+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
1997620065
} break;
1997720066
case GGML_OP_FLASH_FF:
1997820067
{

0 commit comments

Comments
 (0)