Skip to content

Commit a5df71e

Browse files
committed
Removed MQA optimisation from build_attn_mha() as no gains now
1 parent 638b092 commit a5df71e

File tree

1 file changed

+3
-19
lines changed

1 file changed

+3
-19
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,18 +1200,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12001200
//const auto & n_embd_head_k = hparams.n_embd_head_k;
12011201
//const auto & n_embd_head_v = hparams.n_embd_head_v;
12021202

1203-
const auto n_embd = q->ne[0];
1204-
const auto n_tokens = q->ne[1];
1205-
const auto n_head = q->ne[2];
1206-
1207-
const auto n_kv = k->ne[1];
1208-
const auto n_head_kv = k->ne[2];
1209-
12101203
// note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
12111204
const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
12121205

1213-
GGML_ASSERT(k->ne[0] == q->ne[0] && "K and Q embedding size mismatch");
1214-
GGML_ASSERT(k->ne[2] == v->ne[2] && "K and V number of heads mismatch");
1206+
const auto n_tokens = q->ne[1];
1207+
const auto n_head = q->ne[2];
1208+
const auto n_kv = k->ne[1];
12151209

12161210
ggml_tensor * cur;
12171211

@@ -1239,22 +1233,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12391233

12401234
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
12411235
} else {
1242-
1243-
// for MQA (ie: GQA with 1 group) we don't need to use a batched matrix multiply
1244-
if (n_head_kv == 1) {
1245-
q = ggml_reshape_2d(ctx0, q, n_embd, n_tokens*n_head);
1246-
}
1247-
12481236
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
12491237

12501238
// note: this op tends to require high floating point range
12511239
// while for some models F16 is enough, for others it is not, so we default to F32 here
12521240
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
12531241

1254-
if (n_head_kv == 1) {
1255-
kq = ggml_reshape_3d(ctx0, kq, n_kv, n_tokens, n_head);
1256-
}
1257-
12581242
if (arch == LLM_ARCH_GROK) {
12591243
// need to do the following:
12601244
// multiply by attn_output_multiplyer of 0.08838834764831845

0 commit comments

Comments
 (0)