Skip to content

Commit 2ed9fb4

Browse files
committed
fusion POC
1 parent 1534edd commit 2ed9fb4

File tree

4 files changed

+128
-34
lines changed

4 files changed

+128
-34
lines changed

ggml/include/ggml.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,11 @@ extern "C" {
614614

615615
void * extra; // extra things e.g. for ggml-cuda.cu
616616

617-
char padding[8];
617+
// number of operations that use this tensor as a src
618+
int32_t use_count;
619+
620+
// add padding if needed to make a multiple of GGML_MEM_ALIGN
621+
char padding[4];
618622
};
619623

620624
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

ggml/src/ggml-backend.cpp

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,8 +817,8 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
817817
}
818818
if (sched->debug > 1) {
819819
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820-
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
821-
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
820+
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
821+
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), node->use_count);
822822
for (int j = 0; j < GGML_MAX_SRC; j++) {
823823
struct ggml_tensor * src = node->src[j];
824824
if (src == NULL) {
@@ -1562,11 +1562,99 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
15621562
return true;
15631563
}
15641564

1565+
bool is_view_2d(ggml_tensor *view,
1566+
ggml_tensor *src,
1567+
int64_t ne0,
1568+
int64_t ne1,
1569+
size_t nb1,
1570+
size_t offset) {
1571+
if (view->op != GGML_OP_VIEW || view->view_src != src) {
1572+
return false;
1573+
}
1574+
1575+
if (view->nb[0] != src->nb[0] ||
1576+
view->nb[1] != nb1 ||
1577+
view->nb[2] != view->nb[1] * ne1 ||
1578+
view->nb[3] != view->nb[2]) {
1579+
return false;
1580+
}
1581+
if (view->ne[0] != ne0 ||
1582+
view->ne[1] != ne1 ||
1583+
view->ne[2] != 1 ||
1584+
view->ne[3] != 1) {
1585+
return false;
1586+
}
1587+
if (view->view_offs != view->view_src->view_offs + offset) {
1588+
return false;
1589+
}
1590+
return true;
1591+
}
1592+
15651593
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
15661594
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
15671595

15681596
ggml_backend_sched_split_graph(sched, graph);
15691597

1598+
for (int s = 0; s < sched->n_splits; s++) {
1599+
for (int i = sched->splits[s].graph.n_nodes - 1; i >= 0; i--) {
1600+
struct ggml_tensor * node = sched->splits[s].graph.nodes[i];
1601+
1602+
// peephole to find swiglu:
1603+
// x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
1604+
// x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
1605+
// x0 = ggml_silu(ctx0, x0);
1606+
// ggml_mul(ctx0, x0, x1);
1607+
do {
1608+
if (node->op != GGML_OP_MUL) {
1609+
break;
1610+
}
1611+
ggml_tensor * src0 = node->src[0];
1612+
ggml_tensor * src1 = node->src[1];
1613+
if (src0->op != GGML_OP_UNARY || ggml_get_op_params_i32(src0, 0) != GGML_UNARY_OP_SILU) {
1614+
break;
1615+
}
1616+
src0 = src0->src[0];
1617+
if (src0->op != GGML_OP_CONT || src1->op != GGML_OP_CONT) {
1618+
break;
1619+
}
1620+
if (src0->use_count != 1 || src1->use_count != 1) {
1621+
break;
1622+
}
1623+
src0 = src0->src[0];
1624+
src1 = src1->src[0];
1625+
1626+
if (src0->use_count != 1 || src1->use_count != 1) {
1627+
break;
1628+
}
1629+
1630+
ggml_tensor * input = src0->src[0];
1631+
if (!input || input->use_count != 2) {
1632+
break;
1633+
}
1634+
1635+
uint32_t split_point = input->ne[0] / 2;
1636+
1637+
if (!is_view_2d(src0, input, split_point, input->ne[1], input->nb[1], 0)) {
1638+
return false;
1639+
}
1640+
if (!is_view_2d(src1, input, split_point, input->ne[1], input->nb[1], split_point * ggml_element_size(input))) {
1641+
return false;
1642+
}
1643+
//printf("detected swiglu\n");
1644+
1645+
node->src[0]->op = GGML_OP_NONE;
1646+
node->src[0]->src[0]->op = GGML_OP_NONE;
1647+
node->src[1]->op = GGML_OP_NONE;
1648+
1649+
node->op = GGML_OP_GLU;
1650+
node->src[0] = input;
1651+
node->src[1] = NULL;
1652+
ggml_set_op_params_i32(node, 0, (int32_t) GGML_GLU_OP_SWIGLU);
1653+
ggml_set_op_params_i32(node, 1, (int32_t) false);
1654+
} while (0);
1655+
}
1656+
}
1657+
15701658
if (!ggml_backend_sched_alloc_splits(sched)) {
15711659
return false;
15721660
}

ggml/src/ggml.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
16401640
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
16411641
/*.name =*/ { 0 },
16421642
/*.extra =*/ NULL,
1643+
/*.use_count =*/ 0,
16431644
/*.padding =*/ { 0 },
16441645
};
16451646

@@ -5962,6 +5963,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
59625963
/* unknown order, just fall back to using i*/ i;
59635964
if (node->src[k]) {
59645965
ggml_visit_parents(cgraph, node->src[k]);
5966+
node->src[k]->use_count++;
59655967
}
59665968
}
59675969

src/llama-graph.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -554,20 +554,12 @@ ggml_tensor * llm_graph_context::build_ffn(
554554

555555
switch (type_op) {
556556
case LLM_FFN_SILU:
557-
if (gate && type_gate == LLM_FFN_PAR) {
558-
cur = ggml_swiglu_split(ctx0, cur, tmp);
559-
cb(cur, "ffn_swiglu", il);
560-
type_gate = LLM_FFN_SEQ;
561-
} else {
557+
{
562558
cur = ggml_silu(ctx0, cur);
563559
cb(cur, "ffn_silu", il);
564560
} break;
565561
case LLM_FFN_GELU:
566-
if (gate && type_gate == LLM_FFN_PAR) {
567-
cur = ggml_geglu_split(ctx0, cur, tmp);
568-
cb(cur, "ffn_geglu", il);
569-
type_gate = LLM_FFN_SEQ;
570-
} else {
562+
{
571563
cur = ggml_gelu(ctx0, cur);
572564
cb(cur, "ffn_gelu", il);
573565
if (act_scales != NULL) {
@@ -576,11 +568,7 @@ ggml_tensor * llm_graph_context::build_ffn(
576568
}
577569
} break;
578570
case LLM_FFN_RELU:
579-
if (gate && type_gate == LLM_FFN_PAR) {
580-
cur = ggml_reglu_split(ctx0, cur, tmp);
581-
cb(cur, "ffn_reglu", il);
582-
type_gate = LLM_FFN_SEQ;
583-
} else {
571+
{
584572
cur = ggml_relu(ctx0, cur);
585573
cb(cur, "ffn_relu", il);
586574
} break;
@@ -594,19 +582,32 @@ ggml_tensor * llm_graph_context::build_ffn(
594582
} break;
595583
case LLM_FFN_SWIGLU:
596584
{
597-
cur = ggml_swiglu(ctx0, cur);
598-
cb(cur, "ffn_swiglu", il);
585+
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
586+
int64_t split_point = cur->ne[0] / 2;
587+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
588+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
589+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
590+
591+
x0 = ggml_silu(ctx0, x0);
592+
cb(cur, "ffn_silu", il);
593+
594+
cur = ggml_mul(ctx0, x0, x1);
595+
cb(cur, "ffn_mul", il);
599596
} break;
600597
case LLM_FFN_GEGLU:
601598
{
602-
cur = ggml_geglu(ctx0, cur);
599+
// Split into two equal parts
600+
int64_t split_point = cur->ne[0] / 2;
601+
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
602+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
603+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
604+
605+
x0 = ggml_gelu(ctx0, x0);
606+
cb(x0, "ffn_gelu", il);
607+
608+
cur = ggml_mul(ctx0, x0, x1);
603609
cb(cur, "ffn_geglu", il);
604610
} break;
605-
case LLM_FFN_REGLU:
606-
{
607-
cur = ggml_reglu(ctx0, cur);
608-
cb(cur, "ffn_reglu", il);
609-
} break;
610611
}
611612

612613
if (gate && type_gate == LLM_FFN_PAR) {
@@ -736,25 +737,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
736737

737738
switch (type_op) {
738739
case LLM_FFN_SILU:
739-
if (gate_exps) {
740-
cur = ggml_swiglu_split(ctx0, cur, up);
741-
cb(cur, "ffn_moe_swiglu", il);
742-
} else {
740+
{
743741
cur = ggml_silu(ctx0, cur);
744742
cb(cur, "ffn_moe_silu", il);
745743
} break;
746744
case LLM_FFN_GELU:
747-
if (gate_exps) {
748-
cur = ggml_geglu_split(ctx0, cur, up);
749-
cb(cur, "ffn_moe_geglu", il);
750-
} else {
745+
{
751746
cur = ggml_gelu(ctx0, cur);
752747
cb(cur, "ffn_moe_gelu", il);
753748
} break;
754749
default:
755750
GGML_ABORT("fatal error");
756751
}
757752

753+
if (gate_exps) {
754+
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
755+
cb(cur, "ffn_moe_gate_par", il);
756+
}
757+
758758
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
759759
cb(experts, "ffn_moe_down", il);
760760

0 commit comments

Comments
 (0)