Skip to content

Commit 7286558

Browse files
committed
cont : fix Qwen VL multi-pos input
ggml-ci
1 parent 711d195 commit 7286558

9 files changed

+55
-25
lines changed

src/llama-batch.cpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <algorithm>
1010
#include <sstream>
1111

12-
llama_batch_allocr::llama_batch_allocr() {
12+
llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
1313
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
1414
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
1515

@@ -244,9 +244,22 @@ bool llama_batch_allocr::init(
244244
continue;
245245
}
246246

247-
if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
248-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
249-
return false;
247+
if (memory) {
248+
if (batch.token) {
249+
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
251+
return false;
252+
}
253+
} else {
254+
assert(batch.embd);
255+
256+
// for embeddings (typically used as vision input), we allow them to have repeating positions
257+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258+
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260+
return false;
261+
}
262+
}
250263
}
251264

252265
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
@@ -580,9 +593,14 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
580593

581594
auto & ubatch = ubatches.back();
582595

596+
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
597+
598+
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
599+
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
600+
583601
ubatch.token .resize(n_tokens);
584-
ubatch.embd .resize((int64_t) n_tokens*n_embd);
585-
ubatch.pos .resize(n_tokens);
602+
ubatch.embd .resize(n_embd_all);
603+
ubatch.pos .resize(n_pos_all);
586604
ubatch.n_seq_id .resize(n_tokens);
587605
ubatch.seq_id .resize(n_tokens);
588606
ubatch.seq_id_unq.resize(0);
@@ -600,7 +618,10 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
600618
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
601619
}
602620

603-
ubatch.pos[i] = batch.pos[idxs[i]];
621+
for (int j = 0; j < n_pos_cur; ++j) {
622+
ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
623+
}
624+
604625
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
605626
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
606627
ubatch.output[i] = batch.logits[idxs[i]];
@@ -714,9 +735,14 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
714735
}
715736
}
716737

717-
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
718-
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
719-
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
738+
if (ubatch.token) {
739+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
740+
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
741+
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
742+
} else {
743+
LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
744+
__func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
745+
}
720746
}
721747
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
722748
}

src/llama-batch.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct llama_ubatch {
3939
// a helper for sanitizing, fulfilling and splitting a batch
4040
class llama_batch_allocr {
4141
public:
42-
llama_batch_allocr();
42+
llama_batch_allocr(uint32_t n_pos_per_embd);
4343

4444
// sanitize and auto-gen missing data in the input batch
4545
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
@@ -93,6 +93,10 @@ class llama_batch_allocr {
9393
// only for debugging purposes
9494
const llama_vocab * vocab;
9595

96+
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
97+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
98+
const uint32_t n_pos_per_embd;
99+
96100
uint32_t n_embd;
97101
uint32_t n_outputs;
98102

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ llama_context::llama_context(
2020
const llama_model & model,
2121
llama_context_params params) :
2222
model(model),
23-
balloc(std::make_unique<llama_batch_allocr>()) {
23+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
2424
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
2525

2626
t_start_us = model.t_start_us;
@@ -1308,7 +1308,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13081308

13091309
this->n_outputs = n_outputs;
13101310

1311-
llama_batch_allocr balloc;
1311+
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
13121312
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
13131313

13141314
auto * gf = graph_init();

src/llama-graph.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
384384
res (std::make_unique<llm_graph_result>()) {
385385
}
386386

387-
int64_t llm_graph_context::n_pos_per_embd() const {
388-
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
389-
}
390-
391387
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
392388
if (cb_func) {
393389
cb_func(ubatch, cur, name, il);
@@ -832,11 +828,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
832828
}
833829

834830
ggml_tensor * llm_graph_context::build_inp_pos() const {
835-
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
831+
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
836832

837833
auto & cur = inp->pos;
838834

839-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
835+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
840836
ggml_set_input(cur);
841837

842838
res->add_input(std::move(inp));

src/llama-graph.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ class llm_graph_input_embd : public llm_graph_input_i {
9494

9595
class llm_graph_input_pos : public llm_graph_input_i {
9696
public:
97-
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
97+
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
9898
virtual ~llm_graph_input_pos() = default;
9999

100100
void set_input(const llama_ubatch * ubatch) override;
101101

102102
ggml_tensor * pos = nullptr; // I32 [n_batch]
103103

104-
const int64_t n_pos_per_embd = 1;
104+
const uint32_t n_pos_per_embd = 1;
105105
};
106106

107107
// temperature tuning, used by llama4
@@ -436,8 +436,6 @@ struct llm_graph_context {
436436

437437
llm_graph_context(const llm_graph_params & params);
438438

439-
int64_t n_pos_per_embd() const;
440-
441439
void cb(ggml_tensor * cur, const char * name, int il) const;
442440

443441
//

src/llama-hparams.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
8686
return ssm_d_state * ssm_d_inner;
8787
}
8888

89+
uint32_t llama_hparams::n_pos_per_embd() const {
90+
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
91+
}
92+
8993
bool llama_hparams::is_swa(uint32_t il) const {
9094
if (il < n_layer) {
9195
return swa_layers[il];

src/llama-hparams.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ struct llama_hparams {
186186
// dimension of the recurrent state embeddings
187187
uint32_t n_embd_v_s() const;
188188

189+
uint32_t n_pos_per_embd() const;
190+
189191
bool is_swa(uint32_t il) const;
190192
};
191193

src/llama-kv-cache-recurrent.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
829829

830830
seq_rm(dest_seq_id, -1, -1);
831831

832-
llama_batch_allocr balloc;
832+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
833833

834834
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
835835

src/llama-kv-cache-unified.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1499,7 +1499,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
14991499

15001500
seq_rm(dest_seq_id, -1, -1);
15011501

1502-
llama_batch_allocr balloc;
1502+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
15031503

15041504
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
15051505

0 commit comments

Comments
 (0)