@@ -571,6 +571,38 @@ bool llama_batch_allocr::init(
571
571
return true ;
572
572
}
573
573
574
+ llama_ubatch llama_batch_allocr::reserve_one (uint32_t n_tokens) {
575
+ clear ();
576
+ split_reset ();
577
+
578
+ ubatches.emplace_back ();
579
+
580
+ auto & ubatch = ubatches.back ();
581
+
582
+ ubatch.token .resize (n_tokens);
583
+ ubatch.embd .clear ();
584
+ ubatch.pos .resize (n_tokens);
585
+ ubatch.n_seq_id .resize (n_tokens);
586
+ ubatch.seq_id .resize (n_tokens);
587
+ ubatch.output .resize (n_tokens);
588
+
589
+ llama_ubatch res {
590
+ /* .equal_seqs =*/ true ,
591
+ /* .n_tokens =*/ n_tokens,
592
+ /* .n_seq_tokens =*/ n_tokens,
593
+ /* .n_seqs =*/ 1 ,
594
+
595
+ /* .token =*/ ubatch.token .data (),
596
+ /* .embd =*/ nullptr ,
597
+ /* .pos =*/ ubatch.pos .data (),
598
+ /* .n_seq_id =*/ ubatch.n_seq_id .data (),
599
+ /* .seq_id =*/ ubatch.seq_id .data (),
600
+ /* .output =*/ ubatch.output .data (),
601
+ };
602
+
603
+ return res;
604
+ }
605
+
574
606
const llama_batch & llama_batch_allocr::get_batch () const {
575
607
return batch;
576
608
}
@@ -757,10 +789,11 @@ void llama_batch_allocr::clear() {
757
789
n_outputs = 0 ;
758
790
759
791
batch = {};
760
- pos.clear ();
792
+
793
+ pos .clear ();
761
794
n_seq_id.clear ();
762
- seq_id.clear ();
763
- output.clear ();
795
+ seq_id .clear ();
796
+ output .clear ();
764
797
765
798
for (auto & cur : seq_pos) {
766
799
cur.clear ();
@@ -786,12 +819,12 @@ llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, u
786
819
787
820
auto & ubatch = ubatches.back ();
788
821
789
- ubatch.token .resize (n_tokens);
790
- ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
791
- ubatch.pos .resize (n_tokens);
822
+ ubatch.token .resize (n_tokens);
823
+ ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
824
+ ubatch.pos .resize (n_tokens);
792
825
ubatch.n_seq_id .resize (n_tokens);
793
- ubatch.seq_id .resize (n_tokens);
794
- ubatch.output .resize (n_tokens);
826
+ ubatch.seq_id .resize (n_tokens);
827
+ ubatch.output .resize (n_tokens);
795
828
796
829
for (size_t i = 0 ; i < idxs.size (); ++i) {
797
830
if (batch.token ) {
@@ -839,25 +872,25 @@ struct llama_batch llama_batch_get_one(
839
872
llama_token * tokens,
840
873
int32_t n_tokens) {
841
874
return {
842
- /* n_tokens =*/ n_tokens,
843
- /* tokens =*/ tokens,
844
- /* embd =*/ nullptr ,
845
- /* pos =*/ nullptr ,
846
- /* n_seq_id =*/ nullptr ,
847
- /* seq_id =*/ nullptr ,
848
- /* logits =*/ nullptr ,
875
+ /* n_tokens =*/ n_tokens,
876
+ /* tokens =*/ tokens,
877
+ /* embd =*/ nullptr ,
878
+ /* pos =*/ nullptr ,
879
+ /* n_seq_id =*/ nullptr ,
880
+ /* seq_id =*/ nullptr ,
881
+ /* logits =*/ nullptr ,
849
882
};
850
883
}
851
884
852
885
struct llama_batch llama_batch_init (int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
853
886
llama_batch batch = {
854
- /* n_tokens =*/ 0 ,
855
- /* tokens =*/ nullptr ,
856
- /* embd =*/ nullptr ,
857
- /* pos =*/ nullptr ,
858
- /* n_seq_id =*/ nullptr ,
859
- /* seq_id =*/ nullptr ,
860
- /* logits =*/ nullptr ,
887
+ /* n_tokens =*/ 0 ,
888
+ /* tokens =*/ nullptr ,
889
+ /* embd =*/ nullptr ,
890
+ /* pos =*/ nullptr ,
891
+ /* n_seq_id =*/ nullptr ,
892
+ /* seq_id =*/ nullptr ,
893
+ /* logits =*/ nullptr ,
861
894
};
862
895
863
896
if (embd) {
0 commit comments