@@ -76,13 +76,13 @@ llama_context::llama_context(
76
76
}
77
77
78
78
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
79
- cparams.causal_attn = hparams.causal_attn ;
79
+ cparams.attn_type = hparams.causal_attn ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL ;
80
80
} else {
81
- cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL ;
81
+ cparams.attn_type = params.attention_type ;
82
82
}
83
83
84
84
// with causal attention, the batch size is limited by the context size
85
- cparams.n_batch = cparams.causal_attn ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
85
+ cparams.n_batch = cparams.use_past_tokens () ? std::min (cparams.n_ctx , params.n_batch ) : params.n_batch ;
86
86
87
87
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
88
88
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
@@ -102,7 +102,7 @@ llama_context::llama_context(
102
102
LLAMA_LOG_INFO (" %s: n_ctx_per_seq = %u\n " , __func__, n_ctx_per_seq);
103
103
LLAMA_LOG_INFO (" %s: n_batch = %u\n " , __func__, cparams.n_batch );
104
104
LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
105
- LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
105
+ LLAMA_LOG_INFO (" %s: attn_type = %d\n " , __func__, cparams.attn_type );
106
106
LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
107
107
LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
108
108
LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
@@ -966,10 +966,10 @@ void llama_context::set_embeddings(bool value) {
966
966
cparams.embeddings = value;
967
967
}
968
968
969
- void llama_context::set_causal_attn ( bool value) {
969
+ void llama_context::set_attn_type ( enum llama_attention_type value) {
970
970
LLAMA_LOG_DEBUG (" %s: value = %d\n " , __func__, value);
971
971
972
- cparams.causal_attn = value;
972
+ cparams.attn_type = value;
973
973
}
974
974
975
975
void llama_context::set_warmup (bool value) {
@@ -1074,12 +1074,12 @@ int llama_context::encode(llama_batch & inp_batch) {
1074
1074
ggml_backend_sched_reset (sched.get ());
1075
1075
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1076
1076
1077
- const auto causal_attn_org = cparams.causal_attn ;
1077
+ const auto attn_type_org = cparams.attn_type ;
1078
1078
1079
1079
// always use non-causal attention for encoder graphs
1080
1080
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
1081
1081
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1082
- cparams.causal_attn = false ;
1082
+ cparams.attn_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL ;
1083
1083
1084
1084
auto * gf = graph_init ();
1085
1085
auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
@@ -1088,7 +1088,7 @@ int llama_context::encode(llama_batch & inp_batch) {
1088
1088
1089
1089
res->set_inputs (&ubatch);
1090
1090
1091
- cparams.causal_attn = causal_attn_org ;
1091
+ cparams.attn_type = attn_type_org ;
1092
1092
1093
1093
const auto compute_status = graph_compute (gf, n_tokens > 1 );
1094
1094
switch (compute_status) {
@@ -1242,7 +1242,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1242
1242
1243
1243
GGML_ASSERT (n_tokens_all <= cparams.n_batch );
1244
1244
1245
- GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
1245
+ GGML_ASSERT ((! cparams.use_past_tokens () || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
1246
1246
1247
1247
if (t_compute_start_us == 0 ) {
1248
1248
t_compute_start_us = ggml_time_us ();
@@ -1495,7 +1495,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1495
1495
// synchronize();
1496
1496
1497
1497
// decide if we need to defrag the kv cache
1498
- if (cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1498
+ if (cparams.use_past_tokens () && cparams.defrag_thold > 0 .0f ) {
1499
1499
// - do not defrag small contexts (i.e. < 2048 tokens)
1500
1500
// - count the padding towards the number of used tokens
1501
1501
const float fragmentation = kv_self->n >= 2048 ? std::max (0 .0f , 1 .0f - float (kv_self->used + kv_self->get_padding (cparams))/float (kv_self->n )) : 0 .0f ;
@@ -2410,8 +2410,12 @@ void llama_set_embeddings(llama_context * ctx, bool embeddings) {
2410
2410
ctx->set_embeddings (embeddings);
2411
2411
}
2412
2412
2413
+ void llama_set_attn_type (llama_context * ctx, llama_attention_type type) {
2414
+ ctx->set_attn_type (type);
2415
+ }
2416
+
2413
2417
void llama_set_causal_attn (llama_context * ctx, bool causal_attn) {
2414
- ctx->set_causal_attn (causal_attn);
2418
+ ctx->set_attn_type (causal_attn ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL );
2415
2419
}
2416
2420
2417
2421
void llama_set_warmup (llama_context * ctx, bool warmup) {
0 commit comments