Skip to content

Commit c137349

Browse files
committed
rwkv: skip computing output for unused tokens for hybrid models
Signed-off-by: Molly Sophia <[email protected]>
1 parent e2044f6 commit c137349

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

src/llama.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7760,7 +7760,18 @@ struct llm_build_context {
77607760
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
77617761
1
77627762
);
7763-
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
7763+
7764+
struct ggml_tensor * inp_ffn = x_norm_ffn;
7765+
7766+
if (il == n_layer - 1) {
7767+
// skip computing output for unused tokens
7768+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7769+
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
7770+
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
7771+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7772+
}
7773+
7774+
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
77647775
ggml_build_forward_expand(gf, cur);
77657776

77667777
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@@ -7789,9 +7800,8 @@ struct llm_build_context {
77897800
}
77907801

77917802
cur = inpL;
7792-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7793-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
7794-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7803+
// struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7804+
// cur = ggml_get_rows(ctx0, cur, inp_out_ids);
77957805

77967806
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
77977807
cb(cur, "result_norm", -1);
@@ -7874,6 +7884,13 @@ struct llm_build_context {
78747884

78757885
cb(ffn_inp, "ffn_inp", il);
78767886

7887+
if (il == n_layer - 1) {
7888+
// skip computing output for unused tokens
7889+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7890+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7891+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
7892+
}
7893+
78777894
// feed-forward network
78787895
cur = llm_build_norm(ctx0, ffn_inp, hparams,
78797896
model.layers[il].ffn_norm, NULL,
@@ -7897,10 +7914,6 @@ struct llm_build_context {
78977914
}
78987915

78997916
cur = inpL;
7900-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7901-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
7902-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7903-
79047917
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
79057918
cb(cur, "result_norm", -1);
79067919

@@ -7981,7 +7994,18 @@ struct llm_build_context {
79817994
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
79827995
1
79837996
);
7984-
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
7997+
7998+
struct ggml_tensor * inp_ffn = x_norm_ffn;
7999+
8000+
if (il == n_layer - 1) {
8001+
// skip computing output for unused tokens
8002+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8003+
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
8004+
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
8005+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8006+
}
8007+
8008+
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
79858009
ggml_build_forward_expand(gf, cur);
79868010

79878011
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@@ -8010,10 +8034,6 @@ struct llm_build_context {
80108034
}
80118035

80128036
cur = inpL;
8013-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8014-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
8015-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8016-
80178037
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
80188038
cb(cur, "result_norm", -1);
80198039

@@ -8095,6 +8115,13 @@ struct llm_build_context {
80958115

80968116
cb(ffn_inp, "ffn_inp", il);
80978117

8118+
if (il == n_layer - 1) {
8119+
// skip computing output for unused tokens
8120+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8121+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8122+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
8123+
}
8124+
80988125
// feed-forward network
80998126
cur = llm_build_norm(ctx0, ffn_inp, hparams,
81008127
model.layers[il].ffn_norm, NULL,
@@ -8118,10 +8145,6 @@ struct llm_build_context {
81188145
}
81198146

81208147
cur = inpL;
8121-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8122-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
8123-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8124-
81258148
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
81268149
cb(cur, "result_norm", -1);
81278150

0 commit comments

Comments
 (0)