Skip to content

Commit 36ce235

Browse files
committed
Removed build_attn_mla and added nullptr to all build_atnn calls
1 parent 925af99 commit 36ce235

File tree

3 files changed

+77
-204
lines changed

3 files changed

+77
-204
lines changed

src/llama-graph.cpp

Lines changed: 6 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn(
13111311
ggml_tensor * k_cur,
13121312
ggml_tensor * v_cur,
13131313
ggml_tensor * kq_b,
1314+
ggml_tensor * v_mla,
13141315
float kq_scale,
13151316
int il) const {
13161317
GGML_UNUSED(n_tokens);
@@ -1332,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn(
13321333
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
13331334
//cb(k, "v", il);
13341335

1335-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale);
1336+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
13361337

13371338
cb(cur, "kqv_out", il);
13381339

@@ -1386,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn(
13861387
ggml_tensor * k_cur,
13871388
ggml_tensor * v_cur,
13881389
ggml_tensor * kq_b,
1390+
ggml_tensor * v_mla,
13891391
float kq_scale,
13901392
int il) const {
13911393
// these nodes are added to the graph together so that they are not reordered
@@ -1471,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn(
14711473
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
14721474
0);
14731475

1474-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, v_trans, kq_scale);
1476+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
14751477
cb(cur, "kqv_out", il);
14761478

14771479
if (wo) {
@@ -1511,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
15111513
ggml_tensor * k_cur,
15121514
ggml_tensor * v_cur,
15131515
ggml_tensor * kq_b,
1516+
ggml_tensor * v_mla,
15141517
float kq_scale,
15151518
int il) const {
15161519
// these nodes are added to the graph together so that they are not reordered
@@ -1530,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn(
15301533
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
15311534
//cb(k, "v", il);
15321535

1533-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, nullptr, false, kq_scale);
1536+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
15341537

15351538
cb(cur, "kqv_out", il);
15361539

@@ -1549,123 +1552,6 @@ ggml_tensor * llm_graph_context::build_attn(
15491552
return cur;
15501553
}
15511554

1552-
// ****************************************************************************************************************
1553-
// *** THIS WILL BE REMOVED AFTER CODE REVIEW IS ACCPETED AND READY TO MERGE - IT'S JUST A COPY OF build_attn() ***
1554-
// ****************************************************************************************************************
1555-
ggml_tensor * llm_graph_context::build_attn_mla(
1556-
llm_graph_input_attn_kv_unified * inp,
1557-
ggml_cgraph * gf,
1558-
ggml_tensor * wo,
1559-
ggml_tensor * wo_b,
1560-
ggml_tensor * q_cur,
1561-
ggml_tensor * k_cur,
1562-
ggml_tensor * v_cur,
1563-
ggml_tensor * kq_b,
1564-
ggml_tensor * v_mla,
1565-
float kq_scale,
1566-
int il) const {
1567-
// these nodes are added to the graph together so that they are not reordered
1568-
// by doing so, the number of splits in the graph is reduced
1569-
ggml_build_forward_expand(gf, q_cur);
1570-
ggml_build_forward_expand(gf, k_cur);
1571-
ggml_build_forward_expand(gf, v_cur);
1572-
1573-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1574-
const auto & n_ctx = cparams.n_ctx;
1575-
1576-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1577-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1578-
1579-
const auto n_tokens = q_cur->ne[2];
1580-
1581-
const bool v_trans = !cparams.flash_attn;
1582-
1583-
// store to KV cache
1584-
{
1585-
GGML_ASSERT(!kv_self->recurrent);
1586-
1587-
const auto kv_head = kv_self->head;
1588-
1589-
GGML_ASSERT(kv_self->size == n_ctx);
1590-
1591-
ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
1592-
//cb(k_cache_view, "k_cache_view", il);
1593-
1594-
// note: storing RoPE-ed version of K in the KV cache
1595-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
1596-
1597-
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
1598-
1599-
ggml_tensor * v_cache_view = nullptr;
1600-
1601-
if (!v_trans) {
1602-
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1603-
} else {
1604-
// note: the V cache is transposed when not using flash attention
1605-
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1606-
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1607-
(kv_head)*ggml_element_size(kv_self->v_l[il]));
1608-
1609-
v_cur = ggml_transpose(ctx0, v_cur);
1610-
}
1611-
//cb(v_cache_view, "v_cache_view", il);
1612-
1613-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1614-
}
1615-
1616-
const bool is_swa = hparams.is_swa(il);
1617-
1618-
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1619-
1620-
const auto n_kv = kv_self->n;
1621-
1622-
const int64_t n_head_kv = hparams.n_head_kv(il);
1623-
1624-
const auto & n_embd_head_k = hparams.n_embd_head_k;
1625-
const auto & n_embd_head_v = hparams.n_embd_head_v;
1626-
1627-
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1628-
//cb(q, "q", il);
1629-
1630-
ggml_tensor * k =
1631-
ggml_view_3d(ctx0, kv_self->k_l[il],
1632-
n_embd_head_k, n_kv, n_head_kv,
1633-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1634-
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1635-
0);
1636-
//cb(k, "k", il);
1637-
1638-
ggml_tensor * v = !v_trans ?
1639-
ggml_view_3d(ctx0, kv_self->v_l[il],
1640-
n_embd_head_v, n_kv, n_head_kv,
1641-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1642-
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1643-
0) :
1644-
ggml_view_3d(ctx0, kv_self->v_l[il],
1645-
n_kv, n_embd_head_v, n_head_kv,
1646-
ggml_element_size(kv_self->v_l[il])*n_ctx,
1647-
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
1648-
0);
1649-
1650-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1651-
cb(cur, "kqv_out", il);
1652-
1653-
if (wo) {
1654-
cur = build_lora_mm(wo, cur);
1655-
}
1656-
1657-
if (wo_b) {
1658-
//cb(cur, "kqv_wo", il);
1659-
}
1660-
1661-
if (wo_b) {
1662-
cur = ggml_add(ctx0, cur, wo_b);
1663-
}
1664-
1665-
return cur;
1666-
1667-
}
1668-
16691555
ggml_tensor * llm_graph_context::build_copy_mask_state(
16701556
ggml_cgraph * gf,
16711557
ggml_tensor * s,

src/llama-graph.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ struct llm_graph_context {
525525
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
526526
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
527527
ggml_tensor * kq_b,
528+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
528529
float kq_scale,
529530
int il) const;
530531

@@ -539,6 +540,7 @@ struct llm_graph_context {
539540
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
540541
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
541542
ggml_tensor * kq_b,
543+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
542544
float kq_scale,
543545
int il) const;
544546

@@ -552,21 +554,6 @@ struct llm_graph_context {
552554
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
553555
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
554556
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
555-
ggml_tensor * kq_b,
556-
float kq_scale,
557-
int il) const;
558-
559-
// ****************************************************************************************************************
560-
// *** THIS WILL BE REMOVED AFTER CODE REVIEW IS ACCPETED AND READY TO MERGE - IT'S JUST A COPY OF build_attn() ***
561-
// ****************************************************************************************************************
562-
ggml_tensor * build_attn_mla(
563-
llm_graph_input_attn_kv_unified * inp,
564-
ggml_cgraph * gf,
565-
ggml_tensor * wo,
566-
ggml_tensor * wo_b,
567-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
568-
ggml_tensor * k_cur, // [n_embd_head_k, 1, n_tokens]
569-
ggml_tensor * v_cur, // [n_embd_head_v, 1, n_tokens]
570557
ggml_tensor * kq_b,
571558
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
572559
float kq_scale,

0 commit comments

Comments
 (0)