@@ -1200,18 +1200,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1200
1200
// const auto & n_embd_head_k = hparams.n_embd_head_k;
1201
1201
// const auto & n_embd_head_v = hparams.n_embd_head_v;
1202
1202
1203
- const auto n_embd = q->ne [0 ];
1204
- const auto n_tokens = q->ne [1 ];
1205
- const auto n_head = q->ne [2 ];
1206
-
1207
- const auto n_kv = k->ne [1 ];
1208
- const auto n_head_kv = k->ne [2 ];
1209
-
1210
1203
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
1211
1204
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne [1 ] : v->ne [0 ] : v_mla->ne [1 ];
1212
1205
1213
- GGML_ASSERT (k->ne [0 ] == q->ne [0 ] && " K and Q embedding size mismatch" );
1214
- GGML_ASSERT (k->ne [2 ] == v->ne [2 ] && " K and V number of heads mismatch" );
1206
+ const auto n_tokens = q->ne [1 ];
1207
+ const auto n_head = q->ne [2 ];
1208
+ const auto n_kv = k->ne [1 ];
1215
1209
1216
1210
ggml_tensor * cur;
1217
1211
@@ -1239,22 +1233,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1239
1233
1240
1234
cur = ggml_reshape_2d (ctx0, cur, n_embd_head_v*n_head, n_tokens);
1241
1235
} else {
1242
-
1243
- // for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1244
- if (n_head_kv == 1 ) {
1245
- q = ggml_reshape_2d (ctx0, q, n_embd, n_tokens*n_head);
1246
- }
1247
-
1248
1236
ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1249
1237
1250
1238
// note: this op tends to require high floating point range
1251
1239
// while for some models F16 is enough, for others it is not, so we default to F32 here
1252
1240
ggml_mul_mat_set_prec (kq, GGML_PREC_F32);
1253
1241
1254
- if (n_head_kv == 1 ) {
1255
- kq = ggml_reshape_3d (ctx0, kq, n_kv, n_tokens, n_head);
1256
- }
1257
-
1258
1242
if (arch == LLM_ARCH_GROK) {
1259
1243
// need to do the following:
1260
1244
// multiply by attn_output_multiplyer of 0.08838834764831845
0 commit comments