Skip to content

Commit e6a2809

Browse files
committed
add chunk attn mask
1 parent f8f1bd4 commit e6a2809

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

src/llama-graph.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
474474
}
475475

476476
// may need to cut off old tokens for sliding window
477+
// TODO @ngxson : the check for n_attn_chunk is temporary, need to optimize it
477478
if (data_swa) {
478-
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
479-
f = -INFINITY;
479+
if (hparams.n_attn_chunk) {
480+
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
481+
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
482+
f = -INFINITY;
483+
}
484+
} else {
485+
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
486+
f = -INFINITY;
487+
}
480488
}
481489
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
482490
}

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ struct llama_hparams {
114114

115115
uint32_t n_moe_layer_step = 0;
116116
bool use_kq_norm = true;
117+
uint32_t n_attn_chunk = 0;
117118
// values below seems to be fixed on llama4
118119
uint32_t n_no_rope_layer_step = 4;
119120
uint32_t n_attn_temp_floor_scale = 8192;

src/llama-model.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
557557
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
558558
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
559559
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
560+
// hack: we use SWA to store the chunked attn mask
561+
// luckily, the n_swa_pattern is the same as chunked layer pattern: 3 chunked - 1 full
562+
hparams.n_swa_pattern = 4;
563+
hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
564+
hparams.n_swa = 1; // unused, added to trigger the SWA
560565

561566
switch (hparams.n_expert) {
562567
case 16: type = LLM_TYPE_17B_16E; break;

0 commit comments

Comments
 (0)