@@ -1311,6 +1311,7 @@ ggml_tensor * llm_graph_context::build_attn(
1311
1311
ggml_tensor * k_cur,
1312
1312
ggml_tensor * v_cur,
1313
1313
ggml_tensor * kq_b,
1314
+ ggml_tensor * v_mla,
1314
1315
float kq_scale,
1315
1316
int il) const {
1316
1317
GGML_UNUSED (n_tokens);
@@ -1332,7 +1333,7 @@ ggml_tensor * llm_graph_context::build_attn(
1332
1333
ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
1333
1334
// cb(k, "v", il);
1334
1335
1335
- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, nullptr , false , kq_scale);
1336
+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla , false , kq_scale);
1336
1337
1337
1338
cb (cur, " kqv_out" , il);
1338
1339
@@ -1386,6 +1387,7 @@ ggml_tensor * llm_graph_context::build_attn(
1386
1387
ggml_tensor * k_cur,
1387
1388
ggml_tensor * v_cur,
1388
1389
ggml_tensor * kq_b,
1390
+ ggml_tensor * v_mla,
1389
1391
float kq_scale,
1390
1392
int il) const {
1391
1393
// these nodes are added to the graph together so that they are not reordered
@@ -1471,7 +1473,7 @@ ggml_tensor * llm_graph_context::build_attn(
1471
1473
ggml_element_size (kv_self->v_l [il])*n_ctx*n_embd_head_v,
1472
1474
0 );
1473
1475
1474
- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, nullptr , v_trans, kq_scale);
1476
+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla , v_trans, kq_scale);
1475
1477
cb (cur, " kqv_out" , il);
1476
1478
1477
1479
if (wo) {
@@ -1511,6 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
1511
1513
ggml_tensor * k_cur,
1512
1514
ggml_tensor * v_cur,
1513
1515
ggml_tensor * kq_b,
1516
+ ggml_tensor * v_mla,
1514
1517
float kq_scale,
1515
1518
int il) const {
1516
1519
// these nodes are added to the graph together so that they are not reordered
@@ -1530,7 +1533,7 @@ ggml_tensor * llm_graph_context::build_attn(
1530
1533
ggml_tensor * v = ggml_permute (ctx0, v_cur, 0 , 2 , 1 , 3 );
1531
1534
// cb(k, "v", il);
1532
1535
1533
- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, nullptr , false , kq_scale);
1536
+ ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla , false , kq_scale);
1534
1537
1535
1538
cb (cur, " kqv_out" , il);
1536
1539
@@ -1549,123 +1552,6 @@ ggml_tensor * llm_graph_context::build_attn(
1549
1552
return cur;
1550
1553
}
1551
1554
1552
- // ****************************************************************************************************************
1553
- // *** THIS WILL BE REMOVED AFTER CODE REVIEW IS ACCPETED AND READY TO MERGE - IT'S JUST A COPY OF build_attn() ***
1554
- // ****************************************************************************************************************
1555
- ggml_tensor * llm_graph_context::build_attn_mla (
1556
- llm_graph_input_attn_kv_unified * inp,
1557
- ggml_cgraph * gf,
1558
- ggml_tensor * wo,
1559
- ggml_tensor * wo_b,
1560
- ggml_tensor * q_cur,
1561
- ggml_tensor * k_cur,
1562
- ggml_tensor * v_cur,
1563
- ggml_tensor * kq_b,
1564
- ggml_tensor * v_mla,
1565
- float kq_scale,
1566
- int il) const {
1567
- // these nodes are added to the graph together so that they are not reordered
1568
- // by doing so, the number of splits in the graph is reduced
1569
- ggml_build_forward_expand (gf, q_cur);
1570
- ggml_build_forward_expand (gf, k_cur);
1571
- ggml_build_forward_expand (gf, v_cur);
1572
-
1573
- const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1574
- const auto & n_ctx = cparams.n_ctx ;
1575
-
1576
- const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
1577
- const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
1578
-
1579
- const auto n_tokens = q_cur->ne [2 ];
1580
-
1581
- const bool v_trans = !cparams.flash_attn ;
1582
-
1583
- // store to KV cache
1584
- {
1585
- GGML_ASSERT (!kv_self->recurrent );
1586
-
1587
- const auto kv_head = kv_self->head ;
1588
-
1589
- GGML_ASSERT (kv_self->size == n_ctx);
1590
-
1591
- ggml_tensor * k_cache_view = ggml_view_1d (ctx0, kv_self->k_l [il], n_tokens*n_embd_k_gqa, ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa)*kv_head);
1592
- // cb(k_cache_view, "k_cache_view", il);
1593
-
1594
- // note: storing RoPE-ed version of K in the KV cache
1595
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, k_cur, k_cache_view));
1596
-
1597
- v_cur = ggml_reshape_2d (ctx0, v_cur, n_embd_v_gqa, n_tokens);
1598
-
1599
- ggml_tensor * v_cache_view = nullptr ;
1600
-
1601
- if (!v_trans) {
1602
- v_cache_view = ggml_view_1d (ctx0, kv_self->v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa)*kv_head);
1603
- } else {
1604
- // note: the V cache is transposed when not using flash attention
1605
- v_cache_view = ggml_view_2d (ctx0, kv_self->v_l [il], n_tokens, n_embd_v_gqa,
1606
- ( n_ctx)*ggml_element_size (kv_self->v_l [il]),
1607
- (kv_head)*ggml_element_size (kv_self->v_l [il]));
1608
-
1609
- v_cur = ggml_transpose (ctx0, v_cur);
1610
- }
1611
- // cb(v_cache_view, "v_cache_view", il);
1612
-
1613
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
1614
- }
1615
-
1616
- const bool is_swa = hparams.is_swa (il);
1617
-
1618
- const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1619
-
1620
- const auto n_kv = kv_self->n ;
1621
-
1622
- const int64_t n_head_kv = hparams.n_head_kv (il);
1623
-
1624
- const auto & n_embd_head_k = hparams.n_embd_head_k ;
1625
- const auto & n_embd_head_v = hparams.n_embd_head_v ;
1626
-
1627
- ggml_tensor * q = ggml_permute (ctx0, q_cur, 0 , 2 , 1 , 3 );
1628
- // cb(q, "q", il);
1629
-
1630
- ggml_tensor * k =
1631
- ggml_view_3d (ctx0, kv_self->k_l [il],
1632
- n_embd_head_k, n_kv, n_head_kv,
1633
- ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa),
1634
- ggml_row_size (kv_self->k_l [il]->type , n_embd_head_k),
1635
- 0 );
1636
- // cb(k, "k", il);
1637
-
1638
- ggml_tensor * v = !v_trans ?
1639
- ggml_view_3d (ctx0, kv_self->v_l [il],
1640
- n_embd_head_v, n_kv, n_head_kv,
1641
- ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1642
- ggml_row_size (kv_self->v_l [il]->type , n_embd_head_v),
1643
- 0 ) :
1644
- ggml_view_3d (ctx0, kv_self->v_l [il],
1645
- n_kv, n_embd_head_v, n_head_kv,
1646
- ggml_element_size (kv_self->v_l [il])*n_ctx,
1647
- ggml_element_size (kv_self->v_l [il])*n_ctx*n_embd_head_v,
1648
- 0 );
1649
-
1650
- ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1651
- cb (cur, " kqv_out" , il);
1652
-
1653
- if (wo) {
1654
- cur = build_lora_mm (wo, cur);
1655
- }
1656
-
1657
- if (wo_b) {
1658
- // cb(cur, "kqv_wo", il);
1659
- }
1660
-
1661
- if (wo_b) {
1662
- cur = ggml_add (ctx0, cur, wo_b);
1663
- }
1664
-
1665
- return cur;
1666
-
1667
- }
1668
-
1669
1555
ggml_tensor * llm_graph_context::build_copy_mask_state (
1670
1556
ggml_cgraph * gf,
1671
1557
ggml_tensor * s,
0 commit comments