Skip to content

Commit 1b6dfc5

Browse files
committed
cont : use batch allocr for state restore
ggml-ci
1 parent f9fa0e6 commit 1b6dfc5

File tree

3 files changed

+59
-24
lines changed

3 files changed

+59
-24
lines changed

src/llama-batch.cpp

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,38 @@ bool llama_batch_allocr::init(
571571
return true;
572572
}
573573

574+
llama_ubatch llama_batch_allocr::reserve_one(uint32_t n_tokens) {
575+
clear();
576+
split_reset();
577+
578+
ubatches.emplace_back();
579+
580+
auto & ubatch = ubatches.back();
581+
582+
ubatch.token .resize(n_tokens);
583+
ubatch.embd .clear();
584+
ubatch.pos .resize(n_tokens);
585+
ubatch.n_seq_id.resize(n_tokens);
586+
ubatch.seq_id .resize(n_tokens);
587+
ubatch.output .resize(n_tokens);
588+
589+
llama_ubatch res {
590+
/*.equal_seqs =*/ true,
591+
/*.n_tokens =*/ n_tokens,
592+
/*.n_seq_tokens =*/ n_tokens,
593+
/*.n_seqs =*/ 1,
594+
595+
/*.token =*/ ubatch.token.data(),
596+
/*.embd =*/ nullptr,
597+
/*.pos =*/ ubatch.pos.data(),
598+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
599+
/*.seq_id =*/ ubatch.seq_id.data(),
600+
/*.output =*/ ubatch.output.data(),
601+
};
602+
603+
return res;
604+
}
605+
574606
const llama_batch & llama_batch_allocr::get_batch() const {
575607
return batch;
576608
}
@@ -757,10 +789,11 @@ void llama_batch_allocr::clear() {
757789
n_outputs = 0;
758790

759791
batch = {};
760-
pos.clear();
792+
793+
pos .clear();
761794
n_seq_id.clear();
762-
seq_id.clear();
763-
output.clear();
795+
seq_id .clear();
796+
output .clear();
764797

765798
for (auto & cur : seq_pos) {
766799
cur.clear();
@@ -786,12 +819,12 @@ llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, u
786819

787820
auto & ubatch = ubatches.back();
788821

789-
ubatch.token.resize(n_tokens);
790-
ubatch.embd.resize((int64_t) n_tokens*n_embd);
791-
ubatch.pos.resize(n_tokens);
822+
ubatch.token .resize(n_tokens);
823+
ubatch.embd .resize((int64_t) n_tokens*n_embd);
824+
ubatch.pos .resize(n_tokens);
792825
ubatch.n_seq_id.resize(n_tokens);
793-
ubatch.seq_id.resize(n_tokens);
794-
ubatch.output.resize(n_tokens);
826+
ubatch.seq_id .resize(n_tokens);
827+
ubatch.output .resize(n_tokens);
795828

796829
for (size_t i = 0; i < idxs.size(); ++i) {
797830
if (batch.token) {
@@ -839,25 +872,25 @@ struct llama_batch llama_batch_get_one(
839872
llama_token * tokens,
840873
int32_t n_tokens) {
841874
return {
842-
/*n_tokens =*/ n_tokens,
843-
/*tokens =*/ tokens,
844-
/*embd =*/ nullptr,
845-
/*pos =*/ nullptr,
846-
/*n_seq_id =*/ nullptr,
847-
/*seq_id =*/ nullptr,
848-
/*logits =*/ nullptr,
875+
/*n_tokens =*/ n_tokens,
876+
/*tokens =*/ tokens,
877+
/*embd =*/ nullptr,
878+
/*pos =*/ nullptr,
879+
/*n_seq_id =*/ nullptr,
880+
/*seq_id =*/ nullptr,
881+
/*logits =*/ nullptr,
849882
};
850883
}
851884

852885
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
853886
llama_batch batch = {
854-
/*n_tokens =*/ 0,
855-
/*tokens =*/ nullptr,
856-
/*embd =*/ nullptr,
857-
/*pos =*/ nullptr,
858-
/*n_seq_id =*/ nullptr,
859-
/*seq_id =*/ nullptr,
860-
/*logits =*/ nullptr,
887+
/*n_tokens =*/ 0,
888+
/*tokens =*/ nullptr,
889+
/*embd =*/ nullptr,
890+
/*pos =*/ nullptr,
891+
/*n_seq_id =*/ nullptr,
892+
/*seq_id =*/ nullptr,
893+
/*logits =*/ nullptr,
861894
};
862895

863896
if (embd) {

src/llama-batch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class llama_batch_allocr {
119119
// sequence-wise split - each ubatch contains a single sequence
120120
llama_ubatch split_seq(uint32_t n_ubatch);
121121

122+
llama_ubatch reserve_one(uint32_t n_tokens);
122123
private:
123124
void clear();
124125

src/llama-kv-cache-unified.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,8 +1505,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15051505

15061506
seq_rm(dest_seq_id, -1, -1);
15071507

1508-
llama_sbatch sbatch;
1509-
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
1508+
llama_batch_allocr batch_allocr;
1509+
1510+
llama_ubatch ubatch = batch_allocr.reserve_one(cell_count);
15101511

15111512
ubatch.n_tokens = cell_count;
15121513
ubatch.n_seq_tokens = cell_count;

0 commit comments

Comments
 (0)