Skip to content

Commit 0096a03

Browse files
committed
cont
ggml-ci
1 parent 7c242f4 commit 0096a03

10 files changed

+143
-202
lines changed

src/llama-batch.cpp

Lines changed: 83 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,10 @@ bool llama_batch_allocr::init(
301301
const llama_batch & batch_inp,
302302
const llama_vocab & vocab,
303303
const llama_memory_i * memory,
304-
bool embd_all) {
304+
uint32_t n_embd,
305+
bool output_all) {
305306
clear();
306307

307-
split_reset();
308-
309308
batch = batch_inp;
310309

311310
GGML_ASSERT(batch.n_tokens > 0);
@@ -382,7 +381,7 @@ bool llama_batch_allocr::init(
382381
}
383382

384383
if (!batch.logits) {
385-
if (embd_all) {
384+
if (output_all) {
386385
// return the output for all tokens
387386
output.resize(batch.n_tokens, true);
388387
} else {
@@ -392,7 +391,7 @@ bool llama_batch_allocr::init(
392391
}
393392

394393
batch.logits = output.data();
395-
} else if (embd_all) {
394+
} else if (output_all) {
396395
bool warn = false;
397396

398397
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -417,6 +416,8 @@ bool llama_batch_allocr::init(
417416
n_outputs += batch.logits[i] != 0;
418417
}
419418

419+
this->n_embd = n_embd;
420+
420421
// determine coupled sequences
421422
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
422423
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -572,6 +573,8 @@ bool llama_batch_allocr::init(
572573

573574
// TODO: check that positions are increasing
574575

576+
split_reset();
577+
575578
return true;
576579
}
577580

@@ -580,7 +583,7 @@ const llama_batch & llama_batch_allocr::get_batch() const {
580583
}
581584

582585
uint32_t llama_batch_allocr::get_n_tokens() const {
583-
return pos.size();
586+
return batch.n_tokens;
584587
}
585588

586589
uint32_t llama_batch_allocr::get_n_outputs() const {
@@ -609,41 +612,20 @@ void llama_batch_allocr::split_reset() {
609612
}
610613

611614
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
612-
llama_ubatch res {
613-
/*.equal_seqs =*/ false,
614-
/*.n_tokens =*/ 0,
615-
/*.n_seq_tokens =*/ 1,
616-
/*.n_seqs =*/ 0,
617-
618-
/*.token =*/ nullptr,
619-
/*.embd =*/ nullptr,
620-
/*.pos =*/ nullptr,
621-
/*.n_seq_id =*/ nullptr,
622-
/*.seq_id =*/ nullptr,
623-
/*.output =*/ nullptr
624-
};
625-
626615
uint32_t cur_idx = 0;
627616
while (cur_idx < used.size() && used[cur_idx]) {
628617
++cur_idx;
629618
}
630619

631620
if (cur_idx >= used.size()) {
632-
return res;
621+
return {};
633622
}
634623

635624
std::vector<int32_t> idxs;
636625

637626
while (true) {
638-
res.n_tokens++;
639-
res.n_seqs++;
640-
641627
idxs.push_back(cur_idx);
642628

643-
if (output[cur_idx] != 0) {
644-
out_ids.push_back(cur_idx);
645-
}
646-
647629
used[cur_idx] = true;
648630

649631
++cur_idx;
@@ -652,31 +634,15 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
652634
break;
653635
}
654636

655-
if (res.n_tokens >= n_ubatch) {
637+
if (idxs.size() >= n_ubatch) {
656638
break;
657639
}
658640
}
659641

660-
add_ubatch(res, idxs);
661-
662-
return res;
642+
return add_ubatch(idxs, idxs.size(), false);
663643
}
664644

665645
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
666-
llama_ubatch res {
667-
/*.equal_seqs =*/ true,
668-
/*.n_tokens =*/ 0,
669-
/*.n_seq_tokens =*/ 0,
670-
/*.n_seqs =*/ 0,
671-
672-
/*.token =*/ nullptr,
673-
/*.embd =*/ nullptr,
674-
/*.pos =*/ nullptr,
675-
/*.n_seq_id =*/ nullptr,
676-
/*.seq_id =*/ nullptr,
677-
/*.output =*/ nullptr
678-
};
679-
680646
std::vector<seq_set_t> cur_seq_set;
681647

682648
// determine the sequence sets participating in this ubatch
@@ -685,35 +651,45 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
685651
continue;
686652
}
687653

688-
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
654+
bool add = true;
655+
656+
for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
689657
// no overlap with existing sequence sets:
690-
if ((cur_seq_set[s] & seq_set[i]).none()) {
691-
cur_seq_set.push_back(seq_set[i]);
658+
if (!(cur_seq_set[s] & seq_set[i]).none()) {
659+
add = false;
660+
break;
661+
}
662+
}
692663

693-
if (cur_seq_set.size() > (size_t) n_ubatch) {
694-
break;
695-
}
664+
if (add) {
665+
cur_seq_set.push_back(seq_set[i]);
666+
667+
if (cur_seq_set.size() > n_ubatch) {
668+
break;
696669
}
697670
}
698671
}
699672

700-
res.n_seqs = cur_seq_set.size();
673+
const uint32_t n_seqs = cur_seq_set.size();
674+
675+
if (n_seqs == 0) {
676+
return {};
677+
}
701678

702-
std::vector<int32_t> cur_idx(cur_seq_set.size(), 0);
679+
std::vector<int32_t> cur_idx(n_seqs, 0);
703680

704-
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
681+
for (uint32_t s = 0; s < n_seqs; ++s) {
705682
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
706683
++cur_idx[s];
707684
}
708685
}
709686

710-
std::vector<int32_t> idxs;
687+
std::vector<idx_vec_t> idxs_per_seq(n_seqs);
711688

712-
// TODO: reorder from 012301230123..., to 000...111...222...333...
713689
while (true) {
714690
bool can_expand = true;
715691

716-
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
692+
for (uint32_t s = 0; s < n_seqs; ++s) {
717693
if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
718694
can_expand = false;
719695
break;
@@ -724,71 +700,49 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
724700
break;
725701
}
726702

727-
res.n_tokens += res.n_seqs;
728-
729-
for (size_t s = 0; s < cur_seq_set.size(); ++s) {
703+
for (uint32_t s = 0; s < n_seqs; ++s) {
730704
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
731-
idxs.push_back(idx);
732-
733-
if (output[idx] != 0) {
734-
out_ids.push_back(idx);
735-
}
705+
idxs_per_seq[s].push_back(idx);
736706

737707
used[idx] = true;
738708

739709
++cur_idx[s];
740710
}
741711

742-
if (res.n_tokens + res.n_seqs > n_ubatch) {
712+
if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
743713
break;
744714
}
745715
}
746716

747-
add_ubatch(res, idxs);
717+
std::vector<int32_t> idxs;
748718

749-
return res;
719+
for (uint32_t s = 0; s < n_seqs; ++s) {
720+
idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
721+
}
722+
723+
return add_ubatch(idxs, n_seqs, true);
750724
}
751725

752726
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
753-
llama_ubatch res {
754-
/*.equal_seqs =*/ true,
755-
/*.n_tokens =*/ 0,
756-
/*.n_seq_tokens =*/ 0,
757-
/*.n_seqs =*/ 1,
758-
759-
/*.token =*/ nullptr,
760-
/*.embd =*/ nullptr,
761-
/*.pos =*/ nullptr,
762-
/*.n_seq_id =*/ nullptr,
763-
/*.seq_id =*/ nullptr,
764-
/*.output =*/ nullptr,
765-
};
766-
767727
uint32_t cur_idx = 0;
768728
while (cur_idx < used.size() && used[cur_idx]) {
769729
++cur_idx;
770730
}
771731

772732
if (cur_idx >= used.size()) {
773-
return res;
733+
return {};
774734
}
775735

776736
auto cur_seq_set = seq_set[cur_idx];
777737

778738
std::vector<int32_t> idxs;
779739

780740
while (true) {
781-
res.n_tokens++;
782-
783741
idxs.push_back(cur_idx);
784742

785-
if (output[cur_idx] != 0) {
786-
out_ids.push_back(cur_idx);
787-
}
788-
789743
used[cur_idx] = true;
790744

791-
if (res.n_tokens >= n_ubatch) {
745+
if (idxs.size() >= n_ubatch) {
792746
break;
793747
}
794748

@@ -803,9 +757,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
803757
cur_seq_set = seq_set[cur_idx];
804758
}
805759

806-
add_ubatch(res, idxs);
807-
808-
return res;
760+
return add_ubatch(idxs, 1, true);
809761
}
810762

811763
void llama_batch_allocr::clear() {
@@ -834,37 +786,60 @@ void llama_batch_allocr::clear() {
834786
seq_set_map.clear();
835787
}
836788

837-
void llama_batch_allocr::add_ubatch(llama_ubatch & res, const std::vector<int32_t> & idxs) {
838-
ubatches.emplace_back();
789+
llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
790+
const uint32_t n_tokens = idxs.size();
839791

840-
auto & ubatch = ubatches.back();
792+
LLAMA_LOG_DEBUG("add_ubatch: n_tokens = %d, n_seqs = %d, equal_seqs = %d", n_tokens, n_seqs, equal_seqs);
841793

842-
assert(res.n_tokens == idxs.size());
794+
assert(n_tokens%n_seqs == 0);
843795

844-
const auto n_tokens = res.n_tokens;
796+
ubatches.emplace_back();
797+
798+
auto & ubatch = ubatches.back();
845799

846800
ubatch.token.resize(n_tokens);
847-
//ubatch.embd.resize(0); // TODO
801+
ubatch.embd.resize((int64_t) n_tokens*n_embd);
848802
ubatch.pos.resize(n_tokens);
849803
ubatch.n_seq_id.resize(n_tokens);
850804
ubatch.seq_id.resize(n_tokens);
851805
ubatch.output.resize(n_tokens);
852806

853807
for (size_t i = 0; i < idxs.size(); ++i) {
854-
ubatch.token[i] = batch.token[idxs[i]];
855-
//ubatch.embd[i] = batch.embd[idxs[i]]; // TODO
808+
if (batch.token) {
809+
ubatch.token[i] = batch.token[idxs[i]];
810+
}
811+
812+
if (batch.embd) {
813+
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
814+
}
815+
856816
ubatch.pos[i] = batch.pos[idxs[i]];
857817
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
858818
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
859819
ubatch.output[i] = batch.logits[idxs[i]];
820+
821+
if (ubatch.output[i]) {
822+
out_ids.push_back(idxs[i]);
823+
}
860824
}
861825

862-
res.token = ubatch.token.data();
863-
//res.embd = ubatch.embd.data(); // TODO
864-
res.pos = ubatch.pos.data();
865-
res.n_seq_id = ubatch.n_seq_id.data();
866-
res.seq_id = ubatch.seq_id.data();
867-
res.output = ubatch.output.data();
826+
llama_ubatch res {
827+
/*.equal_seqs =*/ equal_seqs,
828+
/*.n_tokens =*/ n_tokens,
829+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
830+
/*.n_seqs =*/ n_seqs,
831+
832+
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
833+
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
834+
/*.pos =*/ ubatch.pos.data(),
835+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
836+
/*.seq_id =*/ ubatch.seq_id.data(),
837+
/*.output =*/ ubatch.output.data(),
838+
};
839+
840+
LLAMA_LOG_DEBUG("%s: added ubatch of size %d\n", __func__, res.n_tokens);
841+
842+
return res;
868843
}
869844

870845
//

src/llama-batch.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class llama_batch_allocr {
9595
const llama_batch & batch_inp,
9696
const llama_vocab & vocab,
9797
const llama_memory_i * memory,
98-
bool embd_all);
98+
uint32_t n_embd,
99+
bool output_all);
99100

100101
const llama_batch & get_batch() const;
101102

@@ -121,10 +122,11 @@ class llama_batch_allocr {
121122
private:
122123
void clear();
123124

124-
void add_ubatch(llama_ubatch & res, const std::vector<int32_t> & idxs);
125+
llama_ubatch add_ubatch(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
125126

126127
llama_batch batch;
127128

129+
uint32_t n_embd;
128130
uint32_t n_outputs;
129131

130132
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

0 commit comments

Comments
 (0)