Skip to content

Commit c3ecff5

Browse files
committed
llama : add llama_set_attn_type API
1 parent 5dec47d commit c3ecff5

File tree

6 files changed

+39
-21
lines changed

6 files changed

+39
-21
lines changed

examples/llava/gemma3-cli.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,14 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
178178
// decode image embeddings
179179
int64_t t1 = ggml_time_ms();
180180
eval_text(ctx, "<start_of_image>");
181-
llama_set_causal_attn(ctx.lctx, false);
181+
llama_set_attn_type(ctx.lctx, LLAMA_ATTENTION_TYPE_CAUSAL_FULL);
182182
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
183183
if (llama_decode(ctx.lctx, batch_img.batch)) {
184184
LOG_ERR("failed to decode image\n");
185185
return 1;
186186
}
187187
ctx.n_past += n_tokens;
188-
llama_set_causal_attn(ctx.lctx, true);
188+
llama_set_attn_type(ctx.lctx, LLAMA_ATTENTION_TYPE_CAUSAL);
189189
eval_text(ctx, "<end_of_image>");
190190
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
191191
return 0;

include/llama.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ extern "C" {
208208
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
209209
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
210210
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
211+
LLAMA_ATTENTION_TYPE_CAUSAL_FULL = 2, // used by gemma 3, allowing image to attention to past tokens
211212
};
212213

213214
enum llama_split_mode {
@@ -942,8 +943,12 @@ extern "C" {
942943
// If true, embeddings will be returned but logits will not
943944
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
944945

946+
// Set the attention type
947+
LLAMA_API void llama_set_attn_type(struct llama_context * ctx, llama_attention_type type);
948+
945949
// Set whether to use causal attention or not
946-
// If set to true, the model will only attend to the past tokens
950+
// - true: the model will only attend to the past tokens, alias of LLAMA_ATTENTION_TYPE_CAUSAL
951+
// - false: alias of LLAMA_ATTENTION_TYPE_FULL
947952
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
948953

949954
// Set whether the model is in warmup mode or not

src/llama-context.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ llama_context::llama_context(
7676
}
7777

7878
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
79-
cparams.causal_attn = hparams.causal_attn;
79+
cparams.attn_type = hparams.causal_attn ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL;
8080
} else {
81-
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
81+
cparams.attn_type = params.attention_type;
8282
}
8383

8484
// with causal attention, the batch size is limited by the context size
85-
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
85+
cparams.n_batch = cparams.use_past_tokens() ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
8686

8787
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
8888
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
@@ -102,7 +102,7 @@ llama_context::llama_context(
102102
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
103103
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
104104
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
105-
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
105+
LLAMA_LOG_INFO("%s: attn_type = %d\n", __func__, cparams.attn_type);
106106
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
107107
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
108108
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@@ -966,10 +966,10 @@ void llama_context::set_embeddings(bool value) {
966966
cparams.embeddings = value;
967967
}
968968

969-
void llama_context::set_causal_attn(bool value) {
969+
void llama_context::set_attn_type(enum llama_attention_type value) {
970970
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
971971

972-
cparams.causal_attn = value;
972+
cparams.attn_type = value;
973973
}
974974

975975
void llama_context::set_warmup(bool value) {
@@ -1074,12 +1074,12 @@ int llama_context::encode(llama_batch & inp_batch) {
10741074
ggml_backend_sched_reset(sched.get());
10751075
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
10761076

1077-
const auto causal_attn_org = cparams.causal_attn;
1077+
const auto attn_type_org = cparams.attn_type;
10781078

10791079
// always use non-causal attention for encoder graphs
10801080
// TODO: this is a tmp solution until we have a proper way to support enc-dec models
10811081
// ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223
1082-
cparams.causal_attn = false;
1082+
cparams.attn_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL;
10831083

10841084
auto * gf = graph_init();
10851085
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
@@ -1088,7 +1088,7 @@ int llama_context::encode(llama_batch & inp_batch) {
10881088

10891089
res->set_inputs(&ubatch);
10901090

1091-
cparams.causal_attn = causal_attn_org;
1091+
cparams.attn_type = attn_type_org;
10921092

10931093
const auto compute_status = graph_compute(gf, n_tokens > 1);
10941094
switch (compute_status) {
@@ -1242,7 +1242,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12421242

12431243
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
12441244

1245-
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
1245+
GGML_ASSERT((!cparams.use_past_tokens() || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
12461246

12471247
if (t_compute_start_us == 0) {
12481248
t_compute_start_us = ggml_time_us();
@@ -1495,7 +1495,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14951495
//synchronize();
14961496

14971497
// decide if we need to defrag the kv cache
1498-
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1498+
if (cparams.use_past_tokens() && cparams.defrag_thold > 0.0f) {
14991499
// - do not defrag small contexts (i.e. < 2048 tokens)
15001500
// - count the padding towards the number of used tokens
15011501
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
@@ -2410,8 +2410,12 @@ void llama_set_embeddings(llama_context * ctx, bool embeddings) {
24102410
ctx->set_embeddings(embeddings);
24112411
}
24122412

2413+
void llama_set_attn_type(llama_context * ctx, llama_attention_type type) {
2414+
ctx->set_attn_type(type);
2415+
}
2416+
24132417
void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
2414-
ctx->set_causal_attn(causal_attn);
2418+
ctx->set_attn_type(causal_attn ? LLAMA_ATTENTION_TYPE_CAUSAL : LLAMA_ATTENTION_TYPE_NON_CAUSAL);
24152419
}
24162420

24172421
void llama_set_warmup(llama_context * ctx, bool warmup) {

src/llama-context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ struct llama_context {
6262

6363
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
6464

65-
void set_embeddings (bool value);
66-
void set_causal_attn(bool value);
65+
void set_embeddings(bool value);
66+
void set_attn_type(enum llama_attention_type value);
6767
void set_warmup(bool value);
6868

6969
void set_adapter_lora(

src/llama-cparams.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@ struct llama_cparams {
2525
float defrag_thold;
2626

2727
bool embeddings;
28-
bool causal_attn;
2928
bool offload_kqv;
3029
bool flash_attn;
3130
bool no_perf;
3231
bool warmup;
3332

33+
enum llama_attention_type attn_type;
34+
3435
enum llama_pooling_type pooling_type;
3536

37+
bool use_past_tokens() const {
38+
return attn_type == LLAMA_ATTENTION_TYPE_CAUSAL || attn_type == LLAMA_ATTENTION_TYPE_CAUSAL_FULL;
39+
}
40+
3641
ggml_backend_sched_eval_callback cb_eval;
3742
void * cb_eval_user_data;
3843
};

src/llama-graph.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
315315

316316
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
317317
if (kq_mask) {
318-
if (cparams.causal_attn) {
318+
if (cparams.use_past_tokens()) {
319319
const int64_t n_kv = ubatch->n_tokens;
320320
const int64_t n_tokens = ubatch->n_tokens;
321321
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -403,12 +403,14 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
403403
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
404404
if (self_kq_mask || self_kq_mask_swa) {
405405
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406-
if (cparams.causal_attn) {
406+
if (cparams.use_past_tokens()) {
407407
const int64_t n_kv = kv_self->n;
408408
const int64_t n_tokens = ubatch->n_tokens;
409409
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
410410
const int64_t n_seqs = ubatch->n_seqs;
411411

412+
bool full_mask = cparams.attn_type == LLAMA_ATTENTION_TYPE_CAUSAL_FULL;
413+
412414
float * data = nullptr;
413415
float * data_swa = nullptr;
414416

@@ -434,7 +436,9 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
434436

435437
for (int i = 0; i < n_kv; ++i) {
436438
float f;
437-
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
439+
// If bidirectional masking is enabled, ignore the ordering check
440+
if (!kv_self->cells[i].has_seq_id(seq_id) ||
441+
(!full_mask && kv_self->cells[i].pos > pos)) {
438442
f = -INFINITY;
439443
} else {
440444
if (hparams.use_alibi) {

0 commit comments

Comments
 (0)