Skip to content

Commit f9fa0e6

Browse files
committed
cont : bug fix
ggml-ci
1 parent 0096a03 commit f9fa0e6

File tree

3 files changed

+5
-18
lines changed

3 files changed

+5
-18
lines changed

src/llama-batch.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,6 @@ llama_batch_allocr::llama_batch_allocr() {
293293
for (auto & cur : seq_cpl) {
294294
cur.resize(LLAMA_MAX_SEQ);
295295
}
296-
297-
seq_idx.resize(LLAMA_MAX_SEQ);
298296
}
299297

300298
bool llama_batch_allocr::init(
@@ -444,11 +442,6 @@ bool llama_batch_allocr::init(
444442
}
445443

446444
seq_set.push_back(cur);
447-
448-
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
449-
seq_idx[batch.seq_id[i][s]].push_back(i);
450-
}
451-
452445
seq_set_map[cur].push_back(i);
453446
}
454447

@@ -561,7 +554,7 @@ bool llama_batch_allocr::init(
561554
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
562555
const llama_seq_id seq_id = batch.seq_id[i][s];
563556

564-
cur_seq_set[seq_id] &= seq_set[seq_id];
557+
cur_seq_set[seq_id] &= seq_set[i];
565558

566559
if (cur_seq_set[seq_id].none()) {
567560
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets\n", __func__, seq_id);
@@ -779,10 +772,6 @@ void llama_batch_allocr::clear() {
779772

780773
seq_set.clear();
781774

782-
for (auto & cur : seq_idx) {
783-
cur.clear();
784-
}
785-
786775
seq_set_map.clear();
787776
}
788777

src/llama-batch.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,9 @@ class llama_batch_allocr {
143143
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
144144

145145
using idx_vec_t = std::vector<int32_t>;
146-
147146
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
148147

149148
std::vector<seq_set_t> seq_set;
150-
std::vector<idx_vec_t> seq_idx;
151149

152150
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map;
153151

src/llama-context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -901,17 +901,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
901901
const int64_t n_embd = hparams.n_embd;
902902

903903
// when computing embeddings, all tokens are output
904-
const bool embd_all = cparams.embeddings;
904+
const bool output_all = cparams.embeddings;
905905

906-
if (!batch_allocr->init(batch_inp, vocab, memory.get(), n_embd, embd_all)) {
906+
if (!batch_allocr->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
907907
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
908908
return -1;
909909
}
910910

911911
const uint32_t n_tokens_all = batch_allocr->get_n_tokens();
912912
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
913913

914-
if (embd_all) {
914+
if (output_all) {
915915
// require that all tokens are output
916916
if (n_outputs_all != n_tokens_all) {
917917
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -940,7 +940,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
940940
llama_memory_state_ptr mstate;
941941

942942
while (true) {
943-
mstate = memory->init_batch(batch_allocr.get(), cparams.n_ubatch, embd_all);
943+
mstate = memory->init_batch(batch_allocr.get(), cparams.n_ubatch, output_all);
944944
if (!mstate) {
945945
return -2;
946946
}

0 commit comments

Comments
 (0)