Skip to content

Commit 91a8ee6

Browse files
huydt84huydt-bti
andauthored
add geglu activation function (#14074)
Co-authored-by: dinhhuy <[email protected]>
1 parent 056eb74 commit 91a8ee6

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/llama-graph.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,28 @@ ggml_tensor * llm_graph_context::build_ffn(
659659
cur = ggml_mul(ctx0, x0, x1);
660660
cb(cur, "ffn_mul", il);
661661
} break;
662+
case LLM_FFN_GEGLU:
663+
{
664+
// Split into two equal parts
665+
int64_t split_point = cur->ne[0] / 2;
666+
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
667+
ctx0, cur, split_point,
668+
cur->ne[1], cur->nb[1], 0
669+
));
670+
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
671+
ctx0, cur, split_point,
672+
cur->ne[1], cur->nb[1],
673+
split_point * ggml_element_size(cur)
674+
));
675+
676+
// Apply GELU activation function to the first part
677+
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
678+
cb(output_ffn_up, "ffn_gelu", il);
679+
680+
// Element-wise multiplication between the activated part and the gate part
681+
cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
682+
cb(cur, "ffn_geglu", il);
683+
} break;
662684
}
663685

664686
if (gate && type_gate == LLM_FFN_PAR) {

src/llama-graph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum llm_ffn_op_type {
3636
LLM_FFN_RELU,
3737
LLM_FFN_RELU_SQR,
3838
LLM_FFN_SWIGLU,
39+
LLM_FFN_GEGLU,
3940
};
4041

4142
enum llm_ffn_gate_type {

0 commit comments

Comments
 (0)