Skip to content

Commit 201b31d

Browse files
authored
graph : fix geglu (#14077)
ggml-ci
1 parent e21d2d4 commit 201b31d

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -663,22 +663,14 @@ ggml_tensor * llm_graph_context::build_ffn(
663663
{
664664
// Split into two equal parts
665665
int64_t split_point = cur->ne[0] / 2;
666-
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
667-
ctx0, cur, split_point,
668-
cur->ne[1], cur->nb[1], 0
669-
));
670-
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
671-
ctx0, cur, split_point,
672-
cur->ne[1], cur->nb[1],
673-
split_point * ggml_element_size(cur)
674-
));
675-
676-
// Apply GELU activation function to the first part
677-
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
678-
cb(output_ffn_up, "ffn_gelu", il);
679-
680-
// Element-wise multiplication between the activated part and the gate part
681-
cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
666+
// TODO: these conts should not be needed
667+
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
668+
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
669+
670+
x0 = ggml_gelu(ctx0, x0);
671+
cb(x0, "ffn_gelu", il);
672+
673+
cur = ggml_mul(ctx0, x0, x1);
682674
cb(cur, "ffn_geglu", il);
683675
} break;
684676
}

0 commit comments

Comments
 (0)