@@ -301,11 +301,10 @@ bool llama_batch_allocr::init(
301
301
const llama_batch & batch_inp,
302
302
const llama_vocab & vocab,
303
303
const llama_memory_i * memory,
304
- bool embd_all) {
304
+ uint32_t n_embd,
305
+ bool output_all) {
305
306
clear ();
306
307
307
- split_reset ();
308
-
309
308
batch = batch_inp;
310
309
311
310
GGML_ASSERT (batch.n_tokens > 0 );
@@ -382,7 +381,7 @@ bool llama_batch_allocr::init(
382
381
}
383
382
384
383
if (!batch.logits ) {
385
- if (embd_all ) {
384
+ if (output_all ) {
386
385
// return the output for all tokens
387
386
output.resize (batch.n_tokens , true );
388
387
} else {
@@ -392,7 +391,7 @@ bool llama_batch_allocr::init(
392
391
}
393
392
394
393
batch.logits = output.data ();
395
- } else if (embd_all ) {
394
+ } else if (output_all ) {
396
395
bool warn = false ;
397
396
398
397
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -417,6 +416,8 @@ bool llama_batch_allocr::init(
417
416
n_outputs += batch.logits [i] != 0 ;
418
417
}
419
418
419
+ this ->n_embd = n_embd;
420
+
420
421
// determine coupled sequences
421
422
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
422
423
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -572,6 +573,8 @@ bool llama_batch_allocr::init(
572
573
573
574
// TODO: check that positions are increasing
574
575
576
+ split_reset ();
577
+
575
578
return true ;
576
579
}
577
580
@@ -580,7 +583,7 @@ const llama_batch & llama_batch_allocr::get_batch() const {
580
583
}
581
584
582
585
uint32_t llama_batch_allocr::get_n_tokens () const {
583
- return pos. size () ;
586
+ return batch. n_tokens ;
584
587
}
585
588
586
589
uint32_t llama_batch_allocr::get_n_outputs () const {
@@ -609,41 +612,20 @@ void llama_batch_allocr::split_reset() {
609
612
}
610
613
611
614
llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
612
- llama_ubatch res {
613
- /* .equal_seqs =*/ false ,
614
- /* .n_tokens =*/ 0 ,
615
- /* .n_seq_tokens =*/ 1 ,
616
- /* .n_seqs =*/ 0 ,
617
-
618
- /* .token =*/ nullptr ,
619
- /* .embd =*/ nullptr ,
620
- /* .pos =*/ nullptr ,
621
- /* .n_seq_id =*/ nullptr ,
622
- /* .seq_id =*/ nullptr ,
623
- /* .output =*/ nullptr
624
- };
625
-
626
615
uint32_t cur_idx = 0 ;
627
616
while (cur_idx < used.size () && used[cur_idx]) {
628
617
++cur_idx;
629
618
}
630
619
631
620
if (cur_idx >= used.size ()) {
632
- return res ;
621
+ return {} ;
633
622
}
634
623
635
624
std::vector<int32_t > idxs;
636
625
637
626
while (true ) {
638
- res.n_tokens ++;
639
- res.n_seqs ++;
640
-
641
627
idxs.push_back (cur_idx);
642
628
643
- if (output[cur_idx] != 0 ) {
644
- out_ids.push_back (cur_idx);
645
- }
646
-
647
629
used[cur_idx] = true ;
648
630
649
631
++cur_idx;
@@ -652,31 +634,15 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
652
634
break ;
653
635
}
654
636
655
- if (res. n_tokens >= n_ubatch) {
637
+ if (idxs. size () >= n_ubatch) {
656
638
break ;
657
639
}
658
640
}
659
641
660
- add_ubatch (res, idxs);
661
-
662
- return res;
642
+ return add_ubatch (idxs, idxs.size (), false );
663
643
}
664
644
665
645
llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
666
- llama_ubatch res {
667
- /* .equal_seqs =*/ true ,
668
- /* .n_tokens =*/ 0 ,
669
- /* .n_seq_tokens =*/ 0 ,
670
- /* .n_seqs =*/ 0 ,
671
-
672
- /* .token =*/ nullptr ,
673
- /* .embd =*/ nullptr ,
674
- /* .pos =*/ nullptr ,
675
- /* .n_seq_id =*/ nullptr ,
676
- /* .seq_id =*/ nullptr ,
677
- /* .output =*/ nullptr
678
- };
679
-
680
646
std::vector<seq_set_t > cur_seq_set;
681
647
682
648
// determine the sequence sets participating in this ubatch
@@ -685,35 +651,45 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
685
651
continue ;
686
652
}
687
653
688
- for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
654
+ bool add = true ;
655
+
656
+ for (uint32_t s = 0 ; s < cur_seq_set.size (); ++s) {
689
657
// no overlap with existing sequence sets:
690
- if ((cur_seq_set[s] & seq_set[i]).none ()) {
691
- cur_seq_set.push_back (seq_set[i]);
658
+ if (!(cur_seq_set[s] & seq_set[i]).none ()) {
659
+ add = false ;
660
+ break ;
661
+ }
662
+ }
692
663
693
- if (cur_seq_set.size () > (size_t ) n_ubatch) {
694
- break ;
695
- }
664
+ if (add) {
665
+ cur_seq_set.push_back (seq_set[i]);
666
+
667
+ if (cur_seq_set.size () > n_ubatch) {
668
+ break ;
696
669
}
697
670
}
698
671
}
699
672
700
- res.n_seqs = cur_seq_set.size ();
673
+ const uint32_t n_seqs = cur_seq_set.size ();
674
+
675
+ if (n_seqs == 0 ) {
676
+ return {};
677
+ }
701
678
702
- std::vector<int32_t > cur_idx (cur_seq_set. size () , 0 );
679
+ std::vector<int32_t > cur_idx (n_seqs , 0 );
703
680
704
- for (size_t s = 0 ; s < cur_seq_set. size () ; ++s) {
681
+ for (uint32_t s = 0 ; s < n_seqs ; ++s) {
705
682
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
706
683
++cur_idx[s];
707
684
}
708
685
}
709
686
710
- std::vector<int32_t > idxs ;
687
+ std::vector<idx_vec_t > idxs_per_seq (n_seqs) ;
711
688
712
- // TODO: reorder from 012301230123..., to 000...111...222...333...
713
689
while (true ) {
714
690
bool can_expand = true ;
715
691
716
- for (size_t s = 0 ; s < cur_seq_set. size () ; ++s) {
692
+ for (uint32_t s = 0 ; s < n_seqs ; ++s) {
717
693
if (cur_idx[s] >= (int32_t ) seq_set_map[cur_seq_set[s]].size ()) {
718
694
can_expand = false ;
719
695
break ;
@@ -724,71 +700,49 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
724
700
break ;
725
701
}
726
702
727
- res.n_tokens += res.n_seqs ;
728
-
729
- for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
703
+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
730
704
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
731
- idxs.push_back (idx);
732
-
733
- if (output[idx] != 0 ) {
734
- out_ids.push_back (idx);
735
- }
705
+ idxs_per_seq[s].push_back (idx);
736
706
737
707
used[idx] = true ;
738
708
739
709
++cur_idx[s];
740
710
}
741
711
742
- if (res. n_tokens + res. n_seqs > n_ubatch) {
712
+ if ((idxs_per_seq[ 0 ]. size () + 1 )* n_seqs > n_ubatch) {
743
713
break ;
744
714
}
745
715
}
746
716
747
- add_ubatch (res, idxs) ;
717
+ std::vector< int32_t > idxs;
748
718
749
- return res;
719
+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
720
+ idxs.insert (idxs.end (), idxs_per_seq[s].begin (), idxs_per_seq[s].end ());
721
+ }
722
+
723
+ return add_ubatch (idxs, n_seqs, true );
750
724
}
751
725
752
726
llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
753
- llama_ubatch res {
754
- /* .equal_seqs =*/ true ,
755
- /* .n_tokens =*/ 0 ,
756
- /* .n_seq_tokens =*/ 0 ,
757
- /* .n_seqs =*/ 1 ,
758
-
759
- /* .token =*/ nullptr ,
760
- /* .embd =*/ nullptr ,
761
- /* .pos =*/ nullptr ,
762
- /* .n_seq_id =*/ nullptr ,
763
- /* .seq_id =*/ nullptr ,
764
- /* .output =*/ nullptr ,
765
- };
766
-
767
727
uint32_t cur_idx = 0 ;
768
728
while (cur_idx < used.size () && used[cur_idx]) {
769
729
++cur_idx;
770
730
}
771
731
772
732
if (cur_idx >= used.size ()) {
773
- return res ;
733
+ return {} ;
774
734
}
775
735
776
736
auto cur_seq_set = seq_set[cur_idx];
777
737
778
738
std::vector<int32_t > idxs;
779
739
780
740
while (true ) {
781
- res.n_tokens ++;
782
-
783
741
idxs.push_back (cur_idx);
784
742
785
- if (output[cur_idx] != 0 ) {
786
- out_ids.push_back (cur_idx);
787
- }
788
-
789
743
used[cur_idx] = true ;
790
744
791
- if (res. n_tokens >= n_ubatch) {
745
+ if (idxs. size () >= n_ubatch) {
792
746
break ;
793
747
}
794
748
@@ -803,9 +757,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
803
757
cur_seq_set = seq_set[cur_idx];
804
758
}
805
759
806
- add_ubatch (res, idxs);
807
-
808
- return res;
760
+ return add_ubatch (idxs, 1 , true );
809
761
}
810
762
811
763
void llama_batch_allocr::clear () {
@@ -834,37 +786,60 @@ void llama_batch_allocr::clear() {
834
786
seq_set_map.clear ();
835
787
}
836
788
837
- void llama_batch_allocr::add_ubatch (llama_ubatch & res, const std::vector<int32_t > & idxs) {
838
- ubatches. emplace_back ();
789
+ llama_ubatch llama_batch_allocr::add_ubatch (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs ) {
790
+ const uint32_t n_tokens = idxs. size ();
839
791
840
- auto & ubatch = ubatches. back ( );
792
+ LLAMA_LOG_DEBUG ( " add_ubatch: n_tokens = %d, n_seqs = %d, equal_seqs = %d " , n_tokens, n_seqs, equal_seqs );
841
793
842
- assert (res. n_tokens == idxs. size () );
794
+ assert (n_tokens%n_seqs == 0 );
843
795
844
- const auto n_tokens = res.n_tokens ;
796
+ ubatches.emplace_back ();
797
+
798
+ auto & ubatch = ubatches.back ();
845
799
846
800
ubatch.token .resize (n_tokens);
847
- // ubatch.embd.resize(0); // TODO
801
+ ubatch.embd .resize (( int64_t ) n_tokens*n_embd);
848
802
ubatch.pos .resize (n_tokens);
849
803
ubatch.n_seq_id .resize (n_tokens);
850
804
ubatch.seq_id .resize (n_tokens);
851
805
ubatch.output .resize (n_tokens);
852
806
853
807
for (size_t i = 0 ; i < idxs.size (); ++i) {
854
- ubatch.token [i] = batch.token [idxs[i]];
855
- // ubatch.embd[i] = batch.embd[idxs[i]]; // TODO
808
+ if (batch.token ) {
809
+ ubatch.token [i] = batch.token [idxs[i]];
810
+ }
811
+
812
+ if (batch.embd ) {
813
+ memcpy (ubatch.embd .data () + i*n_embd, batch.embd + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
814
+ }
815
+
856
816
ubatch.pos [i] = batch.pos [idxs[i]];
857
817
ubatch.n_seq_id [i] = batch.n_seq_id [idxs[i]];
858
818
ubatch.seq_id [i] = batch.seq_id [idxs[i]];
859
819
ubatch.output [i] = batch.logits [idxs[i]];
820
+
821
+ if (ubatch.output [i]) {
822
+ out_ids.push_back (idxs[i]);
823
+ }
860
824
}
861
825
862
- res.token = ubatch.token .data ();
863
- // res.embd = ubatch.embd.data(); // TODO
864
- res.pos = ubatch.pos .data ();
865
- res.n_seq_id = ubatch.n_seq_id .data ();
866
- res.seq_id = ubatch.seq_id .data ();
867
- res.output = ubatch.output .data ();
826
+ llama_ubatch res {
827
+ /* .equal_seqs =*/ equal_seqs,
828
+ /* .n_tokens =*/ n_tokens,
829
+ /* .n_seq_tokens =*/ n_tokens/n_seqs,
830
+ /* .n_seqs =*/ n_seqs,
831
+
832
+ /* .token =*/ batch.token ? ubatch.token .data () : nullptr ,
833
+ /* .embd =*/ batch.embd ? ubatch.embd .data () : nullptr ,
834
+ /* .pos =*/ ubatch.pos .data (),
835
+ /* .n_seq_id =*/ ubatch.n_seq_id .data (),
836
+ /* .seq_id =*/ ubatch.seq_id .data (),
837
+ /* .output =*/ ubatch.output .data (),
838
+ };
839
+
840
+ LLAMA_LOG_DEBUG (" %s: added ubatch of size %d\n " , __func__, res.n_tokens );
841
+
842
+ return res;
868
843
}
869
844
870
845
//
0 commit comments