Skip to content

Commit 7621985

Browse files
committed
rwkv: skip computing output for unused tokens for hybrid models
Signed-off-by: Molly Sophia <[email protected]>
1 parent 01c784a commit 7621985

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

src/llama.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7775,7 +7775,6 @@ struct llm_build_context {
77757775

77767776
cur = inpL;
77777777
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7778-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
77797778
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
77807779

77817780
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
@@ -7863,6 +7862,13 @@ struct llm_build_context {
78637862

78647863
cb(ffn_inp, "ffn_inp", il);
78657864

7865+
if (il == n_layer - 1) {
7866+
// skip computing output for unused tokens
7867+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7868+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7869+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
7870+
}
7871+
78667872
// feed-forward network
78677873
cur = llm_build_norm(ctx0, ffn_inp, hparams,
78687874
model.layers[il].ffn_norm, NULL,
@@ -7886,10 +7892,6 @@ struct llm_build_context {
78867892
}
78877893

78887894
cur = inpL;
7889-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7890-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
7891-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7892-
78937895
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
78947896
cb(cur, "result_norm", -1);
78957897

@@ -8000,7 +8002,6 @@ struct llm_build_context {
80008002

80018003
cur = inpL;
80028004
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8003-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
80048005
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
80058006

80068007
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
@@ -8084,6 +8085,13 @@ struct llm_build_context {
80848085

80858086
cb(ffn_inp, "ffn_inp", il);
80868087

8088+
if (il == n_layer - 1) {
8089+
// skip computing output for unused tokens
8090+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8091+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8092+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
8093+
}
8094+
80878095
// feed-forward network
80888096
cur = llm_build_norm(ctx0, ffn_inp, hparams,
80898097
model.layers[il].ffn_norm, NULL,
@@ -8107,10 +8115,6 @@ struct llm_build_context {
81078115
}
81088116

81098117
cur = inpL;
8110-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8111-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
8112-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8113-
81148118
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
81158119
cb(cur, "result_norm", -1);
81168120

0 commit comments

Comments
 (0)