File tree Expand file tree Collapse file tree 3 files changed +16
-2
lines changed Expand file tree Collapse file tree 3 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -474,9 +474,17 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
474
474
}
475
475
476
476
// may need to cut off old tokens for sliding window
477
+ // TODO @ngxson : the check for n_attn_chunk is temporary, need to optimize it
477
478
if (data_swa) {
478
- if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
479
- f = -INFINITY;
479
+ if (hparams.n_attn_chunk ) {
480
+ llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk ) * hparams.n_attn_chunk ;
481
+ if (kv_self->cells [i].pos < pos_chunk_start || pos < pos_chunk_start) {
482
+ f = -INFINITY;
483
+ }
484
+ } else {
485
+ if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
486
+ f = -INFINITY;
487
+ }
480
488
}
481
489
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
482
490
}
Original file line number Diff line number Diff line change @@ -114,6 +114,7 @@ struct llama_hparams {
114
114
115
115
uint32_t n_moe_layer_step = 0 ;
116
116
bool use_kq_norm = true ;
117
+ uint32_t n_attn_chunk = 0 ;
117
118
// values below seems to be fixed on llama4
118
119
uint32_t n_no_rope_layer_step = 4 ;
119
120
uint32_t n_attn_temp_floor_scale = 8192 ;
Original file line number Diff line number Diff line change @@ -557,6 +557,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
557
557
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
558
558
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
559
559
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
560
+ // hack: we use SWA to store the chunked attn mask
561
+ // luckily, the n_swa_pattern is the same as chunked layer pattern: 3 chunked - 1 full
562
+ hparams.n_swa_pattern = 4;
563
+ hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
564
+ hparams.n_swa = 1; // unused, added to trigger the SWA
560
565
561
566
switch (hparams.n_expert) {
562
567
case 16: type = LLM_TYPE_17B_16E; break;
You can’t perform that action at this time.
0 commit comments