@@ -15883,8 +15883,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15883
15883
GGML_ASSERT(ne2 == N);
15884
15884
15885
15885
GGML_ASSERT(nbq0 == sizeof(float));
15886
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15888
15886
15889
15887
GGML_ASSERT(neq0 == D);
15890
15888
GGML_ASSERT(nek0 == D);
@@ -15945,17 +15943,47 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15945
15943
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15946
15944
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15947
15945
15948
- const uint32_t h = iq2; // head
15946
+ const uint32_t h = iq2; // head index
15949
15947
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
15948
15951
- float S = 0.0f;
15952
- float M = -INFINITY;
15949
+ float S = 0.0f; // sum
15950
+ float M = -INFINITY; // maximum KQ value
15951
+
15952
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15953
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15954
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15955
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) Q buffer
15953
15956
15954
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15957
+ ggml_to_float_t v_to_float = NULL;
15957
15958
15958
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15959
+ switch (v->type) {
15960
+ case GGML_TYPE_F16: {
15961
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15962
+ } break;
15963
+ case GGML_TYPE_Q8_0: {
15964
+ v_to_float = (ggml_to_float_t) dequantize_row_q8_0;
15965
+ memset(VKQ32, 0, D*sizeof(float));
15966
+ } break;
15967
+ case GGML_TYPE_Q5_1: {
15968
+ v_to_float = (ggml_to_float_t) dequantize_row_q5_1;
15969
+ memset(VKQ32, 0, D*sizeof(float));
15970
+ } break;
15971
+ case GGML_TYPE_Q5_0: {
15972
+ v_to_float = (ggml_to_float_t) dequantize_row_q5_0;
15973
+ memset(VKQ32, 0, D*sizeof(float));
15974
+ } break;
15975
+ case GGML_TYPE_Q4_1: {
15976
+ v_to_float = (ggml_to_float_t) dequantize_row_q4_1;
15977
+ memset(VKQ32, 0, D*sizeof(float));
15978
+ } break;
15979
+ case GGML_TYPE_Q4_0: {
15980
+ v_to_float = (ggml_to_float_t) dequantize_row_q4_0;
15981
+ memset(VKQ32, 0, D*sizeof(float));
15982
+ } break;
15983
+ default: {
15984
+ GGML_ASSERT(false);
15985
+ } break;
15986
+ }
15959
15987
15960
15988
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15961
15989
@@ -15967,6 +15995,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15967
15995
const int iv3 = iq3 / rv3;
15968
15996
const int iv2 = iq2 / rv2;
15969
15997
15998
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15999
+ switch (k->type) {
16000
+ case GGML_TYPE_F16: {
16001
+ // convert Q to F16 in V32
16002
+ for (int64_t d = 0; d < D; ++d) {
16003
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
16004
+ }
16005
+ } break;
16006
+ case GGML_TYPE_Q8_0:
16007
+ case GGML_TYPE_Q5_0:
16008
+ case GGML_TYPE_Q4_0: {
16009
+ // convert Q to q8_0 in V32
16010
+ quantize_row_q8_0(pq, Q16, D);
16011
+ } break;
16012
+ case GGML_TYPE_Q5_1:
16013
+ case GGML_TYPE_Q4_1: {
16014
+ // convert Q to q8_0 in V32
16015
+ quantize_row_q8_1(pq, Q16, D);
16016
+ } break;
16017
+ default: {
16018
+ GGML_ASSERT(false && "Unsupported k type.");
16019
+ } break;
16020
+ }
16021
+
15970
16022
// online softmax / attention
15971
16023
// loop over n_kv and n_head_kv
15972
16024
// ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +16028,89 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15976
16028
continue;
15977
16029
}
15978
16030
15979
- float s;
16031
+ float s; // KQ value
15980
16032
15981
- // convert Q to F16 in V32
15982
- {
15983
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15984
-
15985
- for (int64_t d = 0; d < D; ++d) {
15986
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987
- }
16033
+ char * k_data = (char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
16034
+ switch (k->type) {
16035
+ case GGML_TYPE_F16: {
16036
+ ggml_vec_dot_f16(D, &s, 0, (ggml_fp16_t *) k_data, 0, Q16, 0, 1);
16037
+ } break;
16038
+ case GGML_TYPE_Q8_0: {
16039
+ ggml_vec_dot_q8_0_q8_0(D, &s, 0, k_data, 0, Q16, 0, 1);
16040
+ } break;
16041
+ case GGML_TYPE_Q5_1: {
16042
+ ggml_vec_dot_q5_1_q8_1(D, &s, 0, k_data, 0, Q16, 0, 1);
16043
+ } break;
16044
+ case GGML_TYPE_Q5_0: {
16045
+ ggml_vec_dot_q5_0_q8_0(D, &s, 0, k_data, 0, Q16, 0, 1);
16046
+ } break;
16047
+ case GGML_TYPE_Q4_1: {
16048
+ ggml_vec_dot_q4_1_q8_1(D, &s, 0, k_data, 0, Q16, 0, 1);
16049
+ } break;
16050
+ case GGML_TYPE_Q4_0: {
16051
+ ggml_vec_dot_q4_0_q8_0(D, &s, 0, k_data, 0, Q16, 0, 1);
16052
+ } break;
16053
+ default: {
16054
+ GGML_ASSERT(false && "Unsupported k type.");
16055
+ } break;
15988
16056
}
15989
16057
15990
- ggml_vec_dot_f16(D,
15991
- &s, 0,
15992
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993
- Q16, 0, 1);
15994
-
15995
- s = s*scale + mv;
16058
+ s = s*scale + mv; // scale KQ value and apply mask
15996
16059
15997
16060
const float Mold = M;
15998
16061
15999
- float ms = 1.0f;
16000
- float vs = 1.0f;
16062
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
16063
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
16064
+
16065
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
16001
16066
16002
- if (s > M) {
16003
- M = s;
16004
- ms = expf(Mold - M);
16067
+ if (v->type== GGML_TYPE_F16) {
16068
+ if (s > M) {
16069
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
16070
+ M = s;
16071
+ ms = expf(Mold - M);
16005
16072
16006
- // V = V*expf(Mold - M)
16007
- ggml_vec_scale_f16(D, V16, ms);
16073
+ // V = V*expf(Mold - M)
16074
+ ggml_vec_scale_f16(D, VKQ16, ms);
16075
+ } else {
16076
+ // no new maximum, ms == 1.0f, vs != 1.0f
16077
+ vs = expf(s - M);
16078
+ }
16079
+
16080
+ // V += v*expf(s - M)
16081
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
16008
16082
} else {
16009
- vs = expf(s - M);
16010
- }
16083
+ if (s > M) {
16084
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
16085
+ M = s;
16086
+ ms = expf(Mold - M);
16011
16087
16012
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
16088
+ // V = V*expf(Mold - M)
16089
+ ggml_vec_scale_f32(D, VKQ32, ms);
16090
+ } else {
16091
+ // no new maximum, ms == 1.0f, vs != 1.0f
16092
+ vs = expf(s - M);
16093
+ }
16013
16094
16014
- // V += v*expf(s - M)
16015
- ggml_vec_mad_f16(D, V16, v16, vs);
16095
+ v_to_float(v_data, V32, D);
16016
16096
16017
- S = S*ms + vs;
16097
+ // V += v*expf(s - M)
16098
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
16099
+ }
16100
+
16101
+ S = S*ms + vs; // scale and increment sum with partial sum
16018
16102
}
16019
16103
16020
- // V /= S
16021
- for (int64_t d = 0; d < D; ++d) {
16022
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
16104
+ if (v->type == GGML_TYPE_F16) {
16105
+ for (int64_t d = 0; d < D; ++d) {
16106
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
16107
+ }
16023
16108
}
16024
16109
16110
+ // V /= S
16111
+ const float S_inv = 1.0f/S;
16112
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
16113
+
16025
16114
// dst indices
16026
16115
const int i1 = iq1;
16027
16116
const int i2 = iq2;
@@ -16031,7 +16120,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
16031
16120
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16032
16121
16033
16122
// permute(0, 2, 1, 3)
16034
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32 , nb1);
16123
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 , nb1);
16035
16124
}
16036
16125
}
16037
16126
@@ -19972,7 +20061,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19972
20061
{
19973
20062
const int64_t ne00 = node->src[0]->ne[0]; // D
19974
20063
19975
- cur = 2 *sizeof(float)*ne00*n_tasks; // 2x head size
20064
+ cur = 3 *sizeof(float)*ne00*n_tasks; // 3x head size/thread
19976
20065
} break;
19977
20066
case GGML_OP_FLASH_FF:
19978
20067
{
0 commit comments