Skip to content

Commit f48c27d

Browse files
committed
rwkv: skip computing output for unused tokens for hybrid models
Signed-off-by: Molly Sophia <[email protected]>
1 parent 01c784a commit f48c27d

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

src/llama.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7745,7 +7745,18 @@ struct llm_build_context {
77457745
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
77467746
1
77477747
);
7748-
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
7748+
7749+
struct ggml_tensor * inp_ffn = x_norm_ffn;
7750+
7751+
if (il == n_layer - 1) {
7752+
// skip computing output for unused tokens
7753+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7754+
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
7755+
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
7756+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7757+
}
7758+
7759+
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
77497760
ggml_build_forward_expand(gf, cur);
77507761

77517762
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@@ -7774,9 +7785,8 @@ struct llm_build_context {
77747785
}
77757786

77767787
cur = inpL;
7777-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7778-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
7779-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7788+
// struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7789+
// cur = ggml_get_rows(ctx0, cur, inp_out_ids);
77807790

77817791
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
77827792
cb(cur, "result_norm", -1);
@@ -7863,6 +7873,13 @@ struct llm_build_context {
78637873

78647874
cb(ffn_inp, "ffn_inp", il);
78657875

7876+
if (il == n_layer - 1) {
7877+
// skip computing output for unused tokens
7878+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7879+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7880+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
7881+
}
7882+
78667883
// feed-forward network
78677884
cur = llm_build_norm(ctx0, ffn_inp, hparams,
78687885
model.layers[il].ffn_norm, NULL,
@@ -7886,10 +7903,6 @@ struct llm_build_context {
78867903
}
78877904

78887905
cur = inpL;
7889-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7890-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
7891-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7892-
78937906
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
78947907
cb(cur, "result_norm", -1);
78957908

@@ -7970,7 +7983,18 @@ struct llm_build_context {
79707983
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
79717984
1
79727985
);
7973-
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
7986+
7987+
struct ggml_tensor * inp_ffn = x_norm_ffn;
7988+
7989+
if (il == n_layer - 1) {
7990+
// skip computing output for unused tokens
7991+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7992+
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
7993+
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
7994+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7995+
}
7996+
7997+
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
79747998
ggml_build_forward_expand(gf, cur);
79757999

79768000
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
@@ -7999,10 +8023,6 @@ struct llm_build_context {
79998023
}
80008024

80018025
cur = inpL;
8002-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8003-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
8004-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8005-
80068026
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
80078027
cb(cur, "result_norm", -1);
80088028

@@ -8084,6 +8104,13 @@ struct llm_build_context {
80848104

80858105
cb(ffn_inp, "ffn_inp", il);
80868106

8107+
if (il == n_layer - 1) {
8108+
// skip computing output for unused tokens
8109+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8110+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8111+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
8112+
}
8113+
80878114
// feed-forward network
80888115
cur = llm_build_norm(ctx0, ffn_inp, hparams,
80898116
model.layers[il].ffn_norm, NULL,
@@ -8107,10 +8134,6 @@ struct llm_build_context {
81078134
}
81088135

81098136
cur = inpL;
8110-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8111-
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
8112-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8113-
81148137
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
81158138
cb(cur, "result_norm", -1);
81168139

0 commit comments

Comments
 (0)